Skip to main content

symjit_bridge/
runners.rs

1use crate::compile;
2use anyhow::Result;
3use num_complex::Complex;
4use symbolica::evaluate::ExpressionEvaluator;
5use symjit::Storage;
6pub use symjit::{Application, Config};
7
8fn flatten_vec<T>(v: &[T]) -> &[f64] {
9    let n = v.len();
10    let p: *const f64 = unsafe { std::mem::transmute(v.as_ptr()) };
11    let q: &[f64] = unsafe {
12        std::slice::from_raw_parts(p, n * std::mem::size_of::<T>() / std::mem::size_of::<f64>())
13    };
14    q
15}
16
17fn flatten_vec_mut<T>(v: &mut [T]) -> &mut [f64] {
18    let n = v.len();
19    let p: *mut f64 = unsafe { std::mem::transmute(v.as_mut_ptr()) };
20    let q: &mut [f64] = unsafe {
21        std::slice::from_raw_parts_mut(p, n * std::mem::size_of::<T>() / std::mem::size_of::<f64>())
22    };
23    q
24}
25
26/********************* CompiledRealRunner ************************/
27
28pub struct CompiledRealRunner {
29    config: Config,
30    app: Application,
31}
32
33impl CompiledRealRunner {
34    pub fn compile(ev: &ExpressionEvaluator<f64>, mut config: Config) -> Result<Self> {
35        config.set_complex(false);
36        config.set_simd(false);
37        let app = compile(&ev, config)?;
38        Ok(Self { config, app })
39    }
40
41    pub fn evaluate(&mut self, args: &[f64], outs: &mut [f64]) {
42        let n = args.len() / self.app.count_params;
43        assert!(outs.len() / self.app.count_obs >= n);
44
45        if self.config.use_threads() {
46            self.app.evaluate_matrix_without_threads(args, outs, n);
47        } else {
48            self.app.evaluate_matrix_with_threads(args, outs, n);
49        }
50    }
51
52    pub fn save(&self, file: &str) -> Result<()> {
53        let mut fs = std::fs::File::create(file)?;
54        self.app.save(&mut fs)
55    }
56
57    pub fn load(file: &str) -> Result<Self> {
58        let mut fs = std::fs::File::open(file)?;
59        let app = Application::load(&mut fs)?;
60        let config = *app.prog.config();
61        Ok(Self { config, app })
62    }
63}
64
65/************************ CompiledComplexRunner ***************************/
66
67pub struct CompiledComplexRunner {
68    config: Config,
69    app: Application,
70}
71
72impl CompiledComplexRunner {
73    pub fn compile(ev: &ExpressionEvaluator<Complex<f64>>, mut config: Config) -> Result<Self> {
74        config.set_complex(true);
75        config.set_simd(false);
76        let app = compile(&ev, config)?;
77        Ok(CompiledComplexRunner { config, app })
78    }
79
80    pub fn evaluate(&mut self, args: &[Complex<f64>], outs: &mut [Complex<f64>]) {
81        let n = (2 * args.len()) / self.app.count_params;
82        assert!(2 * outs.len() / self.app.count_obs >= n);
83
84        let args = flatten_vec(args);
85        let outs = flatten_vec_mut(outs);
86
87        if self.config.use_threads() {
88            self.app.evaluate_matrix_without_threads(args, outs, n);
89        } else {
90            self.app.evaluate_matrix_with_threads(args, outs, n);
91        }
92    }
93
94    pub fn save(&self, file: &str) -> Result<()> {
95        let mut fs = std::fs::File::create(file)?;
96        self.app.save(&mut fs)
97    }
98
99    pub fn load(file: &str) -> Result<Self> {
100        let mut fs = std::fs::File::open(file)?;
101        let app = Application::load(&mut fs)?;
102        let config = *app.prog.config();
103        Ok(Self { config, app })
104    }
105}
106
107/**************************** CompiledSimdF64x4Runner ****************************/
108
109pub struct CompiledSimdRealRunner {
110    config: Config,
111    app: Application,
112}
113
114impl CompiledSimdRealRunner {
115    pub fn compile(ev: &ExpressionEvaluator<f64>, mut config: Config) -> Result<Self> {
116        config.set_complex(false);
117        config.set_simd(true);
118        let app = compile(&ev, config)?;
119        Ok(Self { config, app })
120    }
121
122    pub fn evaluate<T>(&mut self, args: &[T], outs: &mut [T]) {
123        let n = args.len() / self.app.count_params;
124        assert!(outs.len() / self.app.count_obs >= n);
125
126        let args = flatten_vec(args);
127        let outs = flatten_vec_mut(outs);
128
129        if self.config.use_threads() {
130            self.app
131                .evaluate_matrix_without_threads_simd(args, outs, n, false);
132        } else {
133            self.app
134                .evaluate_matrix_with_threads_simd(args, outs, n, false);
135        }
136    }
137
138    pub fn save(&self, file: &str) -> Result<()> {
139        let mut fs = std::fs::File::create(file)?;
140        self.app.save(&mut fs)
141    }
142
143    pub fn load(file: &str) -> Result<Self> {
144        let mut fs = std::fs::File::open(file)?;
145        let app = Application::load(&mut fs)?;
146        let config = *app.prog.config();
147        Ok(Self { config, app })
148    }
149}
150
151/**************************** CompiledSimdF64x4ComplexRunner ****************************/
152
153pub struct CompiledSimdComplexRunner {
154    config: Config,
155    app: Application,
156}
157
158impl CompiledSimdComplexRunner {
159    pub fn compile(ev: &ExpressionEvaluator<Complex<f64>>, mut config: Config) -> Result<Self> {
160        config.set_complex(true);
161        config.set_simd(true);
162        let app = compile(&ev, config)?;
163        Ok(Self { config, app })
164    }
165
166    pub fn evaluate<T>(&mut self, args: &[Complex<T>], outs: &mut [Complex<T>]) {
167        let n = (2 * args.len()) / self.app.count_params;
168        assert!(2 * outs.len() / self.app.count_obs >= n);
169
170        let args = flatten_vec(args);
171        let outs = flatten_vec_mut(outs);
172
173        if self.config.use_threads() {
174            self.app
175                .evaluate_matrix_without_threads_simd(args, outs, n, false);
176        } else {
177            self.app
178                .evaluate_matrix_with_threads_simd(args, outs, n, false);
179        }
180    }
181
182    pub fn save(&self, file: &str) -> Result<()> {
183        let mut fs = std::fs::File::create(file)?;
184        self.app.save(&mut fs)
185    }
186
187    pub fn load(file: &str) -> Result<Self> {
188        let mut fs = std::fs::File::open(file)?;
189        let app = Application::load(&mut fs)?;
190        let config = *app.prog.config();
191        Ok(Self { config, app })
192    }
193}
194
195/**************************** CompiledScatteredSimdF64x4Runner ****************************/
196
197pub struct CompiledScatteredSimdRealRunner {
198    config: Config,
199    app: Application,
200}
201
202impl CompiledScatteredSimdRealRunner {
203    pub fn compile(ev: &ExpressionEvaluator<f64>, mut config: Config) -> Result<Self> {
204        config.set_complex(false);
205        config.set_simd(true);
206        let app = compile(&ev, config)?;
207        Ok(Self { config, app })
208    }
209
210    pub fn evaluate(&mut self, args: &[f64], outs: &mut [f64]) {
211        let n = args.len() / self.app.count_params;
212        assert!(outs.len() / self.app.count_obs >= n);
213
214        if self.config.use_threads() {
215            self.app
216                .evaluate_matrix_without_threads_simd(args, outs, n, true);
217        } else {
218            self.app
219                .evaluate_matrix_with_threads_simd(args, outs, n, true);
220        }
221    }
222
223    pub fn save(&self, file: &str) -> Result<()> {
224        let mut fs = std::fs::File::create(file)?;
225        self.app.save(&mut fs)
226    }
227
228    pub fn load(file: &str) -> Result<Self> {
229        let mut fs = std::fs::File::open(file)?;
230        let app = Application::load(&mut fs)?;
231        let config = *app.prog.config();
232        Ok(Self { config, app })
233    }
234}
235
236/**************************** CompiledScatteredSimdF64x4ComplexRunner ****************************/
237
238pub struct CompiledScatteredSimdComplexRunner {
239    config: Config,
240    app: Application,
241}
242
243impl CompiledScatteredSimdComplexRunner {
244    pub fn compile(ev: &ExpressionEvaluator<Complex<f64>>, mut config: Config) -> Result<Self> {
245        config.set_complex(true);
246        config.set_simd(true);
247        let app = compile(&ev, config)?;
248        Ok(Self { config, app })
249    }
250
251    pub fn evaluate(&mut self, args: &[Complex<f64>], outs: &mut [Complex<f64>]) {
252        let n = (2 * args.len()) / self.app.count_params;
253        assert!(2 * outs.len() / self.app.count_obs >= n);
254
255        let args = flatten_vec(args);
256        let outs = flatten_vec_mut(outs);
257
258        if self.config.use_threads() {
259            self.app
260                .evaluate_matrix_without_threads_simd(args, outs, n, true);
261        } else {
262            self.app
263                .evaluate_matrix_with_threads_simd(args, outs, n, true);
264        }
265    }
266
267    pub fn save(&self, file: &str) -> Result<()> {
268        let mut fs = std::fs::File::create(file)?;
269        self.app.save(&mut fs)
270    }
271
272    pub fn load(file: &str) -> Result<Self> {
273        let mut fs = std::fs::File::open(file)?;
274        let app = Application::load(&mut fs)?;
275        let config = *app.prog.config();
276        Ok(Self { config, app })
277    }
278}
279
280/********************* InterpretedRealRunner ************************/
281
282pub struct InterpretedRealRunner {
283    app: Application,
284}
285
286impl InterpretedRealRunner {
287    pub fn compile(ev: &ExpressionEvaluator<f64>, config: Config) -> Result<Self> {
288        let mut c = Config::from_name("bytecode", config.opt)?;
289        c.set_complex(false);
290        c.set_simd(false);
291        let app = compile(&ev, c)?;
292        Ok(Self { app })
293    }
294
295    pub fn evaluate(&mut self, args: &[f64], outs: &mut [f64]) {
296        let n = args.len() / self.app.count_params;
297        assert!(outs.len() / self.app.count_obs >= n);
298        self.app.evaluate_matrix_bytecode(args, outs, n);
299    }
300
301    pub fn save(&self, file: &str) -> Result<()> {
302        let mut fs = std::fs::File::create(file)?;
303        self.app.save(&mut fs)
304    }
305
306    pub fn load(file: &str) -> Result<Self> {
307        let mut fs = std::fs::File::open(file)?;
308        let app = Application::load(&mut fs)?;
309        Ok(Self { app })
310    }
311}
312
313/********************* InterpretedComplexRunner ************************/
314
315pub struct InterpretedComplexRunner {
316    app: Application,
317}
318
319impl InterpretedComplexRunner {
320    pub fn compile(ev: &ExpressionEvaluator<Complex<f64>>, config: Config) -> Result<Self> {
321        let mut c = Config::from_name("bytecode", config.opt)?;
322        c.set_complex(true);
323        c.set_simd(false);
324        let app = compile(&ev, c)?;
325        Ok(Self { app })
326    }
327
328    pub fn evaluate(&mut self, args: &[Complex<f64>], outs: &mut [Complex<f64>]) {
329        let n = (2 * args.len()) / self.app.count_params;
330        assert!((2 * outs.len()) / self.app.count_obs >= n);
331
332        let args = flatten_vec(args);
333        let outs = flatten_vec_mut(outs);
334
335        self.app.evaluate_matrix_bytecode(args, outs, n);
336    }
337
338    pub fn save(&self, file: &str) -> Result<()> {
339        let mut fs = std::fs::File::create(file)?;
340        self.app.save(&mut fs)
341    }
342
343    pub fn load(file: &str) -> Result<Self> {
344        let mut fs = std::fs::File::open(file)?;
345        let app = Application::load(&mut fs)?;
346        Ok(Self { app })
347    }
348}