1use numra_core::Scalar;
8
9pub trait FdeSystem<S: Scalar> {
17 fn dim(&self) -> usize;
19
20 fn alpha(&self) -> S;
24
25 fn rhs(&self, t: S, y: &[S], f: &mut [S]);
27
28 fn is_valid_order(&self) -> bool {
30 let alpha = self.alpha();
31 alpha > S::ZERO && alpha <= S::ONE
32 }
33}
34
35#[derive(Clone, Debug)]
48pub struct FdeOptions<S: Scalar> {
49 pub dt: S,
51 pub max_steps: usize,
53 pub tol: S,
55 pub max_iter: usize,
57}
58
59impl<S: Scalar> Default for FdeOptions<S> {
60 fn default() -> Self {
61 Self {
62 dt: S::from_f64(0.01),
63 max_steps: 100_000,
64 tol: S::from_f64(1e-10),
65 max_iter: 100,
66 }
67 }
68}
69
70impl<S: Scalar> FdeOptions<S> {
71 pub fn dt(mut self, dt: S) -> Self {
72 self.dt = dt;
73 self
74 }
75
76 pub fn max_steps(mut self, max_steps: usize) -> Self {
77 self.max_steps = max_steps;
78 self
79 }
80
81 pub fn tol(mut self, tol: S) -> Self {
82 self.tol = tol;
83 self
84 }
85}
86
87#[derive(Clone, Debug, Default)]
89pub struct FdeStats {
90 pub n_rhs: usize,
92 pub n_steps: usize,
94}
95
96#[derive(Clone, Debug)]
98pub struct FdeResult<S: Scalar> {
99 pub t: Vec<S>,
101 pub y: Vec<S>,
103 pub dim: usize,
105 pub stats: FdeStats,
107 pub success: bool,
109 pub message: String,
111}
112
113impl<S: Scalar> FdeResult<S> {
114 pub fn new(t: Vec<S>, y: Vec<S>, dim: usize, stats: FdeStats) -> Self {
115 Self {
116 t,
117 y,
118 dim,
119 stats,
120 success: true,
121 message: String::new(),
122 }
123 }
124
125 pub fn failed(message: String, stats: FdeStats) -> Self {
126 Self {
127 t: Vec::new(),
128 y: Vec::new(),
129 dim: 0,
130 stats,
131 success: false,
132 message,
133 }
134 }
135
136 pub fn len(&self) -> usize {
137 self.t.len()
138 }
139
140 pub fn is_empty(&self) -> bool {
141 self.t.is_empty()
142 }
143
144 pub fn t_final(&self) -> Option<S> {
145 self.t.last().copied()
146 }
147
148 pub fn y_final(&self) -> Option<Vec<S>> {
149 if self.t.is_empty() {
150 None
151 } else {
152 let start = (self.t.len() - 1) * self.dim;
153 Some(self.y[start..start + self.dim].to_vec())
154 }
155 }
156
157 pub fn y_at(&self, i: usize) -> &[S] {
158 let start = i * self.dim;
159 &self.y[start..start + self.dim]
160 }
161}
162
163pub trait FdeSolver<S: Scalar> {
165 fn solve<Sys: FdeSystem<S>>(
167 system: &Sys,
168 t0: S,
169 tf: S,
170 y0: &[S],
171 options: &FdeOptions<S>,
172 ) -> Result<FdeResult<S>, String>;
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 struct TestFde;
180
181 impl FdeSystem<f64> for TestFde {
182 fn dim(&self) -> usize {
183 1
184 }
185 fn alpha(&self) -> f64 {
186 0.5
187 }
188 fn rhs(&self, _t: f64, y: &[f64], f: &mut [f64]) {
189 f[0] = -y[0];
190 }
191 }
192
193 #[test]
194 fn test_fde_system_trait() {
195 let sys = TestFde;
196 assert_eq!(sys.dim(), 1);
197 assert!((sys.alpha() - 0.5).abs() < 1e-10);
198 assert!(sys.is_valid_order());
199
200 let mut f = [0.0];
201 sys.rhs(0.0, &[1.0], &mut f);
202 assert!((f[0] - (-1.0)).abs() < 1e-10);
203 }
204
205 #[test]
206 fn test_fde_options() {
207 let opts: FdeOptions<f64> = FdeOptions::default().dt(0.001);
208 assert!((opts.dt - 0.001).abs() < 1e-10);
209 }
210
211 #[test]
212 fn test_invalid_alpha_zero() {
213 struct AlphaZero;
214 impl FdeSystem<f64> for AlphaZero {
215 fn dim(&self) -> usize {
216 1
217 }
218 fn alpha(&self) -> f64 {
219 0.0
220 }
221 fn rhs(&self, _t: f64, _y: &[f64], f: &mut [f64]) {
222 f[0] = 0.0;
223 }
224 }
225 assert!(!AlphaZero.is_valid_order(), "alpha=0 should be invalid");
226 }
227
228 #[test]
229 fn test_invalid_alpha_negative() {
230 struct AlphaNeg;
231 impl FdeSystem<f64> for AlphaNeg {
232 fn dim(&self) -> usize {
233 1
234 }
235 fn alpha(&self) -> f64 {
236 -0.5
237 }
238 fn rhs(&self, _t: f64, _y: &[f64], f: &mut [f64]) {
239 f[0] = 0.0;
240 }
241 }
242 assert!(!AlphaNeg.is_valid_order(), "alpha=-0.5 should be invalid");
243 }
244
245 #[test]
246 fn test_alpha_exactly_one() {
247 struct AlphaOne;
248 impl FdeSystem<f64> for AlphaOne {
249 fn dim(&self) -> usize {
250 1
251 }
252 fn alpha(&self) -> f64 {
253 1.0
254 }
255 fn rhs(&self, _t: f64, _y: &[f64], f: &mut [f64]) {
256 f[0] = 0.0;
257 }
258 }
259 assert!(AlphaOne.is_valid_order(), "alpha=1.0 should be valid");
260 }
261}