symjit/lib.rs
1#![allow(uncommon_codepoints)]
2
3//! Symjit (<https://github.com/siravan/symjit>) is a lightweight just-in-time (JIT)
4//! optimizer compiler for mathematical expressions written in Rust. It was originally
5//! designed to compile SymPy (Python’s symbolic algebra package) expressions
6//! into machine code and to serve as a bridge between SymPy and numerical routines
7//! provided by NumPy and SciPy libraries.
8//!
9//! Symjit crate is the core compiler coupled to a Rust interface to expose the
10//! JIT functionality to the Rust ecosystem and allow Rust applications to
11//! generate code dynamically. Considering its origin, symjit is geared toward
12//! compiling mathematical expressions instead of being a general-purpose JIT
13//! compiler. Therefore, the only supported types for variables are `f64`,
14//! (SIMD f64x4 and f64x2), and implicitly, `bool` and `i32`.
15//!
16//! Symjit emits AMD64 (x86-64), ARM64 (aarch64), and 64-bit RISC-V (riscv64) machine
17//! codes on Linux, Windows, and macOS platforms. SIMD is supported on x86-64
18//! and ARM64.
19//!
20//! In Rust, there are two ways to contruct expressions to pass to Symjit: using
21//! Symbolica or using Symjit standalone expression builder.
22//!
23//! # Symbolica
24//!
25//! Symbolica (<https://symbolica.io/>) is a fast Rust-based Computer Algebra System.
26//! Symbolica usually generate fast code using external compilers (e.g., using gcc to
27//! compiler generated c++ code). Symjit accepts Symbolica expressions and can act as
28//! an optional code-generator for Symbolica. The link between the two is through
29//! Symbolica's `export_instructions` function that exports an optimized intermediate
30//! representation. Using serde, it is possible to convert the output of `export_instructions`
31//! into JSON, which is then passed to the `translate` function of Symjit `Compiler`
32//! structure. If successful, `translate` returns an `Application` object, which wraps
33//! the compiled code and can be run using one of the six `evaluate` functions:
34//!
35//! * `evaluate(&mut self, args: &[T], outs: &mut [T])`.
36//! * `evaluate_single(&mut self, args: &[T]) -> T`.
37//! * `evaluate_matrix(&mut self, args: &[T], outs: &mut [T], nrows: usize)`.
38//! * `evaluate_simd(&mut self, args: &[S], outs: &mut [S])`.
39//! * `evaluate_simd_single(&mut self, args: &[S]) -> S`.
40//! * `evaluate_simd_matrix(&mut self, args: &[S], outs: &mut [S], nrows: usize)`.
41//!
42//! where `T` is either `f64` or `Complex<f64>` and `S` is `f64x64` on x86-64 or `f64x2`
43//! on aarch64, or the complex version of them.
44//!
45//! /// Examples:
46//!
47//! ```rust
48//! use anyhow::Result;
49//! use symjit::{Compiler, Config};
50//! use symbolica::{atom::AtomCore, parse, symbol};
51//! use symbolica::evaluate::{FunctionMap, OptimizationSettings};
52//!
53//! fn test1() -> Result<()> {
54//! let params = vec![parse!("x"), parse!("y")];
55//! let eval = parse!("x + y^2")
56//! .evaluator(
57//! &FunctionMap::new(),
58//! ¶ms,
59//! OptimizationSettings::default(),
60//! )
61//! .unwrap();
62//!
63//! let json = serde_json::to_string(&eval.export_instructions())?;
64//! let mut comp = Compiler::new();
65//! let mut app = comp.translate(&json)?;
66//! assert!(app.evaluate_single(&[2.0, 3.0]) == 11.0);
67//! Ok(())
68//! }
69//! ```
70//!
71//! Note that Symbolica needs to be imported by `features = ["serde"]` to allow for
72//! applying `serde_json::to_string` to the output of `export_instructions`.
73//!
74//! To change compilation options, one passes a `Config` struct to the `Compiler`
75//! constructor. The following example shows how to compile for complex number.
76//!
77//! ```rust
78//! use anyhow::Result;
79//! use num_complex::Complex;
80//! use symjit::{Compiler, Config};
81//! use symbolica::{atom::AtomCore, parse, symbol};
82//! use symbolica::evaluate::{FunctionMap, OptimizationSettings};
83//!
84//! fn test2() -> Result<()> {
85//! let params = vec![parse!("x"), parse!("y")];
86//! let eval = parse!("x + y^2")
87//! .evaluator(
88//! &FunctionMap::new(),
89//! ¶ms,
90//! OptimizationSettings::default(),
91//! )
92//! .unwrap();
93//!
94//! let json = serde_json::to_string(&eval.export_instructions())?;
95//! let mut config = Config::default();
96//! config.set_complex(true);
97//! let mut comp = Compiler::with_config(config);
98//! let mut app = comp.translate(&json)?;
99//! let v = vec![Complex::new(2.0, 1.0), Complex::new(-1.0, 3.0)];
100//! assert!(app.evaluate_single(&v) == Complex::new(-6.0, -5.0));
101//! Ok(())
102//! }
103//! ```
104//!
105//! Currently, Symjit supports most of Symbolica's expressions with the exception of
106//! external user-defined functions. However, it is possible to link to Symjit
107//! numerical functions (see below) by defining their name using `add_external_function`.
108//! The following example shows how to link to `sinh` function:
109//!
110//!
111//! ```rust
112//! use anyhow::Result;
113//! use symjit::{Compiler, Config};
114//! use symbolica::{atom::AtomCore, parse, symbol};
115//! use symbolica::evaluate::{FunctionMap, OptimizationSettings};
116//!
117//! fn test3() -> Result<()> {
118//! let params = vec![parse!("x")];
119//!
120//! let mut f = FunctionMap::new();
121//! f.add_external_function(symbol!("sinh"), "sinh".to_string())
122//! .unwrap();
123//!
124//! let eval = parse!("sinh(x)")
125//! .evaluator(&f, ¶ms, OptimizationSettings::default())
126//! .unwrap();
127//!
128//! let json = serde_json::to_string(&eval.export_instructions())?;
129//! let mut comp = Compiler::new();
130//! let mut app = comp.translate(&json)?;
131//! assert!(app.evaluate_single(&[1.5]) == f64::sinh(1.5));
132//! Ok(())
133//! }
134//! ```
135//!
136//! # Standalone Expression Builder
137//!
138//! A second way to use Symjit is by using its standalone expression builder. Compared to
139//! Symbolica, the expression builder is limited but is useful in situations that the goal
140//! is to compile an expression without extensive symbolic manipulations.
141//!
142//! The workflow to create, compile, and run expressions is:
143//!
144//! 1. Create terminals (variables and constants) and compose expressions using `Expr` methods:
145//! * Constructors: `var`, `from`, `unary`, `binary`, ...
146//! * Standard algebraic operations: `add`, `mul`, ...
147//! * Standard operators `+`, `-`, `*`, `/`, `%`, `&`, `|`, `^`, `!`.
148//! * Unary functions such as `sin`, `exp`, and other standard mathematical functions.
149//! * Binary functions such as `pow`, `min`, ...
150//! * IfElse operation `ifelse(cond, true_val, false_val)`.
151//! * Heaviside function: `heaviside(x)`, which returns 1 if `x >= 0`; otherwise 0.
152//! * Comparison methods `eq`, `ne`, `lt`, `le`, `gt`, and `ge`.
153//! * Looping constructs `sum` and `prod`.
154//! 2. Create a new `Compiler` object (say, `comp`) using one of its constructors.
155//! 3. Define user-defined functions by calling `comp.def_unary` and `comp.def_binary`
156//! (optional).
157//! 4. Compile by calling `comp.compile` or `comp.compile_params`. The result is of
158//! type `Application` (say, `app`).
159//! 5. Execute the compiled code using one of the `app`'s `call` functions:
160//! * `call(&[f64])`: scalar call.
161//! * `call_params(&[f64], &[f64])`: scalar call with parameters.
162//! * `call_simd(&[__m256d])`: simd call.
163//! * `call_simd_params(&[__m256d], &[f64])`: simd call with parameters.
164//! 6. Optionally, generate a standalone fast function to execute.
165//!
166//! Note that you can use the helper functions `var(&str) -> Expr`, `int(i32) -> Expr`,
167//! `double(f64) -> Expr`, and `boolean(bool) -> f64` to reduce clutter.
168//!
169//! # Examples
170//!
171//! ```rust
172//! use anyhow::Result;
173//! use symjit::{Compiler, Expr};
174//!
175//! pub fn test_scalar() -> Result<()> {
176//! let x = Expr::var("x");
177//! let y = Expr::var("y");
178//! let u = &x + &y;
179//! let v = &x * &y;
180//!
181//! let mut comp = Compiler::new();
182//! let mut app = comp.compile(&[x, y], &[u, v])?;
183//! let res = app.call(&[3.0, 5.0]);
184//! println!("{:?}", &res); // prints [8.0, 15.0]
185//!
186//! Ok(())
187//! }
188//! ```
189//!
190//! `test_scalar` is similar to the following basic example in Python/SymPy:
191//!
192//! ```python
193//! from symjit import compile_func
194//! from sympy import symbols
195//!
196//! x, y = symbols('x y')
197//! f = compile_func([x, y], [x+y, x*y])
198//! print(f(3.0, 5.0)) # prints [8.0, 15.0]
199//! ```
200//!
201//! A more elaborate example, showcasing having a parameter, changing the
202//! optimization level, and using SIMD:
203//!
204//! ```rust
205//! use anyhow::Result;
206//! use symjit::{var, Compiler, Expr};
207//!
208//! pub fn test_simd() -> Result<()> {
209//! use std::arch::x86_64::_mm256_loadu_pd;
210//!
211//! let x = var("x"); // note var instead of Expr::var
212//! let p = var("p"); // the parameter
213//!
214//! let u = &x.square() * &p; // x^2 * p
215//! let mut comp = Compiler::new();
216//! comp.opt_level(2); // optional (opt_level 0 to 2; default 1)
217//! let mut app = comp.compile_params(&[x], &[u], &[p])?;
218//!
219//! let a = &[1.0, 2.0, 3.0, 4.0];
220//! let a = unsafe { _mm256_loadu_pd(a.as_ptr()) };
221//! let res = app.call_simd_params(&[a], &[5.0])?;
222//! println!("{:?}", &res); // prints [__m256d(5.0, 20.0, 45.0, 80.0)]
223//! Ok(())
224//! }
225//! ```
226//!
227//! # Conditional Expression and Loops
228//!
229//! Many mathematical formulas need conditional expressions (`ifelse`) and loops.
230//! Following SymPy, Symjit uses reduction loops such as `sum` and `prod`. The following
231//! example returns the exponential functions:
232//!
233//! ```rust
234//! use symjit::{int, var, Compiler};
235//!
236//! fn test_exp() -> Result<()> {
237//! let x = var("x");
238//! let i = var("i"); // loop variable
239//! let j = var("j"); // loop variable
240//!
241//! // u = x^j / factorial(j) for j in j in 0..=50
242//! let u = x
243//! .pow(&j)
244//! .div(&i.prod(&i, &int(1), &j))
245//! .sum(&j, &int(0), &int(50));
246//!
247//! let mut app = Compiler::new().compile(&[x], &[u])?;
248//! println!("{:?}", app(&[2.0])[0]); // returns exp(2.0) = 7.38905...
249//! Ok(())
250//! }
251//! ```
252//!
253//! An example showing how to calculate pi using the Leibniz formula:
254//!
255//! ```rust
256//! use symjit::{int, var, Compiler};
257//!
258//! fn test_pi() -> Result<()> {
259//! let n = var("n");
260//! let i = var("i"); // loop variable
261//! let j = var("j"); // loop variable
262//!
263//! // numer = if j % 2 == 0 { 4 } else { -4 }
264//! let numer = j.rem(&int(2)).eq(&int(0)).ifelse(&int(4), &int(-4));
265//! // denom = j * 2 + 1
266//! let denom = j.mul(&int(2)).add(&int(1));
267//! // v = numer / denom for j in 0..=n
268//! let v = (&numer / &denom).sum(&j, &int(0), &int(&n));
269//!
270//! let mut app = Compiler::new().compile(&[x], &[v])?;
271//! println!("{:?}", app(&[100000000])[0]); // returns pi
272//! Ok(())
273//! }
274//! ```
275//!
276//! Note that here we are using explicit functions (`add`, `mul`, ...) instead of
277//! the overloaded operators for clarity.
278//!
279//! # Fast Functions
280//!
281//! `Application`'s call functions need to copy the input slice into the function
282//! memory area and then copy the output to a `Vec`. This process is acceptable
283//! for large and complex functions but incurs a penalty for small ones.
284//! Therefore, for a certain subset of applications, Symjit can compile to a
285//! *fast function* and return a function pointer. Examples:
286//!
287//! ```rust
288//! use anyhow::Result;
289//! use symjit::{int, var, Compiler, FastFunc};
290//!
291//! fn test_fast() -> Result<()> {
292//! let x = var("x");
293//! let y = var("y");
294//! let z = var("z");
295//! let u = &x * &(&y - &z).pow(&int(2)); // x * (y - z)^2
296//!
297//! let mut comp = Compiler::new();
298//! let mut app = comp.compile(&[x, y, z], &[u])?;
299//! let f = app.fast_func()?;
300//!
301//! if let FastFunc::F3(f, _) = f {
302//! // f is of type extern "C" fn(f64, f64, f64) -> f64
303//! let res = f(3.0, 5.0, 9.0);
304//! println!("fast\t{:?}", &res);
305//! }
306//!
307//! Ok(())
308//! }
309//! ```
310//!
311//! The conditions for a fast function are:
312//!
313//! * A fast function can have 1 to 8 arguments.
314//! * No SIMD and no parameters.
315//! * It returns only a single value.
316//!
317//! If these conditions are met, you can generate a fast function by calling
318//! `app.fast_func()`, which returns a `Result<FastFunc>`. `FastFunc` is an
319//! enum with eight variants `F1`, `F2`, ..., `F8`, corresponding to functions
320//! with 1 to 8 arguments.
321//!
322//! # User-Defined Functions
323//!
324//! Symjit functions can call into user-defined Rust functions. Currently,
325//! only the following function signatures are accepted:
326//!
327//! ```rust
328//! pub type UnaryFunc = extern "C" fn(f64) -> f64;
329//! pub type BinaryFunc = extern "C" fn(f64, f64) -> f64;
330//! ```
331//!
332//! For example:
333//!
334//! ```rust
335//! extern "C" fn f(x: f64) -> f64 {
336//! x.exp()
337//! }
338//!
339//! extern "C" fn g(x: f64, y: f64) -> f64 {
340//! x.ln() * y
341//! }
342//!
343//! fn test_external() -> Result<()> {
344//! let x = Expr::var("x");
345//! let u = Expr::unary("f_", &x);
346//! let v = &x * &Expr::binary("g_", &u, &x);
347//!
348//! // v(x) = x * (ln(exp(x)) * x) = x ^ 3
349//!
350//! let mut comp = Compiler::new();
351//! comp.def_unary("f_", f);
352//! comp.def_binary("g_", g);
353//! let mut app = comp.compile(&[x], &[v])?;
354//! println!("{:?}", app.call(&[5.0])[0]);
355//!
356//! Ok(())
357//! }
358//! ```
359//!
360//! # Dynamic Expressions
361//!
362//! All the examples up to this point use static expressions. Of course, it
363//! would have been easier just to use Rust expressions for these examples!
364//! The main utility of Symjit for Rust is for dynamic code generation. Here,
365//! we provide a simple example to calculate pi using Viete's formula
366//! (<https://en.wikipedia.org/wiki/Vi%C3%A8te%27s_formula>):
367//!
368//! ```rust
369//! fn test_pi_viete(silent: bool) -> Result<()> {
370//! let x = var("x");
371//! let mut u = int(1);
372//!
373//! for i in 0..50 {
374//! let mut t = x.clone();
375//!
376//! for _ in 0..i {
377//! t = &x + &(&x * &t.sqrt());
378//! }
379//!
380//! u = &u * &t.sqrt();
381//! }
382//!
383//! // u has 1275 = 50 * 51 / 2 sqrt operations
384//! let mut app = Compiler::new().compile(&[x], &[&int(2) / &u])?;
385//! println!("pi = \t{:?}", app.call(&[0.5])[0]);
386//! Ok(())
387//! }
388//! ```
389//!
390//! # C-Interface
391//!
392//! In addition to `Compiler`, this crate provides a C-style interface
393//! used by the Python (<https://github.com/siravan/symjit>) and Julia
394//! (<https://github.com/siravan/Symjit.jl>) packages. This interface
395//! is composed of crate functions like `compile`, `execute`, and
396//! `ptr_states`,..., and is not needed by the Rust interface but can be
397//! used to link symjit to other programming languages.
398//!
399
400use std::collections::HashSet;
401use std::ffi::{c_char, CStr, CString};
402use std::fmt::Debug;
403use std::str::FromStr;
404
405mod allocator;
406mod amd;
407mod applet;
408mod arm;
409mod assembler;
410mod block;
411mod builder;
412mod code;
413mod compactor;
414pub mod compiler;
415mod complexify;
416mod composer;
417mod config;
418mod defuns;
419pub mod expr;
420mod generator;
421pub mod instruction;
422mod machine;
423mod matrix;
424mod memory;
425mod mir;
426mod model;
427mod node;
428mod parser;
429mod runnable;
430mod statement;
431mod symbol;
432mod types;
433mod utils;
434
435#[allow(non_upper_case_globals)]
436mod riscv64;
437
438use matrix::Matrix;
439use model::{CellModel, Program};
440
441pub use applet::Applet;
442pub use compiler::{Compiler, FastFunc, Translator};
443pub use composer::{Composer, Transliterator};
444pub use config::Config;
445pub use defuns::Defuns;
446pub use expr::{double, int, var, Expr};
447pub use instruction::{BuiltinSymbol, Instruction, Slot, SymbolicaModel};
448pub use num_complex::{Complex, ComplexFloat};
449pub use runnable::{Application, CompilerType};
450pub use types::{ElemType, Element};
451pub use utils::{Compiled, Storage};
452
453#[derive(Debug, Clone, Copy)]
454pub enum CompilerStatus {
455 Ok,
456 Incomplete,
457 InvalidUtf8,
458 ParseError,
459 InvalidCompiler,
460 CompilationError,
461}
462
463pub struct CompilerResult {
464 app: Option<Application>,
465 status: CompilerStatus,
466 msg: CString,
467}
468
469fn error_message<E: Debug>(msg: &str, err: E) -> CString {
470 let s = format!("{:?}: {:?}", msg, err);
471 CString::from_str(&s).unwrap()
472}
473
474/// Compiles a model.
475///
476/// * `model` is a json string encoding the model.
477/// * `ty` is the requested arch (amd, arm, native, or bytecode).
478/// * `opt`: compilation options.
479/// * `df`: user-defined functions.
480///
481/// # Safety
482/// * both model and ty are pointers to null-terminated strings.
483/// * The output is a raw pointer to a CompilerResults.
484///
485#[no_mangle]
486pub unsafe extern "C" fn compile(
487 model: *const c_char,
488 ty: *const c_char,
489 opt: u32,
490 df: *const Defuns,
491) -> *const CompilerResult {
492 let mut res = CompilerResult {
493 app: None,
494 status: CompilerStatus::Incomplete,
495 msg: CString::from_str("Success").unwrap(),
496 };
497
498 let model = unsafe {
499 match CStr::from_ptr(model).to_str() {
500 Ok(model) => model,
501 Err(msg) => {
502 res.status = CompilerStatus::InvalidUtf8;
503 res.msg = error_message("Invalid encoding", msg);
504 return Box::into_raw(Box::new(res)) as *const _;
505 }
506 }
507 };
508
509 let ty = unsafe {
510 match CStr::from_ptr(ty).to_str() {
511 Ok(ty) => ty,
512 Err(msg) => {
513 res.status = CompilerStatus::InvalidUtf8;
514 res.msg = error_message("Invalid compiler type", msg);
515 return Box::into_raw(Box::new(res)) as *const _;
516 }
517 }
518 };
519
520 let ml = match CellModel::load(model) {
521 Ok(ml) => ml,
522 Err(msg) => {
523 res.status = CompilerStatus::ParseError;
524 res.msg = error_message("Cannot parse JSON", msg);
525 return Box::into_raw(Box::new(res)) as *const _;
526 }
527 };
528
529 if let Ok(mut config) = Config::from_name(ty, opt) {
530 let df: Defuns = unsafe {
531 if df.is_null() {
532 Defuns::new()
533 } else {
534 (&*df).clone()
535 }
536 };
537
538 config.set_defuns(df);
539
540 let prog = match Program::new(&ml, config) {
541 Ok(prog) => prog,
542 Err(msg) => {
543 res.status = CompilerStatus::CompilationError;
544 res.msg = error_message("Compilation error (prog)", msg);
545 return Box::into_raw(Box::new(res)) as *const _;
546 }
547 };
548
549 let app = Application::new(prog, HashSet::new());
550
551 match app {
552 Ok(app) => {
553 res.status = CompilerStatus::Ok;
554 res.app = Some(app);
555 }
556 Err(msg) => {
557 res.status = CompilerStatus::CompilationError;
558 res.msg = error_message("Compilation error (app)", &msg);
559 }
560 }
561 } else {
562 res.status = CompilerStatus::InvalidCompiler;
563 res.msg = error_message("Config error", opt);
564 }
565
566 Box::into_raw(Box::new(res)) as *const _
567}
568
569/// Translates a Symbolica model.
570///
571/// * `json` is a json string encoding the output of `export_instructions`.
572/// * `ty` is the requested arch (amd, arm, native, or bytecode).
573/// * `opt`: compilation options.
574/// * `df`: user-defined functions (currently ignored).
575///
576/// # Safety
577/// * both model and ty are pointers to null-terminated strings.
578/// * The output is a raw pointer to a CompilerResults.
579///
580#[no_mangle]
581pub unsafe extern "C" fn translate(
582 json: *const c_char,
583 ty: *const c_char,
584 opt: u32,
585 df: *mut Defuns,
586 num_params: usize,
587) -> *const CompilerResult {
588 let mut res = CompilerResult {
589 app: None,
590 status: CompilerStatus::Incomplete,
591 msg: CString::from_str("Success").unwrap(),
592 };
593
594 let json = unsafe {
595 match CStr::from_ptr(json).to_str() {
596 Ok(json) => json,
597 Err(msg) => {
598 res.status = CompilerStatus::InvalidUtf8;
599 res.msg = error_message("Invalid encoding", msg);
600 return Box::into_raw(Box::new(res)) as *const _;
601 }
602 }
603 };
604
605 let ty = unsafe {
606 match CStr::from_ptr(ty).to_str() {
607 Ok(ty) => ty,
608 Err(msg) => {
609 res.status = CompilerStatus::InvalidUtf8;
610 res.msg = error_message("Invalid compiler type", msg);
611 return Box::into_raw(Box::new(res)) as *const _;
612 }
613 }
614 };
615
616 if let Ok(config) = Config::from_name(ty, opt) {
617 let df: Defuns = unsafe {
618 if df.is_null() {
619 Defuns::new()
620 } else {
621 (&*df).clone()
622 }
623 };
624
625 let mut comp = Compiler::with_config(config);
626 let app = comp.translate(json.to_string(), df, num_params);
627
628 match app {
629 Ok(app) => {
630 res.app = Some(app);
631 res.status = CompilerStatus::Ok;
632 }
633 Err(msg) => {
634 res.status = CompilerStatus::InvalidCompiler;
635 res.msg = error_message("Compilation error", msg);
636 }
637 }
638 } else {
639 res.status = CompilerStatus::InvalidCompiler;
640 res.msg = error_message("Config error", opt);
641 }
642
643 Box::into_raw(Box::new(res)) as *const _
644}
645
646/// Checks the status of a `CompilerResult`.
647///
648/// Returns a null-terminated string representing the status message.
649///
650/// # Safety
651/// it is the responsibility of the calling function to ensure
652/// that q points to a valid CompilerResult.
653///
654#[no_mangle]
655pub unsafe extern "C" fn check_status(q: *const CompilerResult) -> *const c_char {
656 let q: &CompilerResult = unsafe { &*q };
657 q.msg.as_ptr() as *const _
658}
659
660/// Checks the status of a `CompilerResult`.
661///
662/// Returns a null-terminated string representing the status message.
663///
664/// # Safety
665/// it is the responsibility of the calling function to ensure
666/// that q points to a valid CompilerResult.
667///
668#[no_mangle]
669pub unsafe extern "C" fn save(q: *const CompilerResult, file: *const c_char) -> bool {
670 let q: &CompilerResult = unsafe { &*q };
671 let file = unsafe {
672 match CStr::from_ptr(file).to_str() {
673 Ok(file) => file,
674 Err(_) => return false,
675 }
676 };
677
678 if let Some(app) = &q.app {
679 if let Ok(mut fs) = std::fs::File::create(file) {
680 app.save(&mut fs).is_ok()
681 } else {
682 false
683 }
684 } else {
685 false
686 }
687}
688
689/// Checks the status of a `CompilerResult`.
690///
691/// Returns a null-terminated string representing the status message.
692///
693/// # Safety
694/// it is the responsibility of the calling function to ensure
695/// that q points to a valid CompilerResult.
696///
697#[no_mangle]
698pub unsafe extern "C" fn load(file: *const c_char) -> *const CompilerResult {
699 let mut res = CompilerResult {
700 app: None,
701 status: CompilerStatus::Incomplete,
702 msg: CString::from_str("Success").unwrap(),
703 };
704
705 let file = unsafe {
706 match CStr::from_ptr(file).to_str() {
707 Ok(file) => file,
708 Err(_) => return Box::into_raw(Box::new(res)) as *const _,
709 }
710 };
711
712 let fs = std::fs::File::open(file);
713
714 match fs {
715 Ok(mut fs) => match Application::load(&mut fs) {
716 Ok(app) => {
717 res.app = Some(app);
718 res.status = CompilerStatus::Ok;
719 }
720 Err(err) => {
721 res.status = CompilerStatus::ParseError;
722 res.msg = error_message("File parse error", &err);
723 }
724 },
725 Err(err) => {
726 res.msg = error_message("File I/O error", &err);
727 }
728 }
729
730 Box::into_raw(Box::new(res)) as *const _
731}
732
733/// Checks the status of a `CompilerResult`.
734///
735/// Returns a null-terminated string representing the status message.
736///
737/// # Safety
738/// it is the responsibility of the calling function to ensure
739/// that q points to a valid CompilerResult.
740///
741#[no_mangle]
742pub unsafe extern "C" fn get_config(q: *const CompilerResult) -> usize {
743 let q: &CompilerResult = unsafe { &*q };
744
745 match &q.app {
746 Some(app) => {
747 let config = app.prog.config();
748
749 let ty: usize = match config.ty {
750 CompilerType::Native => 0,
751 CompilerType::Amd => 1,
752 CompilerType::AmdAVX => 2,
753 CompilerType::AmdSSE => 3,
754 CompilerType::Arm => 4,
755 CompilerType::RiscV => 5,
756 CompilerType::ByteCode => 6,
757 CompilerType::Debug => 7,
758 };
759
760 (config.opt as usize) | (ty << 32)
761 }
762 None => 0,
763 }
764}
765
766/// Returns the number of state variables.
767///
768/// # Safety
769/// it is the responsibility of the calling function to ensure
770/// that q points to a valid CompilerResult.
771///
772#[no_mangle]
773pub unsafe extern "C" fn count_states(q: *const CompilerResult) -> usize {
774 let q: &CompilerResult = unsafe { &*q };
775 if let Some(app) = &q.app {
776 app.count_states
777 } else {
778 0
779 }
780}
781
782/// Returns the number of parameters.
783///
784/// # Safety
785/// it is the responsibility of the calling function to ensure
786/// that q points to a valid CompilerResult.
787///
788#[no_mangle]
789pub unsafe extern "C" fn count_params(q: *const CompilerResult) -> usize {
790 let q: &CompilerResult = unsafe { &*q };
791 if let Some(app) = &q.app {
792 app.count_params
793 } else {
794 0
795 }
796}
797
798/// Returns the number of observables (output).
799///
800/// # Safety
801/// it is the responsibility of the calling function to ensure
802/// that q points to a valid CompilerResult.
803///
804#[no_mangle]
805pub unsafe extern "C" fn count_obs(q: *const CompilerResult) -> usize {
806 let q: &CompilerResult = unsafe { &*q };
807 if let Some(app) = &q.app {
808 app.count_obs
809 } else {
810 0
811 }
812}
813
814/// Returns the number of differential equations.
815///
816/// Generally, it should be the same as the number of states.
817///
818/// # Safety
819/// it is the responsibility of the calling function to ensure
820/// that q points to a valid CompilerResult.
821///
822#[no_mangle]
823pub unsafe extern "C" fn count_diffs(q: *const CompilerResult) -> usize {
824 let q: &CompilerResult = unsafe { &*q };
825 if let Some(app) = &q.app {
826 app.count_diffs
827 } else {
828 0
829 }
830}
831
832/// Deprecated. Previously used for interfacing to DifferentialEquation.jl. It is
833/// replaced with <https://github.com/siravan/SymJit.jl>.
834///
835/// # Safety
836///
837/// Deprecated. No effects.
838#[no_mangle]
839pub unsafe extern "C" fn run(
840 _q: *mut CompilerResult,
841 _du: *mut f64,
842 _u: *const f64,
843 _ns: usize,
844 _p: *const f64,
845 _np: usize,
846 _t: f64,
847) -> bool {
848 // let q: &mut CompilerResult = unsafe { &mut *q };
849
850 // if let Some(app) = &mut q.app {
851 // if app.count_states != ns || app.count_params != np {
852 // return false;
853 // }
854
855 // let du: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(du, ns) };
856 // let u: &[f64] = unsafe { std::slice::from_raw_parts(u, ns) };
857 // let p: &[f64] = unsafe { std::slice::from_raw_parts(p, np) };
858 // app.call(du, u, p, t);
859 // true
860 // } else {
861 // false
862 // }
863 false
864}
865
866/// Executes the compiled function.
867///
868/// The calling routine should fill the states and parameters before
869/// calling `execute`. The result populates obs or diffs (as defined in
870/// model passed to `compile`).
871///
872/// # Safety
873/// it is the responsibility of the calling function to ensure
874/// that q points to a valid CompilerResult.
875///
876#[no_mangle]
877pub unsafe extern "C" fn execute(q: *mut CompilerResult) -> bool {
878 let q: &mut CompilerResult = unsafe { &mut *q };
879
880 if let Some(app) = &mut q.app {
881 app.exec();
882 true
883 } else {
884 false
885 }
886}
887
888/// Executes the compiled function `n` times (vectorized).
889///
890/// The calling function provides `buf`, which is a k x n matrix of doubles,
891/// where k is equal to the `maximum(count_states, count_obs)`. The calling
892/// funciton fills the first `count_states` rows of buf. The result is returned
893/// in the first count_obs rows of buf.
894///
895/// # Safety
896/// it is the responsibility of the calling function to ensure
897/// that q points to a valid CompilerResult.
898///
899/// In addition, buf should points to a valid matrix of correct size.
900///
901#[no_mangle]
902pub unsafe extern "C" fn execute_vectorized(
903 q: *mut CompilerResult,
904 buf: *mut f64,
905 n: usize,
906) -> bool {
907 let q: &mut CompilerResult = unsafe { &mut *q };
908
909 if let Some(app) = &mut q.app {
910 let h = usize::max(app.count_states, app.count_obs);
911 let p: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(buf, h * n) };
912 let mut states = Matrix::from_buf(p, h, n);
913 let p: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(buf, h * n) };
914 let mut obs = Matrix::from_buf(p, h, n);
915 app.exec_vectorized(&mut states, &mut obs);
916 true
917 } else {
918 false
919 }
920}
921
922/// Evaluates the compiled function. This is for Symbolica compatibility.
923///
924/// # Safety
925/// it is the responsibility of the calling function to ensure
926/// that q points to a valid CompilerResult.
927///
928#[no_mangle]
929pub unsafe extern "C" fn evaluate(
930 q: *mut CompilerResult,
931 args: *const f64,
932 nargs: usize,
933 outs: *mut f64,
934 nouts: usize,
935) -> bool {
936 let q: &mut CompilerResult = unsafe { &mut *q };
937
938 if let Some(app) = &mut q.app {
939 let args: &[f64] = unsafe { std::slice::from_raw_parts(args, nargs) };
940 let outs: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(outs, nouts) };
941 app.evaluate(args, outs);
942 true
943 } else {
944 false
945 }
946}
947
948/// Evaluates the compiled function. This is for Symbolica compatibility.
949///
950/// # Safety
951/// it is the responsibility of the calling function to ensure
952/// that q points to a valid CompilerResult.
953///
954#[no_mangle]
955pub unsafe extern "C" fn evaluate_matrix(
956 q: *mut CompilerResult,
957 args: *const f64,
958 nargs: usize,
959 outs: *mut f64,
960 nouts: usize,
961) -> bool {
962 let q: &mut CompilerResult = unsafe { &mut *q };
963
964 if let Some(app) = &mut q.app {
965 if app.count_params == 0 {
966 return false;
967 }
968
969 let args: &[f64] = unsafe { std::slice::from_raw_parts(args, nargs) };
970 let outs: &mut [f64] = unsafe { std::slice::from_raw_parts_mut(outs, nouts) };
971 let n = nargs / app.count_params;
972 app.evaluate_matrix(args, outs, n);
973 true
974 } else {
975 false
976 }
977}
978
979/// Returns a pointer to the state variables (`count_states` doubles).
980///
981/// The function calling `execute` should write the state variables in this area.
982///
983/// # Safety
984/// it is the responsibility of the calling function to ensure
985/// that q points to a valid CompilerResult.
986///
987#[no_mangle]
988pub unsafe extern "C" fn ptr_states(q: *mut CompilerResult) -> *mut f64 {
989 let q: &mut CompilerResult = unsafe { &mut *q };
990 if let Some(app) = &mut q.app {
991 if let Some(f) = &mut app.compiled {
992 &mut f.mem_mut()[app.first_state] as *mut f64
993 } else {
994 &mut app.bytecode.mem_mut()[app.first_state] as *mut f64
995 }
996 } else {
997 std::ptr::null_mut()
998 }
999}
1000
1001/// Returns a pointer to the parameters (`count_params` doubles).
1002///
1003/// The function calling `execute` should write the parameters in this area.
1004///
1005/// # Safety
1006/// it is the responsibility of the calling function to ensure
1007/// that q points to a valid CompilerResult.
1008///
1009#[no_mangle]
1010pub unsafe extern "C" fn ptr_params(q: *mut CompilerResult) -> *mut f64 {
1011 let q: &mut CompilerResult = unsafe { &mut *q };
1012 if let Some(app) = &mut q.app {
1013 //&mut app.compiled.mem_mut()[app.first_param] as *mut f64
1014 &mut app.params[app.first_param] as *mut f64
1015 } else {
1016 std::ptr::null_mut()
1017 }
1018}
1019
1020/// Returns a pointer to the observables (`count_obs` doubles).
1021///
1022/// The function calling `execute` reads the observables from this area.
1023///
1024/// # Safety
1025/// it is the responsibility of the calling function to ensure
1026/// that q points to a valid CompilerResult.
1027///
1028#[no_mangle]
1029pub unsafe extern "C" fn ptr_obs(q: *mut CompilerResult) -> *const f64 {
1030 let q: &CompilerResult = unsafe { &*q };
1031 if let Some(app) = &q.app {
1032 if let Some(f) = &app.compiled {
1033 &f.mem()[app.first_obs] as *const f64
1034 } else {
1035 &app.bytecode.mem()[app.first_obs] as *const f64
1036 }
1037 } else {
1038 std::ptr::null()
1039 }
1040}
1041
1042/// Returns a pointer to the differentials (`count_diffs` doubles).
1043///
1044/// The function calling `execute` reads the differentials from this area.
1045///
1046/// Note: whether the output is returned as observables or differentials is
1047/// defined in the model.
1048///
1049/// # Safety
1050/// it is the responsibility of the calling function to ensure
1051/// that q points to a valid CompilerResult.
1052///
1053#[no_mangle]
1054pub unsafe extern "C" fn ptr_diffs(q: *mut CompilerResult) -> *const f64 {
1055 let q: &CompilerResult = unsafe { &*q };
1056 if let Some(app) = &q.app {
1057 if let Some(f) = &app.compiled {
1058 &f.mem()[app.first_diff] as *const f64
1059 } else {
1060 &app.bytecode.mem()[app.first_diff] as *const f64
1061 }
1062 } else {
1063 std::ptr::null()
1064 }
1065}
1066
1067/// Dumps the compiled binary code to a file (`name`).
1068///
1069/// This function is useful for debugging but is not necessary for
1070/// normal operations.
1071///
1072/// # Safety
1073/// it is the responsibility of the calling function to ensure
1074/// that q points to a valid CompilerResult.
1075///
1076#[no_mangle]
1077pub unsafe extern "C" fn dump(
1078 q: *mut CompilerResult,
1079 name: *const c_char,
1080 what: *const c_char,
1081) -> bool {
1082 let q: &mut CompilerResult = unsafe { &mut *q };
1083 if let Some(app) = &mut q.app {
1084 let name = unsafe { CStr::from_ptr(name).to_str().unwrap() };
1085 let what = unsafe { CStr::from_ptr(what).to_str().unwrap() };
1086 app.dump(name, what)
1087 } else {
1088 false
1089 }
1090}
1091
1092/// Deallocates the CompilerResult pointed by `q`.
1093///
1094/// # Safety
1095/// it is the responsibility of the calling function to ensure
1096/// that q points to a valid CompilerResult and that after
1097/// calling this function, q is invalid and should not
1098/// be used anymore.
1099///
1100#[no_mangle]
1101pub unsafe extern "C" fn finalize(q: *mut CompilerResult) {
1102 if !q.is_null() {
1103 let _ = unsafe { Box::from_raw(q) };
1104 }
1105}
1106
1107/// Returns a null-terminated string representing the version.
1108///
1109/// Used for debugging.
1110///
1111/// # Safety
1112/// the return value is a null-terminated string that should not
1113/// be freed.
1114///
1115#[no_mangle]
1116pub unsafe extern "C" fn info() -> *const c_char {
1117 // let msg = c"symjit 1.3.3";
1118 let msg = CString::new(env!("CARGO_PKG_VERSION")).unwrap();
1119 msg.into_raw() as *const _
1120}
1121
1122/// Returns a pointer to the fast function if one can be compiled.
1123///
1124/// # Safety
1125/// 1. If the model cannot be compiled to a fast function, NULL is returned.
1126/// 2. A fast function code memory is leaked and is not deallocated.
1127///
1128#[no_mangle]
1129pub unsafe extern "C" fn fast_func(q: *mut CompilerResult) -> *const usize {
1130 let q: &mut CompilerResult = unsafe { &mut *q };
1131 if let Some(app) = &mut q.app {
1132 match app.get_fast() {
1133 Some(f) => f as *const usize,
1134 None => std::ptr::null(),
1135 }
1136 } else {
1137 std::ptr::null()
1138 }
1139}
1140
1141/// Interface for Sympy's LowLevelCallable.
1142///
1143/// # Safety
1144/// 1. If the model cannot be compiled to a fast function, NULL is returned.
1145/// 2. The resulting function lives as long as q does and should not be stored
1146/// separately.
1147///
1148#[no_mangle]
1149pub unsafe extern "C" fn callable_quad(n: usize, xx: *const f64, q: *mut CompilerResult) -> f64 {
1150 let q: &mut CompilerResult = unsafe { &mut *q };
1151 let xx: &[f64] = unsafe { std::slice::from_raw_parts(xx, n) };
1152
1153 if let Some(app) = &mut q.app {
1154 app.exec_callable(xx)
1155 } else {
1156 f64::NAN
1157 }
1158}
1159
1160/// Interface for Sympy's LowLevelCallable.
1161///
1162/// # Safety
1163/// 1. If the model cannot be compiled to a fast function, NULL is returned.
1164/// 2. The resulting function lives as long as q does and should not be stored
1165/// separately.
1166///
1167#[no_mangle]
1168pub unsafe extern "C" fn callable_quad_fast(n: usize, xx: *const f64, f: *const usize) -> f64 {
1169 let xx: &[f64] = unsafe { std::slice::from_raw_parts(xx, n) };
1170
1171 match n {
1172 0 => {
1173 let f: fn() -> f64 = unsafe { std::mem::transmute(f) };
1174 f()
1175 }
1176 1 => {
1177 let f: fn(f64) -> f64 = unsafe { std::mem::transmute(f) };
1178 f(xx[0])
1179 }
1180 2 => {
1181 let f: fn(f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1182 f(xx[0], xx[1])
1183 }
1184 3 => {
1185 let f: fn(f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1186 f(xx[0], xx[1], xx[2])
1187 }
1188 4 => {
1189 let f: fn(f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1190 f(xx[0], xx[1], xx[2], xx[3])
1191 }
1192 5 => {
1193 let f: fn(f64, f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1194 f(xx[0], xx[1], xx[2], xx[3], xx[4])
1195 }
1196 6 => {
1197 let f: fn(f64, f64, f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1198 f(xx[0], xx[1], xx[2], xx[3], xx[4], xx[5])
1199 }
1200 7 => {
1201 let f: fn(f64, f64, f64, f64, f64, f64, f64) -> f64 = unsafe { std::mem::transmute(f) };
1202 f(xx[0], xx[1], xx[2], xx[3], xx[4], xx[5], xx[6])
1203 }
1204 _ => {
1205 panic!("too many parameters for a fast func");
1206 }
1207 }
1208}
1209
1210/// Interface for Sympy's LowLevelCallable (image filtering).
1211///
1212/// # Safety
1213/// 1. If the model cannot be compiled to a fast function, NULL is returned.
1214/// 2. The resulting function lives as long as q does and should not be stored
1215/// separately.
1216///
1217#[no_mangle]
1218pub unsafe extern "C" fn callable_filter(
1219 buffer: *const f64,
1220 filter_size: usize,
1221 return_value: *mut f64,
1222 q: *mut CompilerResult,
1223) -> i64 {
1224 let q: &mut CompilerResult = unsafe { &mut *q };
1225 let xx: &[f64] = unsafe { std::slice::from_raw_parts(buffer, filter_size) };
1226
1227 if let Some(app) = &mut q.app {
1228 let p: &mut f64 = unsafe { &mut *return_value };
1229 *p = app.exec_callable(xx);
1230 1
1231 } else {
1232 0
1233 }
1234}
1235
1236/************************************************/
1237
1238/// Creates an empty Matrix (a 2d array).
1239///
1240/// # Safety
1241/// It returns a pointer to the allocated Matrix, which needs to be
1242/// deallocated eventually.
1243///
1244#[no_mangle]
1245pub unsafe extern "C" fn create_matrix<'a>() -> *const Matrix<'a> {
1246 let mat = Matrix::new();
1247 Box::into_raw(Box::new(mat)) as *const Matrix
1248}
1249
1250/// Finalizes (deallocates) the Matrix.
1251///
1252/// # Safety
1253/// 1, mat should point to a valid Matrix object created by create_matrix.
1254/// 2. After finalize_matrix is called, mat is invalid.
1255///
1256#[no_mangle]
1257pub unsafe extern "C" fn finalize_matrix(mat: *mut Matrix) {
1258 if !mat.is_null() {
1259 let _ = unsafe { Box::from_raw(mat) };
1260 }
1261}
1262
1263/// Adds a row to the Matrix.
1264///
1265/// # Safety
1266/// 1, mat should point to a valid Matrix object created by create_matrix.
1267/// 2. v should point to a valid array of doubles of length at least n.
1268/// 3. v should remains valid for the lifespan of mat.
1269///
1270#[no_mangle]
1271pub unsafe extern "C" fn add_row(mat: *mut Matrix, v: *mut f64, n: usize) {
1272 let mat: &mut Matrix = unsafe { &mut *mat };
1273 mat.add_row(v, n);
1274}
1275
1276/// Executes (runs) the matrix model encoded by `q`.
1277///
1278/// # Safety
1279/// 1, q should point to a valid CompilerResult object.
1280/// 2. states should point to a valid Matrix of at least count_states rows.
1281/// 3. obs should point to a valid Matrix of at least count_obs rows.
1282///
1283#[no_mangle]
1284pub unsafe extern "C" fn execute_matrix(
1285 q: *mut CompilerResult,
1286 states: *mut Matrix,
1287 obs: *mut Matrix,
1288) -> bool {
1289 let q: &mut CompilerResult = unsafe { &mut *q };
1290 let states: &mut Matrix = unsafe { &mut *states };
1291 let obs: &mut Matrix = unsafe { &mut *obs };
1292
1293 if let Some(app) = &mut q.app {
1294 app.exec_vectorized(states, obs);
1295 true
1296 } else {
1297 false
1298 }
1299}
1300
1301/************************************************/
1302
1303/// Creates an empty `Defun` (a list of user-defined functions).
1304///
1305/// `Defuns` are used to pass user-defined functions (either Python
1306/// functions or symjit-compiled functions).
1307///
1308/// # Safety
1309/// It returns a pointer to the allocated Defun, which needs to be
1310/// deallocated eventually.
1311///
1312#[no_mangle]
1313pub unsafe extern "C" fn create_defuns() -> *const Defuns {
1314 let df = Defuns::new();
1315 Box::into_raw(Box::new(df)) as *const Defuns
1316}
1317
1318/// Finalizes (deallocates) a `Defun`.
1319///
1320/// # Safety
1321/// 1, df should point to a valid Defun object created by create_defuns.
1322/// 2. After finalize_defun is called, df is invalid.
1323///
1324#[no_mangle]
1325pub unsafe extern "C" fn finalize_defuns(df: *mut Defuns) {
1326 // if !df.is_null() {
1327 // let _ = unsafe { Box::from_raw(df) };
1328 // }
1329}
1330
1331/// Adds a new function to a `Defun`.
1332///
1333/// # Safety
1334/// 1, df should point to a valid Defun object created by create_defun.
1335/// 2. name should be a valid utf8 string.
1336/// 3. p should point to a valid C-styple function pointer that accepts
1337/// num_args double arguments.
1338///
1339#[no_mangle]
1340pub unsafe extern "C" fn add_func(
1341 df: *mut Defuns,
1342 name: *const c_char,
1343 p: *const usize,
1344 num_args: usize,
1345) {
1346 let df: &mut Defuns = unsafe { &mut *df };
1347 let name = unsafe { CStr::from_ptr(name).to_str().unwrap() };
1348 df.add_func(name, p, num_args);
1349}