Skip to main content

symjit_bridge/
runners.rs

1use crate::{compile, compile_string};
2use anyhow::Result;
3use symbolica::evaluate::ExpressionEvaluator;
4use symjit::Storage;
5pub use symjit::{Applet, Application, Complex, Config, Defuns, Element};
6
7fn flatten_vec<T>(v: &[T]) -> &[f64] {
8    let n = v.len();
9    let p: *const f64 = unsafe { std::mem::transmute(v.as_ptr()) };
10    let q: &[f64] = unsafe {
11        std::slice::from_raw_parts(p, n * std::mem::size_of::<T>() / std::mem::size_of::<f64>())
12    };
13    q
14}
15
16fn flatten_vec_mut<T>(v: &mut [T]) -> &mut [f64] {
17    let n = v.len();
18    let p: *mut f64 = unsafe { std::mem::transmute(v.as_mut_ptr()) };
19    let q: &mut [f64] = unsafe {
20        std::slice::from_raw_parts_mut(p, n * std::mem::size_of::<T>() / std::mem::size_of::<f64>())
21    };
22    q
23}
24
25/********************* CompiledRealRunner ************************/
26
27pub struct CompiledRealRunner {
28    pub app: Application,
29}
30
31impl CompiledRealRunner {
32    pub fn compile(ev: &ExpressionEvaluator<f64>, config: Config) -> Result<Self> {
33        Self::compile_with_funcs(ev, config, Defuns::new(), 0)
34    }
35
36    pub fn compile_with_funcs(
37        ev: &ExpressionEvaluator<f64>,
38        mut config: Config,
39        df: Defuns,
40        num_params: usize,
41    ) -> Result<Self> {
42        config.set_complex(false);
43        config.set_simd(true);
44        let app = compile(&ev, config, df, num_params)?;
45        Ok(Self { app })
46    }
47
48    pub fn compile_string(model: String, config: Config) -> Result<Self> {
49        Self::compile_string_with_funcs(model, config, Defuns::new(), 0)
50    }
51
52    pub fn compile_string_with_funcs(
53        model: String,
54        mut config: Config,
55        df: Defuns,
56        num_params: usize,
57    ) -> Result<Self> {
58        config.set_complex(false);
59        config.set_simd(true);
60        let app = compile_string(model, config, df, num_params)?;
61        Ok(Self { app })
62    }
63
64    pub fn evaluate<T>(&self, args: &[T], outs: &mut [T])
65    where
66        T: Element,
67    {
68        let n = args.len() / self.app.count_params;
69        assert!(outs.len() / self.app.count_obs >= n);
70        self.app.evaluate_matrix(args, outs, n);
71    }
72
73    pub fn save(&self, file: &str) -> Result<()> {
74        let mut fs = std::fs::File::create(file)?;
75        self.app.save(&mut fs)
76    }
77
78    pub fn load(file: &str) -> Result<Self> {
79        let mut fs = std::fs::File::open(file)?;
80        let app = Application::load(&mut fs)?;
81        Ok(Self { app })
82    }
83
84    pub fn seal(self) -> Result<Applet> {
85        self.app.seal()
86    }
87}
88
89/************************ CompiledComplexRunner ***************************/
90
91pub struct CompiledComplexRunner {
92    pub app: Application,
93}
94
95impl CompiledComplexRunner {
96    pub fn compile(ev: &ExpressionEvaluator<Complex<f64>>, config: Config) -> Result<Self> {
97        Self::compile_with_funcs(ev, config, Defuns::new(), 0)
98    }
99
100    pub fn compile_with_funcs(
101        ev: &ExpressionEvaluator<Complex<f64>>,
102        mut config: Config,
103        df: Defuns,
104        num_params: usize,
105    ) -> Result<Self> {
106        config.set_complex(true);
107        config.set_simd(true);
108        let app = compile(&ev, config, df, num_params)?;
109        Ok(CompiledComplexRunner { app })
110    }
111
112    pub fn compile_string(model: String, config: Config) -> Result<Self> {
113        Self::compile_string_with_funcs(model, config, Defuns::new(), 0)
114    }
115
116    pub fn compile_string_with_funcs(
117        model: String,
118        mut config: Config,
119        df: Defuns,
120        num_params: usize,
121    ) -> Result<Self> {
122        config.set_complex(true);
123        config.set_simd(true);
124        let app = compile_string(model, config, df, num_params)?;
125        Ok(CompiledComplexRunner { app })
126    }
127
128    pub fn evaluate<T>(&self, args: &[T], outs: &mut [T])
129    where
130        T: Element,
131    {
132        let n = (2 * args.len()) / self.app.count_params;
133        assert!(2 * outs.len() / self.app.count_obs >= n);
134        self.app.evaluate_matrix(args, outs, n);
135    }
136
137    pub fn save(&self, file: &str) -> Result<()> {
138        let mut fs = std::fs::File::create(file)?;
139        self.app.save(&mut fs)
140    }
141
142    pub fn load(file: &str) -> Result<Self> {
143        let mut fs = std::fs::File::open(file)?;
144        let app = Application::load(&mut fs)?;
145        Ok(Self { app })
146    }
147
148    pub fn seal(self) -> Result<Applet> {
149        self.app.seal()
150    }
151}
152
153/********************* InterpretedRealRunner ************************/
154
155pub struct InterpretedRealRunner {
156    app: Application,
157}
158
159impl InterpretedRealRunner {
160    pub fn compile(ev: &ExpressionEvaluator<f64>, config: Config) -> Result<Self> {
161        Self::compile_with_funcs(ev, config, Defuns::new(), 0)
162    }
163
164    pub fn compile_with_funcs(
165        ev: &ExpressionEvaluator<f64>,
166        config: Config,
167        df: Defuns,
168        num_params: usize,
169    ) -> Result<Self> {
170        let mut c = Config::from_name("bytecode", config.opt)?;
171        c.set_complex(false);
172        c.set_simd(false);
173        let app = compile(&ev, c, df, num_params)?;
174        Ok(Self { app })
175    }
176
177    pub fn compile_string(model: String, config: Config) -> Result<Self> {
178        Self::compile_string_with_funcs(model, config, Defuns::new(), 0)
179    }
180
181    pub fn compile_string_with_funcs(
182        model: String,
183        config: Config,
184        df: Defuns,
185        num_params: usize,
186    ) -> Result<Self> {
187        let mut c = Config::from_name("bytecode", config.opt)?;
188        c.set_complex(false);
189        c.set_simd(false);
190        let app = compile_string(model, c, df, num_params)?;
191        Ok(Self { app })
192    }
193
194    pub fn evaluate(&mut self, args: &[f64], outs: &mut [f64]) {
195        let n = args.len() / self.app.count_params;
196        assert!(outs.len() / self.app.count_obs >= n);
197        self.app.interpret_matrix(args, outs, n);
198    }
199
200    pub fn save(&self, file: &str) -> Result<()> {
201        let mut fs = std::fs::File::create(file)?;
202        self.app.save(&mut fs)
203    }
204
205    pub fn load(file: &str) -> Result<Self> {
206        let mut fs = std::fs::File::open(file)?;
207        let app = Application::load(&mut fs)?;
208        Ok(Self { app })
209    }
210}
211
212/********************* InterpretedComplexRunner ************************/
213
214pub struct InterpretedComplexRunner {
215    app: Application,
216}
217
218impl InterpretedComplexRunner {
219    pub fn compile(ev: &ExpressionEvaluator<Complex<f64>>, config: Config) -> Result<Self> {
220        Self::compile_with_funcs(ev, config, Defuns::new(), 0)
221    }
222
223    pub fn compile_with_funcs(
224        ev: &ExpressionEvaluator<Complex<f64>>,
225        config: Config,
226        df: Defuns,
227        num_params: usize,
228    ) -> Result<Self> {
229        let mut c = Config::from_name("bytecode", config.opt)?;
230        c.set_complex(true);
231        c.set_simd(false);
232        let app = compile(&ev, c, df, num_params)?;
233        Ok(Self { app })
234    }
235
236    pub fn compile_string(model: String, config: Config) -> Result<Self> {
237        Self::compile_string_with_funcs(model, config, Defuns::new(), 0)
238    }
239
240    pub fn compile_string_with_funcs(
241        model: String,
242        config: Config,
243        df: Defuns,
244        num_params: usize,
245    ) -> Result<Self> {
246        let mut c = Config::from_name("bytecode", config.opt)?;
247        c.set_complex(true);
248        c.set_simd(false);
249        let app = compile_string(model, c, df, num_params)?;
250        Ok(Self { app })
251    }
252
253    pub fn evaluate(&mut self, args: &[Complex<f64>], outs: &mut [Complex<f64>]) {
254        let n = (2 * args.len()) / self.app.count_params;
255        assert!((2 * outs.len()) / self.app.count_obs >= n);
256
257        let args = flatten_vec(args);
258        let outs = flatten_vec_mut(outs);
259
260        self.app.interpret_matrix(args, outs, n);
261    }
262
263    pub fn save(&self, file: &str) -> Result<()> {
264        let mut fs = std::fs::File::create(file)?;
265        self.app.save(&mut fs)
266    }
267
268    pub fn load(file: &str) -> Result<Self> {
269        let mut fs = std::fs::File::open(file)?;
270        let app = Application::load(&mut fs)?;
271        Ok(Self { app })
272    }
273}