Skip to main content

diffsol_c/
solution_wrapper.rs

1use std::sync::{Arc, Mutex, MutexGuard};
2
3use diffsol::DiffsolError;
4use serde::{Serialize, Serializer, ser::SerializeStruct};
5
6use crate::{
7    error::DiffsolJsError,
8    host_array::{FromHostArray, HostArray},
9    solution::Solution,
10};
11
12#[derive(Clone)]
13pub struct SolutionWrapper(Arc<Mutex<Box<dyn Solution>>>);
14
15impl SolutionWrapper {
16    pub(crate) fn new(solution: Box<dyn Solution>) -> Self {
17        Self(Arc::new(Mutex::new(solution)))
18    }
19
20    fn guard(&self) -> Result<MutexGuard<'_, Box<dyn Solution>>, DiffsolJsError> {
21        self.0
22            .lock()
23            .map_err(|_| DiffsolError::Other("Solution mutex poisoned".to_string()).into())
24    }
25
26    pub fn get_ys(&self) -> Result<HostArray, DiffsolJsError> {
27        let guard = self.guard()?;
28        Ok(guard.get_ys())
29    }
30
31    pub fn get_ts(&self) -> Result<HostArray, DiffsolJsError> {
32        let guard = self.guard()?;
33        Ok(guard.get_ts())
34    }
35
36    pub fn get_sens(&self) -> Result<Vec<HostArray>, DiffsolJsError> {
37        let guard = self.guard()?;
38        Ok(guard.get_sens())
39    }
40}
41
42impl Serialize for SolutionWrapper {
43    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
44    where
45        S: Serializer,
46    {
47        let ts_host = self.get_ts().map_err(serde::ser::Error::custom)?;
48        let ys_host = self.get_ys().map_err(serde::ser::Error::custom)?;
49
50        let ts = Vec::<f64>::from_host_array(ts_host).map_err(serde::ser::Error::custom)?;
51        let ys = Vec::<Vec<f64>>::from_host_array(ys_host).map_err(serde::ser::Error::custom)?;
52        let sensitivities = self
53            .get_sens()
54            .map_err(serde::ser::Error::custom)?
55            .into_iter()
56            .map(Vec::<Vec<f64>>::from_host_array)
57            .collect::<Result<Vec<_>, _>>()
58            .map_err(serde::ser::Error::custom)?;
59
60        let mut state = serializer.serialize_struct("SolutionWrapper", 3)?;
61        state.serialize_field("ts", &ts)?;
62        state.serialize_field("ys", &ys)?;
63        state.serialize_field("sensitivities", &sensitivities)?;
64        state.end()
65    }
66}