diffsol/ode_equations/
adjoint_equations.rs

1use num_traits::{One, Zero};
2use std::{
3    cell::RefCell,
4    ops::{AddAssign, SubAssign},
5    rc::Rc,
6};
7
8use crate::{
9    error::DiffsolError, op::nonlinear_op::NonLinearOpJacobian, AugmentedOdeEquations,
10    Checkpointing, ConstantOp, ConstantOpSensAdjoint, LinearOp, LinearOpTranspose, Matrix,
11    NonLinearOp, NonLinearOpAdjoint, NonLinearOpSensAdjoint, OdeEquations, OdeEquationsAdjoint,
12    OdeEquationsRef, OdeSolverMethod, OdeSolverProblem, Op, Vector,
13};
14
15pub struct AdjointContext<'a, Eqn, Method>
16where
17    Eqn: OdeEquations,
18    Method: OdeSolverMethod<'a, Eqn>,
19{
20    checkpointer: Checkpointing<'a, Eqn, Method>,
21    x: Eqn::V,
22    index: usize,
23    max_index: usize,
24    last_t: Option<Eqn::T>,
25    col: Eqn::V,
26}
27
28impl<'a, Eqn, Method> AdjointContext<'a, Eqn, Method>
29where
30    Eqn: OdeEquations,
31    Method: OdeSolverMethod<'a, Eqn>,
32{
33    pub fn new(checkpointer: Checkpointing<'a, Eqn, Method>, max_index: usize) -> Self {
34        let ctx = checkpointer.problem().eqn.context();
35        let x = <Eqn::V as Vector>::zeros(checkpointer.problem().eqn.rhs().nstates(), ctx.clone());
36        let mut col = <Eqn::V as Vector>::zeros(max_index, ctx.clone());
37        let index = 0;
38        col.set_index(0, Eqn::T::one());
39        Self {
40            checkpointer,
41            x,
42            index,
43            max_index,
44            col,
45            last_t: None,
46        }
47    }
48
49    pub fn set_state(&mut self, t: Eqn::T) {
50        if let Some(last_t) = self.last_t {
51            if last_t == t {
52                return;
53            }
54        }
55        self.last_t = Some(t);
56        self.checkpointer.interpolate(t, &mut self.x).unwrap();
57        // for diffsl, we need to set data for the adjoint state!
58        // basically just involves calling the normal rhs function with the new self.x
59        // todo: this seems a bit hacky, perhaps a dedicated function on the trait for this?
60        self.checkpointer.problem().eqn.rhs().call(&self.x, t);
61    }
62
63    pub fn state(&self) -> &Eqn::V {
64        &self.x
65    }
66
67    pub fn col(&self) -> &Eqn::V {
68        &self.col
69    }
70
71    pub fn set_index(&mut self, index: usize) {
72        self.col.set_index(self.index, Eqn::T::zero());
73        self.index = index;
74        self.col.set_index(self.index, Eqn::T::one());
75    }
76}
77
78pub struct AdjointMass<'a, Eqn>
79where
80    Eqn: OdeEquations,
81{
82    eqn: &'a Eqn,
83}
84
85impl<'a, Eqn> AdjointMass<'a, Eqn>
86where
87    Eqn: OdeEquations,
88{
89    pub fn new(eqn: &'a Eqn) -> Self {
90        Self { eqn }
91    }
92}
93
94impl<Eqn> Op for AdjointMass<'_, Eqn>
95where
96    Eqn: OdeEquations,
97{
98    type T = Eqn::T;
99    type V = Eqn::V;
100    type M = Eqn::M;
101    type C = Eqn::C;
102
103    fn nstates(&self) -> usize {
104        self.eqn.rhs().nstates()
105    }
106    fn nout(&self) -> usize {
107        self.eqn.rhs().nstates()
108    }
109    fn nparams(&self) -> usize {
110        self.eqn.rhs().nparams()
111    }
112    fn context(&self) -> &Self::C {
113        self.eqn.context()
114    }
115}
116
117impl<Eqn> LinearOp for AdjointMass<'_, Eqn>
118where
119    Eqn: OdeEquationsAdjoint,
120{
121    fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) {
122        self.eqn
123            .mass()
124            .unwrap()
125            .gemv_transpose_inplace(x, t, beta, y);
126    }
127
128    fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) {
129        self.eqn.mass().unwrap().transpose_inplace(t, y);
130    }
131}
132
133pub struct AdjointInit<'a, Eqn>
134where
135    Eqn: OdeEquations,
136{
137    eqn: &'a Eqn,
138}
139
140impl<'a, Eqn> AdjointInit<'a, Eqn>
141where
142    Eqn: OdeEquations,
143{
144    pub fn new(eqn: &'a Eqn) -> Self {
145        Self { eqn }
146    }
147}
148
149impl<Eqn> Op for AdjointInit<'_, Eqn>
150where
151    Eqn: OdeEquations,
152{
153    type T = Eqn::T;
154    type V = Eqn::V;
155    type M = Eqn::M;
156    type C = Eqn::C;
157
158    fn nstates(&self) -> usize {
159        self.eqn.rhs().nstates()
160    }
161    fn nout(&self) -> usize {
162        self.eqn.rhs().nstates()
163    }
164    fn nparams(&self) -> usize {
165        self.eqn.rhs().nparams()
166    }
167    fn context(&self) -> &Self::C {
168        self.eqn.context()
169    }
170}
171
172impl<Eqn> ConstantOp for AdjointInit<'_, Eqn>
173where
174    Eqn: OdeEquations,
175{
176    fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
177        y.fill(Eqn::T::zero());
178    }
179}
180
181/// Right-hand side of the adjoint equations is:
182///
183/// F(λ, x, t) = -f^T_x(x, t) λ - g^T_x(x,t)
184///
185/// f_x is the partial derivative of the right-hand side with respect to the state vector.
186/// g_x is the partial derivative of the functional g with respect to the state vector.
187///
188/// We need the current state x(t), which is obtained from the checkpointed forward solve at the current time step.
189pub struct AdjointRhs<'a, Eqn, Method>
190where
191    Eqn: OdeEquations,
192    Method: OdeSolverMethod<'a, Eqn>,
193{
194    eqn: &'a Eqn,
195    context: Rc<RefCell<AdjointContext<'a, Eqn, Method>>>,
196    tmp: RefCell<Eqn::V>,
197    with_out: bool,
198}
199
200impl<'a, Eqn, Method> AdjointRhs<'a, Eqn, Method>
201where
202    Eqn: OdeEquations,
203    Method: OdeSolverMethod<'a, Eqn>,
204{
205    pub fn new(
206        eqn: &'a Eqn,
207        context: Rc<RefCell<AdjointContext<'a, Eqn, Method>>>,
208        with_out: bool,
209    ) -> Self {
210        let tmp_n = if with_out { eqn.rhs().nstates() } else { 0 };
211        let tmp = RefCell::new(<Eqn::V as Vector>::zeros(tmp_n, eqn.context().clone()));
212        Self {
213            eqn,
214            context,
215            tmp,
216            with_out,
217        }
218    }
219}
220
221impl<'a, Eqn, Method> Op for AdjointRhs<'a, Eqn, Method>
222where
223    Eqn: OdeEquations,
224    Method: OdeSolverMethod<'a, Eqn>,
225{
226    type T = Eqn::T;
227    type V = Eqn::V;
228    type M = Eqn::M;
229    type C = Eqn::C;
230
231    fn nstates(&self) -> usize {
232        self.eqn.rhs().nstates()
233    }
234    fn nout(&self) -> usize {
235        self.eqn.rhs().nstates()
236    }
237    fn nparams(&self) -> usize {
238        self.eqn.rhs().nparams()
239    }
240    fn context(&self) -> &Self::C {
241        self.eqn.context()
242    }
243}
244
245impl<'a, Eqn, Method> NonLinearOp for AdjointRhs<'a, Eqn, Method>
246where
247    Eqn: OdeEquationsAdjoint,
248    Method: OdeSolverMethod<'a, Eqn>,
249{
250    /// F(λ, x, t) = -f^T_x(x, t) λ - g^T_x(x,t)
251    fn call_inplace(&self, lambda: &Self::V, t: Self::T, y: &mut Self::V) {
252        self.context.borrow_mut().set_state(t);
253        let context = self.context.borrow();
254        let x = context.state();
255
256        // y = -f^T_x(x, t) λ
257        self.eqn.rhs().jac_transpose_mul_inplace(x, t, lambda, y);
258
259        // y = -f^T_x(x, t) λ - g^T_x(x,t)
260        if self.with_out {
261            let col = context.col();
262            let mut tmp = self.tmp.borrow_mut();
263            self.eqn
264                .out()
265                .unwrap()
266                .jac_transpose_mul_inplace(x, t, col, &mut tmp);
267            y.add_assign(&*tmp);
268        }
269    }
270}
271
272impl<'a, Eqn, Method> NonLinearOpJacobian for AdjointRhs<'a, Eqn, Method>
273where
274    Eqn: OdeEquationsAdjoint,
275    Method: OdeSolverMethod<'a, Eqn>,
276{
277    // J = -f^T_x(x, t)
278    fn jac_mul_inplace(&self, _x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
279        self.context.borrow_mut().set_state(t);
280        let context = self.context.borrow();
281        let x = context.state();
282        self.eqn.rhs().jac_transpose_mul_inplace(x, t, v, y);
283    }
284    fn jacobian_inplace(&self, _x: &Self::V, t: Self::T, y: &mut Self::M) {
285        self.context.borrow_mut().set_state(t);
286        let context = self.context.borrow();
287        let x = context.state();
288        self.eqn.rhs().adjoint_inplace(x, t, y);
289    }
290    fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
291        self.eqn.rhs().adjoint_sparsity()
292    }
293}
294
295/// Output of the adjoint equations is:
296///
297/// F(λ, x, t) = -g_p^T(x, t) - f_p^T(x, t) λ
298///
299/// f_p is the partial derivative of the right-hand side with respect to the parameter vector
300/// g_p is the partial derivative of the functional g with respect to the parameter vector
301///
302/// We need the current state x(t), which is obtained from the checkpointed forward solve at the current time step.
303pub struct AdjointOut<'a, Eqn, Method>
304where
305    Eqn: OdeEquations,
306    Method: OdeSolverMethod<'a, Eqn>,
307{
308    eqn: &'a Eqn,
309    context: Rc<RefCell<AdjointContext<'a, Eqn, Method>>>,
310    tmp: RefCell<Eqn::V>,
311    with_out: bool,
312}
313
314impl<'a, Eqn, Method> AdjointOut<'a, Eqn, Method>
315where
316    Eqn: OdeEquations,
317    Method: OdeSolverMethod<'a, Eqn>,
318{
319    pub fn new(
320        eqn: &'a Eqn,
321        context: Rc<RefCell<AdjointContext<'a, Eqn, Method>>>,
322        with_out: bool,
323    ) -> Self {
324        let tmp_n = if with_out { eqn.rhs().nparams() } else { 0 };
325        let tmp = RefCell::new(<Eqn::V as Vector>::zeros(tmp_n, eqn.context().clone()));
326        Self {
327            eqn,
328            context,
329            tmp,
330            with_out,
331        }
332    }
333}
334
335impl<'a, Eqn, Method> Op for AdjointOut<'a, Eqn, Method>
336where
337    Eqn: OdeEquations,
338    Method: OdeSolverMethod<'a, Eqn>,
339{
340    type T = Eqn::T;
341    type V = Eqn::V;
342    type M = Eqn::M;
343    type C = Eqn::C;
344
345    fn nstates(&self) -> usize {
346        self.eqn.rhs().nstates()
347    }
348    fn nout(&self) -> usize {
349        self.eqn.rhs().nparams()
350    }
351    fn nparams(&self) -> usize {
352        self.eqn.rhs().nparams()
353    }
354    fn context(&self) -> &Self::C {
355        self.eqn.context()
356    }
357}
358
359impl<'a, Eqn, Method> NonLinearOp for AdjointOut<'a, Eqn, Method>
360where
361    Eqn: OdeEquationsAdjoint,
362    Method: OdeSolverMethod<'a, Eqn>,
363{
364    /// F(λ, x, t) = -g_p(x, t) - λ^T f_p(x, t)
365    fn call_inplace(&self, lambda: &Self::V, t: Self::T, y: &mut Self::V) {
366        self.context.borrow_mut().set_state(t);
367        let context = self.context.borrow();
368        let x = context.state();
369        self.eqn.rhs().sens_transpose_mul_inplace(x, t, lambda, y);
370
371        if self.with_out {
372            let col = context.col();
373            let mut tmp = self.tmp.borrow_mut();
374            self.eqn
375                .out()
376                .unwrap()
377                .sens_transpose_mul_inplace(x, t, col, &mut tmp);
378            y.add_assign(&*tmp);
379        }
380    }
381}
382
383impl<'a, Eqn, Method> NonLinearOpJacobian for AdjointOut<'a, Eqn, Method>
384where
385    Eqn: OdeEquationsAdjoint,
386    Method: OdeSolverMethod<'a, Eqn>,
387{
388    // J = -f_p(x, t)
389    fn jac_mul_inplace(&self, _x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
390        self.context.borrow_mut().set_state(t);
391        let context = self.context.borrow();
392        let x = context.state();
393        self.eqn.rhs().sens_transpose_mul_inplace(x, t, v, y);
394    }
395    fn jacobian_inplace(&self, _x: &Self::V, t: Self::T, y: &mut Self::M) {
396        self.context.borrow_mut().set_state(t);
397        let context = self.context.borrow();
398        let x = context.state();
399        self.eqn.rhs().sens_adjoint_inplace(x, t, y);
400    }
401    fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
402        self.eqn.rhs().sens_adjoint_sparsity()
403    }
404}
405
406/// Adjoint equations for ODEs
407///
408/// M * dλ/dt = -f^T_x(x, t) λ - g^T_x(x,t)
409/// λ(T) = 0
410/// g(λ, x, t) = -g_p(x, t) - λ^T f_p(x, t)
411///
412pub struct AdjointEquations<'a, Eqn, Method>
413where
414    Eqn: OdeEquations,
415    Method: OdeSolverMethod<'a, Eqn>,
416{
417    eqn: &'a Eqn,
418    rhs: AdjointRhs<'a, Eqn, Method>,
419    out: AdjointOut<'a, Eqn, Method>,
420    mass: Option<AdjointMass<'a, Eqn>>,
421    context: Rc<RefCell<AdjointContext<'a, Eqn, Method>>>,
422    tmp: RefCell<Eqn::V>,
423    tmp2: RefCell<Eqn::V>,
424    init: AdjointInit<'a, Eqn>,
425    atol: Option<&'a Eqn::V>,
426    rtol: Option<Eqn::T>,
427    out_rtol: Option<Eqn::T>,
428    out_atol: Option<&'a Eqn::V>,
429}
430
431impl<'a, Eqn, Method> Clone for AdjointEquations<'a, Eqn, Method>
432where
433    Eqn: OdeEquations,
434    Method: OdeSolverMethod<'a, Eqn>,
435{
436    fn clone(&self) -> Self {
437        let context = Rc::new(RefCell::new(AdjointContext::new(
438            self.context.borrow().checkpointer.clone(),
439            self.context.borrow().max_index,
440        )));
441        let rhs = AdjointRhs::new(self.eqn, context.clone(), self.rhs.with_out);
442        let init = AdjointInit::new(self.eqn);
443        let out = AdjointOut::new(self.eqn, context.clone(), self.out.with_out);
444        let tmp = self.tmp.clone();
445        let tmp2 = self.tmp2.clone();
446        let atol = self.atol;
447        let rtol = self.rtol;
448        let out_atol = self.out_atol;
449        let out_rtol = self.out_rtol;
450        let mass = self.eqn.mass().map(|_m| AdjointMass::new(self.eqn));
451        Self {
452            rhs,
453            init,
454            mass,
455            context,
456            out,
457            tmp,
458            tmp2,
459            eqn: self.eqn,
460            atol,
461            rtol,
462            out_rtol,
463            out_atol,
464        }
465    }
466}
467
468impl<'a, Eqn, Method> AdjointEquations<'a, Eqn, Method>
469where
470    Eqn: OdeEquationsAdjoint,
471    Method: OdeSolverMethod<'a, Eqn>,
472{
473    pub(crate) fn new(
474        problem: &'a OdeSolverProblem<Eqn>,
475        context: Rc<RefCell<AdjointContext<'a, Eqn, Method>>>,
476        with_out: bool,
477    ) -> Self {
478        let eqn = &problem.eqn;
479        let rhs = AdjointRhs::new(eqn, context.clone(), with_out);
480        let init = AdjointInit::new(eqn);
481        let out = AdjointOut::new(eqn, context.clone(), with_out);
482        let tmp = RefCell::new(<Eqn::V as Vector>::zeros(
483            eqn.rhs().nparams(),
484            eqn.context().clone(),
485        ));
486        let tmp2 = RefCell::new(<Eqn::V as Vector>::zeros(
487            eqn.rhs().nstates(),
488            eqn.context().clone(),
489        ));
490        let atol = problem.sens_atol.as_ref();
491        let rtol = problem.sens_rtol;
492        let out_atol = problem.param_atol.as_ref();
493        let out_rtol = problem.param_rtol;
494        let mass = eqn.mass().map(|_m| AdjointMass::new(eqn));
495        Self {
496            rhs,
497            init,
498            mass,
499            context,
500            out,
501            tmp,
502            tmp2,
503            eqn,
504            atol,
505            rtol,
506            out_rtol,
507            out_atol,
508        }
509    }
510
511    pub fn eqn(&self) -> &'a Eqn {
512        self.eqn
513    }
514
515    pub fn correct_sg_for_init(&self, t: Eqn::T, s: &[Eqn::V], sg: &mut [Eqn::V]) {
516        let mut tmp = self.tmp.borrow_mut();
517        for (s_i, sg_i) in s.iter().zip(sg.iter_mut()) {
518            if let Some(mass) = self.eqn.mass() {
519                let mut tmp2 = self.tmp2.borrow_mut();
520                mass.call_transpose_inplace(s_i, t, &mut tmp2);
521                self.eqn
522                    .init()
523                    .sens_transpose_mul_inplace(t, &tmp2, &mut tmp);
524                sg_i.sub_assign(&*tmp);
525            } else {
526                self.eqn.init().sens_transpose_mul_inplace(t, s_i, &mut tmp);
527                sg_i.sub_assign(&*tmp);
528            }
529        }
530    }
531
532    pub fn interpolate_forward_state(&self, t: Eqn::T, y: &mut Eqn::V) -> Result<(), DiffsolError> {
533        self.context.borrow_mut().set_state(t);
534        let context = self.context.borrow();
535        context.checkpointer.interpolate(t, y)
536    }
537}
538
539impl<'a, Eqn, Method> std::fmt::Debug for AdjointEquations<'a, Eqn, Method>
540where
541    Eqn: OdeEquations,
542    Method: OdeSolverMethod<'a, Eqn>,
543{
544    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
545        f.debug_struct("AdjointEquations").finish()
546    }
547}
548
549impl<'a, Eqn, Method> Op for AdjointEquations<'a, Eqn, Method>
550where
551    Eqn: OdeEquations,
552    Method: OdeSolverMethod<'a, Eqn>,
553{
554    type T = Eqn::T;
555    type V = Eqn::V;
556    type M = Eqn::M;
557    type C = Eqn::C;
558
559    fn nstates(&self) -> usize {
560        self.eqn.rhs().nstates()
561    }
562    fn nout(&self) -> usize {
563        self.eqn.rhs().nout()
564    }
565    fn nparams(&self) -> usize {
566        self.eqn.rhs().nparams()
567    }
568    fn context(&self) -> &Self::C {
569        self.eqn.context()
570    }
571}
572
573impl<'a, 'b, Eqn, Method> OdeEquationsRef<'a> for AdjointEquations<'b, Eqn, Method>
574where
575    Eqn: OdeEquationsAdjoint,
576    Method: OdeSolverMethod<'b, Eqn>,
577{
578    type Rhs = &'a AdjointRhs<'b, Eqn, Method>;
579    type Mass = &'a AdjointMass<'b, Eqn>;
580    type Root = <Eqn as OdeEquationsRef<'a>>::Root;
581    type Init = &'a AdjointInit<'b, Eqn>;
582    type Out = &'a AdjointOut<'b, Eqn, Method>;
583}
584
585impl<'a, Eqn, Method> OdeEquations for AdjointEquations<'a, Eqn, Method>
586where
587    Eqn: OdeEquationsAdjoint,
588    Method: OdeSolverMethod<'a, Eqn>,
589{
590    fn rhs(&self) -> &AdjointRhs<'a, Eqn, Method> {
591        &self.rhs
592    }
593    fn mass(&self) -> Option<&AdjointMass<'a, Eqn>> {
594        self.mass.as_ref()
595    }
596    fn root(&self) -> Option<<Eqn as OdeEquationsRef<'_>>::Root> {
597        None
598    }
599    fn init(&self) -> &AdjointInit<'a, Eqn> {
600        &self.init
601    }
602    fn out(&self) -> Option<&AdjointOut<'a, Eqn, Method>> {
603        Some(&self.out)
604    }
605    fn set_params(&mut self, p: &Self::V) {
606        self.eqn.set_params(p);
607    }
608    fn get_params(&self, p: &mut Self::V) {
609        self.eqn.get_params(p);
610    }
611}
612
613impl<'a, Eqn, Method> AugmentedOdeEquations<Eqn> for AdjointEquations<'a, Eqn, Method>
614where
615    Eqn: OdeEquationsAdjoint,
616    Method: OdeSolverMethod<'a, Eqn>,
617{
618    fn include_in_error_control(&self) -> bool {
619        self.atol.is_some() && self.rtol.is_some()
620    }
621    fn include_out_in_error_control(&self) -> bool {
622        self.out().is_some() && self.out_atol.is_some() && self.out_rtol.is_some()
623    }
624
625    fn atol(&self) -> Option<&Eqn::V> {
626        self.atol
627    }
628    fn out_atol(&self) -> Option<&Eqn::V> {
629        self.out_atol
630    }
631    fn out_rtol(&self) -> Option<Eqn::T> {
632        self.out_rtol
633    }
634    fn rtol(&self) -> Option<Eqn::T> {
635        self.rtol
636    }
637
638    fn max_index(&self) -> usize {
639        self.context.borrow().max_index
640    }
641
642    fn set_index(&mut self, index: usize) {
643        self.context.borrow_mut().set_index(index);
644    }
645
646    fn update_rhs_out_state(&mut self, _y: &Eqn::V, _dy: &Eqn::V, _t: Eqn::T) {}
647
648    fn integrate_main_eqn(&self) -> bool {
649        false
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use std::{cell::RefCell, rc::Rc};
656
657    use crate::{
658        matrix::dense_nalgebra_serial::NalgebraMat,
659        ode_equations::{
660            adjoint_equations::AdjointEquations,
661            test_models::exponential_decay::exponential_decay_problem_adjoint,
662        },
663        AdjointContext, AugmentedOdeEquations, Checkpointing, DenseMatrix, FaerSparseLU,
664        FaerSparseMat, FaerVec, Matrix, MatrixCommon, NalgebraVec, NonLinearOp,
665        NonLinearOpJacobian, OdeEquations, Op, RkState, Vector,
666    };
667    type Mcpu = NalgebraMat<f64>;
668    type Vcpu = NalgebraVec<f64>;
669    type LS = crate::NalgebraLU<f64>;
670
671    #[test]
672    fn test_rhs_exponential() {
673        // dy/dt = -ay (p = [a])
674        // a = 0.1
675        let (problem, _soln) = exponential_decay_problem_adjoint::<Mcpu>(true);
676        let ctx = problem.eqn.context();
677        let state = RkState {
678            t: 0.0,
679            y: Vcpu::from_vec(vec![1.0, 1.0], *ctx),
680            dy: Vcpu::from_vec(vec![1.0, 1.0], *ctx),
681            g: Vcpu::zeros(0, *ctx),
682            dg: Vcpu::zeros(0, *ctx),
683            sg: Vec::new(),
684            dsg: Vec::new(),
685            s: Vec::new(),
686            ds: Vec::new(),
687            h: 0.0,
688        };
689        let nout = problem.eqn.out().unwrap().nout();
690        let solver = problem.esdirk34_solver::<LS>(state.clone()).unwrap();
691        let checkpointer = Checkpointing::new(solver, 0, vec![state.clone(), state.clone()], None);
692        let context = Rc::new(RefCell::new(AdjointContext::new(checkpointer, nout)));
693        let adj_eqn = AdjointEquations::new(&problem, context.clone(), false);
694        // F(λ, x, t) = -f^T_x(x, t) λ
695        // f_x = |-a 0|
696        //       |0 -a|
697        // F(s, t)_0 =  |a 0| |1| = |a| = |0.1|
698        //              |0 a| |2|   |2a| = |0.2|
699        let v = Vcpu::from_vec(vec![1.0, 2.0], *ctx);
700        let f = adj_eqn.rhs.call(&v, state.t);
701        let f_expect = Vcpu::from_vec(vec![0.1, 0.2], *ctx);
702        f.assert_eq_st(&f_expect, 1e-10);
703
704        let mut adj_eqn = AdjointEquations::new(&problem, context, true);
705
706        // f_x^T = |-a 0|
707        //         |0 -a|
708        // J = -f_x^T
709        let adjoint = adj_eqn.rhs.jacobian(&state.y, state.t);
710        assert_eq!(adjoint.nrows(), 2);
711        assert_eq!(adjoint.ncols(), 2);
712        assert_eq!(adjoint.get_index(0, 0), 0.1);
713        assert_eq!(adjoint.get_index(1, 1), 0.1);
714
715        // g_x = |1 2|
716        //       |3 4|
717        // S = -g^T_x(x,t)
718        // so S = |-1 -3|
719        //        |-2 -4|
720
721        // f_p^T = |-x_1 -x_2 |
722        //         |0   0 |
723        // g_p = |0 0|
724        //       |0 0|
725        // g(λ, x, t) = -g_p(x, t) - λ^T f_p(x, t)
726        //            = |1  1| |1| + |0| = |3|
727        //              |0  0| |2|  |0|  = |0|
728        adj_eqn.set_index(0);
729        let out = adj_eqn.out.call(&v, state.t);
730        let out_expect = Vcpu::from_vec(vec![3.0, 0.0], *ctx);
731        out.assert_eq_st(&out_expect, 1e-10);
732
733        // F(λ, x, t) = -f^T_x(x, t) λ - g^T_x(x,t)
734        // f_x = |-a 0|
735        //       |0 -a|
736        // F(s, t)_0 =  |a 0| |1| - |1.0| = | a - 1| = |-0.9|
737        //              |0 a| |2|   |2.0|   |2a - 2| = |-1.8|
738        let f = adj_eqn.rhs.call(&v, state.t);
739        let f_expect = Vcpu::from_vec(vec![-0.9, -1.8], *ctx);
740        f.assert_eq_st(&f_expect, 1e-10);
741    }
742
743    #[test]
744    fn test_rhs_exponential_sparse() {
745        // dy/dt = -ay (p = [a])
746        // a = 0.1
747        let (problem, _soln) = exponential_decay_problem_adjoint::<FaerSparseMat<f64>>(true);
748        let ctx = problem.eqn.context();
749        let state = RkState {
750            t: 0.0,
751            y: FaerVec::from_vec(vec![1.0, 1.0], *ctx),
752            dy: FaerVec::from_vec(vec![1.0, 1.0], *ctx),
753            g: FaerVec::zeros(0, *ctx),
754            dg: FaerVec::zeros(0, *ctx),
755            sg: Vec::new(),
756            dsg: Vec::new(),
757            s: Vec::new(),
758            ds: Vec::new(),
759            h: 0.0,
760        };
761        let nout = problem.eqn.out().unwrap().nout();
762        let solver = problem
763            .esdirk34_solver::<FaerSparseLU<f64>>(state.clone())
764            .unwrap();
765        let checkpointer = Checkpointing::new(solver, 0, vec![state.clone(), state.clone()], None);
766        let context = Rc::new(RefCell::new(AdjointContext::new(checkpointer, nout)));
767        let mut adj_eqn = AdjointEquations::new(&problem, context, true);
768
769        // f_x^T = |-a 0|
770        //         |0 -a|
771        // J = -f_x^T
772        let adjoint = adj_eqn.rhs.jacobian(&state.y, state.t);
773        assert_eq!(adjoint.nrows(), 2);
774        assert_eq!(adjoint.ncols(), 2);
775        for (i, j, v) in adjoint.triplet_iter() {
776            if i == j {
777                assert_eq!(v, 0.1);
778            } else {
779                assert_eq!(v, 0.0);
780            }
781        }
782
783        // g_x = |1 2|
784        //       |3 4|
785        // S = -g^T_x(x,t)
786        // so S = |-1 -3|
787        //        |-2 -4|
788
789        // F(λ, x, t) = -f^T_x(x, t) λ - g^T_x(x,t)
790        // f_x = |-a 0|
791        //       |0 -a|
792        // F(s, t)_0 =  |a 0| |1| - |1.0| = |a - 1| = |-0.9|
793        //              |0 a| |2|   |2.0|   |2a - 2| = |-1.8|
794        adj_eqn.set_index(0);
795        let v = FaerVec::from_vec(vec![1.0, 2.0], *ctx);
796        let f = adj_eqn.rhs.call(&v, state.t);
797        let f_expect = FaerVec::from_vec(vec![-0.9, -1.8], *ctx);
798        f.assert_eq_st(&f_expect, 1e-10);
799    }
800}