Skip to main content

symjit_bridge/
runners.rs

1use crate::{compile, compile_string};
2use anyhow::Result;
3use symbolica::evaluate::ExpressionEvaluator;
4use symjit::Storage;
5pub use symjit::{Application, Complex, Config, Defuns};
6
7fn flatten_vec<T>(v: &[T]) -> &[f64] {
8    let n = v.len();
9    let p: *const f64 = unsafe { std::mem::transmute(v.as_ptr()) };
10    let q: &[f64] = unsafe {
11        std::slice::from_raw_parts(p, n * std::mem::size_of::<T>() / std::mem::size_of::<f64>())
12    };
13    q
14}
15
16fn flatten_vec_mut<T>(v: &mut [T]) -> &mut [f64] {
17    let n = v.len();
18    let p: *mut f64 = unsafe { std::mem::transmute(v.as_mut_ptr()) };
19    let q: &mut [f64] = unsafe {
20        std::slice::from_raw_parts_mut(p, n * std::mem::size_of::<T>() / std::mem::size_of::<f64>())
21    };
22    q
23}
24
25/********************* CompiledRealRunner ************************/
26
27pub struct CompiledRealRunner {
28    app: Application,
29}
30
31impl CompiledRealRunner {
32    pub fn compile(ev: &ExpressionEvaluator<f64>, config: Config) -> Result<Self> {
33        Self::compile_with_funcs(ev, config, &Defuns::new(), 0)
34    }
35
36    pub fn compile_with_funcs(
37        ev: &ExpressionEvaluator<f64>,
38        mut config: Config,
39        df: &Defuns,
40        num_params: usize,
41    ) -> Result<Self> {
42        config.set_complex(false);
43        config.set_simd(true);
44        let app = compile(&ev, config, df, num_params)?;
45        Ok(Self { app })
46    }
47
48    pub fn compile_string(model: String, config: Config) -> Result<Self> {
49        Self::compile_string_with_funcs(model, config, &Defuns::new(), 0)
50    }
51
52    pub fn compile_string_with_funcs(
53        model: String,
54        mut config: Config,
55        df: &Defuns,
56        num_params: usize,
57    ) -> Result<Self> {
58        config.set_complex(false);
59        config.set_simd(true);
60        let app = compile_string(model, config, df, num_params)?;
61        Ok(Self { app })
62    }
63
64    pub fn evaluate(&mut self, args: &[f64], outs: &mut [f64]) {
65        let n = args.len() / self.app.count_params;
66        assert!(outs.len() / self.app.count_obs >= n);
67        self.app.evaluate_matrix(args, outs, n);
68    }
69
70    pub fn save(&self, file: &str) -> Result<()> {
71        let mut fs = std::fs::File::create(file)?;
72        self.app.save(&mut fs)
73    }
74
75    pub fn load(file: &str) -> Result<Self> {
76        let mut fs = std::fs::File::open(file)?;
77        let app = Application::load(&mut fs)?;
78        Ok(Self { app })
79    }
80}
81
82/************************ CompiledComplexRunner ***************************/
83
84pub struct CompiledComplexRunner {
85    app: Application,
86}
87
88impl CompiledComplexRunner {
89    pub fn compile(ev: &ExpressionEvaluator<Complex<f64>>, config: Config) -> Result<Self> {
90        Self::compile_with_funcs(ev, config, &Defuns::new(), 0)
91    }
92
93    pub fn compile_with_funcs(
94        ev: &ExpressionEvaluator<Complex<f64>>,
95        mut config: Config,
96        df: &Defuns,
97        num_params: usize,
98    ) -> Result<Self> {
99        config.set_complex(true);
100        config.set_simd(true);
101        let app = compile(&ev, config, df, num_params)?;
102        Ok(CompiledComplexRunner { app })
103    }
104
105    pub fn compile_string(model: String, config: Config) -> Result<Self> {
106        Self::compile_string_with_funcs(model, config, &Defuns::new(), 0)
107    }
108
109    pub fn compile_string_with_funcs(
110        model: String,
111        mut config: Config,
112        df: &Defuns,
113        num_params: usize,
114    ) -> Result<Self> {
115        config.set_complex(true);
116        config.set_simd(true);
117        let app = compile_string(model, config, df, num_params)?;
118        Ok(CompiledComplexRunner { app })
119    }
120
121    pub fn evaluate(&mut self, args: &[Complex<f64>], outs: &mut [Complex<f64>]) {
122        let n = (2 * args.len()) / self.app.count_params;
123        assert!(2 * outs.len() / self.app.count_obs >= n);
124        self.app.evaluate_complex_matrix(args, outs, n);
125    }
126
127    pub fn save(&self, file: &str) -> Result<()> {
128        let mut fs = std::fs::File::create(file)?;
129        self.app.save(&mut fs)
130    }
131
132    pub fn load(file: &str) -> Result<Self> {
133        let mut fs = std::fs::File::open(file)?;
134        let app = Application::load(&mut fs)?;
135        Ok(Self { app })
136    }
137}
138
139/********************* InterpretedRealRunner ************************/
140
141pub struct InterpretedRealRunner {
142    app: Application,
143}
144
145impl InterpretedRealRunner {
146    pub fn compile(ev: &ExpressionEvaluator<f64>, config: Config) -> Result<Self> {
147        Self::compile_with_funcs(ev, config, &Defuns::new(), 0)
148    }
149
150    pub fn compile_with_funcs(
151        ev: &ExpressionEvaluator<f64>,
152        config: Config,
153        df: &Defuns,
154        num_params: usize,
155    ) -> Result<Self> {
156        let mut c = Config::from_name("bytecode", config.opt)?;
157        c.set_complex(false);
158        c.set_simd(false);
159        let app = compile(&ev, c, df, num_params)?;
160        Ok(Self { app })
161    }
162
163    pub fn compile_string(model: String, config: Config) -> Result<Self> {
164        Self::compile_string_with_funcs(model, config, &Defuns::new(), 0)
165    }
166
167    pub fn compile_string_with_funcs(
168        model: String,
169        config: Config,
170        df: &Defuns,
171        num_params: usize,
172    ) -> Result<Self> {
173        let mut c = Config::from_name("bytecode", config.opt)?;
174        c.set_complex(false);
175        c.set_simd(false);
176        let app = compile_string(model, c, df, num_params)?;
177        Ok(Self { app })
178    }
179
180    pub fn evaluate(&mut self, args: &[f64], outs: &mut [f64]) {
181        let n = args.len() / self.app.count_params;
182        assert!(outs.len() / self.app.count_obs >= n);
183        self.app.evaluate_matrix_bytecode(args, outs, n);
184    }
185
186    pub fn save(&self, file: &str) -> Result<()> {
187        let mut fs = std::fs::File::create(file)?;
188        self.app.save(&mut fs)
189    }
190
191    pub fn load(file: &str) -> Result<Self> {
192        let mut fs = std::fs::File::open(file)?;
193        let app = Application::load(&mut fs)?;
194        Ok(Self { app })
195    }
196}
197
198/********************* InterpretedComplexRunner ************************/
199
200pub struct InterpretedComplexRunner {
201    app: Application,
202}
203
204impl InterpretedComplexRunner {
205    pub fn compile(ev: &ExpressionEvaluator<Complex<f64>>, config: Config) -> Result<Self> {
206        Self::compile_with_funcs(ev, config, &Defuns::new(), 0)
207    }
208
209    pub fn compile_with_funcs(
210        ev: &ExpressionEvaluator<Complex<f64>>,
211        config: Config,
212        df: &Defuns,
213        num_params: usize,
214    ) -> Result<Self> {
215        let mut c = Config::from_name("bytecode", config.opt)?;
216        c.set_complex(true);
217        c.set_simd(false);
218        let app = compile(&ev, c, df, num_params)?;
219        Ok(Self { app })
220    }
221
222    pub fn compile_string(model: String, config: Config) -> Result<Self> {
223        Self::compile_string_with_funcs(model, config, &Defuns::new(), 0)
224    }
225
226    pub fn compile_string_with_funcs(
227        model: String,
228        config: Config,
229        df: &Defuns,
230        num_params: usize,
231    ) -> Result<Self> {
232        let mut c = Config::from_name("bytecode", config.opt)?;
233        c.set_complex(true);
234        c.set_simd(false);
235        let app = compile_string(model, c, df, num_params)?;
236        Ok(Self { app })
237    }
238
239    pub fn evaluate(&mut self, args: &[Complex<f64>], outs: &mut [Complex<f64>]) {
240        let n = (2 * args.len()) / self.app.count_params;
241        assert!((2 * outs.len()) / self.app.count_obs >= n);
242
243        let args = flatten_vec(args);
244        let outs = flatten_vec_mut(outs);
245
246        self.app.evaluate_matrix_bytecode(args, outs, n);
247    }
248
249    pub fn save(&self, file: &str) -> Result<()> {
250        let mut fs = std::fs::File::create(file)?;
251        self.app.save(&mut fs)
252    }
253
254    pub fn load(file: &str) -> Result<Self> {
255        let mut fs = std::fs::File::open(file)?;
256        let app = Application::load(&mut fs)?;
257        Ok(Self { app })
258    }
259}