eom/
semi_implicit.rs

1//! semi-implicit schemes
2
3use ndarray::*;
4use ndarray_linalg::*;
5
6use super::traits::*;
7
8/// Linear ODE with diagonalized matrix (exactly solvable)
9#[derive(Debug, Clone)]
10pub struct Diagonal<F: SemiImplicit> {
11    exp_diag: Array<F::Scalar, F::Dim>,
12    diag: Array<F::Scalar, F::Dim>,
13    dt: <F::Scalar as Scalar>::Real,
14}
15
16impl<F: SemiImplicit> TimeStep for Diagonal<F> {
17    type Time = <F::Scalar as Scalar>::Real;
18
19    fn get_dt(&self) -> Self::Time {
20        self.dt
21    }
22    fn set_dt(&mut self, dt: Self::Time) {
23        Zip::from(&mut self.exp_diag)
24            .and(&self.diag)
25            .for_each(|a, &b| {
26                *a = b.mul_real(dt).exp();
27            });
28    }
29}
30
31impl<F: SemiImplicit> ModelSpec for Diagonal<F> {
32    type Scalar = F::Scalar;
33    type Dim = F::Dim;
34
35    fn model_size(&self) -> <Self::Dim as Dimension>::Pattern {
36        self.exp_diag.dim()
37    }
38}
39
40impl<F: SemiImplicit> TimeEvolution for Diagonal<F> {
41    fn iterate<'a, S>(
42        &mut self,
43        x: &'a mut ArrayBase<S, Self::Dim>,
44    ) -> &'a mut ArrayBase<S, Self::Dim>
45    where
46        S: DataMut<Elem = Self::Scalar>,
47    {
48        for (val, d) in x.iter_mut().zip(self.exp_diag.iter()) {
49            *val *= *d;
50        }
51        x
52    }
53}
54
55impl<F: SemiImplicit> Diagonal<F> {
56    fn new(f: F, dt: <Self as TimeStep>::Time) -> Self {
57        let diag = f.diag();
58        let mut exp_diag = diag.to_owned();
59        for v in exp_diag.iter_mut() {
60            *v = v.mul_real(dt).exp();
61        }
62        Diagonal { exp_diag, diag, dt }
63    }
64}
65
66#[derive(Debug, Clone)]
67pub struct DiagRK4<F: SemiImplicit> {
68    nlin: F,
69    lin: Diagonal<F>,
70    dt: <Diagonal<F> as TimeStep>::Time,
71    x: Array<F::Scalar, F::Dim>,
72    lx: Array<F::Scalar, F::Dim>,
73    k1: Array<F::Scalar, F::Dim>,
74    k2: Array<F::Scalar, F::Dim>,
75    k3: Array<F::Scalar, F::Dim>,
76}
77
78impl<F: SemiImplicit> Scheme for DiagRK4<F> {
79    type Core = F;
80    fn new(nlin: F, dt: Self::Time) -> Self {
81        let lin = Diagonal::new(nlin.clone(), dt / F::Scalar::real(2.0));
82        let x = Array::zeros(lin.model_size());
83        let lx = Array::zeros(lin.model_size());
84        let k1 = Array::zeros(lin.model_size());
85        let k2 = Array::zeros(lin.model_size());
86        let k3 = Array::zeros(lin.model_size());
87        DiagRK4 {
88            nlin,
89            lin,
90            dt,
91            x,
92            lx,
93            k1,
94            k2,
95            k3,
96        }
97    }
98    fn core(&self) -> &Self::Core {
99        &self.nlin
100    }
101    fn core_mut(&mut self) -> &mut Self::Core {
102        &mut self.nlin
103    }
104}
105
106impl<F: SemiImplicit> TimeStep for DiagRK4<F> {
107    type Time = <Diagonal<F> as TimeStep>::Time;
108
109    fn get_dt(&self) -> Self::Time {
110        self.dt
111    }
112
113    fn set_dt(&mut self, dt: Self::Time) {
114        self.lin.set_dt(dt / F::Scalar::real(2.0));
115    }
116}
117
118impl<F: SemiImplicit> ModelSpec for DiagRK4<F> {
119    type Scalar = F::Scalar;
120    type Dim = F::Dim;
121
122    fn model_size(&self) -> <Self::Dim as Dimension>::Pattern {
123        self.nlin.model_size() // TODO check
124    }
125}
126
127impl<F: SemiImplicit> TimeEvolution for DiagRK4<F> {
128    fn iterate<'a, S>(
129        &mut self,
130        x: &'a mut ArrayBase<S, Self::Dim>,
131    ) -> &'a mut ArrayBase<S, Self::Dim>
132    where
133        S: DataMut<Elem = Self::Scalar>,
134    {
135        // constants
136        let dt = self.dt;
137        let dt_2 = self.dt / F::Scalar::real(2.0);
138        let dt_3 = self.dt / F::Scalar::real(3.0);
139        let dt_6 = self.dt / F::Scalar::real(6.0);
140        // operators
141        let l = &mut self.lin;
142        let f = &mut self.nlin;
143        // calc
144        self.x.zip_mut_with(x, |buf, x| *buf = *x);
145        self.lx.zip_mut_with(x, |buf, lx| *buf = *lx);
146        l.iterate(&mut self.lx);
147        let k1 = f.nlin(x);
148        self.k1.zip_mut_with(k1, |buf, k1| *buf = *k1);
149        Zip::from(&mut *k1).and(&self.x).for_each(|k1, &x_| {
150            *k1 = x_ + k1.mul_real(dt_2);
151        });
152        let k2 = f.nlin(l.iterate(k1));
153        self.k2.zip_mut_with(k2, |buf, k| *buf = *k);
154        Zip::from(&mut *k2).and(&self.lx).for_each(|k2, &lx| {
155            *k2 = lx + k2.mul_real(dt_2);
156        });
157        let k3 = f.nlin(k2);
158        self.k3.zip_mut_with(k3, |buf, k| *buf = *k);
159        Zip::from(&mut *k3).and(&self.lx).for_each(|k3, &lx| {
160            *k3 = lx + k3.mul_real(dt);
161        });
162        let k4 = f.nlin(l.iterate(k3));
163        Zip::from(&mut self.x)
164            .and(&self.k1)
165            .for_each(|x_, k1_| *x_ += k1_.mul_real(dt_6));
166        l.iterate(&mut self.x);
167        Zip::from(&mut self.x)
168            .and(&self.k2)
169            .and(&self.k3)
170            .for_each(|x_, &k2_, &k3_| *x_ += (k2_ + k3_).mul_real(dt_3));
171        l.iterate(&mut self.x);
172        Zip::from(&mut *k4).and(&self.x).for_each(|k4, &x_| {
173            *k4 = x_ + k4.mul_real(dt_6);
174        });
175        k4
176    }
177}