diffsl/execution/
interface.rs

1use anyhow::{anyhow, Result};
2use std::collections::HashMap;
3
4type RealType = f64;
5type UIntType = u32;
6
7pub type BarrierInitFunc = unsafe extern "C" fn();
8
9pub type SetConstantsFunc = unsafe extern "C" fn(thread_id: UIntType, thread_dim: UIntType);
10
11pub type StopFunc = unsafe extern "C" fn(
12    time: RealType,
13    u: *const RealType,
14    data: *mut RealType,
15    root: *mut RealType,
16    thread_id: UIntType,
17    thread_dim: UIntType,
18);
19pub type RhsFunc = unsafe extern "C" fn(
20    time: RealType,
21    u: *const RealType,
22    data: *mut RealType,
23    rr: *mut RealType,
24    thread_id: UIntType,
25    thread_dim: UIntType,
26);
27pub type RhsGradFunc = unsafe extern "C" fn(
28    time: RealType,
29    u: *const RealType,
30    du: *const RealType,
31    data: *const RealType,
32    ddata: *mut RealType,
33    rr: *const RealType,
34    drr: *mut RealType,
35    thread_id: UIntType,
36    thread_dim: UIntType,
37);
38pub type RhsRevGradFunc = unsafe extern "C" fn(
39    time: RealType,
40    u: *const RealType,
41    du: *mut RealType,
42    data: *const RealType,
43    ddata: *mut RealType,
44    rr: *const RealType,
45    drr: *mut RealType,
46    thread_id: UIntType,
47    thread_dim: UIntType,
48);
49pub type RhsSensGradFunc = unsafe extern "C" fn(
50    time: RealType,
51    u: *const RealType,
52    data: *const RealType,
53    ddata: *mut RealType,
54    rr: *const RealType,
55    drr: *mut RealType,
56    thread_id: UIntType,
57    thread_dim: UIntType,
58);
59pub type RhsSensRevGradFunc = unsafe extern "C" fn(
60    time: RealType,
61    u: *const RealType,
62    data: *const RealType,
63    ddata: *mut RealType,
64    rr: *const RealType,
65    drr: *mut RealType,
66    thread_id: UIntType,
67    thread_dim: UIntType,
68);
69pub type MassFunc = unsafe extern "C" fn(
70    time: RealType,
71    v: *const RealType,
72    data: *mut RealType,
73    mv: *mut RealType,
74    thread_id: UIntType,
75    thread_dim: UIntType,
76);
77pub type MassRevGradFunc = unsafe extern "C" fn(
78    time: RealType,
79    v: *const RealType,
80    dv: *mut RealType,
81    data: *const RealType,
82    ddata: *mut RealType,
83    mv: *const RealType,
84    dmv: *mut RealType,
85    thread_id: UIntType,
86    thread_dim: UIntType,
87);
88pub type U0Func = unsafe extern "C" fn(
89    u: *mut RealType,
90    data: *mut RealType,
91    thread_id: UIntType,
92    thread_dim: UIntType,
93);
94pub type U0GradFunc = unsafe extern "C" fn(
95    u: *const RealType,
96    du: *mut RealType,
97    data: *const RealType,
98    ddata: *mut RealType,
99    thread_id: UIntType,
100    thread_dim: UIntType,
101);
102pub type U0RevGradFunc = unsafe extern "C" fn(
103    u: *const RealType,
104    du: *mut RealType,
105    data: *const RealType,
106    ddata: *mut RealType,
107    thread_id: UIntType,
108    thread_dim: UIntType,
109);
110pub type CalcOutFunc = unsafe extern "C" fn(
111    time: RealType,
112    u: *const RealType,
113    data: *mut RealType,
114    out: *mut RealType,
115    thread_id: UIntType,
116    thread_dim: UIntType,
117);
118pub type CalcOutGradFunc = unsafe extern "C" fn(
119    time: RealType,
120    u: *const RealType,
121    du: *const RealType,
122    data: *const RealType,
123    ddata: *mut RealType,
124    out: *const RealType,
125    dout: *mut RealType,
126    thread_id: UIntType,
127    thread_dim: UIntType,
128);
129pub type CalcOutRevGradFunc = unsafe extern "C" fn(
130    time: RealType,
131    u: *const RealType,
132    du: *mut RealType,
133    data: *const RealType,
134    ddata: *mut RealType,
135    out: *const RealType,
136    dout: *mut RealType,
137    thread_id: UIntType,
138    thread_dim: UIntType,
139);
140pub type CalcOutSensGradFunc = unsafe extern "C" fn(
141    time: RealType,
142    u: *const RealType,
143    data: *const RealType,
144    ddata: *mut RealType,
145    out: *const RealType,
146    dout: *mut RealType,
147    thread_id: UIntType,
148    thread_dim: UIntType,
149);
150pub type CalcOutSensRevGradFunc = unsafe extern "C" fn(
151    time: RealType,
152    u: *const RealType,
153    data: *const RealType,
154    ddata: *mut RealType,
155    out: *const RealType,
156    dout: *mut RealType,
157    thread_id: UIntType,
158    thread_dim: UIntType,
159);
160pub type GetDimsFunc = unsafe extern "C" fn(
161    states: *mut UIntType,
162    inputs: *mut UIntType,
163    outputs: *mut UIntType,
164    data: *mut UIntType,
165    stop: *mut UIntType,
166    has_mass: *mut UIntType,
167);
168pub type SetInputsFunc = unsafe extern "C" fn(inputs: *const RealType, data: *mut RealType);
169pub type GetInputsFunc = unsafe extern "C" fn(inputs: *mut RealType, data: *const RealType);
170pub type SetInputsGradFunc = unsafe extern "C" fn(
171    inputs: *const RealType,
172    dinputs: *const RealType,
173    data: *const RealType,
174    ddata: *mut RealType,
175);
176pub type SetInputsRevGradFunc = unsafe extern "C" fn(
177    inputs: *const RealType,
178    dinputs: *mut RealType,
179    data: *const RealType,
180    ddata: *mut RealType,
181);
182pub type SetIdFunc = unsafe extern "C" fn(id: *mut RealType);
183pub type GetTensorFunc = unsafe extern "C" fn(
184    data: *const RealType,
185    tensor_data: *mut *mut RealType,
186    tensor_size: *mut UIntType,
187);
188pub type GetConstantFunc =
189    unsafe extern "C" fn(tensor_data: *mut *const RealType, tensor_size: *mut UIntType);
190
191pub(crate) struct JitFunctions {
192    pub(crate) set_u0: U0Func,
193    pub(crate) rhs: RhsFunc,
194    pub(crate) mass: MassFunc,
195    pub(crate) calc_out: CalcOutFunc,
196    pub(crate) calc_stop: StopFunc,
197    pub(crate) set_id: SetIdFunc,
198    pub(crate) get_dims: GetDimsFunc,
199    pub(crate) set_inputs: SetInputsFunc,
200    pub(crate) get_inputs: GetInputsFunc,
201    #[allow(dead_code)]
202    pub(crate) barrier_init: Option<BarrierInitFunc>,
203    pub(crate) set_constants: SetConstantsFunc,
204}
205
206impl JitFunctions {
207    pub(crate) fn new(symbol_map: &HashMap<String, *const u8>) -> Result<Self> {
208        // check if all required symbols are present
209        let required_symbols = [
210            "set_u0",
211            "rhs",
212            "mass",
213            "calc_out",
214            "calc_stop",
215            "set_id",
216            "get_dims",
217            "set_inputs",
218            "get_inputs",
219            "set_constants",
220        ];
221        for symbol in &required_symbols {
222            if !symbol_map.contains_key(*symbol) {
223                return Err(anyhow!("Missing required symbol: {}", symbol));
224            }
225        }
226        let set_u0 = unsafe { std::mem::transmute::<*const u8, U0Func>(symbol_map["set_u0"]) };
227        let rhs = unsafe { std::mem::transmute::<*const u8, RhsFunc>(symbol_map["rhs"]) };
228        let mass = unsafe { std::mem::transmute::<*const u8, MassFunc>(symbol_map["mass"]) };
229        let calc_out =
230            unsafe { std::mem::transmute::<*const u8, CalcOutFunc>(symbol_map["calc_out"]) };
231        let calc_stop =
232            unsafe { std::mem::transmute::<*const u8, StopFunc>(symbol_map["calc_stop"]) };
233        let set_id = unsafe { std::mem::transmute::<*const u8, SetIdFunc>(symbol_map["set_id"]) };
234        let get_dims =
235            unsafe { std::mem::transmute::<*const u8, GetDimsFunc>(symbol_map["get_dims"]) };
236        let set_inputs =
237            unsafe { std::mem::transmute::<*const u8, SetInputsFunc>(symbol_map["set_inputs"]) };
238        let get_inputs =
239            unsafe { std::mem::transmute::<*const u8, GetInputsFunc>(symbol_map["get_inputs"]) };
240        let barrier_init = symbol_map.get("barrier_init").map(|func_ptr| unsafe {
241            std::mem::transmute::<*const u8, BarrierInitFunc>(*func_ptr)
242        });
243        let set_constants = unsafe {
244            std::mem::transmute::<*const u8, SetConstantsFunc>(symbol_map["set_constants"])
245        };
246
247        Ok(Self {
248            set_u0,
249            rhs,
250            mass,
251            calc_out,
252            calc_stop,
253            set_id,
254            get_dims,
255            set_inputs,
256            get_inputs,
257            barrier_init,
258            set_constants,
259        })
260    }
261}
262
263pub(crate) struct JitGradFunctions {
264    pub(crate) set_u0_grad: U0GradFunc,
265    pub(crate) rhs_grad: RhsGradFunc,
266    pub(crate) calc_out_grad: CalcOutGradFunc,
267    pub(crate) set_inputs_grad: SetInputsGradFunc,
268}
269
270impl JitGradFunctions {
271    pub(crate) fn new(symbol_map: &HashMap<String, *const u8>) -> Result<Self> {
272        // check if all required symbols are present
273        let required_symbols = [
274            "set_u0_grad",
275            "rhs_grad",
276            "calc_out_grad",
277            "set_inputs_grad",
278        ];
279        for symbol in &required_symbols {
280            if !symbol_map.contains_key(*symbol) {
281                return Err(anyhow!("Missing required symbol: {}", symbol));
282            }
283        }
284        let set_u0_grad =
285            unsafe { std::mem::transmute::<*const u8, U0GradFunc>(symbol_map["set_u0_grad"]) };
286        let rhs_grad =
287            unsafe { std::mem::transmute::<*const u8, RhsGradFunc>(symbol_map["rhs_grad"]) };
288        let calc_out_grad = unsafe {
289            std::mem::transmute::<*const u8, CalcOutGradFunc>(symbol_map["calc_out_grad"])
290        };
291        let set_inputs_grad = unsafe {
292            std::mem::transmute::<*const u8, SetInputsGradFunc>(symbol_map["set_inputs_grad"])
293        };
294
295        Ok(Self {
296            set_u0_grad,
297            rhs_grad,
298            calc_out_grad,
299            set_inputs_grad,
300        })
301    }
302}
303
304pub(crate) struct JitGradRFunctions {
305    pub(crate) set_u0_rgrad: U0RevGradFunc,
306    pub(crate) rhs_rgrad: RhsRevGradFunc,
307    pub(crate) mass_rgrad: MassRevGradFunc,
308    pub(crate) calc_out_rgrad: CalcOutRevGradFunc,
309    pub(crate) set_inputs_rgrad: SetInputsRevGradFunc,
310}
311
312impl JitGradRFunctions {
313    pub(crate) fn new(symbol_map: &HashMap<String, *const u8>) -> Result<Self> {
314        let required_symbols = [
315            "set_u0_rgrad",
316            "rhs_rgrad",
317            "mass_rgrad",
318            "calc_out_rgrad",
319            "set_inputs_rgrad",
320        ];
321        for symbol in &required_symbols {
322            if !symbol_map.contains_key(*symbol) {
323                return Err(anyhow!("Missing required symbol: {}", symbol));
324            }
325        }
326        let set_u0_rgrad =
327            unsafe { std::mem::transmute::<*const u8, U0RevGradFunc>(symbol_map["set_u0_rgrad"]) };
328        let rhs_rgrad =
329            unsafe { std::mem::transmute::<*const u8, RhsRevGradFunc>(symbol_map["rhs_rgrad"]) };
330        let mass_rgrad =
331            unsafe { std::mem::transmute::<*const u8, MassRevGradFunc>(symbol_map["mass_rgrad"]) };
332        let calc_out_rgrad = unsafe {
333            std::mem::transmute::<*const u8, CalcOutRevGradFunc>(symbol_map["calc_out_rgrad"])
334        };
335        let set_inputs_rgrad = unsafe {
336            std::mem::transmute::<*const u8, SetInputsRevGradFunc>(symbol_map["set_inputs_rgrad"])
337        };
338
339        Ok(Self {
340            set_u0_rgrad,
341            rhs_rgrad,
342            mass_rgrad,
343            calc_out_rgrad,
344            set_inputs_rgrad,
345        })
346    }
347}
348
349pub(crate) struct JitSensGradFunctions {
350    pub(crate) rhs_sgrad: RhsSensGradFunc,
351    pub(crate) calc_out_sgrad: CalcOutSensGradFunc,
352}
353
354impl JitSensGradFunctions {
355    pub(crate) fn new(symbol_map: &HashMap<String, *const u8>) -> Result<Self> {
356        let required_symbols = ["rhs_sgrad", "calc_out_sgrad"];
357        for symbol in &required_symbols {
358            if !symbol_map.contains_key(*symbol) {
359                return Err(anyhow!("Missing required symbol: {}", symbol));
360            }
361        }
362        let rhs_sgrad =
363            unsafe { std::mem::transmute::<*const u8, RhsSensGradFunc>(symbol_map["rhs_sgrad"]) };
364        let calc_out_sgrad = unsafe {
365            std::mem::transmute::<*const u8, CalcOutSensGradFunc>(symbol_map["calc_out_sgrad"])
366        };
367
368        Ok(Self {
369            rhs_sgrad,
370            calc_out_sgrad,
371        })
372    }
373}
374
375pub(crate) struct JitSensRevGradFunctions {
376    pub(crate) rhs_rgrad: RhsSensRevGradFunc,
377    pub(crate) calc_out_rgrad: CalcOutSensRevGradFunc,
378}
379
380impl JitSensRevGradFunctions {
381    pub(crate) fn new(symbol_map: &HashMap<String, *const u8>) -> Result<Self> {
382        let required_symbols = ["rhs_srgrad", "calc_out_srgrad"];
383        for symbol in &required_symbols {
384            if !symbol_map.contains_key(*symbol) {
385                return Err(anyhow!("Missing required symbol: {}", symbol));
386            }
387        }
388        let rhs_rgrad = unsafe {
389            std::mem::transmute::<*const u8, RhsSensRevGradFunc>(symbol_map["rhs_srgrad"])
390        };
391        let calc_out_rgrad = unsafe {
392            std::mem::transmute::<*const u8, CalcOutSensRevGradFunc>(symbol_map["calc_out_srgrad"])
393        };
394
395        Ok(Self {
396            rhs_rgrad,
397            calc_out_rgrad,
398        })
399    }
400}
401
402pub(crate) struct JitGetTensorFunctions {
403    pub(crate) data_map: HashMap<String, GetTensorFunc>,
404    pub(crate) constant_map: HashMap<String, GetConstantFunc>,
405}
406
407impl JitGetTensorFunctions {
408    pub(crate) fn new(symbol_map: &HashMap<String, *const u8>) -> Result<Self> {
409        let mut data_map = HashMap::new();
410        let mut constant_map = HashMap::new();
411        let data_prefix = "get_tensor_";
412        let constant_prefix = "get_constant_";
413        for (name, func_ptr) in symbol_map.iter() {
414            if name.starts_with(data_prefix) {
415                let func = unsafe { std::mem::transmute::<*const u8, GetTensorFunc>(*func_ptr) };
416                data_map.insert(name.strip_prefix(data_prefix).unwrap().to_string(), func);
417            } else if name.starts_with(constant_prefix) {
418                let func = unsafe { std::mem::transmute::<*const u8, GetConstantFunc>(*func_ptr) };
419                constant_map.insert(
420                    name.strip_prefix(constant_prefix).unwrap().to_string(),
421                    func,
422                );
423            }
424        }
425        Ok(Self {
426            data_map,
427            constant_map,
428        })
429    }
430}