1use numra_core::Scalar;
8
9pub trait IdeSystem<S: Scalar> {
16 fn dim(&self) -> usize;
18
19 fn rhs(&self, t: S, y: &[S], f: &mut [S]);
23
24 fn kernel(&self, t: S, s: S, y_s: &[S], k: &mut [S]);
32
33 fn is_convolution_kernel(&self) -> bool {
37 false
38 }
39}
40
41#[derive(Clone, Debug)]
52pub struct IdeOptions<S: Scalar> {
53 pub dt: S,
55 pub max_steps: usize,
57 pub tol: S,
59 pub max_iter: usize,
61 pub quad_points: usize,
63}
64
65impl<S: Scalar> Default for IdeOptions<S> {
66 fn default() -> Self {
67 Self {
68 dt: S::from_f64(0.01),
69 max_steps: 100_000,
70 tol: S::from_f64(1e-10),
71 max_iter: 100,
72 quad_points: 4, }
74 }
75}
76
77impl<S: Scalar> IdeOptions<S> {
78 pub fn dt(mut self, dt: S) -> Self {
79 self.dt = dt;
80 self
81 }
82
83 pub fn max_steps(mut self, max_steps: usize) -> Self {
84 self.max_steps = max_steps;
85 self
86 }
87
88 pub fn tol(mut self, tol: S) -> Self {
89 self.tol = tol;
90 self
91 }
92
93 pub fn quad_points(mut self, n: usize) -> Self {
94 self.quad_points = n;
95 self
96 }
97}
98
99#[derive(Clone, Debug, Default)]
101pub struct IdeStats {
102 pub n_rhs: usize,
104 pub n_kernel: usize,
106 pub n_steps: usize,
108}
109
110#[derive(Clone, Debug)]
112pub struct IdeResult<S: Scalar> {
113 pub t: Vec<S>,
115 pub y: Vec<S>,
117 pub dim: usize,
119 pub stats: IdeStats,
121 pub success: bool,
123 pub message: String,
125}
126
127impl<S: Scalar> IdeResult<S> {
128 pub fn new(t: Vec<S>, y: Vec<S>, dim: usize, stats: IdeStats) -> Self {
129 Self {
130 t,
131 y,
132 dim,
133 stats,
134 success: true,
135 message: String::new(),
136 }
137 }
138
139 pub fn failed(message: String, stats: IdeStats) -> Self {
140 Self {
141 t: Vec::new(),
142 y: Vec::new(),
143 dim: 0,
144 stats,
145 success: false,
146 message,
147 }
148 }
149
150 pub fn len(&self) -> usize {
151 self.t.len()
152 }
153
154 pub fn is_empty(&self) -> bool {
155 self.t.is_empty()
156 }
157
158 pub fn t_final(&self) -> Option<S> {
159 self.t.last().copied()
160 }
161
162 pub fn y_final(&self) -> Option<Vec<S>> {
163 if self.t.is_empty() {
164 None
165 } else {
166 let start = (self.t.len() - 1) * self.dim;
167 Some(self.y[start..start + self.dim].to_vec())
168 }
169 }
170
171 pub fn y_at(&self, i: usize) -> &[S] {
172 let start = i * self.dim;
173 &self.y[start..start + self.dim]
174 }
175}
176
177pub trait IdeSolver<S: Scalar> {
179 fn solve<Sys: IdeSystem<S>>(
181 system: &Sys,
182 t0: S,
183 tf: S,
184 y0: &[S],
185 options: &IdeOptions<S>,
186 ) -> Result<IdeResult<S>, String>;
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 struct TestIde;
194
195 impl IdeSystem<f64> for TestIde {
196 fn dim(&self) -> usize {
197 1
198 }
199
200 fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
201 f[0] = -y[0];
202 }
203
204 fn kernel(&self, t: f64, s: f64, y_s: &[f64], k: &mut [f64]) {
205 k[0] = (-(t - s)).exp() * y_s[0];
207 }
208
209 fn is_convolution_kernel(&self) -> bool {
210 true
211 }
212 }
213
214 #[test]
215 fn test_ide_system_trait() {
216 let sys = TestIde;
217 assert_eq!(sys.dim(), 1);
218 assert!(sys.is_convolution_kernel());
219
220 let mut f = [0.0];
221 sys.rhs(0.0, &[1.0], &mut f);
222 assert!((f[0] - (-1.0)).abs() < 1e-10);
223
224 let mut k = [0.0];
225 sys.kernel(1.0, 0.5, &[2.0], &mut k);
226 assert!((k[0] - 2.0 * (-0.5_f64).exp()).abs() < 1e-10);
228 }
229
230 #[test]
231 fn test_ide_options() {
232 let opts: IdeOptions<f64> = IdeOptions::default().dt(0.001).quad_points(8);
233 assert!((opts.dt - 0.001).abs() < 1e-10);
234 assert_eq!(opts.quad_points, 8);
235 }
236
237 #[test]
238 fn test_ide_result_accessors() {
239 let t = vec![0.0, 0.5, 1.0];
240 let y = vec![1.0, 2.0, 0.8, 1.5, 0.6, 1.2]; let stats = IdeStats {
242 n_rhs: 10,
243 n_kernel: 20,
244 n_steps: 2,
245 };
246 let result = IdeResult::new(t, y, 2, stats);
247
248 assert_eq!(result.len(), 3);
250 assert!(!result.is_empty());
251
252 assert!((result.t_final().unwrap() - 1.0).abs() < 1e-15);
254
255 let yf = result.y_final().unwrap();
257 assert_eq!(yf.len(), 2);
258 assert!((yf[0] - 0.6).abs() < 1e-15);
259 assert!((yf[1] - 1.2).abs() < 1e-15);
260
261 let y0 = result.y_at(0);
263 assert!((y0[0] - 1.0).abs() < 1e-15);
264 assert!((y0[1] - 2.0).abs() < 1e-15);
265
266 let y1 = result.y_at(1);
267 assert!((y1[0] - 0.8).abs() < 1e-15);
268 assert!((y1[1] - 1.5).abs() < 1e-15);
269 }
270
271 #[test]
272 fn test_ide_result_failed() {
273 let stats = IdeStats {
274 n_rhs: 5,
275 n_kernel: 3,
276 n_steps: 1,
277 };
278 let result: IdeResult<f64> = IdeResult::failed("something went wrong".to_string(), stats);
279
280 assert!(!result.success);
281 assert!(result.message.contains("something went wrong"));
282 assert!(result.is_empty());
283 assert_eq!(result.len(), 0);
284 assert!(result.t_final().is_none());
285 assert!(result.y_final().is_none());
286 }
287
288 #[test]
289 fn test_ide_options_max_steps() {
290 let opts: IdeOptions<f64> = IdeOptions::default().max_steps(42);
291 assert_eq!(opts.max_steps, 42);
292
293 let default_opts: IdeOptions<f64> = IdeOptions::default();
295 assert_eq!(default_opts.max_steps, 100_000);
296 }
297}