1use numra_core::Scalar;
8
9pub trait IdeSystem<S: Scalar> {
16 fn dim(&self) -> usize;
18
19 fn rhs(&self, t: S, y: &[S], f: &mut [S]);
23
24 fn kernel(&self, t: S, s: S, y_s: &[S], k: &mut [S]);
32
33 fn is_convolution_kernel(&self) -> bool {
37 false
38 }
39}
40
41#[derive(Clone, Debug)]
43pub struct IdeOptions<S: Scalar> {
44 pub dt: S,
46 pub max_steps: usize,
48 pub tol: S,
50 pub max_iter: usize,
52 pub quad_points: usize,
54}
55
56impl<S: Scalar> Default for IdeOptions<S> {
57 fn default() -> Self {
58 Self {
59 dt: S::from_f64(0.01),
60 max_steps: 100_000,
61 tol: S::from_f64(1e-10),
62 max_iter: 100,
63 quad_points: 4, }
65 }
66}
67
68impl<S: Scalar> IdeOptions<S> {
69 pub fn dt(mut self, dt: S) -> Self {
70 self.dt = dt;
71 self
72 }
73
74 pub fn max_steps(mut self, max_steps: usize) -> Self {
75 self.max_steps = max_steps;
76 self
77 }
78
79 pub fn tol(mut self, tol: S) -> Self {
80 self.tol = tol;
81 self
82 }
83
84 pub fn quad_points(mut self, n: usize) -> Self {
85 self.quad_points = n;
86 self
87 }
88}
89
90#[derive(Clone, Debug, Default)]
92pub struct IdeStats {
93 pub n_rhs: usize,
95 pub n_kernel: usize,
97 pub n_steps: usize,
99}
100
101#[derive(Clone, Debug)]
103pub struct IdeResult<S: Scalar> {
104 pub t: Vec<S>,
106 pub y: Vec<S>,
108 pub dim: usize,
110 pub stats: IdeStats,
112 pub success: bool,
114 pub message: String,
116}
117
118impl<S: Scalar> IdeResult<S> {
119 pub fn new(t: Vec<S>, y: Vec<S>, dim: usize, stats: IdeStats) -> Self {
120 Self {
121 t,
122 y,
123 dim,
124 stats,
125 success: true,
126 message: String::new(),
127 }
128 }
129
130 pub fn failed(message: String, stats: IdeStats) -> Self {
131 Self {
132 t: Vec::new(),
133 y: Vec::new(),
134 dim: 0,
135 stats,
136 success: false,
137 message,
138 }
139 }
140
141 pub fn len(&self) -> usize {
142 self.t.len()
143 }
144
145 pub fn is_empty(&self) -> bool {
146 self.t.is_empty()
147 }
148
149 pub fn t_final(&self) -> Option<S> {
150 self.t.last().copied()
151 }
152
153 pub fn y_final(&self) -> Option<Vec<S>> {
154 if self.t.is_empty() {
155 None
156 } else {
157 let start = (self.t.len() - 1) * self.dim;
158 Some(self.y[start..start + self.dim].to_vec())
159 }
160 }
161
162 pub fn y_at(&self, i: usize) -> &[S] {
163 let start = i * self.dim;
164 &self.y[start..start + self.dim]
165 }
166}
167
168pub trait IdeSolver<S: Scalar> {
170 fn solve<Sys: IdeSystem<S>>(
172 system: &Sys,
173 t0: S,
174 tf: S,
175 y0: &[S],
176 options: &IdeOptions<S>,
177 ) -> Result<IdeResult<S>, String>;
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 struct TestIde;
185
186 impl IdeSystem<f64> for TestIde {
187 fn dim(&self) -> usize {
188 1
189 }
190
191 fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
192 f[0] = -y[0];
193 }
194
195 fn kernel(&self, t: f64, s: f64, y_s: &[f64], k: &mut [f64]) {
196 k[0] = (-(t - s)).exp() * y_s[0];
198 }
199
200 fn is_convolution_kernel(&self) -> bool {
201 true
202 }
203 }
204
205 #[test]
206 fn test_ide_system_trait() {
207 let sys = TestIde;
208 assert_eq!(sys.dim(), 1);
209 assert!(sys.is_convolution_kernel());
210
211 let mut f = [0.0];
212 sys.rhs(0.0, &[1.0], &mut f);
213 assert!((f[0] - (-1.0)).abs() < 1e-10);
214
215 let mut k = [0.0];
216 sys.kernel(1.0, 0.5, &[2.0], &mut k);
217 assert!((k[0] - 2.0 * (-0.5_f64).exp()).abs() < 1e-10);
219 }
220
221 #[test]
222 fn test_ide_options() {
223 let opts: IdeOptions<f64> = IdeOptions::default().dt(0.001).quad_points(8);
224 assert!((opts.dt - 0.001).abs() < 1e-10);
225 assert_eq!(opts.quad_points, 8);
226 }
227
228 #[test]
229 fn test_ide_result_accessors() {
230 let t = vec![0.0, 0.5, 1.0];
231 let y = vec![1.0, 2.0, 0.8, 1.5, 0.6, 1.2]; let stats = IdeStats {
233 n_rhs: 10,
234 n_kernel: 20,
235 n_steps: 2,
236 };
237 let result = IdeResult::new(t, y, 2, stats);
238
239 assert_eq!(result.len(), 3);
241 assert!(!result.is_empty());
242
243 assert!((result.t_final().unwrap() - 1.0).abs() < 1e-15);
245
246 let yf = result.y_final().unwrap();
248 assert_eq!(yf.len(), 2);
249 assert!((yf[0] - 0.6).abs() < 1e-15);
250 assert!((yf[1] - 1.2).abs() < 1e-15);
251
252 let y0 = result.y_at(0);
254 assert!((y0[0] - 1.0).abs() < 1e-15);
255 assert!((y0[1] - 2.0).abs() < 1e-15);
256
257 let y1 = result.y_at(1);
258 assert!((y1[0] - 0.8).abs() < 1e-15);
259 assert!((y1[1] - 1.5).abs() < 1e-15);
260 }
261
262 #[test]
263 fn test_ide_result_failed() {
264 let stats = IdeStats {
265 n_rhs: 5,
266 n_kernel: 3,
267 n_steps: 1,
268 };
269 let result: IdeResult<f64> = IdeResult::failed("something went wrong".to_string(), stats);
270
271 assert!(!result.success);
272 assert!(result.message.contains("something went wrong"));
273 assert!(result.is_empty());
274 assert_eq!(result.len(), 0);
275 assert!(result.t_final().is_none());
276 assert!(result.y_final().is_none());
277 }
278
279 #[test]
280 fn test_ide_options_max_steps() {
281 let opts: IdeOptions<f64> = IdeOptions::default().max_steps(42);
282 assert_eq!(opts.max_steps, 42);
283
284 let default_opts: IdeOptions<f64> = IdeOptions::default();
286 assert_eq!(default_opts.max_steps, 100_000);
287 }
288}