1use std::cell::RefCell;
2
3use crate::{
4 find_matrix_non_zeros, find_transpose_non_zeros, jacobian::JacobianColoring,
5 matrix::sparsity::MatrixSparsity, LinearOp, LinearOpTranspose, Matrix, Op,
6};
7
8use super::{BuilderOp, OpStatistics, ParameterisedOp};
9
10pub struct LinearClosureWithAdjoint<M, F, G>
11where
12 M: Matrix,
13 F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
14 G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
15{
16 func: F,
17 func_adjoint: G,
18 nstates: usize,
19 nout: usize,
20 nparams: usize,
21 coloring: Option<JacobianColoring<M>>,
22 sparsity: Option<M::Sparsity>,
23 coloring_adjoint: Option<JacobianColoring<M>>,
24 sparsity_adjoint: Option<M::Sparsity>,
25 statistics: RefCell<OpStatistics>,
26 ctx: M::C,
27}
28
29impl<M, F, G> LinearClosureWithAdjoint<M, F, G>
30where
31 M: Matrix,
32 F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
33 G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
34{
35 pub fn new(
36 func: F,
37 func_adjoint: G,
38 nstates: usize,
39 nout: usize,
40 nparams: usize,
41 ctx: M::C,
42 ) -> Self {
43 Self {
44 func,
45 func_adjoint,
46 nstates,
47 statistics: RefCell::new(OpStatistics::default()),
48 nout,
49 nparams,
50 coloring: None,
51 sparsity: None,
52 coloring_adjoint: None,
53 sparsity_adjoint: None,
54 ctx,
55 }
56 }
57
58 pub fn calculate_sparsity(&mut self, t0: M::T, p: &M::V) {
59 let op = ParameterisedOp { op: self, p };
60 let non_zeros = find_matrix_non_zeros(&op, t0);
61 self.sparsity = Some(
62 MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone())
63 .expect("invalid sparsity pattern"),
64 );
65 self.coloring = Some(JacobianColoring::new(
66 self.sparsity.as_ref().unwrap(),
67 &non_zeros,
68 self.ctx.clone(),
69 ));
70 }
71 pub fn calculate_adjoint_sparsity(&mut self, t0: M::T, p: &M::V) {
72 let op = ParameterisedOp { op: self, p };
73 let non_zeros = find_transpose_non_zeros(&op, t0);
74 self.sparsity_adjoint = Some(
75 MatrixSparsity::try_from_indices(self.nstates, self.nout, non_zeros.clone())
76 .expect("invalid sparsity pattern"),
77 );
78 self.coloring_adjoint = Some(JacobianColoring::new(
79 self.sparsity_adjoint.as_ref().unwrap(),
80 &non_zeros,
81 self.ctx.clone(),
82 ));
83 }
84}
85
86impl<M, F, G> Op for LinearClosureWithAdjoint<M, F, G>
87where
88 M: Matrix,
89 F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
90 G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
91{
92 type V = M::V;
93 type T = M::T;
94 type M = M;
95 type C = M::C;
96 fn nstates(&self) -> usize {
97 self.nstates
98 }
99 fn nout(&self) -> usize {
100 self.nout
101 }
102 fn nparams(&self) -> usize {
103 self.nparams
104 }
105 fn context(&self) -> &Self::C {
106 &self.ctx
107 }
108
109 fn statistics(&self) -> OpStatistics {
110 self.statistics.borrow().clone()
111 }
112}
113
114impl<M, F, G> BuilderOp for LinearClosureWithAdjoint<M, F, G>
115where
116 M: Matrix,
117 F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
118 G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
119{
120 fn calculate_sparsity(&mut self, _y0: &Self::V, t0: Self::T, p: &Self::V) {
121 self.calculate_sparsity(t0, p);
122 self.calculate_adjoint_sparsity(t0, p);
123 }
124 fn set_nout(&mut self, nout: usize) {
125 self.nout = nout;
126 }
127 fn set_nparams(&mut self, nparams: usize) {
128 self.nparams = nparams;
129 }
130 fn set_nstates(&mut self, nstates: usize) {
131 self.nstates = nstates;
132 }
133}
134
135impl<M, F, G> LinearOp for ParameterisedOp<'_, LinearClosureWithAdjoint<M, F, G>>
136where
137 M: Matrix,
138 F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
139 G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
140{
141 fn gemv_inplace(&self, x: &M::V, t: M::T, beta: M::T, y: &mut M::V) {
142 self.op.statistics.borrow_mut().increment_call();
143 (self.op.func)(x, self.p, t, beta, y)
144 }
145
146 fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) {
147 self.op.statistics.borrow_mut().increment_matrix();
148 if let Some(coloring) = &self.op.coloring {
149 coloring.matrix_inplace(self, t, y);
150 } else {
151 self._default_matrix_inplace(t, y);
152 }
153 }
154 fn sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
155 self.op.sparsity.clone()
156 }
157}
158
159impl<M, F, G> LinearOpTranspose for ParameterisedOp<'_, LinearClosureWithAdjoint<M, F, G>>
160where
161 M: Matrix,
162 F: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
163 G: Fn(&M::V, &M::V, M::T, M::T, &mut M::V),
164{
165 fn gemv_transpose_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) {
166 (self.op.func_adjoint)(x, self.p, t, beta, y)
167 }
168 fn transpose_inplace(&self, t: Self::T, y: &mut Self::M) {
169 if let Some(coloring) = &self.op.coloring_adjoint {
170 coloring.matrix_inplace(self, t, y);
171 } else {
172 self._default_transpose_inplace(t, y);
173 }
174 }
175
176 fn transpose_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
177 self.op.sparsity_adjoint.clone()
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use crate::{
184 context::nalgebra::NalgebraContext, matrix::dense_nalgebra_serial::NalgebraMat,
185 matrix::Matrix, DenseMatrix, LinearOp, LinearOpTranspose, Op, Vector,
186 };
187
188 use super::{super::BuilderOp, LinearClosureWithAdjoint};
189
190 type M = NalgebraMat<f64>;
191 type V = crate::NalgebraVec<f64>;
192
193 fn forward(x: &V, p: &V, _t: f64, beta: f64, y: &mut V) {
194 let out = V::from_vec(
195 vec![
196 p.get_index(0) * x.get_index(0),
197 x.get_index(0) + p.get_index(1) * x.get_index(1),
198 ],
199 NalgebraContext,
200 );
201 y.axpy(1.0, &out, beta);
202 }
203
204 fn adjoint(x: &V, p: &V, _t: f64, beta: f64, y: &mut V) {
205 let out = V::from_vec(
206 vec![
207 p.get_index(0) * x.get_index(0) + x.get_index(1),
208 p.get_index(1) * x.get_index(1),
209 ],
210 NalgebraContext,
211 );
212 y.axpy(1.0, &out, beta);
213 }
214
215 type TestFn = fn(&V, &V, f64, f64, &mut V);
216
217 fn make_op() -> LinearClosureWithAdjoint<M, TestFn, TestFn> {
218 LinearClosureWithAdjoint::new(forward, adjoint, 2, 2, 2, NalgebraContext)
219 }
220
221 #[test]
222 fn linear_closure_with_adjoint_builds_matrices_and_tracks_statistics() {
223 let mut op = make_op();
224 op.set_nstates(2);
225 op.set_nout(2);
226 op.set_nparams(2);
227
228 let y0 = V::from_vec(vec![1.0, 1.0], NalgebraContext);
229 let p = V::from_vec(vec![2.0, 3.0], NalgebraContext);
230 BuilderOp::calculate_sparsity(&mut op, &y0, 0.0, &p);
231
232 assert_eq!(op.nstates(), 2);
233 assert_eq!(op.nout(), 2);
234 assert_eq!(op.nparams(), 2);
235
236 let pop = crate::ParameterisedOp::new(&op, &p);
237 let matrix = pop.matrix(0.0);
238 assert_eq!(matrix.get_index(0, 0), 2.0);
239 assert_eq!(matrix.get_index(1, 0), 1.0);
240 assert_eq!(matrix.get_index(0, 1), 0.0);
241 assert_eq!(matrix.get_index(1, 1), 3.0);
242 assert!(pop.sparsity().is_some());
243
244 let mut transpose = M::zeros(2, 2, NalgebraContext);
245 pop.transpose_inplace(0.0, &mut transpose);
246 assert_eq!(transpose.get_index(0, 0), 2.0);
247 assert_eq!(transpose.get_index(1, 0), 0.0);
248 assert_eq!(transpose.get_index(0, 1), 0.0);
249 assert_eq!(transpose.get_index(1, 1), 3.0);
250 assert!(pop.transpose_sparsity().is_some());
251
252 let x = V::from_vec(vec![4.0, 5.0], NalgebraContext);
253 let mut y = V::from_vec(vec![1.0, 1.0], NalgebraContext);
254 pop.gemv_inplace(&x, 0.0, 0.5, &mut y);
255 y.assert_eq_st(&V::from_vec(vec![8.5, 19.5], NalgebraContext), 1e-12);
256
257 let stats = pop.statistics();
258 assert!(stats.number_of_calls >= 1);
259 assert!(stats.number_of_matrix_evals >= 1);
260 }
261}