Skip to main content

diffsol/op/
closure.rs

1use std::cell::RefCell;
2
3use crate::{
4    find_jacobian_non_zeros, jacobian::JacobianColoring, Matrix, MatrixSparsity, NonLinearOp,
5    NonLinearOpJacobian, Op,
6};
7
8use super::{BuilderOp, OpStatistics, ParameterisedOp};
9
10pub struct Closure<M, F, G>
11where
12    M: Matrix,
13    F: Fn(&M::V, &M::V, M::T, &mut M::V),
14    G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
15{
16    func: F,
17    jacobian_action: G,
18    nstates: usize,
19    nout: usize,
20    nparams: usize,
21    coloring: Option<JacobianColoring<M>>,
22    sparsity: Option<M::Sparsity>,
23    statistics: RefCell<OpStatistics>,
24    ctx: M::C,
25}
26
27impl<M, F, G> Closure<M, F, G>
28where
29    M: Matrix,
30    F: Fn(&M::V, &M::V, M::T, &mut M::V),
31    G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
32{
33    pub fn new(
34        func: F,
35        jacobian_action: G,
36        nstates: usize,
37        nout: usize,
38        nparams: usize,
39        ctx: M::C,
40    ) -> Self {
41        Self {
42            func,
43            jacobian_action,
44            nstates,
45            nparams,
46            nout,
47            statistics: RefCell::new(OpStatistics::default()),
48            coloring: None,
49            sparsity: None,
50            ctx,
51        }
52    }
53    pub fn calculate_sparsity(&mut self, y0: &M::V, t0: M::T, p: &M::V) {
54        let param_op = ParameterisedOp { op: self, p };
55        let non_zeros = find_jacobian_non_zeros(&param_op, y0, t0);
56        self.sparsity = Some(
57            MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone())
58                .expect("invalid sparsity pattern"),
59        );
60        self.coloring = Some(JacobianColoring::new(
61            self.sparsity.as_ref().unwrap(),
62            &non_zeros,
63            self.ctx.clone(),
64        ));
65    }
66}
67
68impl<M, F, G> BuilderOp for Closure<M, F, G>
69where
70    M: Matrix,
71    F: Fn(&M::V, &M::V, M::T, &mut M::V),
72    G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
73{
74    fn calculate_sparsity(&mut self, y0: &M::V, t0: M::T, p: &M::V) {
75        self.calculate_sparsity(y0, t0, p);
76    }
77
78    fn set_nstates(&mut self, nstates: usize) {
79        self.nstates = nstates;
80    }
81    fn set_nout(&mut self, nout: usize) {
82        self.nout = nout;
83    }
84    fn set_nparams(&mut self, nparams: usize) {
85        self.nparams = nparams;
86    }
87}
88
89impl<M, F, G> Op for Closure<M, F, G>
90where
91    M: Matrix,
92    F: Fn(&M::V, &M::V, M::T, &mut M::V),
93    G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
94{
95    type V = M::V;
96    type T = M::T;
97    type M = M;
98    type C = M::C;
99
100    fn context(&self) -> &Self::C {
101        &self.ctx
102    }
103
104    fn nstates(&self) -> usize {
105        self.nstates
106    }
107    fn nout(&self) -> usize {
108        self.nout
109    }
110    fn nparams(&self) -> usize {
111        self.nparams
112    }
113    fn statistics(&self) -> OpStatistics {
114        self.statistics.borrow().clone()
115    }
116}
117
118impl<M, F, G> NonLinearOp for ParameterisedOp<'_, Closure<M, F, G>>
119where
120    M: Matrix,
121    F: Fn(&M::V, &M::V, M::T, &mut M::V),
122    G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
123{
124    fn call_inplace(&self, x: &M::V, t: M::T, y: &mut M::V) {
125        self.op.statistics.borrow_mut().increment_call();
126        (self.op.func)(x, self.p, t, y)
127    }
128}
129
130impl<M, F, G> NonLinearOpJacobian for ParameterisedOp<'_, Closure<M, F, G>>
131where
132    M: Matrix,
133    F: Fn(&M::V, &M::V, M::T, &mut M::V),
134    G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
135{
136    fn jac_mul_inplace(&self, x: &M::V, t: M::T, v: &M::V, y: &mut M::V) {
137        self.op.statistics.borrow_mut().increment_jac_mul();
138        (self.op.jacobian_action)(x, self.p, t, v, y)
139    }
140    fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
141        self.op.statistics.borrow_mut().increment_matrix();
142        if let Some(coloring) = self.op.coloring.as_ref() {
143            coloring.jacobian_inplace(self, x, t, y);
144        } else {
145            self._default_jacobian_inplace(x, t, y);
146        }
147    }
148    fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
149        self.op.sparsity.clone()
150    }
151}