1use crate::{
2 ConstantOp, ConstantOpSens, ConstantOpSensAdjoint, Context, LinearOp, LinearOpTranspose,
3 Matrix, NonLinearOp, NonLinearOpAdjoint, NonLinearOpSens, NonLinearOpSensAdjoint, Scalar,
4 Vector,
5};
6
7use nonlinear_op::NonLinearOpJacobian;
8use serde::Serialize;
9
10pub mod bdf;
11pub mod closure;
12pub mod closure_no_jac;
13pub mod closure_with_adjoint;
14pub mod closure_with_sens;
15pub mod constant_closure;
16pub mod constant_closure_with_adjoint;
17pub mod constant_closure_with_sens;
18pub mod constant_op;
19pub mod init;
20pub mod linear_closure;
21pub mod linear_closure_with_adjoint;
22pub mod linear_op;
23pub mod linearise;
24pub mod matrix;
25pub mod nonlinear_op;
26pub mod sdirk;
27pub mod stoch;
28pub mod unit;
29
30pub trait Op {
36 type T: Scalar;
37 type V: Vector<T = Self::T, C = Self::C>;
38 type M: Matrix<T = Self::T, V = Self::V, C = Self::C>;
39 type C: Context;
40
41 fn context(&self) -> &Self::C;
43
44 fn nstates(&self) -> usize;
46
47 fn nout(&self) -> usize;
49
50 fn nparams(&self) -> usize;
52
53 fn statistics(&self) -> OpStatistics {
55 OpStatistics::default()
56 }
57}
58
59pub struct ParameterisedOp<'a, C: Op> {
61 pub op: &'a C,
62 pub p: &'a C::V,
63}
64
65impl<'a, C: Op> ParameterisedOp<'a, C> {
66 pub fn new(op: &'a C, p: &'a C::V) -> Self {
67 Self { op, p }
68 }
69}
70
71pub trait BuilderOp: Op {
73 fn set_nstates(&mut self, nstates: usize);
74 fn set_nparams(&mut self, nparams: usize);
75 fn set_nout(&mut self, nout: usize);
76 fn calculate_sparsity(&mut self, y0: &Self::V, t0: Self::T, p: &Self::V);
77}
78
79impl<C: Op> Op for ParameterisedOp<'_, C> {
80 type V = C::V;
81 type T = C::T;
82 type M = C::M;
83 type C = C::C;
84 fn nstates(&self) -> usize {
85 self.op.nstates()
86 }
87 fn nout(&self) -> usize {
88 self.op.nout()
89 }
90 fn nparams(&self) -> usize {
91 self.op.nparams()
92 }
93 fn statistics(&self) -> OpStatistics {
94 self.op.statistics()
95 }
96 fn context(&self) -> &Self::C {
97 self.op.context()
98 }
99}
100
101#[derive(Default, Clone, Serialize, Debug)]
103pub struct OpStatistics {
104 pub number_of_calls: usize,
106 pub number_of_jac_muls: usize,
108 pub number_of_matrix_evals: usize,
110 pub number_of_jac_adj_muls: usize,
112}
113
114impl OpStatistics {
115 pub fn new() -> Self {
116 Self {
117 number_of_jac_muls: 0,
118 number_of_calls: 0,
119 number_of_matrix_evals: 0,
120 number_of_jac_adj_muls: 0,
121 }
122 }
123
124 pub fn increment_call(&mut self) {
125 self.number_of_calls += 1;
126 }
127
128 pub fn increment_jac_mul(&mut self) {
129 self.number_of_jac_muls += 1;
130 }
131
132 pub fn increment_jac_adj_mul(&mut self) {
133 self.number_of_jac_adj_muls += 1;
134 }
135
136 pub fn increment_matrix(&mut self) {
137 self.number_of_matrix_evals += 1;
138 }
139}
140
141impl<C: Op> Op for &C {
142 type T = C::T;
143 type V = C::V;
144 type M = C::M;
145 type C = C::C;
146 fn nstates(&self) -> usize {
147 C::nstates(*self)
148 }
149 fn nout(&self) -> usize {
150 C::nout(*self)
151 }
152 fn nparams(&self) -> usize {
153 C::nparams(*self)
154 }
155 fn statistics(&self) -> OpStatistics {
156 C::statistics(*self)
157 }
158 fn context(&self) -> &Self::C {
159 C::context(*self)
160 }
161}
162
163impl<C: Op> Op for &mut C {
164 type T = C::T;
165 type V = C::V;
166 type M = C::M;
167 type C = C::C;
168 fn nstates(&self) -> usize {
169 C::nstates(*self)
170 }
171 fn nout(&self) -> usize {
172 C::nout(*self)
173 }
174 fn nparams(&self) -> usize {
175 C::nparams(*self)
176 }
177 fn statistics(&self) -> OpStatistics {
178 C::statistics(*self)
179 }
180 fn context(&self) -> &Self::C {
181 C::context(*self)
182 }
183}
184
185impl<C: NonLinearOp> NonLinearOp for &C {
186 fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
187 C::call_inplace(*self, x, t, y)
188 }
189}
190
191impl<C: NonLinearOpJacobian> NonLinearOpJacobian for &C {
192 fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
193 C::jac_mul_inplace(*self, x, t, v, y)
194 }
195 fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
196 C::jacobian_inplace(*self, x, t, y)
197 }
198 fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
199 C::jacobian_sparsity(*self)
200 }
201}
202
203impl<C: NonLinearOpAdjoint> NonLinearOpAdjoint for &C {
204 fn adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
205 C::adjoint_inplace(*self, x, t, y)
206 }
207 fn adjoint_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
208 C::adjoint_sparsity(*self)
209 }
210 fn jac_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
211 C::jac_transpose_mul_inplace(*self, x, t, v, y)
212 }
213}
214
215impl<C: NonLinearOpSens> NonLinearOpSens for &C {
216 fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
217 C::sens_mul_inplace(*self, x, t, v, y)
218 }
219 fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
220 C::sens_inplace(*self, x, t, y)
221 }
222
223 fn sens_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
224 C::sens_sparsity(*self)
225 }
226}
227
228impl<C: NonLinearOpSensAdjoint> NonLinearOpSensAdjoint for &C {
229 fn sens_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
230 C::sens_transpose_mul_inplace(*self, x, t, v, y)
231 }
232 fn sens_adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
233 C::sens_adjoint_inplace(*self, x, t, y)
234 }
235 fn sens_adjoint_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
236 C::sens_adjoint_sparsity(*self)
237 }
238}
239
240impl<C: LinearOp> LinearOp for &C {
241 fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) {
242 C::gemv_inplace(*self, x, t, beta, y)
243 }
244 fn sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
245 C::sparsity(*self)
246 }
247 fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) {
248 C::matrix_inplace(*self, t, y)
249 }
250}
251
252impl<C: LinearOpTranspose> LinearOpTranspose for &C {
253 fn gemv_transpose_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) {
254 C::gemv_transpose_inplace(*self, x, t, beta, y)
255 }
256 fn transpose_inplace(&self, t: Self::T, y: &mut Self::M) {
257 C::transpose_inplace(*self, t, y)
258 }
259 fn transpose_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
260 C::transpose_sparsity(*self)
261 }
262}
263
264impl<C: ConstantOp> ConstantOp for &C {
265 fn call_inplace(&self, t: Self::T, y: &mut Self::V) {
266 C::call_inplace(*self, t, y)
267 }
268}
269
270impl<C: ConstantOpSens> ConstantOpSens for &C {
271 fn sens_mul_inplace(&self, t: Self::T, v: &Self::V, y: &mut Self::V) {
272 C::sens_mul_inplace(*self, t, v, y)
273 }
274 fn sens_inplace(&self, t: Self::T, y: &mut Self::M) {
275 C::sens_inplace(*self, t, y)
276 }
277 fn sens_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
278 C::sens_sparsity(*self)
279 }
280}
281
282impl<C: ConstantOpSensAdjoint> ConstantOpSensAdjoint for &C {
283 fn sens_transpose_mul_inplace(&self, t: Self::T, v: &Self::V, y: &mut Self::V) {
284 C::sens_transpose_mul_inplace(*self, t, v, y)
285 }
286 fn sens_adjoint_inplace(&self, t: Self::T, y: &mut Self::M) {
287 C::sens_adjoint_inplace(*self, t, y)
288 }
289 fn sens_adjoint_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
290 C::sens_adjoint_sparsity(*self)
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use std::cell::RefCell;
297
298 use crate::{
299 context::nalgebra::NalgebraContext, matrix::dense_nalgebra_serial::NalgebraMat, ConstantOp,
300 ConstantOpSens, ConstantOpSensAdjoint, LinearOp, LinearOpTranspose, NonLinearOp,
301 NonLinearOpAdjoint, NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, Vector,
302 };
303
304 use super::{Op, OpStatistics, ParameterisedOp};
305
306 type M = NalgebraMat<f64>;
307
308 struct ForwardingOp {
309 ctx: NalgebraContext,
310 stats: RefCell<OpStatistics>,
311 }
312
313 impl ForwardingOp {
314 fn new() -> Self {
315 Self {
316 ctx: NalgebraContext,
317 stats: RefCell::new(OpStatistics::new()),
318 }
319 }
320 }
321
322 impl Op for ForwardingOp {
323 type T = f64;
324 type V = crate::NalgebraVec<f64>;
325 type M = M;
326 type C = NalgebraContext;
327
328 fn context(&self) -> &Self::C {
329 &self.ctx
330 }
331 fn nstates(&self) -> usize {
332 2
333 }
334 fn nout(&self) -> usize {
335 2
336 }
337 fn nparams(&self) -> usize {
338 2
339 }
340 fn statistics(&self) -> OpStatistics {
341 self.stats.borrow().clone()
342 }
343 }
344
345 impl NonLinearOp for ForwardingOp {
346 fn call_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::V) {
347 self.stats.borrow_mut().increment_call();
348 y.copy_from(x);
349 }
350 }
351
352 impl NonLinearOpJacobian for ForwardingOp {
353 fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) {
354 self.stats.borrow_mut().increment_jac_mul();
355 y.copy_from(v);
356 }
357 }
358
359 impl NonLinearOpAdjoint for ForwardingOp {
360 fn jac_transpose_mul_inplace(
361 &self,
362 _x: &Self::V,
363 _t: Self::T,
364 v: &Self::V,
365 y: &mut Self::V,
366 ) {
367 self.stats.borrow_mut().increment_jac_adj_mul();
368 y.copy_from(v);
369 }
370 }
371
372 impl NonLinearOpSens for ForwardingOp {
373 fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
374 y.fill(0.0);
375 }
376 }
377
378 impl NonLinearOpSensAdjoint for ForwardingOp {
379 fn sens_transpose_mul_inplace(
380 &self,
381 _x: &Self::V,
382 _t: Self::T,
383 _v: &Self::V,
384 y: &mut Self::V,
385 ) {
386 y.fill(0.0);
387 }
388 }
389
390 impl LinearOp for ForwardingOp {
391 fn gemv_inplace(&self, x: &Self::V, _t: Self::T, beta: Self::T, y: &mut Self::V) {
392 self.stats.borrow_mut().increment_call();
393 y.axpy(1.0, x, beta);
394 }
395 }
396
397 impl LinearOpTranspose for ForwardingOp {
398 fn gemv_transpose_inplace(&self, x: &Self::V, _t: Self::T, beta: Self::T, y: &mut Self::V) {
399 self.stats.borrow_mut().increment_jac_adj_mul();
400 y.axpy(1.0, x, beta);
401 }
402 }
403
404 impl ConstantOp for ForwardingOp {
405 fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
406 self.stats.borrow_mut().increment_call();
407 y.copy_from(&Self::V::from_vec(vec![1.0, 2.0], self.ctx));
408 }
409 }
410
411 impl ConstantOpSens for ForwardingOp {
412 fn sens_mul_inplace(&self, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
413 y.fill(0.0);
414 }
415 }
416
417 impl ConstantOpSensAdjoint for ForwardingOp {
418 fn sens_transpose_mul_inplace(&self, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
419 y.fill(0.0);
420 }
421 }
422
423 #[test]
424 fn op_statistics_increment_methods_update_counters() {
425 let mut stats = OpStatistics::new();
426 stats.increment_call();
427 stats.increment_jac_mul();
428 stats.increment_jac_adj_mul();
429 stats.increment_matrix();
430 assert_eq!(stats.number_of_calls, 1);
431 assert_eq!(stats.number_of_jac_muls, 1);
432 assert_eq!(stats.number_of_jac_adj_muls, 1);
433 assert_eq!(stats.number_of_matrix_evals, 1);
434 }
435
436 #[test]
437 fn parameterised_op_and_reference_forwarding_delegate_to_inner_operator() {
438 let op = ForwardingOp::new();
439 let p = crate::NalgebraVec::from_vec(vec![1.0, 2.0], NalgebraContext);
440 let pop = ParameterisedOp::new(&op, &p);
441 assert_eq!(pop.nstates(), 2);
442 assert_eq!(pop.nout(), 2);
443 assert_eq!(pop.nparams(), 2);
444
445 let x = crate::NalgebraVec::from_vec(vec![3.0, 4.0], NalgebraContext);
446 let mut y = crate::NalgebraVec::zeros(2, NalgebraContext);
447 NonLinearOp::call_inplace(&&op, &x, 0.0, &mut y);
448 y.assert_eq_st(&x, 1e-12);
449
450 op.jac_mul_inplace(&x, 0.0, &x, &mut y);
451 y.assert_eq_st(&x, 1e-12);
452
453 op.jac_transpose_mul_inplace(&x, 0.0, &x, &mut y);
454 y.assert_eq_st(&x, 1e-12);
455
456 NonLinearOpSens::sens_mul_inplace(&&op, &x, 0.0, &x, &mut y);
457 y.assert_eq_st(&crate::NalgebraVec::zeros(2, NalgebraContext), 1e-12);
458
459 NonLinearOpSensAdjoint::sens_transpose_mul_inplace(&&op, &x, 0.0, &x, &mut y);
460 y.assert_eq_st(&crate::NalgebraVec::zeros(2, NalgebraContext), 1e-12);
461
462 op.gemv_inplace(&x, 0.0, 0.0, &mut y);
463 y.assert_eq_st(&x, 1e-12);
464
465 op.gemv_transpose_inplace(&x, 0.0, 0.0, &mut y);
466 y.assert_eq_st(&x, 1e-12);
467
468 let mut y_const = crate::NalgebraVec::zeros(2, NalgebraContext);
469 <&ForwardingOp as ConstantOp>::call_inplace(&&op, 0.0, &mut y_const);
470 y_const.assert_eq_st(
471 &crate::NalgebraVec::from_vec(vec![1.0, 2.0], NalgebraContext),
472 1e-12,
473 );
474
475 let op_ref_stats = pop.statistics();
476 assert!(op_ref_stats.number_of_calls >= 1);
477
478 let op_mut = ForwardingOp::new();
479 assert_eq!(op_mut.nstates(), 2);
480 assert_eq!(op_mut.nout(), 2);
481 assert_eq!(op_mut.nparams(), 2);
482 }
483}