1use std::cell::RefCell;
2
3use crate::{
4 jacobian::{find_jacobian_non_zeros, find_sens_non_zeros, JacobianColoring},
5 Matrix, MatrixSparsity, NonLinearOp, NonLinearOpJacobian, NonLinearOpSens, Op, Vector,
6};
7
8use super::{BuilderOp, OpStatistics, ParameterisedOp};
9
10pub struct ClosureWithSens<M, F, G, H>
11where
12 M: Matrix,
13{
14 func: F,
15 jacobian_action: G,
16 sens_action: H,
17 nstates: usize,
18 nparams: usize,
19 nout: usize,
20 coloring: Option<JacobianColoring<M>>,
21 sens_coloring: Option<JacobianColoring<M>>,
22 sparsity: Option<M::Sparsity>,
23 sens_sparsity: Option<M::Sparsity>,
24 statistics: RefCell<OpStatistics>,
25 ctx: M::C,
26}
27
28impl<M, F, G, H> ClosureWithSens<M, F, G, H>
29where
30 M: Matrix,
31 F: Fn(&M::V, &M::V, M::T, &mut M::V),
32 G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
33 H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
34{
35 pub fn new(
36 func: F,
37 jacobian_action: G,
38 sens_action: H,
39 nstates: usize,
40 nparams: usize,
41 nout: usize,
42 ctx: M::C,
43 ) -> Self {
44 Self {
45 func,
46 jacobian_action,
47 sens_action,
48 nstates,
49 nout,
50 nparams,
51 statistics: RefCell::new(OpStatistics::default()),
52 coloring: None,
53 sparsity: None,
54 sens_coloring: None,
55 sens_sparsity: None,
56 ctx,
57 }
58 }
59
60 pub fn calculate_jacobian_sparsity(&mut self, y0: &M::V, t0: M::T, p: &M::V) {
61 let op = ParameterisedOp { op: self, p };
62 let non_zeros = find_jacobian_non_zeros(&op, y0, t0);
63 self.sparsity = Some(
64 MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone())
65 .expect("invalid sparsity pattern"),
66 );
67 self.coloring = Some(JacobianColoring::new(
68 self.sparsity.as_ref().unwrap(),
69 &non_zeros,
70 self.ctx.clone(),
71 ));
72 }
73 pub fn calculate_sens_sparsity(&mut self, y0: &M::V, t0: M::T, p: &M::V) {
74 let op = ParameterisedOp { op: self, p };
75 let non_zeros = find_sens_non_zeros(&op, y0, t0);
76 let nparams = p.len();
77 self.sens_sparsity = Some(
78 MatrixSparsity::try_from_indices(self.nout(), nparams, non_zeros.clone())
79 .expect("invalid sparsity pattern"),
80 );
81 self.sens_coloring = Some(JacobianColoring::new(
82 self.sens_sparsity.as_ref().unwrap(),
83 &non_zeros,
84 self.ctx.clone(),
85 ));
86 }
87}
88
89impl<M, F, G, H> BuilderOp for ClosureWithSens<M, F, G, H>
90where
91 M: Matrix,
92 F: Fn(&M::V, &M::V, M::T, &mut M::V),
93 G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
94 H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
95{
96 fn set_nstates(&mut self, nstates: usize) {
97 self.nstates = nstates;
98 }
99 fn set_nout(&mut self, nout: usize) {
100 self.nout = nout;
101 }
102 fn set_nparams(&mut self, nparams: usize) {
103 self.nparams = nparams;
104 }
105
106 fn calculate_sparsity(&mut self, y0: &Self::V, t0: Self::T, p: &Self::V) {
107 self.calculate_jacobian_sparsity(y0, t0, p);
108 self.calculate_sens_sparsity(y0, t0, p);
109 }
110}
111
112impl<M, F, G, H> Op for ClosureWithSens<M, F, G, H>
113where
114 M: Matrix,
115{
116 type V = M::V;
117 type T = M::T;
118 type M = M;
119 type C = M::C;
120 fn nstates(&self) -> usize {
121 self.nstates
122 }
123 fn nout(&self) -> usize {
124 self.nout
125 }
126 fn nparams(&self) -> usize {
127 self.nparams
128 }
129 fn statistics(&self) -> OpStatistics {
130 self.statistics.borrow().clone()
131 }
132 fn context(&self) -> &Self::C {
133 &self.ctx
134 }
135}
136
137impl<M, F, G, H> NonLinearOp for ParameterisedOp<'_, ClosureWithSens<M, F, G, H>>
138where
139 M: Matrix,
140 F: Fn(&M::V, &M::V, M::T, &mut M::V),
141 G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
142 H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
143{
144 fn call_inplace(&self, x: &M::V, t: M::T, y: &mut M::V) {
145 self.op.statistics.borrow_mut().increment_call();
146 (self.op.func)(x, self.p, t, y)
147 }
148}
149
150impl<M, F, G, H> NonLinearOpJacobian for ParameterisedOp<'_, ClosureWithSens<M, F, G, H>>
151where
152 M: Matrix,
153 F: Fn(&M::V, &M::V, M::T, &mut M::V),
154 G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
155 H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
156{
157 fn jac_mul_inplace(&self, x: &M::V, t: M::T, v: &M::V, y: &mut M::V) {
158 self.op.statistics.borrow_mut().increment_jac_mul();
159 (self.op.jacobian_action)(x, self.p, t, v, y)
160 }
161 fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
162 self.op.statistics.borrow_mut().increment_matrix();
163 if let Some(coloring) = self.op.coloring.as_ref() {
164 coloring.jacobian_inplace(self, x, t, y);
165 } else {
166 self._default_jacobian_inplace(x, t, y);
167 }
168 }
169 fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
170 self.op.sparsity.clone()
171 }
172}
173
174impl<M, F, G, H> NonLinearOpSens for ParameterisedOp<'_, ClosureWithSens<M, F, G, H>>
175where
176 M: Matrix,
177 F: Fn(&M::V, &M::V, M::T, &mut M::V),
178 G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
179 H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
180{
181 fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
182 (self.op.sens_action)(x, self.p, t, v, y);
183 }
184
185 fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
186 if let Some(coloring) = self.op.sens_coloring.as_ref() {
187 coloring.sens_inplace(self, x, t, y);
188 } else {
189 self._default_sens_inplace(x, t, y);
190 }
191 }
192 fn sens_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
193 self.op.sens_sparsity.clone()
194 }
195}