1use num_traits::{One, Zero};
2use std::{
3 cell::RefCell,
4 ops::{AddAssign, SubAssign},
5 rc::Rc,
6};
7
8use crate::{
9 error::DiffsolError, op::nonlinear_op::NonLinearOpJacobian, AugmentedOdeEquations,
10 Checkpointing, ConstantOp, ConstantOpSensAdjoint, LinearOp, LinearOpTranspose, Matrix,
11 NonLinearOp, NonLinearOpAdjoint, NonLinearOpSensAdjoint, OdeEquations, OdeEquationsAdjoint,
12 OdeEquationsRef, OdeSolverMethod, OdeSolverProblem, Op, Vector,
13};
14
15pub struct AdjointContext<'a, Eqn, Method>
16where
17 Eqn: OdeEquations,
18 Method: OdeSolverMethod<'a, Eqn>,
19{
20 checkpointer: Checkpointing<'a, Eqn, Method>,
21 x: Eqn::V,
22 index: usize,
23 max_index: usize,
24 last_t: Option<Eqn::T>,
25 col: Eqn::V,
26}
27
28impl<'a, Eqn, Method> AdjointContext<'a, Eqn, Method>
29where
30 Eqn: OdeEquations,
31 Method: OdeSolverMethod<'a, Eqn>,
32{
33 pub fn new(checkpointer: Checkpointing<'a, Eqn, Method>, max_index: usize) -> Self {
34 let ctx = checkpointer.problem().eqn.context();
35 let x = <Eqn::V as Vector>::zeros(checkpointer.problem().eqn.rhs().nstates(), ctx.clone());
36 let mut col = <Eqn::V as Vector>::zeros(max_index, ctx.clone());
37 let index = 0;
38 col.set_index(0, Eqn::T::one());
39 Self {
40 checkpointer,
41 x,
42 index,
43 max_index,
44 col,
45 last_t: None,
46 }
47 }
48
49 pub fn set_state(&mut self, t: Eqn::T) {
50 if let Some(last_t) = self.last_t {
51 if last_t == t {
52 return;
53 }
54 }
55 self.last_t = Some(t);
56 self.checkpointer.interpolate(t, &mut self.x).unwrap();
57 self.checkpointer.problem().eqn.rhs().call(&self.x, t);
61 }
62
63 pub fn state(&self) -> &Eqn::V {
64 &self.x
65 }
66
67 pub fn col(&self) -> &Eqn::V {
68 &self.col
69 }
70
71 pub fn set_index(&mut self, index: usize) {
72 self.col.set_index(self.index, Eqn::T::zero());
73 self.index = index;
74 self.col.set_index(self.index, Eqn::T::one());
75 }
76}
77
78pub struct AdjointMass<'a, Eqn>
79where
80 Eqn: OdeEquations,
81{
82 eqn: &'a Eqn,
83}
84
85impl<'a, Eqn> AdjointMass<'a, Eqn>
86where
87 Eqn: OdeEquations,
88{
89 pub fn new(eqn: &'a Eqn) -> Self {
90 Self { eqn }
91 }
92}
93
94impl<Eqn> Op for AdjointMass<'_, Eqn>
95where
96 Eqn: OdeEquations,
97{
98 type T = Eqn::T;
99 type V = Eqn::V;
100 type M = Eqn::M;
101 type C = Eqn::C;
102
103 fn nstates(&self) -> usize {
104 self.eqn.rhs().nstates()
105 }
106 fn nout(&self) -> usize {
107 self.eqn.rhs().nstates()
108 }
109 fn nparams(&self) -> usize {
110 self.eqn.rhs().nparams()
111 }
112 fn context(&self) -> &Self::C {
113 self.eqn.context()
114 }
115}
116
117impl<Eqn> LinearOp for AdjointMass<'_, Eqn>
118where
119 Eqn: OdeEquationsAdjoint,
120{
121 fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) {
122 self.eqn
123 .mass()
124 .unwrap()
125 .gemv_transpose_inplace(x, t, beta, y);
126 }
127
128 fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) {
129 self.eqn.mass().unwrap().transpose_inplace(t, y);
130 }
131}
132
133pub struct AdjointInit<'a, Eqn>
134where
135 Eqn: OdeEquations,
136{
137 eqn: &'a Eqn,
138}
139
140impl<'a, Eqn> AdjointInit<'a, Eqn>
141where
142 Eqn: OdeEquations,
143{
144 pub fn new(eqn: &'a Eqn) -> Self {
145 Self { eqn }
146 }
147}
148
149impl<Eqn> Op for AdjointInit<'_, Eqn>
150where
151 Eqn: OdeEquations,
152{
153 type T = Eqn::T;
154 type V = Eqn::V;
155 type M = Eqn::M;
156 type C = Eqn::C;
157
158 fn nstates(&self) -> usize {
159 self.eqn.rhs().nstates()
160 }
161 fn nout(&self) -> usize {
162 self.eqn.rhs().nstates()
163 }
164 fn nparams(&self) -> usize {
165 self.eqn.rhs().nparams()
166 }
167 fn context(&self) -> &Self::C {
168 self.eqn.context()
169 }
170}
171
172impl<Eqn> ConstantOp for AdjointInit<'_, Eqn>
173where
174 Eqn: OdeEquations,
175{
176 fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
177 y.fill(Eqn::T::zero());
178 }
179}
180
181pub struct AdjointRhs<'a, Eqn, Method>
190where
191 Eqn: OdeEquations,
192 Method: OdeSolverMethod<'a, Eqn>,
193{
194 eqn: &'a Eqn,
195 context: Rc<RefCell<AdjointContext<'a, Eqn, Method>>>,
196 tmp: RefCell<Eqn::V>,
197 with_out: bool,
198}
199
200impl<'a, Eqn, Method> AdjointRhs<'a, Eqn, Method>
201where
202 Eqn: OdeEquations,
203 Method: OdeSolverMethod<'a, Eqn>,
204{
205 pub fn new(
206 eqn: &'a Eqn,
207 context: Rc<RefCell<AdjointContext<'a, Eqn, Method>>>,
208 with_out: bool,
209 ) -> Self {
210 let tmp_n = if with_out { eqn.rhs().nstates() } else { 0 };
211 let tmp = RefCell::new(<Eqn::V as Vector>::zeros(tmp_n, eqn.context().clone()));
212 Self {
213 eqn,
214 context,
215 tmp,
216 with_out,
217 }
218 }
219}
220
221impl<'a, Eqn, Method> Op for AdjointRhs<'a, Eqn, Method>
222where
223 Eqn: OdeEquations,
224 Method: OdeSolverMethod<'a, Eqn>,
225{
226 type T = Eqn::T;
227 type V = Eqn::V;
228 type M = Eqn::M;
229 type C = Eqn::C;
230
231 fn nstates(&self) -> usize {
232 self.eqn.rhs().nstates()
233 }
234 fn nout(&self) -> usize {
235 self.eqn.rhs().nstates()
236 }
237 fn nparams(&self) -> usize {
238 self.eqn.rhs().nparams()
239 }
240 fn context(&self) -> &Self::C {
241 self.eqn.context()
242 }
243}
244
245impl<'a, Eqn, Method> NonLinearOp for AdjointRhs<'a, Eqn, Method>
246where
247 Eqn: OdeEquationsAdjoint,
248 Method: OdeSolverMethod<'a, Eqn>,
249{
250 fn call_inplace(&self, lambda: &Self::V, t: Self::T, y: &mut Self::V) {
252 self.context.borrow_mut().set_state(t);
253 let context = self.context.borrow();
254 let x = context.state();
255
256 self.eqn.rhs().jac_transpose_mul_inplace(x, t, lambda, y);
258
259 if self.with_out {
261 let col = context.col();
262 let mut tmp = self.tmp.borrow_mut();
263 self.eqn
264 .out()
265 .unwrap()
266 .jac_transpose_mul_inplace(x, t, col, &mut tmp);
267 y.add_assign(&*tmp);
268 }
269 }
270}
271
272impl<'a, Eqn, Method> NonLinearOpJacobian for AdjointRhs<'a, Eqn, Method>
273where
274 Eqn: OdeEquationsAdjoint,
275 Method: OdeSolverMethod<'a, Eqn>,
276{
277 fn jac_mul_inplace(&self, _x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
279 self.context.borrow_mut().set_state(t);
280 let context = self.context.borrow();
281 let x = context.state();
282 self.eqn.rhs().jac_transpose_mul_inplace(x, t, v, y);
283 }
284 fn jacobian_inplace(&self, _x: &Self::V, t: Self::T, y: &mut Self::M) {
285 self.context.borrow_mut().set_state(t);
286 let context = self.context.borrow();
287 let x = context.state();
288 self.eqn.rhs().adjoint_inplace(x, t, y);
289 }
290 fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
291 self.eqn.rhs().adjoint_sparsity()
292 }
293}
294
295pub struct AdjointOut<'a, Eqn, Method>
304where
305 Eqn: OdeEquations,
306 Method: OdeSolverMethod<'a, Eqn>,
307{
308 eqn: &'a Eqn,
309 context: Rc<RefCell<AdjointContext<'a, Eqn, Method>>>,
310 tmp: RefCell<Eqn::V>,
311 with_out: bool,
312}
313
314impl<'a, Eqn, Method> AdjointOut<'a, Eqn, Method>
315where
316 Eqn: OdeEquations,
317 Method: OdeSolverMethod<'a, Eqn>,
318{
319 pub fn new(
320 eqn: &'a Eqn,
321 context: Rc<RefCell<AdjointContext<'a, Eqn, Method>>>,
322 with_out: bool,
323 ) -> Self {
324 let tmp_n = if with_out { eqn.rhs().nparams() } else { 0 };
325 let tmp = RefCell::new(<Eqn::V as Vector>::zeros(tmp_n, eqn.context().clone()));
326 Self {
327 eqn,
328 context,
329 tmp,
330 with_out,
331 }
332 }
333}
334
335impl<'a, Eqn, Method> Op for AdjointOut<'a, Eqn, Method>
336where
337 Eqn: OdeEquations,
338 Method: OdeSolverMethod<'a, Eqn>,
339{
340 type T = Eqn::T;
341 type V = Eqn::V;
342 type M = Eqn::M;
343 type C = Eqn::C;
344
345 fn nstates(&self) -> usize {
346 self.eqn.rhs().nstates()
347 }
348 fn nout(&self) -> usize {
349 self.eqn.rhs().nparams()
350 }
351 fn nparams(&self) -> usize {
352 self.eqn.rhs().nparams()
353 }
354 fn context(&self) -> &Self::C {
355 self.eqn.context()
356 }
357}
358
359impl<'a, Eqn, Method> NonLinearOp for AdjointOut<'a, Eqn, Method>
360where
361 Eqn: OdeEquationsAdjoint,
362 Method: OdeSolverMethod<'a, Eqn>,
363{
364 fn call_inplace(&self, lambda: &Self::V, t: Self::T, y: &mut Self::V) {
366 self.context.borrow_mut().set_state(t);
367 let context = self.context.borrow();
368 let x = context.state();
369 self.eqn.rhs().sens_transpose_mul_inplace(x, t, lambda, y);
370
371 if self.with_out {
372 let col = context.col();
373 let mut tmp = self.tmp.borrow_mut();
374 self.eqn
375 .out()
376 .unwrap()
377 .sens_transpose_mul_inplace(x, t, col, &mut tmp);
378 y.add_assign(&*tmp);
379 }
380 }
381}
382
383impl<'a, Eqn, Method> NonLinearOpJacobian for AdjointOut<'a, Eqn, Method>
384where
385 Eqn: OdeEquationsAdjoint,
386 Method: OdeSolverMethod<'a, Eqn>,
387{
388 fn jac_mul_inplace(&self, _x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
390 self.context.borrow_mut().set_state(t);
391 let context = self.context.borrow();
392 let x = context.state();
393 self.eqn.rhs().sens_transpose_mul_inplace(x, t, v, y);
394 }
395 fn jacobian_inplace(&self, _x: &Self::V, t: Self::T, y: &mut Self::M) {
396 self.context.borrow_mut().set_state(t);
397 let context = self.context.borrow();
398 let x = context.state();
399 self.eqn.rhs().sens_adjoint_inplace(x, t, y);
400 }
401 fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
402 self.eqn.rhs().sens_adjoint_sparsity()
403 }
404}
405
406pub struct AdjointEquations<'a, Eqn, Method>
413where
414 Eqn: OdeEquations,
415 Method: OdeSolverMethod<'a, Eqn>,
416{
417 eqn: &'a Eqn,
418 rhs: AdjointRhs<'a, Eqn, Method>,
419 out: AdjointOut<'a, Eqn, Method>,
420 mass: Option<AdjointMass<'a, Eqn>>,
421 context: Rc<RefCell<AdjointContext<'a, Eqn, Method>>>,
422 tmp: RefCell<Eqn::V>,
423 tmp2: RefCell<Eqn::V>,
424 init: AdjointInit<'a, Eqn>,
425 atol: Option<&'a Eqn::V>,
426 rtol: Option<Eqn::T>,
427 out_rtol: Option<Eqn::T>,
428 out_atol: Option<&'a Eqn::V>,
429}
430
431impl<'a, Eqn, Method> Clone for AdjointEquations<'a, Eqn, Method>
432where
433 Eqn: OdeEquations,
434 Method: OdeSolverMethod<'a, Eqn>,
435{
436 fn clone(&self) -> Self {
437 let context = Rc::new(RefCell::new(AdjointContext::new(
438 self.context.borrow().checkpointer.clone(),
439 self.context.borrow().max_index,
440 )));
441 let rhs = AdjointRhs::new(self.eqn, context.clone(), self.rhs.with_out);
442 let init = AdjointInit::new(self.eqn);
443 let out = AdjointOut::new(self.eqn, context.clone(), self.out.with_out);
444 let tmp = self.tmp.clone();
445 let tmp2 = self.tmp2.clone();
446 let atol = self.atol;
447 let rtol = self.rtol;
448 let out_atol = self.out_atol;
449 let out_rtol = self.out_rtol;
450 let mass = self.eqn.mass().map(|_m| AdjointMass::new(self.eqn));
451 Self {
452 rhs,
453 init,
454 mass,
455 context,
456 out,
457 tmp,
458 tmp2,
459 eqn: self.eqn,
460 atol,
461 rtol,
462 out_rtol,
463 out_atol,
464 }
465 }
466}
467
468impl<'a, Eqn, Method> AdjointEquations<'a, Eqn, Method>
469where
470 Eqn: OdeEquationsAdjoint,
471 Method: OdeSolverMethod<'a, Eqn>,
472{
473 pub(crate) fn new(
474 problem: &'a OdeSolverProblem<Eqn>,
475 context: Rc<RefCell<AdjointContext<'a, Eqn, Method>>>,
476 with_out: bool,
477 ) -> Self {
478 let eqn = &problem.eqn;
479 let rhs = AdjointRhs::new(eqn, context.clone(), with_out);
480 let init = AdjointInit::new(eqn);
481 let out = AdjointOut::new(eqn, context.clone(), with_out);
482 let tmp = RefCell::new(<Eqn::V as Vector>::zeros(
483 eqn.rhs().nparams(),
484 eqn.context().clone(),
485 ));
486 let tmp2 = RefCell::new(<Eqn::V as Vector>::zeros(
487 eqn.rhs().nstates(),
488 eqn.context().clone(),
489 ));
490 let atol = problem.sens_atol.as_ref();
491 let rtol = problem.sens_rtol;
492 let out_atol = problem.param_atol.as_ref();
493 let out_rtol = problem.param_rtol;
494 let mass = eqn.mass().map(|_m| AdjointMass::new(eqn));
495 Self {
496 rhs,
497 init,
498 mass,
499 context,
500 out,
501 tmp,
502 tmp2,
503 eqn,
504 atol,
505 rtol,
506 out_rtol,
507 out_atol,
508 }
509 }
510
511 pub fn eqn(&self) -> &'a Eqn {
512 self.eqn
513 }
514
515 pub fn correct_sg_for_init(&self, t: Eqn::T, s: &[Eqn::V], sg: &mut [Eqn::V]) {
516 let mut tmp = self.tmp.borrow_mut();
517 for (s_i, sg_i) in s.iter().zip(sg.iter_mut()) {
518 if let Some(mass) = self.eqn.mass() {
519 let mut tmp2 = self.tmp2.borrow_mut();
520 mass.call_transpose_inplace(s_i, t, &mut tmp2);
521 self.eqn
522 .init()
523 .sens_transpose_mul_inplace(t, &tmp2, &mut tmp);
524 sg_i.sub_assign(&*tmp);
525 } else {
526 self.eqn.init().sens_transpose_mul_inplace(t, s_i, &mut tmp);
527 sg_i.sub_assign(&*tmp);
528 }
529 }
530 }
531
532 pub fn interpolate_forward_state(&self, t: Eqn::T, y: &mut Eqn::V) -> Result<(), DiffsolError> {
533 self.context.borrow_mut().set_state(t);
534 let context = self.context.borrow();
535 context.checkpointer.interpolate(t, y)
536 }
537}
538
539impl<'a, Eqn, Method> std::fmt::Debug for AdjointEquations<'a, Eqn, Method>
540where
541 Eqn: OdeEquations,
542 Method: OdeSolverMethod<'a, Eqn>,
543{
544 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
545 f.debug_struct("AdjointEquations").finish()
546 }
547}
548
549impl<'a, Eqn, Method> Op for AdjointEquations<'a, Eqn, Method>
550where
551 Eqn: OdeEquations,
552 Method: OdeSolverMethod<'a, Eqn>,
553{
554 type T = Eqn::T;
555 type V = Eqn::V;
556 type M = Eqn::M;
557 type C = Eqn::C;
558
559 fn nstates(&self) -> usize {
560 self.eqn.rhs().nstates()
561 }
562 fn nout(&self) -> usize {
563 self.eqn.rhs().nout()
564 }
565 fn nparams(&self) -> usize {
566 self.eqn.rhs().nparams()
567 }
568 fn context(&self) -> &Self::C {
569 self.eqn.context()
570 }
571}
572
573impl<'a, 'b, Eqn, Method> OdeEquationsRef<'a> for AdjointEquations<'b, Eqn, Method>
574where
575 Eqn: OdeEquationsAdjoint,
576 Method: OdeSolverMethod<'b, Eqn>,
577{
578 type Rhs = &'a AdjointRhs<'b, Eqn, Method>;
579 type Mass = &'a AdjointMass<'b, Eqn>;
580 type Root = <Eqn as OdeEquationsRef<'a>>::Root;
581 type Init = &'a AdjointInit<'b, Eqn>;
582 type Out = &'a AdjointOut<'b, Eqn, Method>;
583}
584
585impl<'a, Eqn, Method> OdeEquations for AdjointEquations<'a, Eqn, Method>
586where
587 Eqn: OdeEquationsAdjoint,
588 Method: OdeSolverMethod<'a, Eqn>,
589{
590 fn rhs(&self) -> &AdjointRhs<'a, Eqn, Method> {
591 &self.rhs
592 }
593 fn mass(&self) -> Option<&AdjointMass<'a, Eqn>> {
594 self.mass.as_ref()
595 }
596 fn root(&self) -> Option<<Eqn as OdeEquationsRef<'_>>::Root> {
597 None
598 }
599 fn init(&self) -> &AdjointInit<'a, Eqn> {
600 &self.init
601 }
602 fn out(&self) -> Option<&AdjointOut<'a, Eqn, Method>> {
603 Some(&self.out)
604 }
605 fn set_params(&mut self, p: &Self::V) {
606 self.eqn.set_params(p);
607 }
608 fn get_params(&self, p: &mut Self::V) {
609 self.eqn.get_params(p);
610 }
611}
612
613impl<'a, Eqn, Method> AugmentedOdeEquations<Eqn> for AdjointEquations<'a, Eqn, Method>
614where
615 Eqn: OdeEquationsAdjoint,
616 Method: OdeSolverMethod<'a, Eqn>,
617{
618 fn include_in_error_control(&self) -> bool {
619 self.atol.is_some() && self.rtol.is_some()
620 }
621 fn include_out_in_error_control(&self) -> bool {
622 self.out().is_some() && self.out_atol.is_some() && self.out_rtol.is_some()
623 }
624
625 fn atol(&self) -> Option<&Eqn::V> {
626 self.atol
627 }
628 fn out_atol(&self) -> Option<&Eqn::V> {
629 self.out_atol
630 }
631 fn out_rtol(&self) -> Option<Eqn::T> {
632 self.out_rtol
633 }
634 fn rtol(&self) -> Option<Eqn::T> {
635 self.rtol
636 }
637
638 fn max_index(&self) -> usize {
639 self.context.borrow().max_index
640 }
641
642 fn set_index(&mut self, index: usize) {
643 self.context.borrow_mut().set_index(index);
644 }
645
646 fn update_rhs_out_state(&mut self, _y: &Eqn::V, _dy: &Eqn::V, _t: Eqn::T) {}
647
648 fn integrate_main_eqn(&self) -> bool {
649 false
650 }
651}
652
653#[cfg(test)]
654mod tests {
655 use std::{cell::RefCell, rc::Rc};
656
657 use crate::{
658 matrix::dense_nalgebra_serial::NalgebraMat,
659 ode_equations::{
660 adjoint_equations::AdjointEquations,
661 test_models::exponential_decay::exponential_decay_problem_adjoint,
662 },
663 AdjointContext, AugmentedOdeEquations, Checkpointing, DenseMatrix, FaerSparseLU,
664 FaerSparseMat, FaerVec, Matrix, MatrixCommon, NalgebraVec, NonLinearOp,
665 NonLinearOpJacobian, OdeEquations, Op, RkState, Vector,
666 };
667 type Mcpu = NalgebraMat<f64>;
668 type Vcpu = NalgebraVec<f64>;
669 type LS = crate::NalgebraLU<f64>;
670
671 #[test]
672 fn test_rhs_exponential() {
673 let (problem, _soln) = exponential_decay_problem_adjoint::<Mcpu>(true);
676 let ctx = problem.eqn.context();
677 let state = RkState {
678 t: 0.0,
679 y: Vcpu::from_vec(vec![1.0, 1.0], *ctx),
680 dy: Vcpu::from_vec(vec![1.0, 1.0], *ctx),
681 g: Vcpu::zeros(0, *ctx),
682 dg: Vcpu::zeros(0, *ctx),
683 sg: Vec::new(),
684 dsg: Vec::new(),
685 s: Vec::new(),
686 ds: Vec::new(),
687 h: 0.0,
688 };
689 let nout = problem.eqn.out().unwrap().nout();
690 let solver = problem.esdirk34_solver::<LS>(state.clone()).unwrap();
691 let checkpointer = Checkpointing::new(solver, 0, vec![state.clone(), state.clone()], None);
692 let context = Rc::new(RefCell::new(AdjointContext::new(checkpointer, nout)));
693 let adj_eqn = AdjointEquations::new(&problem, context.clone(), false);
694 let v = Vcpu::from_vec(vec![1.0, 2.0], *ctx);
700 let f = adj_eqn.rhs.call(&v, state.t);
701 let f_expect = Vcpu::from_vec(vec![0.1, 0.2], *ctx);
702 f.assert_eq_st(&f_expect, 1e-10);
703
704 let mut adj_eqn = AdjointEquations::new(&problem, context, true);
705
706 let adjoint = adj_eqn.rhs.jacobian(&state.y, state.t);
710 assert_eq!(adjoint.nrows(), 2);
711 assert_eq!(adjoint.ncols(), 2);
712 assert_eq!(adjoint.get_index(0, 0), 0.1);
713 assert_eq!(adjoint.get_index(1, 1), 0.1);
714
715 adj_eqn.set_index(0);
729 let out = adj_eqn.out.call(&v, state.t);
730 let out_expect = Vcpu::from_vec(vec![3.0, 0.0], *ctx);
731 out.assert_eq_st(&out_expect, 1e-10);
732
733 let f = adj_eqn.rhs.call(&v, state.t);
739 let f_expect = Vcpu::from_vec(vec![-0.9, -1.8], *ctx);
740 f.assert_eq_st(&f_expect, 1e-10);
741 }
742
743 #[test]
744 fn test_rhs_exponential_sparse() {
745 let (problem, _soln) = exponential_decay_problem_adjoint::<FaerSparseMat<f64>>(true);
748 let ctx = problem.eqn.context();
749 let state = RkState {
750 t: 0.0,
751 y: FaerVec::from_vec(vec![1.0, 1.0], *ctx),
752 dy: FaerVec::from_vec(vec![1.0, 1.0], *ctx),
753 g: FaerVec::zeros(0, *ctx),
754 dg: FaerVec::zeros(0, *ctx),
755 sg: Vec::new(),
756 dsg: Vec::new(),
757 s: Vec::new(),
758 ds: Vec::new(),
759 h: 0.0,
760 };
761 let nout = problem.eqn.out().unwrap().nout();
762 let solver = problem
763 .esdirk34_solver::<FaerSparseLU<f64>>(state.clone())
764 .unwrap();
765 let checkpointer = Checkpointing::new(solver, 0, vec![state.clone(), state.clone()], None);
766 let context = Rc::new(RefCell::new(AdjointContext::new(checkpointer, nout)));
767 let mut adj_eqn = AdjointEquations::new(&problem, context, true);
768
769 let adjoint = adj_eqn.rhs.jacobian(&state.y, state.t);
773 assert_eq!(adjoint.nrows(), 2);
774 assert_eq!(adjoint.ncols(), 2);
775 for (i, j, v) in adjoint.triplet_iter() {
776 if i == j {
777 assert_eq!(v, 0.1);
778 } else {
779 assert_eq!(v, 0.0);
780 }
781 }
782
783 adj_eqn.set_index(0);
795 let v = FaerVec::from_vec(vec![1.0, 2.0], *ctx);
796 let f = adj_eqn.rhs.call(&v, state.t);
797 let f_expect = FaerVec::from_vec(vec![-0.9, -1.8], *ctx);
798 f.assert_eq_st(&f_expect, 1e-10);
799 }
800}