diffsol_c/
solution_wrapper.rs1use std::sync::{Arc, Mutex, MutexGuard};
2
3use diffsol::DiffsolError;
4use serde::{ser::SerializeStruct, Serialize, Serializer};
5
6use crate::{
7 error::DiffsolJsError,
8 host_array::{FromHostArray, HostArray},
9 solution::Solution,
10};
11
12#[derive(Clone)]
13pub struct SolutionWrapper(Arc<Mutex<Box<dyn Solution>>>);
14
15impl SolutionWrapper {
16 pub(crate) fn new(solution: Box<dyn Solution>) -> Self {
17 Self(Arc::new(Mutex::new(solution)))
18 }
19
20 fn guard(&self) -> Result<MutexGuard<'_, Box<dyn Solution>>, DiffsolJsError> {
21 self.0
22 .lock()
23 .map_err(|_| DiffsolError::Other("Solution mutex poisoned".to_string()).into())
24 }
25
26 pub fn get_ys(&self) -> Result<HostArray, DiffsolJsError> {
27 let guard = self.guard()?;
28 Ok(guard.get_ys())
29 }
30
31 pub fn get_ts(&self) -> Result<HostArray, DiffsolJsError> {
32 let guard = self.guard()?;
33 Ok(guard.get_ts())
34 }
35
36 pub fn get_sens(&self) -> Result<Vec<HostArray>, DiffsolJsError> {
37 let guard = self.guard()?;
38 Ok(guard.get_sens())
39 }
40}
41
42impl Serialize for SolutionWrapper {
43 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
44 where
45 S: Serializer,
46 {
47 let ts_host = self.get_ts().map_err(serde::ser::Error::custom)?;
48 let ys_host = self.get_ys().map_err(serde::ser::Error::custom)?;
49
50 let ts = Vec::<f64>::from_host_array(ts_host).map_err(serde::ser::Error::custom)?;
51 let ys = Vec::<Vec<f64>>::from_host_array(ys_host).map_err(serde::ser::Error::custom)?;
52 let sensitivities = self
53 .get_sens()
54 .map_err(serde::ser::Error::custom)?
55 .into_iter()
56 .map(Vec::<Vec<f64>>::from_host_array)
57 .collect::<Result<Vec<_>, _>>()
58 .map_err(serde::ser::Error::custom)?;
59
60 let mut state = serializer.serialize_struct("SolutionWrapper", 3)?;
61 state.serialize_field("ts", &ts)?;
62 state.serialize_field("ys", &ys)?;
63 state.serialize_field("sensitivities", &sensitivities)?;
64 state.end()
65 }
66}