Skip to main content

diffsol/op/
mod.rs

1use crate::{
2    ConstantOp, ConstantOpSens, ConstantOpSensAdjoint, Context, LinearOp, LinearOpTranspose,
3    Matrix, NonLinearOp, NonLinearOpAdjoint, NonLinearOpSens, NonLinearOpSensAdjoint, Scalar,
4    Vector,
5};
6
7use nonlinear_op::NonLinearOpJacobian;
8use serde::Serialize;
9
10pub mod bdf;
11pub mod closure;
12pub mod closure_no_jac;
13pub mod closure_with_adjoint;
14pub mod closure_with_sens;
15pub mod constant_closure;
16pub mod constant_closure_with_adjoint;
17pub mod constant_closure_with_sens;
18pub mod constant_op;
19pub mod init;
20pub mod linear_closure;
21pub mod linear_closure_with_adjoint;
22pub mod linear_op;
23pub mod linearise;
24pub mod matrix;
25pub mod nonlinear_op;
26pub mod sdirk;
27pub mod stoch;
28pub mod unit;
29
30/// A generic operator trait.
31///
32/// Op is a trait for operators that, given a paramter vector `p`, operates on an input vector `x` to produce an output vector `y`.
33/// It defines the number of states (i.e. length of `x`), the number of outputs (i.e. length of `y`), and number of parameters (i.e. length of `p`) of the operator.
34/// It also defines the type of the scalar, vector, and matrices used in the operator.
35pub trait Op {
36    type T: Scalar;
37    type V: Vector<T = Self::T, C = Self::C>;
38    type M: Matrix<T = Self::T, V = Self::V, C = Self::C>;
39    type C: Context;
40
41    /// return the context of the operator
42    fn context(&self) -> &Self::C;
43
44    /// Return the number of input states of the operator.
45    fn nstates(&self) -> usize;
46
47    /// Return the number of outputs of the operator.
48    fn nout(&self) -> usize;
49
50    /// Return the number of parameters of the operator.
51    fn nparams(&self) -> usize;
52
53    /// Return statistics about the operator (e.g. how many times it was called, how many times the jacobian was computed, etc.)
54    fn statistics(&self) -> OpStatistics {
55        OpStatistics::default()
56    }
57}
58
59/// A wrapper for an operator that parameterises it with a parameter vector.
60pub struct ParameterisedOp<'a, C: Op> {
61    pub op: &'a C,
62    pub p: &'a C::V,
63}
64
65impl<'a, C: Op> ParameterisedOp<'a, C> {
66    pub fn new(op: &'a C, p: &'a C::V) -> Self {
67        Self { op, p }
68    }
69}
70
71/// trait interface for operators used in the [builder pattern](crate::OdeBuilder)
72pub trait BuilderOp: Op {
73    fn set_nstates(&mut self, nstates: usize);
74    fn set_nparams(&mut self, nparams: usize);
75    fn set_nout(&mut self, nout: usize);
76    fn calculate_sparsity(&mut self, y0: &Self::V, t0: Self::T, p: &Self::V);
77}
78
79impl<C: Op> Op for ParameterisedOp<'_, C> {
80    type V = C::V;
81    type T = C::T;
82    type M = C::M;
83    type C = C::C;
84    fn nstates(&self) -> usize {
85        self.op.nstates()
86    }
87    fn nout(&self) -> usize {
88        self.op.nout()
89    }
90    fn nparams(&self) -> usize {
91        self.op.nparams()
92    }
93    fn statistics(&self) -> OpStatistics {
94        self.op.statistics()
95    }
96    fn context(&self) -> &Self::C {
97        self.op.context()
98    }
99}
100
101/// Useful statistics about an operator.
102#[derive(Default, Clone, Serialize, Debug)]
103pub struct OpStatistics {
104    /// number of times the operator was called
105    pub number_of_calls: usize,
106    /// number of times the jacobian-vector product was computed
107    pub number_of_jac_muls: usize,
108    /// number of times the jacobian matrix was evaluated
109    pub number_of_matrix_evals: usize,
110    /// number of times the adjoint jacobian-vector product was computed
111    pub number_of_jac_adj_muls: usize,
112}
113
114impl OpStatistics {
115    pub fn new() -> Self {
116        Self {
117            number_of_jac_muls: 0,
118            number_of_calls: 0,
119            number_of_matrix_evals: 0,
120            number_of_jac_adj_muls: 0,
121        }
122    }
123
124    pub fn increment_call(&mut self) {
125        self.number_of_calls += 1;
126    }
127
128    pub fn increment_jac_mul(&mut self) {
129        self.number_of_jac_muls += 1;
130    }
131
132    pub fn increment_jac_adj_mul(&mut self) {
133        self.number_of_jac_adj_muls += 1;
134    }
135
136    pub fn increment_matrix(&mut self) {
137        self.number_of_matrix_evals += 1;
138    }
139}
140
141impl<C: Op> Op for &C {
142    type T = C::T;
143    type V = C::V;
144    type M = C::M;
145    type C = C::C;
146    fn nstates(&self) -> usize {
147        C::nstates(*self)
148    }
149    fn nout(&self) -> usize {
150        C::nout(*self)
151    }
152    fn nparams(&self) -> usize {
153        C::nparams(*self)
154    }
155    fn statistics(&self) -> OpStatistics {
156        C::statistics(*self)
157    }
158    fn context(&self) -> &Self::C {
159        C::context(*self)
160    }
161}
162
163impl<C: Op> Op for &mut C {
164    type T = C::T;
165    type V = C::V;
166    type M = C::M;
167    type C = C::C;
168    fn nstates(&self) -> usize {
169        C::nstates(*self)
170    }
171    fn nout(&self) -> usize {
172        C::nout(*self)
173    }
174    fn nparams(&self) -> usize {
175        C::nparams(*self)
176    }
177    fn statistics(&self) -> OpStatistics {
178        C::statistics(*self)
179    }
180    fn context(&self) -> &Self::C {
181        C::context(*self)
182    }
183}
184
185impl<C: NonLinearOp> NonLinearOp for &C {
186    fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
187        C::call_inplace(*self, x, t, y)
188    }
189}
190
191impl<C: NonLinearOpJacobian> NonLinearOpJacobian for &C {
192    fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
193        C::jac_mul_inplace(*self, x, t, v, y)
194    }
195    fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
196        C::jacobian_inplace(*self, x, t, y)
197    }
198    fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
199        C::jacobian_sparsity(*self)
200    }
201}
202
203impl<C: NonLinearOpAdjoint> NonLinearOpAdjoint for &C {
204    fn adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
205        C::adjoint_inplace(*self, x, t, y)
206    }
207    fn adjoint_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
208        C::adjoint_sparsity(*self)
209    }
210    fn jac_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
211        C::jac_transpose_mul_inplace(*self, x, t, v, y)
212    }
213}
214
215impl<C: NonLinearOpSens> NonLinearOpSens for &C {
216    fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
217        C::sens_mul_inplace(*self, x, t, v, y)
218    }
219    fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
220        C::sens_inplace(*self, x, t, y)
221    }
222
223    fn sens_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
224        C::sens_sparsity(*self)
225    }
226}
227
228impl<C: NonLinearOpSensAdjoint> NonLinearOpSensAdjoint for &C {
229    fn sens_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
230        C::sens_transpose_mul_inplace(*self, x, t, v, y)
231    }
232    fn sens_adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
233        C::sens_adjoint_inplace(*self, x, t, y)
234    }
235    fn sens_adjoint_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
236        C::sens_adjoint_sparsity(*self)
237    }
238}
239
240impl<C: LinearOp> LinearOp for &C {
241    fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) {
242        C::gemv_inplace(*self, x, t, beta, y)
243    }
244    fn sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
245        C::sparsity(*self)
246    }
247    fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) {
248        C::matrix_inplace(*self, t, y)
249    }
250}
251
252impl<C: LinearOpTranspose> LinearOpTranspose for &C {
253    fn gemv_transpose_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) {
254        C::gemv_transpose_inplace(*self, x, t, beta, y)
255    }
256    fn transpose_inplace(&self, t: Self::T, y: &mut Self::M) {
257        C::transpose_inplace(*self, t, y)
258    }
259    fn transpose_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
260        C::transpose_sparsity(*self)
261    }
262}
263
264impl<C: ConstantOp> ConstantOp for &C {
265    fn call_inplace(&self, t: Self::T, y: &mut Self::V) {
266        C::call_inplace(*self, t, y)
267    }
268}
269
270impl<C: ConstantOpSens> ConstantOpSens for &C {
271    fn sens_mul_inplace(&self, t: Self::T, v: &Self::V, y: &mut Self::V) {
272        C::sens_mul_inplace(*self, t, v, y)
273    }
274    fn sens_inplace(&self, t: Self::T, y: &mut Self::M) {
275        C::sens_inplace(*self, t, y)
276    }
277    fn sens_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
278        C::sens_sparsity(*self)
279    }
280}
281
282impl<C: ConstantOpSensAdjoint> ConstantOpSensAdjoint for &C {
283    fn sens_transpose_mul_inplace(&self, t: Self::T, v: &Self::V, y: &mut Self::V) {
284        C::sens_transpose_mul_inplace(*self, t, v, y)
285    }
286    fn sens_adjoint_inplace(&self, t: Self::T, y: &mut Self::M) {
287        C::sens_adjoint_inplace(*self, t, y)
288    }
289    fn sens_adjoint_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
290        C::sens_adjoint_sparsity(*self)
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use std::cell::RefCell;
297
298    use crate::{
299        context::nalgebra::NalgebraContext, matrix::dense_nalgebra_serial::NalgebraMat, ConstantOp,
300        ConstantOpSens, ConstantOpSensAdjoint, LinearOp, LinearOpTranspose, NonLinearOp,
301        NonLinearOpAdjoint, NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, Vector,
302    };
303
304    use super::{Op, OpStatistics, ParameterisedOp};
305
306    type M = NalgebraMat<f64>;
307
308    struct ForwardingOp {
309        ctx: NalgebraContext,
310        stats: RefCell<OpStatistics>,
311    }
312
313    impl ForwardingOp {
314        fn new() -> Self {
315            Self {
316                ctx: NalgebraContext,
317                stats: RefCell::new(OpStatistics::new()),
318            }
319        }
320    }
321
322    impl Op for ForwardingOp {
323        type T = f64;
324        type V = crate::NalgebraVec<f64>;
325        type M = M;
326        type C = NalgebraContext;
327
328        fn context(&self) -> &Self::C {
329            &self.ctx
330        }
331        fn nstates(&self) -> usize {
332            2
333        }
334        fn nout(&self) -> usize {
335            2
336        }
337        fn nparams(&self) -> usize {
338            2
339        }
340        fn statistics(&self) -> OpStatistics {
341            self.stats.borrow().clone()
342        }
343    }
344
345    impl NonLinearOp for ForwardingOp {
346        fn call_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::V) {
347            self.stats.borrow_mut().increment_call();
348            y.copy_from(x);
349        }
350    }
351
352    impl NonLinearOpJacobian for ForwardingOp {
353        fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) {
354            self.stats.borrow_mut().increment_jac_mul();
355            y.copy_from(v);
356        }
357    }
358
359    impl NonLinearOpAdjoint for ForwardingOp {
360        fn jac_transpose_mul_inplace(
361            &self,
362            _x: &Self::V,
363            _t: Self::T,
364            v: &Self::V,
365            y: &mut Self::V,
366        ) {
367            self.stats.borrow_mut().increment_jac_adj_mul();
368            y.copy_from(v);
369        }
370    }
371
372    impl NonLinearOpSens for ForwardingOp {
373        fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
374            y.fill(0.0);
375        }
376    }
377
378    impl NonLinearOpSensAdjoint for ForwardingOp {
379        fn sens_transpose_mul_inplace(
380            &self,
381            _x: &Self::V,
382            _t: Self::T,
383            _v: &Self::V,
384            y: &mut Self::V,
385        ) {
386            y.fill(0.0);
387        }
388    }
389
390    impl LinearOp for ForwardingOp {
391        fn gemv_inplace(&self, x: &Self::V, _t: Self::T, beta: Self::T, y: &mut Self::V) {
392            self.stats.borrow_mut().increment_call();
393            y.axpy(1.0, x, beta);
394        }
395    }
396
397    impl LinearOpTranspose for ForwardingOp {
398        fn gemv_transpose_inplace(&self, x: &Self::V, _t: Self::T, beta: Self::T, y: &mut Self::V) {
399            self.stats.borrow_mut().increment_jac_adj_mul();
400            y.axpy(1.0, x, beta);
401        }
402    }
403
404    impl ConstantOp for ForwardingOp {
405        fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
406            self.stats.borrow_mut().increment_call();
407            y.copy_from(&Self::V::from_vec(vec![1.0, 2.0], self.ctx));
408        }
409    }
410
411    impl ConstantOpSens for ForwardingOp {
412        fn sens_mul_inplace(&self, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
413            y.fill(0.0);
414        }
415    }
416
417    impl ConstantOpSensAdjoint for ForwardingOp {
418        fn sens_transpose_mul_inplace(&self, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
419            y.fill(0.0);
420        }
421    }
422
423    #[test]
424    fn op_statistics_increment_methods_update_counters() {
425        let mut stats = OpStatistics::new();
426        stats.increment_call();
427        stats.increment_jac_mul();
428        stats.increment_jac_adj_mul();
429        stats.increment_matrix();
430        assert_eq!(stats.number_of_calls, 1);
431        assert_eq!(stats.number_of_jac_muls, 1);
432        assert_eq!(stats.number_of_jac_adj_muls, 1);
433        assert_eq!(stats.number_of_matrix_evals, 1);
434    }
435
436    #[test]
437    fn parameterised_op_and_reference_forwarding_delegate_to_inner_operator() {
438        let op = ForwardingOp::new();
439        let p = crate::NalgebraVec::from_vec(vec![1.0, 2.0], NalgebraContext);
440        let pop = ParameterisedOp::new(&op, &p);
441        assert_eq!(pop.nstates(), 2);
442        assert_eq!(pop.nout(), 2);
443        assert_eq!(pop.nparams(), 2);
444
445        let x = crate::NalgebraVec::from_vec(vec![3.0, 4.0], NalgebraContext);
446        let mut y = crate::NalgebraVec::zeros(2, NalgebraContext);
447        NonLinearOp::call_inplace(&&op, &x, 0.0, &mut y);
448        y.assert_eq_st(&x, 1e-12);
449
450        op.jac_mul_inplace(&x, 0.0, &x, &mut y);
451        y.assert_eq_st(&x, 1e-12);
452
453        op.jac_transpose_mul_inplace(&x, 0.0, &x, &mut y);
454        y.assert_eq_st(&x, 1e-12);
455
456        NonLinearOpSens::sens_mul_inplace(&&op, &x, 0.0, &x, &mut y);
457        y.assert_eq_st(&crate::NalgebraVec::zeros(2, NalgebraContext), 1e-12);
458
459        NonLinearOpSensAdjoint::sens_transpose_mul_inplace(&&op, &x, 0.0, &x, &mut y);
460        y.assert_eq_st(&crate::NalgebraVec::zeros(2, NalgebraContext), 1e-12);
461
462        op.gemv_inplace(&x, 0.0, 0.0, &mut y);
463        y.assert_eq_st(&x, 1e-12);
464
465        op.gemv_transpose_inplace(&x, 0.0, 0.0, &mut y);
466        y.assert_eq_st(&x, 1e-12);
467
468        let mut y_const = crate::NalgebraVec::zeros(2, NalgebraContext);
469        <&ForwardingOp as ConstantOp>::call_inplace(&&op, 0.0, &mut y_const);
470        y_const.assert_eq_st(
471            &crate::NalgebraVec::from_vec(vec![1.0, 2.0], NalgebraContext),
472            1e-12,
473        );
474
475        let op_ref_stats = pop.statistics();
476        assert!(op_ref_stats.number_of_calls >= 1);
477
478        let op_mut = ForwardingOp::new();
479        assert_eq!(op_mut.nstates(), 2);
480        assert_eq!(op_mut.nout(), 2);
481        assert_eq!(op_mut.nparams(), 2);
482    }
483}