Skip to main content

diffsol/op/
closure_with_sens.rs

1use std::cell::RefCell;
2
3use crate::{
4    jacobian::{find_jacobian_non_zeros, find_sens_non_zeros, JacobianColoring},
5    Matrix, MatrixSparsity, NonLinearOp, NonLinearOpJacobian, NonLinearOpSens, Op, Vector,
6};
7
8use super::{BuilderOp, OpStatistics, ParameterisedOp};
9
10pub struct ClosureWithSens<M, F, G, H>
11where
12    M: Matrix,
13{
14    func: F,
15    jacobian_action: G,
16    sens_action: H,
17    nstates: usize,
18    nparams: usize,
19    nout: usize,
20    coloring: Option<JacobianColoring<M>>,
21    sens_coloring: Option<JacobianColoring<M>>,
22    sparsity: Option<M::Sparsity>,
23    sens_sparsity: Option<M::Sparsity>,
24    statistics: RefCell<OpStatistics>,
25    ctx: M::C,
26}
27
28impl<M, F, G, H> ClosureWithSens<M, F, G, H>
29where
30    M: Matrix,
31    F: Fn(&M::V, &M::V, M::T, &mut M::V),
32    G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
33    H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
34{
35    pub fn new(
36        func: F,
37        jacobian_action: G,
38        sens_action: H,
39        nstates: usize,
40        nparams: usize,
41        nout: usize,
42        ctx: M::C,
43    ) -> Self {
44        Self {
45            func,
46            jacobian_action,
47            sens_action,
48            nstates,
49            nout,
50            nparams,
51            statistics: RefCell::new(OpStatistics::default()),
52            coloring: None,
53            sparsity: None,
54            sens_coloring: None,
55            sens_sparsity: None,
56            ctx,
57        }
58    }
59
60    pub fn calculate_jacobian_sparsity(&mut self, y0: &M::V, t0: M::T, p: &M::V) {
61        let op = ParameterisedOp { op: self, p };
62        let non_zeros = find_jacobian_non_zeros(&op, y0, t0);
63        self.sparsity = Some(
64            MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone())
65                .expect("invalid sparsity pattern"),
66        );
67        self.coloring = Some(JacobianColoring::new(
68            self.sparsity.as_ref().unwrap(),
69            &non_zeros,
70            self.ctx.clone(),
71        ));
72    }
73    pub fn calculate_sens_sparsity(&mut self, y0: &M::V, t0: M::T, p: &M::V) {
74        let op = ParameterisedOp { op: self, p };
75        let non_zeros = find_sens_non_zeros(&op, y0, t0);
76        let nparams = p.len();
77        self.sens_sparsity = Some(
78            MatrixSparsity::try_from_indices(self.nout(), nparams, non_zeros.clone())
79                .expect("invalid sparsity pattern"),
80        );
81        self.sens_coloring = Some(JacobianColoring::new(
82            self.sens_sparsity.as_ref().unwrap(),
83            &non_zeros,
84            self.ctx.clone(),
85        ));
86    }
87}
88
89impl<M, F, G, H> BuilderOp for ClosureWithSens<M, F, G, H>
90where
91    M: Matrix,
92    F: Fn(&M::V, &M::V, M::T, &mut M::V),
93    G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
94    H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
95{
96    fn set_nstates(&mut self, nstates: usize) {
97        self.nstates = nstates;
98    }
99    fn set_nout(&mut self, nout: usize) {
100        self.nout = nout;
101    }
102    fn set_nparams(&mut self, nparams: usize) {
103        self.nparams = nparams;
104    }
105
106    fn calculate_sparsity(&mut self, y0: &Self::V, t0: Self::T, p: &Self::V) {
107        self.calculate_jacobian_sparsity(y0, t0, p);
108        self.calculate_sens_sparsity(y0, t0, p);
109    }
110}
111
112impl<M, F, G, H> Op for ClosureWithSens<M, F, G, H>
113where
114    M: Matrix,
115{
116    type V = M::V;
117    type T = M::T;
118    type M = M;
119    type C = M::C;
120    fn nstates(&self) -> usize {
121        self.nstates
122    }
123    fn nout(&self) -> usize {
124        self.nout
125    }
126    fn nparams(&self) -> usize {
127        self.nparams
128    }
129    fn statistics(&self) -> OpStatistics {
130        self.statistics.borrow().clone()
131    }
132    fn context(&self) -> &Self::C {
133        &self.ctx
134    }
135}
136
137impl<M, F, G, H> NonLinearOp for ParameterisedOp<'_, ClosureWithSens<M, F, G, H>>
138where
139    M: Matrix,
140    F: Fn(&M::V, &M::V, M::T, &mut M::V),
141    G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
142    H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
143{
144    fn call_inplace(&self, x: &M::V, t: M::T, y: &mut M::V) {
145        self.op.statistics.borrow_mut().increment_call();
146        (self.op.func)(x, self.p, t, y)
147    }
148}
149
150impl<M, F, G, H> NonLinearOpJacobian for ParameterisedOp<'_, ClosureWithSens<M, F, G, H>>
151where
152    M: Matrix,
153    F: Fn(&M::V, &M::V, M::T, &mut M::V),
154    G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
155    H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
156{
157    fn jac_mul_inplace(&self, x: &M::V, t: M::T, v: &M::V, y: &mut M::V) {
158        self.op.statistics.borrow_mut().increment_jac_mul();
159        (self.op.jacobian_action)(x, self.p, t, v, y)
160    }
161    fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
162        self.op.statistics.borrow_mut().increment_matrix();
163        if let Some(coloring) = self.op.coloring.as_ref() {
164            coloring.jacobian_inplace(self, x, t, y);
165        } else {
166            self._default_jacobian_inplace(x, t, y);
167        }
168    }
169    fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
170        self.op.sparsity.clone()
171    }
172}
173
174impl<M, F, G, H> NonLinearOpSens for ParameterisedOp<'_, ClosureWithSens<M, F, G, H>>
175where
176    M: Matrix,
177    F: Fn(&M::V, &M::V, M::T, &mut M::V),
178    G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
179    H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
180{
181    fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
182        (self.op.sens_action)(x, self.p, t, v, y);
183    }
184
185    fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
186        if let Some(coloring) = self.op.sens_coloring.as_ref() {
187            coloring.sens_inplace(self, x, t, y);
188        } else {
189            self._default_sens_inplace(x, t, y);
190        }
191    }
192    fn sens_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
193        self.op.sens_sparsity.clone()
194    }
195}