Skip to main content

diffsol/op/
linear_closure_with_adjoint.rs

1use std::cell::RefCell;
2
3use crate::{
4    find_matrix_non_zeros, find_transpose_non_zeros, jacobian::JacobianColoring,
5    matrix::sparsity::MatrixSparsity, LinearOp, LinearOpTranspose, Matrix, Op,
6};
7
8use super::{BuilderOp, OpStatistics, ParameterisedOp};
9
10pub struct LinearClosureWithAdjoint<M, F, G>
11where
12    M: Matrix,
13    F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
14    G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
15{
16    func: F,
17    func_adjoint: G,
18    nstates: usize,
19    nout: usize,
20    nparams: usize,
21    coloring: Option<JacobianColoring<M>>,
22    sparsity: Option<M::Sparsity>,
23    coloring_adjoint: Option<JacobianColoring<M>>,
24    sparsity_adjoint: Option<M::Sparsity>,
25    statistics: RefCell<OpStatistics>,
26    ctx: M::C,
27}
28
29impl<M, F, G> LinearClosureWithAdjoint<M, F, G>
30where
31    M: Matrix,
32    F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
33    G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
34{
35    pub fn new(
36        func: F,
37        func_adjoint: G,
38        nstates: usize,
39        nout: usize,
40        nparams: usize,
41        ctx: M::C,
42    ) -> Self {
43        Self {
44            func,
45            func_adjoint,
46            nstates,
47            statistics: RefCell::new(OpStatistics::default()),
48            nout,
49            nparams,
50            coloring: None,
51            sparsity: None,
52            coloring_adjoint: None,
53            sparsity_adjoint: None,
54            ctx,
55        }
56    }
57
58    pub fn calculate_sparsity(&mut self, t0: M::T, p: &M::V) {
59        let op = ParameterisedOp { op: self, p };
60        let non_zeros = find_matrix_non_zeros(&op, t0);
61        self.sparsity = Some(
62            MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone())
63                .expect("invalid sparsity pattern"),
64        );
65        self.coloring = Some(JacobianColoring::new(
66            self.sparsity.as_ref().unwrap(),
67            &non_zeros,
68            self.ctx.clone(),
69        ));
70    }
71    pub fn calculate_adjoint_sparsity(&mut self, t0: M::T, p: &M::V) {
72        let op = ParameterisedOp { op: self, p };
73        let non_zeros = find_transpose_non_zeros(&op, t0);
74        self.sparsity_adjoint = Some(
75            MatrixSparsity::try_from_indices(self.nstates, self.nout, non_zeros.clone())
76                .expect("invalid sparsity pattern"),
77        );
78        self.coloring_adjoint = Some(JacobianColoring::new(
79            self.sparsity_adjoint.as_ref().unwrap(),
80            &non_zeros,
81            self.ctx.clone(),
82        ));
83    }
84}
85
86impl<M, F, G> Op for LinearClosureWithAdjoint<M, F, G>
87where
88    M: Matrix,
89    F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
90    G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
91{
92    type V = M::V;
93    type T = M::T;
94    type M = M;
95    type C = M::C;
96    fn nstates(&self) -> usize {
97        self.nstates
98    }
99    fn nout(&self) -> usize {
100        self.nout
101    }
102    fn nparams(&self) -> usize {
103        self.nparams
104    }
105    fn context(&self) -> &Self::C {
106        &self.ctx
107    }
108
109    fn statistics(&self) -> OpStatistics {
110        self.statistics.borrow().clone()
111    }
112}
113
114impl<M, F, G> BuilderOp for LinearClosureWithAdjoint<M, F, G>
115where
116    M: Matrix,
117    F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
118    G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
119{
120    fn calculate_sparsity(&mut self, _y0: &Self::V, t0: Self::T, p: &Self::V) {
121        self.calculate_sparsity(t0, p);
122        self.calculate_adjoint_sparsity(t0, p);
123    }
124    fn set_nout(&mut self, nout: usize) {
125        self.nout = nout;
126    }
127    fn set_nparams(&mut self, nparams: usize) {
128        self.nparams = nparams;
129    }
130    fn set_nstates(&mut self, nstates: usize) {
131        self.nstates = nstates;
132    }
133}
134
135impl<M, F, G> LinearOp for ParameterisedOp<'_, LinearClosureWithAdjoint<M, F, G>>
136where
137    M: Matrix,
138    F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
139    G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
140{
141    fn gemv_inplace(&self, x: &M::V, t: M::T, beta: M::T, y: &mut M::V) {
142        self.op.statistics.borrow_mut().increment_call();
143        (self.op.func)(x, self.p, t, beta, y)
144    }
145
146    fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) {
147        self.op.statistics.borrow_mut().increment_matrix();
148        if let Some(coloring) = &self.op.coloring {
149            coloring.matrix_inplace(self, t, y);
150        } else {
151            self._default_matrix_inplace(t, y);
152        }
153    }
154    fn sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
155        self.op.sparsity.clone()
156    }
157}
158
159impl<M, F, G> LinearOpTranspose for ParameterisedOp<'_, LinearClosureWithAdjoint<M, F, G>>
160where
161    M: Matrix,
162    F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
163    G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
164{
165    fn gemv_transpose_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) {
166        (self.op.func_adjoint)(x, self.p, t, beta, y)
167    }
168    fn transpose_inplace(&self, t: Self::T, y: &mut Self::M) {
169        if let Some(coloring) = &self.op.coloring_adjoint {
170            coloring.matrix_inplace(self, t, y);
171        } else {
172            self._default_transpose_inplace(t, y);
173        }
174    }
175
176    fn transpose_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
177        self.op.sparsity_adjoint.clone()
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use crate::{
184        context::nalgebra::NalgebraContext, matrix::dense_nalgebra_serial::NalgebraMat,
185        matrix::Matrix, DenseMatrix, LinearOp, LinearOpTranspose, Op, Vector,
186    };
187
188    use super::{super::BuilderOp, LinearClosureWithAdjoint};
189
190    type M = NalgebraMat<f64>;
191    type V = crate::NalgebraVec<f64>;
192
193    fn forward(x: &V, p: &V, _t: f64, beta: f64, y: &mut V) {
194        let out = V::from_vec(
195            vec![
196                p.get_index(0) * x.get_index(0),
197                x.get_index(0) + p.get_index(1) * x.get_index(1),
198            ],
199            NalgebraContext,
200        );
201        y.axpy(1.0, &out, beta);
202    }
203
204    fn adjoint(x: &V, p: &V, _t: f64, beta: f64, y: &mut V) {
205        let out = V::from_vec(
206            vec![
207                p.get_index(0) * x.get_index(0) + x.get_index(1),
208                p.get_index(1) * x.get_index(1),
209            ],
210            NalgebraContext,
211        );
212        y.axpy(1.0, &out, beta);
213    }
214
215    type TestFn = fn(&V, &V, f64, f64, &mut V);
216
217    fn make_op() -> LinearClosureWithAdjoint<M, TestFn, TestFn> {
218        LinearClosureWithAdjoint::new(forward, adjoint, 2, 2, 2, NalgebraContext)
219    }
220
221    #[test]
222    fn linear_closure_with_adjoint_builds_matrices_and_tracks_statistics() {
223        let mut op = make_op();
224        op.set_nstates(2);
225        op.set_nout(2);
226        op.set_nparams(2);
227
228        let y0 = V::from_vec(vec![1.0, 1.0], NalgebraContext);
229        let p = V::from_vec(vec![2.0, 3.0], NalgebraContext);
230        BuilderOp::calculate_sparsity(&mut op, &y0, 0.0, &p);
231
232        assert_eq!(op.nstates(), 2);
233        assert_eq!(op.nout(), 2);
234        assert_eq!(op.nparams(), 2);
235
236        let pop = crate::ParameterisedOp::new(&op, &p);
237        let matrix = pop.matrix(0.0);
238        assert_eq!(matrix.get_index(0, 0), 2.0);
239        assert_eq!(matrix.get_index(1, 0), 1.0);
240        assert_eq!(matrix.get_index(0, 1), 0.0);
241        assert_eq!(matrix.get_index(1, 1), 3.0);
242        assert!(pop.sparsity().is_some());
243
244        let mut transpose = M::zeros(2, 2, NalgebraContext);
245        pop.transpose_inplace(0.0, &mut transpose);
246        assert_eq!(transpose.get_index(0, 0), 2.0);
247        assert_eq!(transpose.get_index(1, 0), 0.0);
248        assert_eq!(transpose.get_index(0, 1), 0.0);
249        assert_eq!(transpose.get_index(1, 1), 3.0);
250        assert!(pop.transpose_sparsity().is_some());
251
252        let x = V::from_vec(vec![4.0, 5.0], NalgebraContext);
253        let mut y = V::from_vec(vec![1.0, 1.0], NalgebraContext);
254        pop.gemv_inplace(&x, 0.0, 0.5, &mut y);
255        y.assert_eq_st(&V::from_vec(vec![8.5, 19.5], NalgebraContext), 1e-12);
256
257        let stats = pop.statistics();
258        assert!(stats.number_of_calls >= 1);
259        assert!(stats.number_of_matrix_evals >= 1);
260    }
261}