1use ndarray::*;
4use ndarray_linalg::*;
5
6use super::traits::*;
7
8#[derive(Debug, Clone)]
10pub struct Diagonal<F: SemiImplicit> {
11 exp_diag: Array<F::Scalar, F::Dim>,
12 diag: Array<F::Scalar, F::Dim>,
13 dt: <F::Scalar as Scalar>::Real,
14}
15
16impl<F: SemiImplicit> TimeStep for Diagonal<F> {
17 type Time = <F::Scalar as Scalar>::Real;
18
19 fn get_dt(&self) -> Self::Time {
20 self.dt
21 }
22 fn set_dt(&mut self, dt: Self::Time) {
23 Zip::from(&mut self.exp_diag)
24 .and(&self.diag)
25 .for_each(|a, &b| {
26 *a = b.mul_real(dt).exp();
27 });
28 }
29}
30
31impl<F: SemiImplicit> ModelSpec for Diagonal<F> {
32 type Scalar = F::Scalar;
33 type Dim = F::Dim;
34
35 fn model_size(&self) -> <Self::Dim as Dimension>::Pattern {
36 self.exp_diag.dim()
37 }
38}
39
40impl<F: SemiImplicit> TimeEvolution for Diagonal<F> {
41 fn iterate<'a, S>(
42 &mut self,
43 x: &'a mut ArrayBase<S, Self::Dim>,
44 ) -> &'a mut ArrayBase<S, Self::Dim>
45 where
46 S: DataMut<Elem = Self::Scalar>,
47 {
48 for (val, d) in x.iter_mut().zip(self.exp_diag.iter()) {
49 *val *= *d;
50 }
51 x
52 }
53}
54
55impl<F: SemiImplicit> Diagonal<F> {
56 fn new(f: F, dt: <Self as TimeStep>::Time) -> Self {
57 let diag = f.diag();
58 let mut exp_diag = diag.to_owned();
59 for v in exp_diag.iter_mut() {
60 *v = v.mul_real(dt).exp();
61 }
62 Diagonal { exp_diag, diag, dt }
63 }
64}
65
66#[derive(Debug, Clone)]
67pub struct DiagRK4<F: SemiImplicit> {
68 nlin: F,
69 lin: Diagonal<F>,
70 dt: <Diagonal<F> as TimeStep>::Time,
71 x: Array<F::Scalar, F::Dim>,
72 lx: Array<F::Scalar, F::Dim>,
73 k1: Array<F::Scalar, F::Dim>,
74 k2: Array<F::Scalar, F::Dim>,
75 k3: Array<F::Scalar, F::Dim>,
76}
77
78impl<F: SemiImplicit> Scheme for DiagRK4<F> {
79 type Core = F;
80 fn new(nlin: F, dt: Self::Time) -> Self {
81 let lin = Diagonal::new(nlin.clone(), dt / F::Scalar::real(2.0));
82 let x = Array::zeros(lin.model_size());
83 let lx = Array::zeros(lin.model_size());
84 let k1 = Array::zeros(lin.model_size());
85 let k2 = Array::zeros(lin.model_size());
86 let k3 = Array::zeros(lin.model_size());
87 DiagRK4 {
88 nlin,
89 lin,
90 dt,
91 x,
92 lx,
93 k1,
94 k2,
95 k3,
96 }
97 }
98 fn core(&self) -> &Self::Core {
99 &self.nlin
100 }
101 fn core_mut(&mut self) -> &mut Self::Core {
102 &mut self.nlin
103 }
104}
105
106impl<F: SemiImplicit> TimeStep for DiagRK4<F> {
107 type Time = <Diagonal<F> as TimeStep>::Time;
108
109 fn get_dt(&self) -> Self::Time {
110 self.dt
111 }
112
113 fn set_dt(&mut self, dt: Self::Time) {
114 self.lin.set_dt(dt / F::Scalar::real(2.0));
115 }
116}
117
118impl<F: SemiImplicit> ModelSpec for DiagRK4<F> {
119 type Scalar = F::Scalar;
120 type Dim = F::Dim;
121
122 fn model_size(&self) -> <Self::Dim as Dimension>::Pattern {
123 self.nlin.model_size() }
125}
126
127impl<F: SemiImplicit> TimeEvolution for DiagRK4<F> {
128 fn iterate<'a, S>(
129 &mut self,
130 x: &'a mut ArrayBase<S, Self::Dim>,
131 ) -> &'a mut ArrayBase<S, Self::Dim>
132 where
133 S: DataMut<Elem = Self::Scalar>,
134 {
135 let dt = self.dt;
137 let dt_2 = self.dt / F::Scalar::real(2.0);
138 let dt_3 = self.dt / F::Scalar::real(3.0);
139 let dt_6 = self.dt / F::Scalar::real(6.0);
140 let l = &mut self.lin;
142 let f = &mut self.nlin;
143 self.x.zip_mut_with(x, |buf, x| *buf = *x);
145 self.lx.zip_mut_with(x, |buf, lx| *buf = *lx);
146 l.iterate(&mut self.lx);
147 let k1 = f.nlin(x);
148 self.k1.zip_mut_with(k1, |buf, k1| *buf = *k1);
149 Zip::from(&mut *k1).and(&self.x).for_each(|k1, &x_| {
150 *k1 = x_ + k1.mul_real(dt_2);
151 });
152 let k2 = f.nlin(l.iterate(k1));
153 self.k2.zip_mut_with(k2, |buf, k| *buf = *k);
154 Zip::from(&mut *k2).and(&self.lx).for_each(|k2, &lx| {
155 *k2 = lx + k2.mul_real(dt_2);
156 });
157 let k3 = f.nlin(k2);
158 self.k3.zip_mut_with(k3, |buf, k| *buf = *k);
159 Zip::from(&mut *k3).and(&self.lx).for_each(|k3, &lx| {
160 *k3 = lx + k3.mul_real(dt);
161 });
162 let k4 = f.nlin(l.iterate(k3));
163 Zip::from(&mut self.x)
164 .and(&self.k1)
165 .for_each(|x_, k1_| *x_ += k1_.mul_real(dt_6));
166 l.iterate(&mut self.x);
167 Zip::from(&mut self.x)
168 .and(&self.k2)
169 .and(&self.k3)
170 .for_each(|x_, &k2_, &k3_| *x_ += (k2_ + k3_).mul_real(dt_3));
171 l.iterate(&mut self.x);
172 Zip::from(&mut *k4).and(&self.x).for_each(|k4, &x_| {
173 *k4 = x_ + k4.mul_real(dt_6);
174 });
175 k4
176 }
177}