Skip to main content

symjit_bridge/
runners.rs

1use crate::{compile, compile_string};
2use anyhow::Result;
3use num_complex::Complex;
4use symbolica::evaluate::ExpressionEvaluator;
5use symjit::Storage;
6pub use symjit::{Application, Config, Defuns};
7
8fn flatten_vec<T>(v: &[T]) -> &[f64] {
9    let n = v.len();
10    let p: *const f64 = unsafe { std::mem::transmute(v.as_ptr()) };
11    let q: &[f64] = unsafe {
12        std::slice::from_raw_parts(p, n * std::mem::size_of::<T>() / std::mem::size_of::<f64>())
13    };
14    q
15}
16
17fn flatten_vec_mut<T>(v: &mut [T]) -> &mut [f64] {
18    let n = v.len();
19    let p: *mut f64 = unsafe { std::mem::transmute(v.as_mut_ptr()) };
20    let q: &mut [f64] = unsafe {
21        std::slice::from_raw_parts_mut(p, n * std::mem::size_of::<T>() / std::mem::size_of::<f64>())
22    };
23    q
24}
25
26/********************* CompiledRealRunner ************************/
27
28pub struct CompiledRealRunner {
29    app: Application,
30}
31
32impl CompiledRealRunner {
33    pub fn compile(ev: &ExpressionEvaluator<f64>, config: Config) -> Result<Self> {
34        Self::compile_with_funcs(ev, config, &Defuns::new(), 0)
35    }
36
37    pub fn compile_with_funcs(
38        ev: &ExpressionEvaluator<f64>,
39        mut config: Config,
40        df: &Defuns,
41        num_params: usize,
42    ) -> Result<Self> {
43        config.set_complex(false);
44        config.set_simd(true);
45        let app = compile(&ev, config, df, num_params)?;
46        Ok(Self { app })
47    }
48
49    pub fn compile_string(model: String, config: Config) -> Result<Self> {
50        Self::compile_string_with_funcs(model, config, &Defuns::new(), 0)
51    }
52
53    pub fn compile_string_with_funcs(
54        model: String,
55        mut config: Config,
56        df: &Defuns,
57        num_params: usize,
58    ) -> Result<Self> {
59        config.set_complex(false);
60        config.set_simd(true);
61        let app = compile_string(model, config, df, num_params)?;
62        Ok(Self { app })
63    }
64
65    pub fn evaluate(&mut self, args: &[f64], outs: &mut [f64]) {
66        let n = args.len() / self.app.count_params;
67        assert!(outs.len() / self.app.count_obs >= n);
68        self.app.evaluate_matrix(args, outs, n);
69    }
70
71    pub fn save(&self, file: &str) -> Result<()> {
72        let mut fs = std::fs::File::create(file)?;
73        self.app.save(&mut fs)
74    }
75
76    pub fn load(file: &str) -> Result<Self> {
77        let mut fs = std::fs::File::open(file)?;
78        let app = Application::load(&mut fs)?;
79        Ok(Self { app })
80    }
81}
82
83/************************ CompiledComplexRunner ***************************/
84
85pub struct CompiledComplexRunner {
86    app: Application,
87}
88
89impl CompiledComplexRunner {
90    pub fn compile(ev: &ExpressionEvaluator<Complex<f64>>, config: Config) -> Result<Self> {
91        Self::compile_with_funcs(ev, config, &Defuns::new(), 0)
92    }
93
94    pub fn compile_with_funcs(
95        ev: &ExpressionEvaluator<Complex<f64>>,
96        mut config: Config,
97        df: &Defuns,
98        num_params: usize,
99    ) -> Result<Self> {
100        config.set_complex(true);
101        config.set_simd(true);
102        let app = compile(&ev, config, df, num_params)?;
103        Ok(CompiledComplexRunner { app })
104    }
105
106    pub fn compile_string(model: String, config: Config) -> Result<Self> {
107        Self::compile_string_with_funcs(model, config, &Defuns::new(), 0)
108    }
109
110    pub fn compile_string_with_funcs(
111        model: String,
112        mut config: Config,
113        df: &Defuns,
114        num_params: usize,
115    ) -> Result<Self> {
116        config.set_complex(true);
117        config.set_simd(true);
118        let app = compile_string(model, config, df, num_params)?;
119        Ok(CompiledComplexRunner { app })
120    }
121
122    pub fn evaluate(&mut self, args: &[Complex<f64>], outs: &mut [Complex<f64>]) {
123        let n = (2 * args.len()) / self.app.count_params;
124        assert!(2 * outs.len() / self.app.count_obs >= n);
125        self.app.evaluate_complex_matrix(args, outs, n);
126    }
127
128    pub fn save(&self, file: &str) -> Result<()> {
129        let mut fs = std::fs::File::create(file)?;
130        self.app.save(&mut fs)
131    }
132
133    pub fn load(file: &str) -> Result<Self> {
134        let mut fs = std::fs::File::open(file)?;
135        let app = Application::load(&mut fs)?;
136        Ok(Self { app })
137    }
138}
139
140/********************* InterpretedRealRunner ************************/
141
142pub struct InterpretedRealRunner {
143    app: Application,
144}
145
146impl InterpretedRealRunner {
147    pub fn compile(ev: &ExpressionEvaluator<f64>, config: Config) -> Result<Self> {
148        Self::compile_with_funcs(ev, config, &Defuns::new(), 0)
149    }
150
151    pub fn compile_with_funcs(
152        ev: &ExpressionEvaluator<f64>,
153        config: Config,
154        df: &Defuns,
155        num_params: usize,
156    ) -> Result<Self> {
157        let mut c = Config::from_name("bytecode", config.opt)?;
158        c.set_complex(false);
159        c.set_simd(false);
160        let app = compile(&ev, c, df, num_params)?;
161        Ok(Self { app })
162    }
163
164    pub fn compile_string(model: String, config: Config) -> Result<Self> {
165        Self::compile_string_with_funcs(model, config, &Defuns::new(), 0)
166    }
167
168    pub fn compile_string_with_funcs(
169        model: String,
170        config: Config,
171        df: &Defuns,
172        num_params: usize,
173    ) -> Result<Self> {
174        let mut c = Config::from_name("bytecode", config.opt)?;
175        c.set_complex(false);
176        c.set_simd(false);
177        let app = compile_string(model, c, df, num_params)?;
178        Ok(Self { app })
179    }
180
181    pub fn evaluate(&mut self, args: &[f64], outs: &mut [f64]) {
182        let n = args.len() / self.app.count_params;
183        assert!(outs.len() / self.app.count_obs >= n);
184        self.app.evaluate_matrix_bytecode(args, outs, n);
185    }
186
187    pub fn save(&self, file: &str) -> Result<()> {
188        let mut fs = std::fs::File::create(file)?;
189        self.app.save(&mut fs)
190    }
191
192    pub fn load(file: &str) -> Result<Self> {
193        let mut fs = std::fs::File::open(file)?;
194        let app = Application::load(&mut fs)?;
195        Ok(Self { app })
196    }
197}
198
199/********************* InterpretedComplexRunner ************************/
200
201pub struct InterpretedComplexRunner {
202    app: Application,
203}
204
205impl InterpretedComplexRunner {
206    pub fn compile(ev: &ExpressionEvaluator<Complex<f64>>, config: Config) -> Result<Self> {
207        Self::compile_with_funcs(ev, config, &Defuns::new(), 0)
208    }
209
210    pub fn compile_with_funcs(
211        ev: &ExpressionEvaluator<Complex<f64>>,
212        config: Config,
213        df: &Defuns,
214        num_params: usize,
215    ) -> Result<Self> {
216        let mut c = Config::from_name("bytecode", config.opt)?;
217        c.set_complex(true);
218        c.set_simd(false);
219        let app = compile(&ev, c, df, num_params)?;
220        Ok(Self { app })
221    }
222
223    pub fn compile_string(model: String, config: Config) -> Result<Self> {
224        Self::compile_string_with_funcs(model, config, &Defuns::new(), 0)
225    }
226
227    pub fn compile_string_with_funcs(
228        model: String,
229        config: Config,
230        df: &Defuns,
231        num_params: usize,
232    ) -> Result<Self> {
233        let mut c = Config::from_name("bytecode", config.opt)?;
234        c.set_complex(true);
235        c.set_simd(false);
236        let app = compile_string(model, c, df, num_params)?;
237        Ok(Self { app })
238    }
239
240    pub fn evaluate(&mut self, args: &[Complex<f64>], outs: &mut [Complex<f64>]) {
241        let n = (2 * args.len()) / self.app.count_params;
242        assert!((2 * outs.len()) / self.app.count_obs >= n);
243
244        let args = flatten_vec(args);
245        let outs = flatten_vec_mut(outs);
246
247        self.app.evaluate_matrix_bytecode(args, outs, n);
248    }
249
250    pub fn save(&self, file: &str) -> Result<()> {
251        let mut fs = std::fs::File::create(file)?;
252        self.app.save(&mut fs)
253    }
254
255    pub fn load(file: &str) -> Result<Self> {
256        let mut fs = std::fs::File::open(file)?;
257        let app = Application::load(&mut fs)?;
258        Ok(Self { app })
259    }
260}