1use anyhow::{anyhow, Result};
2use std::collections::HashMap;
3
4type RealType = f64;
5type UIntType = u32;
6
7pub type BarrierInitFunc = unsafe extern "C" fn();
8
9pub type SetConstantsFunc = unsafe extern "C" fn(thread_id: UIntType, thread_dim: UIntType);
10
11pub type StopFunc = unsafe extern "C" fn(
12 time: RealType,
13 u: *const RealType,
14 data: *mut RealType,
15 root: *mut RealType,
16 thread_id: UIntType,
17 thread_dim: UIntType,
18);
19pub type RhsFunc = unsafe extern "C" fn(
20 time: RealType,
21 u: *const RealType,
22 data: *mut RealType,
23 rr: *mut RealType,
24 thread_id: UIntType,
25 thread_dim: UIntType,
26);
27pub type RhsGradFunc = unsafe extern "C" fn(
28 time: RealType,
29 u: *const RealType,
30 du: *const RealType,
31 data: *const RealType,
32 ddata: *mut RealType,
33 rr: *const RealType,
34 drr: *mut RealType,
35 thread_id: UIntType,
36 thread_dim: UIntType,
37);
38pub type RhsRevGradFunc = unsafe extern "C" fn(
39 time: RealType,
40 u: *const RealType,
41 du: *mut RealType,
42 data: *const RealType,
43 ddata: *mut RealType,
44 rr: *const RealType,
45 drr: *mut RealType,
46 thread_id: UIntType,
47 thread_dim: UIntType,
48);
49pub type RhsSensGradFunc = unsafe extern "C" fn(
50 time: RealType,
51 u: *const RealType,
52 data: *const RealType,
53 ddata: *mut RealType,
54 rr: *const RealType,
55 drr: *mut RealType,
56 thread_id: UIntType,
57 thread_dim: UIntType,
58);
59pub type RhsSensRevGradFunc = unsafe extern "C" fn(
60 time: RealType,
61 u: *const RealType,
62 data: *const RealType,
63 ddata: *mut RealType,
64 rr: *const RealType,
65 drr: *mut RealType,
66 thread_id: UIntType,
67 thread_dim: UIntType,
68);
69pub type MassFunc = unsafe extern "C" fn(
70 time: RealType,
71 v: *const RealType,
72 data: *mut RealType,
73 mv: *mut RealType,
74 thread_id: UIntType,
75 thread_dim: UIntType,
76);
77pub type MassRevGradFunc = unsafe extern "C" fn(
78 time: RealType,
79 v: *const RealType,
80 dv: *mut RealType,
81 data: *const RealType,
82 ddata: *mut RealType,
83 mv: *const RealType,
84 dmv: *mut RealType,
85 thread_id: UIntType,
86 thread_dim: UIntType,
87);
88pub type U0Func = unsafe extern "C" fn(
89 u: *mut RealType,
90 data: *mut RealType,
91 thread_id: UIntType,
92 thread_dim: UIntType,
93);
94pub type U0GradFunc = unsafe extern "C" fn(
95 u: *const RealType,
96 du: *mut RealType,
97 data: *const RealType,
98 ddata: *mut RealType,
99 thread_id: UIntType,
100 thread_dim: UIntType,
101);
102pub type U0RevGradFunc = unsafe extern "C" fn(
103 u: *const RealType,
104 du: *mut RealType,
105 data: *const RealType,
106 ddata: *mut RealType,
107 thread_id: UIntType,
108 thread_dim: UIntType,
109);
110pub type CalcOutFunc = unsafe extern "C" fn(
111 time: RealType,
112 u: *const RealType,
113 data: *mut RealType,
114 out: *mut RealType,
115 thread_id: UIntType,
116 thread_dim: UIntType,
117);
118pub type CalcOutGradFunc = unsafe extern "C" fn(
119 time: RealType,
120 u: *const RealType,
121 du: *const RealType,
122 data: *const RealType,
123 ddata: *mut RealType,
124 out: *const RealType,
125 dout: *mut RealType,
126 thread_id: UIntType,
127 thread_dim: UIntType,
128);
129pub type CalcOutRevGradFunc = unsafe extern "C" fn(
130 time: RealType,
131 u: *const RealType,
132 du: *mut RealType,
133 data: *const RealType,
134 ddata: *mut RealType,
135 out: *const RealType,
136 dout: *mut RealType,
137 thread_id: UIntType,
138 thread_dim: UIntType,
139);
140pub type CalcOutSensGradFunc = unsafe extern "C" fn(
141 time: RealType,
142 u: *const RealType,
143 data: *const RealType,
144 ddata: *mut RealType,
145 out: *const RealType,
146 dout: *mut RealType,
147 thread_id: UIntType,
148 thread_dim: UIntType,
149);
150pub type CalcOutSensRevGradFunc = unsafe extern "C" fn(
151 time: RealType,
152 u: *const RealType,
153 data: *const RealType,
154 ddata: *mut RealType,
155 out: *const RealType,
156 dout: *mut RealType,
157 thread_id: UIntType,
158 thread_dim: UIntType,
159);
160pub type GetDimsFunc = unsafe extern "C" fn(
161 states: *mut UIntType,
162 inputs: *mut UIntType,
163 outputs: *mut UIntType,
164 data: *mut UIntType,
165 stop: *mut UIntType,
166 has_mass: *mut UIntType,
167);
168pub type SetInputsFunc = unsafe extern "C" fn(inputs: *const RealType, data: *mut RealType);
169pub type GetInputsFunc = unsafe extern "C" fn(inputs: *mut RealType, data: *const RealType);
170pub type SetInputsGradFunc = unsafe extern "C" fn(
171 inputs: *const RealType,
172 dinputs: *const RealType,
173 data: *const RealType,
174 ddata: *mut RealType,
175);
176pub type SetInputsRevGradFunc = unsafe extern "C" fn(
177 inputs: *const RealType,
178 dinputs: *mut RealType,
179 data: *const RealType,
180 ddata: *mut RealType,
181);
182pub type SetIdFunc = unsafe extern "C" fn(id: *mut RealType);
183pub type GetTensorFunc = unsafe extern "C" fn(
184 data: *const RealType,
185 tensor_data: *mut *mut RealType,
186 tensor_size: *mut UIntType,
187);
188pub type GetConstantFunc =
189 unsafe extern "C" fn(tensor_data: *mut *const RealType, tensor_size: *mut UIntType);
190
191pub(crate) struct JitFunctions {
192 pub(crate) set_u0: U0Func,
193 pub(crate) rhs: RhsFunc,
194 pub(crate) mass: MassFunc,
195 pub(crate) calc_out: CalcOutFunc,
196 pub(crate) calc_stop: StopFunc,
197 pub(crate) set_id: SetIdFunc,
198 pub(crate) get_dims: GetDimsFunc,
199 pub(crate) set_inputs: SetInputsFunc,
200 pub(crate) get_inputs: GetInputsFunc,
201 #[allow(dead_code)]
202 pub(crate) barrier_init: Option<BarrierInitFunc>,
203 pub(crate) set_constants: SetConstantsFunc,
204}
205
206impl JitFunctions {
207 pub(crate) fn new(symbol_map: &HashMap<String, *const u8>) -> Result<Self> {
208 let required_symbols = [
210 "set_u0",
211 "rhs",
212 "mass",
213 "calc_out",
214 "calc_stop",
215 "set_id",
216 "get_dims",
217 "set_inputs",
218 "get_inputs",
219 "set_constants",
220 ];
221 for symbol in &required_symbols {
222 if !symbol_map.contains_key(*symbol) {
223 return Err(anyhow!("Missing required symbol: {}", symbol));
224 }
225 }
226 let set_u0 = unsafe { std::mem::transmute::<*const u8, U0Func>(symbol_map["set_u0"]) };
227 let rhs = unsafe { std::mem::transmute::<*const u8, RhsFunc>(symbol_map["rhs"]) };
228 let mass = unsafe { std::mem::transmute::<*const u8, MassFunc>(symbol_map["mass"]) };
229 let calc_out =
230 unsafe { std::mem::transmute::<*const u8, CalcOutFunc>(symbol_map["calc_out"]) };
231 let calc_stop =
232 unsafe { std::mem::transmute::<*const u8, StopFunc>(symbol_map["calc_stop"]) };
233 let set_id = unsafe { std::mem::transmute::<*const u8, SetIdFunc>(symbol_map["set_id"]) };
234 let get_dims =
235 unsafe { std::mem::transmute::<*const u8, GetDimsFunc>(symbol_map["get_dims"]) };
236 let set_inputs =
237 unsafe { std::mem::transmute::<*const u8, SetInputsFunc>(symbol_map["set_inputs"]) };
238 let get_inputs =
239 unsafe { std::mem::transmute::<*const u8, GetInputsFunc>(symbol_map["get_inputs"]) };
240 let barrier_init = symbol_map.get("barrier_init").map(|func_ptr| unsafe {
241 std::mem::transmute::<*const u8, BarrierInitFunc>(*func_ptr)
242 });
243 let set_constants = unsafe {
244 std::mem::transmute::<*const u8, SetConstantsFunc>(symbol_map["set_constants"])
245 };
246
247 Ok(Self {
248 set_u0,
249 rhs,
250 mass,
251 calc_out,
252 calc_stop,
253 set_id,
254 get_dims,
255 set_inputs,
256 get_inputs,
257 barrier_init,
258 set_constants,
259 })
260 }
261}
262
263pub(crate) struct JitGradFunctions {
264 pub(crate) set_u0_grad: U0GradFunc,
265 pub(crate) rhs_grad: RhsGradFunc,
266 pub(crate) calc_out_grad: CalcOutGradFunc,
267 pub(crate) set_inputs_grad: SetInputsGradFunc,
268}
269
270impl JitGradFunctions {
271 pub(crate) fn new(symbol_map: &HashMap<String, *const u8>) -> Result<Self> {
272 let required_symbols = [
274 "set_u0_grad",
275 "rhs_grad",
276 "calc_out_grad",
277 "set_inputs_grad",
278 ];
279 for symbol in &required_symbols {
280 if !symbol_map.contains_key(*symbol) {
281 return Err(anyhow!("Missing required symbol: {}", symbol));
282 }
283 }
284 let set_u0_grad =
285 unsafe { std::mem::transmute::<*const u8, U0GradFunc>(symbol_map["set_u0_grad"]) };
286 let rhs_grad =
287 unsafe { std::mem::transmute::<*const u8, RhsGradFunc>(symbol_map["rhs_grad"]) };
288 let calc_out_grad = unsafe {
289 std::mem::transmute::<*const u8, CalcOutGradFunc>(symbol_map["calc_out_grad"])
290 };
291 let set_inputs_grad = unsafe {
292 std::mem::transmute::<*const u8, SetInputsGradFunc>(symbol_map["set_inputs_grad"])
293 };
294
295 Ok(Self {
296 set_u0_grad,
297 rhs_grad,
298 calc_out_grad,
299 set_inputs_grad,
300 })
301 }
302}
303
304pub(crate) struct JitGradRFunctions {
305 pub(crate) set_u0_rgrad: U0RevGradFunc,
306 pub(crate) rhs_rgrad: RhsRevGradFunc,
307 pub(crate) mass_rgrad: MassRevGradFunc,
308 pub(crate) calc_out_rgrad: CalcOutRevGradFunc,
309 pub(crate) set_inputs_rgrad: SetInputsRevGradFunc,
310}
311
312impl JitGradRFunctions {
313 pub(crate) fn new(symbol_map: &HashMap<String, *const u8>) -> Result<Self> {
314 let required_symbols = [
315 "set_u0_rgrad",
316 "rhs_rgrad",
317 "mass_rgrad",
318 "calc_out_rgrad",
319 "set_inputs_rgrad",
320 ];
321 for symbol in &required_symbols {
322 if !symbol_map.contains_key(*symbol) {
323 return Err(anyhow!("Missing required symbol: {}", symbol));
324 }
325 }
326 let set_u0_rgrad =
327 unsafe { std::mem::transmute::<*const u8, U0RevGradFunc>(symbol_map["set_u0_rgrad"]) };
328 let rhs_rgrad =
329 unsafe { std::mem::transmute::<*const u8, RhsRevGradFunc>(symbol_map["rhs_rgrad"]) };
330 let mass_rgrad =
331 unsafe { std::mem::transmute::<*const u8, MassRevGradFunc>(symbol_map["mass_rgrad"]) };
332 let calc_out_rgrad = unsafe {
333 std::mem::transmute::<*const u8, CalcOutRevGradFunc>(symbol_map["calc_out_rgrad"])
334 };
335 let set_inputs_rgrad = unsafe {
336 std::mem::transmute::<*const u8, SetInputsRevGradFunc>(symbol_map["set_inputs_rgrad"])
337 };
338
339 Ok(Self {
340 set_u0_rgrad,
341 rhs_rgrad,
342 mass_rgrad,
343 calc_out_rgrad,
344 set_inputs_rgrad,
345 })
346 }
347}
348
349pub(crate) struct JitSensGradFunctions {
350 pub(crate) rhs_sgrad: RhsSensGradFunc,
351 pub(crate) calc_out_sgrad: CalcOutSensGradFunc,
352}
353
354impl JitSensGradFunctions {
355 pub(crate) fn new(symbol_map: &HashMap<String, *const u8>) -> Result<Self> {
356 let required_symbols = ["rhs_sgrad", "calc_out_sgrad"];
357 for symbol in &required_symbols {
358 if !symbol_map.contains_key(*symbol) {
359 return Err(anyhow!("Missing required symbol: {}", symbol));
360 }
361 }
362 let rhs_sgrad =
363 unsafe { std::mem::transmute::<*const u8, RhsSensGradFunc>(symbol_map["rhs_sgrad"]) };
364 let calc_out_sgrad = unsafe {
365 std::mem::transmute::<*const u8, CalcOutSensGradFunc>(symbol_map["calc_out_sgrad"])
366 };
367
368 Ok(Self {
369 rhs_sgrad,
370 calc_out_sgrad,
371 })
372 }
373}
374
375pub(crate) struct JitSensRevGradFunctions {
376 pub(crate) rhs_rgrad: RhsSensRevGradFunc,
377 pub(crate) calc_out_rgrad: CalcOutSensRevGradFunc,
378}
379
380impl JitSensRevGradFunctions {
381 pub(crate) fn new(symbol_map: &HashMap<String, *const u8>) -> Result<Self> {
382 let required_symbols = ["rhs_srgrad", "calc_out_srgrad"];
383 for symbol in &required_symbols {
384 if !symbol_map.contains_key(*symbol) {
385 return Err(anyhow!("Missing required symbol: {}", symbol));
386 }
387 }
388 let rhs_rgrad = unsafe {
389 std::mem::transmute::<*const u8, RhsSensRevGradFunc>(symbol_map["rhs_srgrad"])
390 };
391 let calc_out_rgrad = unsafe {
392 std::mem::transmute::<*const u8, CalcOutSensRevGradFunc>(symbol_map["calc_out_srgrad"])
393 };
394
395 Ok(Self {
396 rhs_rgrad,
397 calc_out_rgrad,
398 })
399 }
400}
401
402pub(crate) struct JitGetTensorFunctions {
403 pub(crate) data_map: HashMap<String, GetTensorFunc>,
404 pub(crate) constant_map: HashMap<String, GetConstantFunc>,
405}
406
407impl JitGetTensorFunctions {
408 pub(crate) fn new(symbol_map: &HashMap<String, *const u8>) -> Result<Self> {
409 let mut data_map = HashMap::new();
410 let mut constant_map = HashMap::new();
411 let data_prefix = "get_tensor_";
412 let constant_prefix = "get_constant_";
413 for (name, func_ptr) in symbol_map.iter() {
414 if name.starts_with(data_prefix) {
415 let func = unsafe { std::mem::transmute::<*const u8, GetTensorFunc>(*func_ptr) };
416 data_map.insert(name.strip_prefix(data_prefix).unwrap().to_string(), func);
417 } else if name.starts_with(constant_prefix) {
418 let func = unsafe { std::mem::transmute::<*const u8, GetConstantFunc>(*func_ptr) };
419 constant_map.insert(
420 name.strip_prefix(constant_prefix).unwrap().to_string(),
421 func,
422 );
423 }
424 }
425 Ok(Self {
426 data_map,
427 constant_map,
428 })
429 }
430}