Skip to main content

symjit/
lib.rs

1#![allow(uncommon_codepoints)]
2
3//! Symjit (<https://github.com/siravan/symjit>) is a lightweight just-in-time (JIT)
4//! optimizer compiler for mathematical expressions written in Rust. It was originally
5//! designed to compile SymPy (Python’s symbolic algebra package) expressions
6//! into machine code and to serve as a bridge between SymPy and numerical routines
7//! provided by NumPy and SciPy libraries.
8//!
9//! Symjit crate is the core compiler coupled to a Rust interface to expose the
10//! JIT functionality to the Rust ecosystem and allow Rust applications to
11//! generate code dynamically. Considering its origin, symjit is geared toward
12//! compiling mathematical expressions instead of being a general-purpose JIT
13//! compiler. Therefore, the only supported types for variables are `f64`,
14//! (SIMD f64x4 and f64x2), and implicitly, `bool` and `i32`.
15//!
16//! Symjit emits AMD64 (x86-64), ARM64 (aarch64), and 64-bit RISC-V (riscv64) machine
17//! codes on Linux, Windows, and macOS platforms. SIMD is supported on x86-64
18//! and ARM64.
19//!
20//! In Rust, there are two ways to contruct expressions to pass to Symjit: using
21//! Symbolica or using Symjit standalone expression builder.
22//!
23//! # Symbolica
24//!
25//! Symbolica (<https://symbolica.io/>) is a fast Rust-based Computer Algebra System.
26//! Symbolica usually generate fast code using external compilers (e.g., using gcc to
27//! compiler generated c++ code). Symjit accepts Symbolica expressions and can act as
28//! an optional code-generator for Symbolica. The link between the two is through
29//! Symbolica's `export_instructions` function that exports an optimized intermediate
30//! representation. Using serde, it is possible to convert the output of `export_instructions`
31//! into JSON, which is then passed to the `translate` function of Symjit `Compiler`
32//! structure. If successful, `translate` returns an `Application` object, which wraps
33//! the compiled code and can be run using one of the six `evaluate` functions:
34//!
35//! * `evaluate(&mut self, args: &[T], outs: &mut [T])`.
36//! * `evaluate_single(&mut self, args: &[T]) -> T`.
37//! * `evaluate_matrix(&mut self, args: &[T], outs: &mut [T], nrows: usize)`.
38//! * `evaluate_simd(&mut self, args: &[S], outs: &mut [S])`.
39//! * `evaluate_simd_single(&mut self, args: &[S]) -> S`.
40//! * `evaluate_simd_matrix(&mut self, args: &[S], outs: &mut [S], nrows: usize)`.
41//!
42//! where `T` is either `f64` or `Complex<f64>` and `S` is `f64x64` on x86-64 or `f64x2`
43//! on aarch64, or the complex version of them.
44//!
45//! /// Examples:
46//!
47//! ```rust
48//! use anyhow::Result;
49//! use symjit::{Compiler, Config};
50//! use symbolica::{atom::AtomCore, parse, symbol};
51//! use symbolica::evaluate::{FunctionMap, OptimizationSettings};
52//!
53//! fn test1() -> Result<()> {
54//!     let params = vec![parse!("x"), parse!("y")];
55//!     let eval = parse!("x + y^2")
56//!         .evaluator(
57//!             &FunctionMap::new(),
58//!             &params,
59//!             OptimizationSettings::default(),
60//!         )
61//!         .unwrap();
62//!
63//!     let json = serde_json::to_string(&eval.export_instructions())?;
64//!     let mut comp = Compiler::new();
65//!     let mut app = comp.translate(&json)?;
66//!     assert!(app.evaluate_single(&[2.0, 3.0]) == 11.0);
67//!     Ok(())
68//! }
69//! ```
70//!
71//! Note that Symbolica needs to be imported by `features = ["serde"]` to allow for
72//! applying `serde_json::to_string` to the output of `export_instructions`.
73//!
74//! To change compilation options, one passes a `Config` struct to the `Compiler`
75//! constructor. The following example shows how to compile for complex number.
76//!
77//! ```rust
78//! use anyhow::Result;
79//! use num_complex::Complex;
80//! use symjit::{Compiler, Config};
81//! use symbolica::{atom::AtomCore, parse, symbol};
82//! use symbolica::evaluate::{FunctionMap, OptimizationSettings};
83//!
84//! fn test2() -> Result<()> {
85//!     let params = vec![parse!("x"), parse!("y")];
86//!     let eval = parse!("x + y^2")
87//!         .evaluator(
88//!             &FunctionMap::new(),
89//!             &params,
90//!             OptimizationSettings::default(),
91//!         )
92//!         .unwrap();
93//!
94//!     let json = serde_json::to_string(&eval.export_instructions())?;
95//!     let mut config = Config::default();
96//!     config.set_complex(true);
97//!     let mut comp = Compiler::with_config(config);
98//!     let mut app = comp.translate(&json)?;
99//!     let v = vec![Complex::new(2.0, 1.0), Complex::new(-1.0, 3.0)];
100//!     assert!(app.evaluate_single(&v) == Complex::new(-6.0, -5.0));
101//!     Ok(())
102//! }
103//! ```
104//!
105//! Currently, Symjit supports most of Symbolica's expressions with the exception of
106//! external user-defined functions. However, it is possible to link to Symjit
107//! numerical functions (see below) by defining their name using `add_external_function`.
108//! The following example shows how to link to `sinh` function:
109//!
110//!
111//! ```rust
112//! use anyhow::Result;
113//! use symjit::{Compiler, Config};
114//! use symbolica::{atom::AtomCore, parse, symbol};
115//! use symbolica::evaluate::{FunctionMap, OptimizationSettings};
116//!
117//! fn test3() -> Result<()> {
118//!     let params = vec![parse!("x")];
119//!
120//!     let mut f = FunctionMap::new();
121//!     f.add_external_function(symbol!("sinh"), "sinh".to_string())
122//!         .unwrap();
123//!
124//!     let eval = parse!("sinh(x)")
125//!         .evaluator(&f, &params, OptimizationSettings::default())
126//!         .unwrap();
127//!
128//!     let json = serde_json::to_string(&eval.export_instructions())?;
129//!     let mut comp = Compiler::new();
130//!     let mut app = comp.translate(&json)?;
131//!     assert!(app.evaluate_single(&[1.5]) == f64::sinh(1.5));
132//!     Ok(())
133//! }
134//! ```
135//!
136//! # Standalone Expression Builder
137//!
138//! A second way to use Symjit is by using its standalone expression builder. Compared to
139//! Symbolica, the expression builder is limited but is useful in situations that the goal
140//! is to compile an expression without extensive symbolic manipulations.
141//!
142//! The workflow to create, compile, and run expressions is:
143//!
144//! 1. Create terminals (variables and constants) and compose expressions using `Expr` methods:
145//!    * Constructors: `var`, `from`, `unary`, `binary`, ...
146//!    * Standard algebraic operations: `add`, `mul`, ...
147//!    * Standard operators `+`, `-`, `*`, `/`, `%`, `&`, `|`, `^`, `!`.
148//!    * Unary functions such as `sin`, `exp`, and other standard mathematical functions.
149//!    * Binary functions such as `pow`, `min`, ...
150//!    * IfElse operation `ifelse(cond, true_val, false_val)`.
151//!    * Heaviside function: `heaviside(x)`, which returns 1 if `x >= 0`; otherwise 0.
152//!    * Comparison methods `eq`, `ne`, `lt`, `le`, `gt`, and `ge`.
153//!    * Looping constructs `sum` and `prod`.
154//! 2. Create a new `Compiler` object (say, `comp`) using one of its constructors.
155//! 3. Define user-defined functions by calling `comp.def_unary` and `comp.def_binary`
156//!    (optional).
157//! 4. Compile by calling `comp.compile` or `comp.compile_params`. The result is of
158//!    type `Application` (say, `app`).
159//! 5. Execute the compiled code using one of the `app`'s `call` functions:
160//!    * `call(&[f64])`: scalar call.
161//!    * `call_params(&[f64], &[f64])`: scalar call with parameters.
162//!    * `call_simd(&[__m256d])`: simd call.
163//!    * `call_simd_params(&[__m256d], &[f64])`: simd call with parameters.
164//! 6. Optionally, generate a standalone fast function to execute.
165//!
166//! Note that you can use the helper functions `var(&str) -> Expr`, `int(i32) -> Expr`,
167//! `double(f64) -> Expr`, and `boolean(bool) -> f64` to reduce clutter.
168//!
169//! # Examples
170//!
171//! ```rust
172//! use anyhow::Result;
173//! use symjit::{Compiler, Expr};
174//!
175//! pub fn test_scalar() -> Result<()> {
176//!     let x = Expr::var("x");
177//!     let y = Expr::var("y");
178//!     let u = &x + &y;
179//!     let v = &x * &y;
180//!
181//!     let mut comp = Compiler::new();
182//!     let mut app = comp.compile(&[x, y], &[u, v])?;
183//!     let res = app.call(&[3.0, 5.0]);
184//!     println!("{:?}", &res);   // prints [8.0, 15.0]
185//!
186//!     Ok(())
187//! }
188//! ```
189//!
190//! `test_scalar` is similar to the following basic example in Python/SymPy:
191//!
192//! ```python
193//! from symjit import compile_func
194//! from sympy import symbols
195//!
196//! x, y = symbols('x y')
197//! f = compile_func([x, y], [x+y, x*y])
198//! print(f(3.0, 5.0))  # prints [8.0, 15.0]
199//! ```
200//!
201//! A more elaborate example, showcasing having a parameter, changing the
202//! optimization level, and using SIMD:
203//!
204//! ```rust
205//! use anyhow::Result;
206//! use symjit::{var, Compiler, Expr};
207//!
208//! pub fn test_simd() -> Result<()> {
209//!     use std::arch::x86_64::_mm256_loadu_pd;
210//!
211//!     let x = var("x");   // note var instead of Expr::var
212//!     let p = var("p");   // the parameter
213//!
214//!     let u = &x.square() * &p;    // x^2 * p
215//!     let mut comp = Compiler::new();
216//!     comp.opt_level(2);  // optional (opt_level 0 to 2; default 1)
217//!     let mut app = comp.compile_params(&[x], &[u], &[p])?;
218//!
219//!     let a = &[1.0, 2.0, 3.0, 4.0];
220//!     let a = unsafe { _mm256_loadu_pd(a.as_ptr()) };
221//!     let res = app.call_simd_params(&[a], &[5.0])?;
222//!     println!("{:?}", &res);   // prints [__m256d(5.0, 20.0, 45.0, 80.0)]
223//!     Ok(())
224//! }
225//! ```
226//!
227//! # Conditional Expression and Loops
228//!
229//! Many mathematical formulas need conditional expressions (`ifelse`) and loops.
230//! Following SymPy, Symjit uses reduction loops such as `sum` and `prod`. The following
231//! example returns the exponential functions:
232//!
233//! ```rust
234//! use symjit::{int, var, Compiler};
235//!
236//! fn test_exp() -> Result<()> {
237//!     let x = var("x");
238//!     let i = var("i");   // loop variable
239//!     let j = var("j");   // loop variable
240//!
241//!     // u = x^j / factorial(j) for j in j in 0..=50
242//!     let u = x
243//!         .pow(&j)
244//!         .div(&i.prod(&i, &int(1), &j))
245//!         .sum(&j, &int(0), &int(50));
246//!
247//!     let mut app = Compiler::new().compile(&[x], &[u])?;
248//!     println!("{:?}", app(&[2.0])[0]); // returns exp(2.0) = 7.38905...
249//!     Ok(())
250//! }
251//! ```
252//!
253//! An example showing how to calculate pi using the Leibniz formula:
254//!
255//! ```rust
256//! use symjit::{int, var, Compiler};
257//!
258//! fn test_pi() -> Result<()> {
259//!     let n = var("n");
260//!     let i = var("i");   // loop variable
261//!     let j = var("j");   // loop variable
262//!
263//!     // numer = if j % 2 == 0 { 4 } else { -4 }
264//!     let numer = j.rem(&int(2)).eq(&int(0)).ifelse(&int(4), &int(-4));
265//!     // denom = j * 2 + 1
266//!     let denom = j.mul(&int(2)).add(&int(1));
267//!     // v = numer / denom for j in 0..=n
268//!     let v = (&numer / &denom).sum(&j, &int(0), &int(&n));
269//!
270//!     let mut app = Compiler::new().compile(&[x], &[v])?;
271//!     println!("{:?}", app(&[100000000])[0]); // returns pi
272//!     Ok(())
273//! }
274//! ```
275//!
276//! Note that here we are using explicit functions (`add`, `mul`, ...) instead of
277//! the overloaded operators for clarity.
278//!
279//! # Fast Functions
280//!
281//! `Application`'s call functions need to copy the input slice into the function
282//! memory area and then copy the output to a `Vec`. This process is acceptable
283//! for large and complex functions but incurs a penalty for small ones.
284//! Therefore, for a certain subset of applications, Symjit can compile to a
285//! *fast function* and return a function pointer. Examples:
286//!
287//! ```rust
288//! use anyhow::Result;
289//! use symjit::{int, var, Compiler, FastFunc};
290//!
291//! fn test_fast() -> Result<()> {
292//!     let x = var("x");
293//!     let y = var("y");
294//!     let z = var("z");
295//!     let u = &x * &(&y - &z).pow(&int(2));    // x * (y - z)^2
296//!
297//!     let mut comp = Compiler::new();
298//!     let mut app = comp.compile(&[x, y, z], &[u])?;
299//!     let f = app.fast_func()?;
300//!
301//!     if let FastFunc::F3(f, _) = f {
302//!         // f is of type extern "C" fn(f64, f64, f64) -> f64
303//!         let res = f(3.0, 5.0, 9.0);
304//!         println!("fast\t{:?}", &res);
305//!     }
306//!
307//!     Ok(())
308//! }
309//! ```
310//!
311//! The conditions for a fast function are:
312//!
313//! * A fast function can have 1 to 8 arguments.
314//! * No SIMD and no parameters.
315//! * It returns only a single value.
316//!
317//! If these conditions are met, you can generate a fast function by calling
318//! `app.fast_func()`, which returns a `Result<FastFunc>`. `FastFunc` is an
319//! enum with eight variants `F1`, `F2`, ..., `F8`, corresponding to functions
320//! with 1 to 8 arguments.
321//!
322//! # User-Defined Functions
323//!
324//! Symjit functions can call into user-defined Rust functions. Currently,
325//! only the following function signatures are accepted:
326//!
327//! ```rust
328//! pub type UnaryFunc = extern "C" fn(f64) -> f64;
329//! pub type BinaryFunc = extern "C" fn(f64, f64) -> f64;
330//! ```
331//!
332//! For example:
333//!
334//! ```rust
335//! extern "C" fn f(x: f64) -> f64 {
336//!     x.exp()
337//! }
338//!
339//! extern "C" fn g(x: f64, y: f64) -> f64 {
340//!     x.ln() * y
341//! }
342//!
343//! fn test_external() -> Result<()> {
344//!     let x = Expr::var("x");
345//!     let u = Expr::unary("f_", &x);
346//!     let v = &x * &Expr::binary("g_", &u, &x);
347//!
348//!     // v(x) = x * (ln(exp(x)) * x) = x ^ 3
349//!
350//!     let mut comp = Compiler::new();
351//!     comp.def_unary("f_", f);
352//!     comp.def_binary("g_", g);
353//!     let mut app = comp.compile(&[x], &[v])?;
354//!     println!("{:?}", app.call(&[5.0])[0]);
355//!
356//!     Ok(())
357//! }
358//! ```
359//!
360//! # Dynamic Expressions
361//!
362//! All the examples up to this point use static expressions. Of course, it
363//! would have been easier just to use Rust expressions for these examples!
364//! The main utility of Symjit for Rust is for dynamic code generation. Here,
365//! we provide a simple example to calculate pi using Viete's formula
366//! (<https://en.wikipedia.org/wiki/Vi%C3%A8te%27s_formula>):
367//!
368//! ```rust
369//! fn test_pi_viete(silent: bool) -> Result<()> {
370//!     let x = var("x");
371//!     let mut u = int(1);
372//!
373//!     for i in 0..50 {
374//!         let mut t = x.clone();
375//!
376//!         for _ in 0..i {
377//!             t = &x + &(&x * &t.sqrt());
378//!         }
379//!
380//!         u = &u * &t.sqrt();
381//!     }
382//!
383//!     // u has 1275 = 50 * 51 / 2 sqrt operations
384//!     let mut app = Compiler::new().compile(&[x], &[&int(2) / &u])?;
385//!     println!("pi = \t{:?}", app.call(&[0.5])[0]);
386//!     Ok(())
387//! }
388//! ```
389//!
390//! # C-Interface
391//!
392//! In addition to `Compiler`, this crate provides a C-style interface
393//! used by the Python (<https://github.com/siravan/symjit>) and Julia
394//! (<https://github.com/siravan/Symjit.jl>) packages. This interface
395//! is composed of crate functions like `compile`, `execute`, and
396//! `ptr_states`,..., and is not needed by the Rust interface but can be
397//! used to link symjit to other programming languages.
398//!
399
400use std::collections::HashSet;
401use std::ffi::{c_char, CStr, CString};
402use std::fmt::Debug;
403use std::str::FromStr;
404
405mod allocator;
406mod amd;
407mod applet;
408mod arm;
409mod assembler;
410mod block;
411mod builder;
412mod code;
413mod compactor;
414pub mod compiler;
415mod complexify;
416mod composer;
417mod config;
418mod defuns;
419pub mod expr;
420mod generator;
421pub mod instruction;
422mod machine;
423mod matrix;
424mod memory;
425mod mir;
426mod model;
427mod node;
428mod parser;
429mod runnable;
430mod statement;
431mod symbol;
432mod types;
433mod utils;
434
435#[allow(non_upper_case_globals)]
436mod riscv64;
437
438use matrix::Matrix;
439use model::{CellModel, Program};
440
441pub use applet::Applet;
442pub use compiler::{Compiler, FastFunc, Translator};
443pub use composer::{Composer, Transliterator};
444pub use config::Config;
445pub use defuns::Defuns;
446pub use expr::{double, int, var, Expr};
447pub use instruction::{BuiltinSymbol, Instruction, Slot, SymbolicaModel};
448pub use num_complex::{Complex, ComplexFloat};
449pub use runnable::{Application, CompilerType};
450pub use types::{ElemType, Element};
451pub use utils::{Compiled, Storage};
452
453#[derive(Debug, Clone, Copy)]
454pub enum CompilerStatus {
455    Ok,
456    Incomplete,
457    InvalidUtf8,
458    ParseError,
459    InvalidCompiler,
460    CompilationError,
461}
462
463pub struct CompilerResult {
464    app: Option<Application>,
465    status: CompilerStatus,
466    msg: CString,
467}
468
469fn error_message<E: Debug>(msg: &str, err: E) -> CString {
470    let s = format!("{:?}: {:?}", msg, err);
471    CString::from_str(&s).unwrap()
472}
473
474/// Compiles a model.
475///
476/// * `model` is a json string encoding the model.
477/// * `ty` is the requested arch (amd, arm, native, or bytecode).
478/// * `opt`: compilation options.
479/// * `df`: user-defined functions.
480///
481/// # Safety
482///     * both model and ty are pointers to null-terminated strings.
483///     * The output is a raw pointer to a CompilerResults.
484///
485#[no_mangle]
486pub unsafe extern "C" fn compile(
487    model: *const c_char,
488    ty: *const c_char,
489    opt: u32,
490    df: *const Defuns,
491) -> *const CompilerResult {
492    let mut res = CompilerResult {
493        app: None,
494        status: CompilerStatus::Incomplete,
495        msg: CString::from_str("Success").unwrap(),
496    };
497
498    let model = unsafe {
499        match CStr::from_ptr(model).to_str() {
500            Ok(model) => model,
501            Err(msg) => {
502                res.status = CompilerStatus::InvalidUtf8;
503                res.msg = error_message("Invalid encoding", msg);
504                return Box::into_raw(Box::new(res)) as *const _;
505            }
506        }
507    };
508
509    let ty = unsafe {
510        match CStr::from_ptr(ty).to_str() {
511            Ok(ty) => ty,
512            Err(msg) => {
513                res.status = CompilerStatus::InvalidUtf8;
514                res.msg = error_message("Invalid compiler type", msg);
515                return Box::into_raw(Box::new(res)) as *const _;
516            }
517        }
518    };
519
520    let ml = match CellModel::load(model) {
521        Ok(ml) => ml,
522        Err(msg) => {
523            res.status = CompilerStatus::ParseError;
524            res.msg = error_message("Cannot parse JSON", msg);
525            return Box::into_raw(Box::new(res)) as *const _;
526        }
527    };
528
529    if let Ok(mut config) = Config::from_name(ty, opt) {
530        let df: Defuns = unsafe {
531            if df.is_null() {
532                Defuns::new()
533            } else {
534                (&*df).clone()
535            }
536        };
537
538        config.set_defuns(df);
539
540        let prog = match Program::new(&ml, config) {
541            Ok(prog) => prog,
542            Err(msg) => {
543                res.status = CompilerStatus::CompilationError;
544                res.msg = error_message("Compilation error (prog)", msg);
545                return Box::into_raw(Box::new(res)) as *const _;
546            }
547        };
548
549        let app = Application::new(prog, HashSet::new());
550
551        match app {
552            Ok(app) => {
553                res.status = CompilerStatus::Ok;
554                res.app = Some(app);
555            }
556            Err(msg) => {
557                res.status = CompilerStatus::CompilationError;
558                res.msg = error_message("Compilation error (app)", &msg);
559            }
560        }
561    } else {
562        res.status = CompilerStatus::InvalidCompiler;
563        res.msg = error_message("Config error", opt);
564    }
565
566    Box::into_raw(Box::new(res)) as *const _
567}
568
569/// Translates a Symbolica model.
570///
571/// * `json` is a json string encoding the output of `export_instructions`.
572/// * `ty` is the requested arch (amd, arm, native, or bytecode).
573/// * `opt`: compilation options.
574/// * `df`: user-defined functions (currently ignored).
575///
576/// # Safety
577///     * both model and ty are pointers to null-terminated strings.
578///     * The output is a raw pointer to a CompilerResults.
579///
580#[no_mangle]
581pub unsafe extern "C" fn translate(
582    json: *const c_char,
583    ty: *const c_char,
584    opt: u32,
585    df: *mut Defuns,
586    num_params: usize,
587) -> *const CompilerResult {
588    let mut res = CompilerResult {
589        app: None,
590        status: CompilerStatus::Incomplete,
591        msg: CString::from_str("Success").unwrap(),
592    };
593
594    let json = unsafe {
595        match CStr::from_ptr(json).to_str() {
596            Ok(json) => json,
597            Err(msg) => {
598                res.status = CompilerStatus::InvalidUtf8;
599                res.msg = error_message("Invalid encoding", msg);
600                return Box::into_raw(Box::new(res)) as *const _;
601            }
602        }
603    };
604
605    let ty = unsafe {
606        match CStr::from_ptr(ty).to_str() {
607            Ok(ty) => ty,
608            Err(msg) => {
609                res.status = CompilerStatus::InvalidUtf8;
610                res.msg = error_message("Invalid compiler type", msg);
611                return Box::into_raw(Box::new(res)) as *const _;
612            }
613        }
614    };
615
616    if let Ok(config) = Config::from_name(ty, opt) {
617        let df: Defuns = unsafe {
618            if df.is_null() {
619                Defuns::new()
620            } else {
621                (&*df).clone()
622            }
623        };
624
625        let mut comp = Compiler::with_config(config);
626        let app = comp.translate(json.to_string(), df, num_params);
627
628        match app {
629            Ok(app) => {
630                res.app = Some(app);
631                res.status = CompilerStatus::Ok;
632            }
633            Err(msg) => {
634                res.status = CompilerStatus::InvalidCompiler;
635                res.msg = error_message("Compilation error", msg);
636            }
637        }
638    } else {
639        res.status = CompilerStatus::InvalidCompiler;
640        res.msg = error_message("Config error", opt);
641    }
642
643    Box::into_raw(Box::new(res)) as *const _
644}
645
646/// Checks the status of a `CompilerResult`.
647///
648/// Returns a null-terminated string representing the status message.
649///
650/// # Safety
651///     it is the responsibility of the calling function to ensure
652///     that q points to a valid CompilerResult.
653///
654#[no_mangle]
655pub unsafe extern "C" fn check_status(q: *const CompilerResult) -> *const c_char {
656    let q: &CompilerResult = unsafe { &*q };
657    q.msg.as_ptr() as *const _
658}
659
660/// Checks the status of a `CompilerResult`.
661///
662/// Returns a null-terminated string representing the status message.
663///
664/// # Safety
665///     it is the responsibility of the calling function to ensure
666///     that q points to a valid CompilerResult.
667///
668#[no_mangle]
669pub unsafe extern "C" fn save(q: *const CompilerResult, file: *const c_char) -> bool {
670    let q: &CompilerResult = unsafe { &*q };
671    let file = unsafe {
672        match CStr::from_ptr(file).to_str() {
673            Ok(file) => file,
674            Err(_) => return false,
675        }
676    };
677
678    if let Some(app) = &q.app {
679        if let Ok(mut fs) = std::fs::File::create(file) {
680            app.save(&mut fs).is_ok()
681        } else {
682            false
683        }
684    } else {
685        false
686    }
687}
688
689/// Checks the status of a `CompilerResult`.
690///
691/// Returns a null-terminated string representing the status message.
692///
693/// # Safety
694///     it is the responsibility of the calling function to ensure
695///     that q points to a valid CompilerResult.
696///
697#[no_mangle]
698pub unsafe extern "C" fn load(file: *const c_char) -> *const CompilerResult {
699    let mut res = CompilerResult {
700        app: None,
701        status: CompilerStatus::Incomplete,
702        msg: CString::from_str("Success").unwrap(),
703    };
704
705    let file = unsafe {
706        match CStr::from_ptr(file).to_str() {
707            Ok(file) => file,
708            Err(_) => return Box::into_raw(Box::new(res)) as *const _,
709        }
710    };
711
712    let fs = std::fs::File::open(file);
713
714    match fs {
715        Ok(mut fs) => match Application::load(&mut fs) {
716            Ok(app) => {
717                res.app = Some(app);
718                res.status = CompilerStatus::Ok;
719            }
720            Err(err) => {
721                res.status = CompilerStatus::ParseError;
722                res.msg = error_message("File parse error", &err);
723            }
724        },
725        Err(err) => {
726            res.msg = error_message("File I/O error", &err);
727        }
728    }
729
730    Box::into_raw(Box::new(res)) as *const _
731}
732
733/// Checks the status of a `CompilerResult`.
734///
735/// Returns a null-terminated string representing the status message.
736///
737/// # Safety
738///     it is the responsibility of the calling function to ensure
739///     that q points to a valid CompilerResult.
740///
741#[no_mangle]
742pub unsafe extern "C" fn get_config(q: *const CompilerResult) -> usize {
743    let q: &CompilerResult = unsafe { &*q };
744
745    match &q.app {
746        Some(app) => {
747            let config = app.prog.config();
748
749            let ty: usize = match config.ty {
750                CompilerType::Native => 0,
751                CompilerType::Amd => 1,
752                CompilerType::AmdAVX => 2,
753                CompilerType::AmdSSE => 3,
754                CompilerType::Arm => 4,
755                CompilerType::RiscV => 5,
756                CompilerType::ByteCode => 6,
757                CompilerType::Debug => 7,
758            };
759
760            (config.opt as usize) | (ty << 32)
761        }
762        None => 0,
763    }
764}
765
766/// Returns the number of state variables.
767///
768/// # Safety
769///     it is the responsibility of the calling function to ensure
770///     that q points to a valid CompilerResult.
771///
772#[no_mangle]
773pub unsafe extern "C" fn count_states(q: *const CompilerResult) -> usize {
774    let q: &CompilerResult = unsafe { &*q };
775    if let Some(app) = &q.app {
776        app.count_states
777    } else {
778        0
779    }
780}
781
782/// Returns the number of parameters.
783///
784/// # Safety
785///     it is the responsibility of the calling function to ensure
786///     that q points to a valid CompilerResult.
787///
788#[no_mangle]
789pub unsafe extern "C" fn count_params(q: *const CompilerResult) -> usize {
790    let q: &CompilerResult = unsafe { &*q };
791    if let Some(app) = &q.app {
792        app.count_params
793    } else {
794        0
795    }
796}
797
798/// Returns the number of observables (output).
799///
800/// # Safety
801///     it is the responsibility of the calling function to ensure
802///     that q points to a valid CompilerResult.
803///
804#[no_mangle]
805pub unsafe extern "C" fn count_obs(q: *const CompilerResult) -> usize {
806    let q: &CompilerResult = unsafe { &*q };
807    if let Some(app) = &q.app {
808        app.count_obs
809    } else {
810        0
811    }
812}
813
814/// Returns the number of differential equations.
815///
816/// Generally, it should be the same as the number of states.
817///
818/// # Safety
819///     it is the responsibility of the calling function to ensure
820///     that q points to a valid CompilerResult.
821///
822#[no_mangle]
823pub unsafe extern "C" fn count_diffs(q: *const CompilerResult) -> usize {
824    let q: &CompilerResult = unsafe { &*q };
825    if let Some(app) = &q.app {
826        app.count_diffs
827    } else {
828        0
829    }
830}
831
832/// Deprecated. Previously used for interfacing to DifferentialEquation.jl. It is
833/// replaced with <https://github.com/siravan/SymJit.jl>.
834///
835/// # Safety
836///
837/// Deprecated. No effects.
838#[no_mangle]
839pub unsafe extern "C" fn run(
840    _q: *mut CompilerResult,
841    _du: *mut f64,
842    _u: *const f64,
843    _ns: usize,
844    _p: *const f64,
845    _np: usize,
846    _t: f64,
847) -> bool {
848    // let q: &mut CompilerResult = unsafe { &mut *q };
849
850    // if let Some(app) = &mut q.app {
851    //     if app.count_states != ns || app.count_params != np {
852    //         return false;
853    //     }
854
855    //     let du: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(du, ns) };
856    //     let u: &[f64] = unsafe { std::slice::from_raw_parts(u, ns) };
857    //     let p: &[f64] = unsafe { std::slice::from_raw_parts(p, np) };
858    //     app.call(du, u, p, t);
859    //     true
860    // } else {
861    //     false
862    // }
863    false
864}
865
866/// Executes the compiled function.
867///
868/// The calling routine should fill the states and parameters before
869/// calling `execute`. The result populates obs or diffs (as defined in
870/// model passed to `compile`).
871///
872/// # Safety
873///     it is the responsibility of the calling function to ensure
874///     that q points to a valid CompilerResult.
875///
876#[no_mangle]
877pub unsafe extern "C" fn execute(q: *mut CompilerResult) -> bool {
878    let q: &mut CompilerResult = unsafe { &mut *q };
879
880    if let Some(app) = &mut q.app {
881        app.exec();
882        true
883    } else {
884        false
885    }
886}
887
888/// Executes the compiled function `n` times (vectorized).
889///
890/// The calling function provides `buf`, which is a k x n matrix of doubles,
891/// where k is equal to the `maximum(count_states, count_obs)`. The calling
892/// funciton fills the first `count_states` rows of buf. The result is returned
893/// in the first count_obs rows of buf.
894///
895/// # Safety
896///     it is the responsibility of the calling function to ensure
897///     that q points to a valid CompilerResult.
898///
899///     In addition, buf should points to a valid matrix of correct size.
900///
901#[no_mangle]
902pub unsafe extern "C" fn execute_vectorized(
903    q: *mut CompilerResult,
904    buf: *mut f64,
905    n: usize,
906) -> bool {
907    let q: &mut CompilerResult = unsafe { &mut *q };
908
909    if let Some(app) = &mut q.app {
910        let h = usize::max(app.count_states, app.count_obs);
911        let p: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(buf, h * n) };
912        let mut states = Matrix::from_buf(p, h, n);
913        let p: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(buf, h * n) };
914        let mut obs = Matrix::from_buf(p, h, n);
915        app.exec_vectorized(&mut states, &mut obs);
916        true
917    } else {
918        false
919    }
920}
921
922/// Evaluates the compiled function. This is for Symbolica compatibility.
923///
924/// # Safety
925///     it is the responsibility of the calling function to ensure
926///     that q points to a valid CompilerResult.
927///
928#[no_mangle]
929pub unsafe extern "C" fn evaluate(
930    q: *mut CompilerResult,
931    args: *const f64,
932    nargs: usize,
933    outs: *mut f64,
934    nouts: usize,
935) -> bool {
936    let q: &mut CompilerResult = unsafe { &mut *q };
937
938    if let Some(app) = &mut q.app {
939        let args: &[f64] = unsafe { std::slice::from_raw_parts(args, nargs) };
940        let outs: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(outs, nouts) };
941        app.evaluate(args, outs);
942        true
943    } else {
944        false
945    }
946}
947
948/// Evaluates the compiled function. This is for Symbolica compatibility.
949///
950/// # Safety
951///     it is the responsibility of the calling function to ensure
952///     that q points to a valid CompilerResult.
953///
954#[no_mangle]
955pub unsafe extern "C" fn evaluate_matrix(
956    q: *mut CompilerResult,
957    args: *const f64,
958    nargs: usize,
959    outs: *mut f64,
960    nouts: usize,
961) -> bool {
962    let q: &mut CompilerResult = unsafe { &mut *q };
963
964    if let Some(app) = &mut q.app {
965        if app.count_params == 0 {
966            return false;
967        }
968
969        let args: &[f64] = unsafe { std::slice::from_raw_parts(args, nargs) };
970        let outs: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(outs, nouts) };
971        let n = nargs / app.count_params;
972        app.evaluate_matrix(args, outs, n);
973        true
974    } else {
975        false
976    }
977}
978
979/// Returns a pointer to the state variables (`count_states` doubles).
980///
981/// The function calling `execute` should write the state variables in this area.
982///
983/// # Safety
984///     it is the responsibility of the calling function to ensure
985///     that q points to a valid CompilerResult.
986///
987#[no_mangle]
988pub unsafe extern "C" fn ptr_states(q: *mut CompilerResult) -> *mut f64 {
989    let q: &mut CompilerResult = unsafe { &mut *q };
990    if let Some(app) = &mut q.app {
991        if let Some(f) = &mut app.compiled {
992            &mut f.mem_mut()[app.first_state] as *mut f64
993        } else {
994            &mut app.bytecode.mem_mut()[app.first_state] as *mut f64
995        }
996    } else {
997        std::ptr::null_mut()
998    }
999}
1000
1001/// Returns a pointer to the parameters (`count_params` doubles).
1002///
1003/// The function calling `execute` should write the parameters in this area.
1004///
1005/// # Safety
1006///     it is the responsibility of the calling function to ensure
1007///     that q points to a valid CompilerResult.
1008///
1009#[no_mangle]
1010pub unsafe extern "C" fn ptr_params(q: *mut CompilerResult) -> *mut f64 {
1011    let q: &mut CompilerResult = unsafe { &mut *q };
1012    if let Some(app) = &mut q.app {
1013        //&mut app.compiled.mem_mut()[app.first_param] as *mut f64
1014        &mut app.params[app.first_param] as *mut f64
1015    } else {
1016        std::ptr::null_mut()
1017    }
1018}
1019
1020/// Returns a pointer to the observables (`count_obs` doubles).
1021///
1022/// The function calling `execute` reads the observables from this area.
1023///
1024/// # Safety
1025///     it is the responsibility of the calling function to ensure
1026///     that q points to a valid CompilerResult.
1027///
1028#[no_mangle]
1029pub unsafe extern "C" fn ptr_obs(q: *mut CompilerResult) -> *const f64 {
1030    let q: &CompilerResult = unsafe { &*q };
1031    if let Some(app) = &q.app {
1032        if let Some(f) = &app.compiled {
1033            &f.mem()[app.first_obs] as *const f64
1034        } else {
1035            &app.bytecode.mem()[app.first_obs] as *const f64
1036        }
1037    } else {
1038        std::ptr::null()
1039    }
1040}
1041
1042/// Returns a pointer to the differentials (`count_diffs` doubles).
1043///
1044/// The function calling `execute` reads the differentials from this area.
1045///
1046/// Note: whether the output is returned as observables or differentials is
1047/// defined in the model.
1048///
1049/// # Safety
1050///     it is the responsibility of the calling function to ensure
1051///     that q points to a valid CompilerResult.
1052///
1053#[no_mangle]
1054pub unsafe extern "C" fn ptr_diffs(q: *mut CompilerResult) -> *const f64 {
1055    let q: &CompilerResult = unsafe { &*q };
1056    if let Some(app) = &q.app {
1057        if let Some(f) = &app.compiled {
1058            &f.mem()[app.first_diff] as *const f64
1059        } else {
1060            &app.bytecode.mem()[app.first_diff] as *const f64
1061        }
1062    } else {
1063        std::ptr::null()
1064    }
1065}
1066
1067/// Dumps the compiled binary code to a file (`name`).
1068///
1069/// This function is useful for debugging but is not necessary for
1070/// normal operations.
1071///
1072/// # Safety
1073///     it is the responsibility of the calling function to ensure
1074///     that q points to a valid CompilerResult.
1075///
1076#[no_mangle]
1077pub unsafe extern "C" fn dump(
1078    q: *mut CompilerResult,
1079    name: *const c_char,
1080    what: *const c_char,
1081) -> bool {
1082    let q: &mut CompilerResult = unsafe { &mut *q };
1083    if let Some(app) = &mut q.app {
1084        let name = unsafe { CStr::from_ptr(name).to_str().unwrap() };
1085        let what = unsafe { CStr::from_ptr(what).to_str().unwrap() };
1086        app.dump(name, what)
1087    } else {
1088        false
1089    }
1090}
1091
1092/// Deallocates the CompilerResult pointed by `q`.
1093///
1094/// # Safety
1095///     it is the responsibility of the calling function to ensure
1096///     that q points to a valid CompilerResult and that after
1097///     calling this function, q is invalid and should not
1098///     be used anymore.
1099///
1100#[no_mangle]
1101pub unsafe extern "C" fn finalize(q: *mut CompilerResult) {
1102    if !q.is_null() {
1103        let _ = unsafe { Box::from_raw(q) };
1104    }
1105}
1106
1107/// Returns a null-terminated string representing the version.
1108///
1109/// Used for debugging.
1110///
1111/// # Safety
1112///     the return value is a null-terminated string that should not
1113///     be freed.
1114///
1115#[no_mangle]
1116pub unsafe extern "C" fn info() -> *const c_char {
1117    // let msg = c"symjit 1.3.3";
1118    let msg = CString::new(env!("CARGO_PKG_VERSION")).unwrap();
1119    msg.into_raw() as *const _
1120}
1121
1122/// Returns a pointer to the fast function if one can be compiled.
1123///
1124/// # Safety
1125///     1. If the model cannot be compiled to a fast function, NULL is returned.
1126///     2. A fast function code memory is leaked and is not deallocated.
1127///
1128#[no_mangle]
1129pub unsafe extern "C" fn fast_func(q: *mut CompilerResult) -> *const usize {
1130    let q: &mut CompilerResult = unsafe { &mut *q };
1131    if let Some(app) = &mut q.app {
1132        match app.get_fast() {
1133            Some(f) => f as *const usize,
1134            None => std::ptr::null(),
1135        }
1136    } else {
1137        std::ptr::null()
1138    }
1139}
1140
1141/// Interface for Sympy's LowLevelCallable.
1142///
1143/// # Safety
1144///     1. If the model cannot be compiled to a fast function, NULL is returned.
1145///     2. The resulting function lives as long as q does and should not be stored
1146///         separately.
1147///
1148#[no_mangle]
1149pub unsafe extern "C" fn callable_quad(n: usize, xx: *const f64, q: *mut CompilerResult) -> f64 {
1150    let q: &mut CompilerResult = unsafe { &mut *q };
1151    let xx: &[f64] = unsafe { std::slice::from_raw_parts(xx, n) };
1152
1153    if let Some(app) = &mut q.app {
1154        app.exec_callable(xx)
1155    } else {
1156        f64::NAN
1157    }
1158}
1159
1160/// Interface for Sympy's LowLevelCallable.
1161///
1162/// # Safety
1163///     1. If the model cannot be compiled to a fast function, NULL is returned.
1164///     2. The resulting function lives as long as q does and should not be stored
1165///         separately.
1166///
1167#[no_mangle]
1168pub unsafe extern "C" fn callable_quad_fast(n: usize, xx: *const f64, f: *const usize) -> f64 {
1169    let xx: &[f64] = unsafe { std::slice::from_raw_parts(xx, n) };
1170
1171    match n {
1172        0 => {
1173            let f: fn() -> f64 = unsafe { std::mem::transmute(f) };
1174            f()
1175        }
1176        1 => {
1177            let f: fn(f64) -> f64 = unsafe { std::mem::transmute(f) };
1178            f(xx[0])
1179        }
1180        2 => {
1181            let f: fn(f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1182            f(xx[0], xx[1])
1183        }
1184        3 => {
1185            let f: fn(f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1186            f(xx[0], xx[1], xx[2])
1187        }
1188        4 => {
1189            let f: fn(f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1190            f(xx[0], xx[1], xx[2], xx[3])
1191        }
1192        5 => {
1193            let f: fn(f64, f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1194            f(xx[0], xx[1], xx[2], xx[3], xx[4])
1195        }
1196        6 => {
1197            let f: fn(f64, f64, f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1198            f(xx[0], xx[1], xx[2], xx[3], xx[4], xx[5])
1199        }
1200        7 => {
1201            let f: fn(f64, f64, f64, f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1202            f(xx[0], xx[1], xx[2], xx[3], xx[4], xx[5], xx[6])
1203        }
1204        _ => {
1205            panic!("too many parameters for a fast func");
1206        }
1207    }
1208}
1209
1210/// Interface for Sympy's LowLevelCallable (image filtering).
1211///
1212/// # Safety
1213///     1. If the model cannot be compiled to a fast function, NULL is returned.
1214///     2. The resulting function lives as long as q does and should not be stored
1215///         separately.
1216///
1217#[no_mangle]
1218pub unsafe extern "C" fn callable_filter(
1219    buffer: *const f64,
1220    filter_size: usize,
1221    return_value: *mut f64,
1222    q: *mut CompilerResult,
1223) -> i64 {
1224    let q: &mut CompilerResult = unsafe { &mut *q };
1225    let xx: &[f64] = unsafe { std::slice::from_raw_parts(buffer, filter_size) };
1226
1227    if let Some(app) = &mut q.app {
1228        let p: &mut f64 = unsafe { &mut *return_value };
1229        *p = app.exec_callable(xx);
1230        1
1231    } else {
1232        0
1233    }
1234}
1235
1236/************************************************/
1237
1238/// Creates an empty Matrix (a 2d array).
1239///
1240/// # Safety
1241///     It returns a pointer to the allocated Matrix, which needs to be
1242///     deallocated eventually.
1243///
1244#[no_mangle]
1245pub unsafe extern "C" fn create_matrix<'a>() -> *const Matrix<'a> {
1246    let mat = Matrix::new();
1247    Box::into_raw(Box::new(mat)) as *const Matrix
1248}
1249
1250/// Finalizes (deallocates) the Matrix.
1251///
1252/// # Safety
1253///     1, mat should point to a valid Matrix object created by create_matrix.
1254///     2. After finalize_matrix is called, mat is invalid.
1255///
1256#[no_mangle]
1257pub unsafe extern "C" fn finalize_matrix(mat: *mut Matrix) {
1258    if !mat.is_null() {
1259        let _ = unsafe { Box::from_raw(mat) };
1260    }
1261}
1262
1263/// Adds a row to the Matrix.
1264///
1265/// # Safety
1266///     1, mat should point to a valid Matrix object created by create_matrix.
1267///     2. v should point to a valid array of doubles of length at least n.
1268///     3. v should remains valid for the lifespan of mat.
1269///
1270#[no_mangle]
1271pub unsafe extern "C" fn add_row(mat: *mut Matrix, v: *mut f64, n: usize) {
1272    let mat: &mut Matrix = unsafe { &mut *mat };
1273    mat.add_row(v, n);
1274}
1275
1276/// Executes (runs) the matrix model encoded by `q`.
1277///
1278/// # Safety
1279///     1, q should point to a valid CompilerResult object.
1280///     2. states should point to a valid Matrix of at least count_states rows.
1281///     3. obs should point to a valid Matrix of at least count_obs rows.
1282///
1283#[no_mangle]
1284pub unsafe extern "C" fn execute_matrix(
1285    q: *mut CompilerResult,
1286    states: *mut Matrix,
1287    obs: *mut Matrix,
1288) -> bool {
1289    let q: &mut CompilerResult = unsafe { &mut *q };
1290    let states: &mut Matrix = unsafe { &mut *states };
1291    let obs: &mut Matrix = unsafe { &mut *obs };
1292
1293    if let Some(app) = &mut q.app {
1294        app.exec_vectorized(states, obs);
1295        true
1296    } else {
1297        false
1298    }
1299}
1300
1301/************************************************/
1302
1303/// Creates an empty `Defun` (a list of user-defined functions).
1304///
1305/// `Defuns` are used to pass user-defined functions (either Python
1306/// functions or symjit-compiled functions).
1307///
1308/// # Safety
1309///     It returns a pointer to the allocated Defun, which needs to be
1310///     deallocated eventually.
1311///
1312#[no_mangle]
1313pub unsafe extern "C" fn create_defuns() -> *const Defuns {
1314    let df = Defuns::new();
1315    Box::into_raw(Box::new(df)) as *const Defuns
1316}
1317
1318/// Finalizes (deallocates) a `Defun`.
1319///
1320/// # Safety
1321///     1, df should point to a valid Defun object created by create_defuns.
1322///     2. After finalize_defun is called, df is invalid.
1323///
1324#[no_mangle]
1325pub unsafe extern "C" fn finalize_defuns(df: *mut Defuns) {
1326    // if !df.is_null() {
1327    //     let _ = unsafe { Box::from_raw(df) };
1328    // }
1329}
1330
1331/// Adds a new function to a `Defun`.
1332///
1333/// # Safety
1334///     1, df should point to a valid Defun object created by create_defun.
1335///     2. name should be a valid utf8 string.
1336///     3. p should point to a valid C-styple function pointer that accepts
1337///         num_args double arguments.
1338///
1339#[no_mangle]
1340pub unsafe extern "C" fn add_func(
1341    df: *mut Defuns,
1342    name: *const c_char,
1343    p: *const usize,
1344    num_args: usize,
1345) {
1346    let df: &mut Defuns = unsafe { &mut *df };
1347    let name = unsafe { CStr::from_ptr(name).to_str().unwrap() };
1348    df.add_func(name, p, num_args);
1349}