1use numra_core::Scalar;
8use numra_pde::Grid1D;
9
10#[derive(Clone, Debug, Default)]
12pub enum NoiseCorrelation<S: Scalar> {
13 #[default]
15 White,
16 Colored {
18 correlation_length: S,
20 },
21 TraceClass {
23 n_modes: usize,
25 decay_rate: S,
27 },
28}
29
30pub trait SpdeSystem<S: Scalar> {
38 fn dim(&self) -> usize {
40 1
41 }
42
43 fn drift(&self, t: S, u: &[S], du: &mut [S], grid: &Grid1D<S>);
51
52 fn diffusion(&self, t: S, u: &[S], sigma: &mut [S], grid: &Grid1D<S>);
63
64 fn noise_correlation(&self) -> NoiseCorrelation<S> {
66 NoiseCorrelation::White
67 }
68
69 fn is_additive(&self) -> bool {
71 true
72 }
73}
74
75#[allow(dead_code)]
77pub struct Spde1D<S: Scalar, F, G>
78where
79 F: Fn(S, &[S], &mut [S], &Grid1D<S>),
80 G: Fn(S, &[S], &mut [S], &Grid1D<S>),
81{
82 drift_fn: F,
84 diffusion_fn: G,
86 correlation: NoiseCorrelation<S>,
88 additive: bool,
90}
91
92#[allow(dead_code)]
93impl<S: Scalar, F, G> Spde1D<S, F, G>
94where
95 F: Fn(S, &[S], &mut [S], &Grid1D<S>),
96 G: Fn(S, &[S], &mut [S], &Grid1D<S>),
97{
98 pub fn new(drift: F, diffusion: G) -> Self {
100 Self {
101 drift_fn: drift,
102 diffusion_fn: diffusion,
103 correlation: NoiseCorrelation::White,
104 additive: true,
105 }
106 }
107
108 pub fn with_correlation(mut self, correlation: NoiseCorrelation<S>) -> Self {
110 self.correlation = correlation;
111 self
112 }
113
114 pub fn with_additive(mut self, additive: bool) -> Self {
116 self.additive = additive;
117 self
118 }
119}
120
121impl<S: Scalar, F, G> SpdeSystem<S> for Spde1D<S, F, G>
122where
123 F: Fn(S, &[S], &mut [S], &Grid1D<S>),
124 G: Fn(S, &[S], &mut [S], &Grid1D<S>),
125{
126 fn drift(&self, t: S, u: &[S], du: &mut [S], grid: &Grid1D<S>) {
127 (self.drift_fn)(t, u, du, grid)
128 }
129
130 fn diffusion(&self, t: S, u: &[S], sigma: &mut [S], grid: &Grid1D<S>) {
131 (self.diffusion_fn)(t, u, sigma, grid)
132 }
133
134 fn noise_correlation(&self) -> NoiseCorrelation<S> {
135 self.correlation.clone()
136 }
137
138 fn is_additive(&self) -> bool {
139 self.additive
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn test_spde1d_construction() {
149 let _spde: Spde1D<f64, _, _> = Spde1D::new(
150 |_t, u, du, grid| {
151 let dx = grid.dx_uniform();
152 let n = u.len();
153 for i in 0..n {
154 let u_left = if i == 0 { 0.0 } else { u[i - 1] };
155 let u_right = if i == n - 1 { 0.0 } else { u[i + 1] };
156 du[i] = (u_left - 2.0 * u[i] + u_right) / (dx * dx);
157 }
158 },
159 |_t, _u, sigma, _grid| {
160 for s in sigma.iter_mut() {
161 *s = 0.1;
162 }
163 },
164 );
165 }
166
167 #[test]
168 fn test_noise_correlation_default() {
169 let correlation: NoiseCorrelation<f64> = NoiseCorrelation::default();
170 assert!(matches!(correlation, NoiseCorrelation::White));
171 }
172
173 #[test]
174 fn test_colored_noise() {
175 let correlation: NoiseCorrelation<f64> = NoiseCorrelation::Colored {
176 correlation_length: 0.1,
177 };
178 if let NoiseCorrelation::Colored { correlation_length } = correlation {
179 assert!((correlation_length - 0.1).abs() < 1e-10);
180 }
181 }
182}