Skip to main content

aeon_tk/solver/
intergrate.rs

1use crate::{
2    image::{ImageMut, ImageRef, ImageShared},
3    kernel::SystemBoundaryConds,
4    mesh::{Function, FunctionBorrowMut, Mesh},
5};
6use datasize::DataSize;
7use rayon::iter::{ParallelBridge, ParallelIterator};
8use reborrow::{Reborrow, ReborrowMut};
9
10/// Method to be used for numerical intergration of ODE.
11#[derive(Clone, Copy, Debug, Default, DataSize)]
12pub enum Method {
13    // First order accurate Euler integration
14    #[default]
15    ForwardEuler,
16    RK4,
17    RK4KO6(f64),
18}
19
20#[derive(Clone, Debug, DataSize)]
21pub struct Integrator {
22    /// Numerical Method
23    pub method: Method,
24    /// Intermediate data storage.
25    tmp: Vec<f64>,
26}
27
28impl Integrator {
29    /// Constructs a new integrator which is set to use the given method.
30    pub fn new(method: Method) -> Self {
31        Self {
32            method,
33            tmp: Vec::new(),
34        }
35    }
36
37    /// Step the integrator forwards in time.
38    pub fn step<const N: usize, C: SystemBoundaryConds<N> + Sync, F: Function<N> + Sync>(
39        &mut self,
40        mesh: &mut Mesh<N>,
41        order: usize,
42        conditions: C,
43        mut deriv: F,
44        h: f64,
45        mut result: ImageMut,
46    ) -> Result<(), F::Error>
47    where
48        F::Error: Send,
49    {
50        assert!(mesh.num_nodes() == result.num_nodes());
51        let num_channels = result.num_channels();
52
53        // Number of degrees of freedom required to store one system.
54        let dimension = num_channels * result.num_nodes();
55        self.tmp.clear();
56
57        match self.method {
58            Method::ForwardEuler => {
59                // Resize temporary vector to appropriate size
60                self.tmp.resize(dimension, 0.0);
61                // Retrieve reference to tmp `SystemVec`.
62                let mut tmp = ImageMut::from_storage(&mut self.tmp, num_channels);
63
64                // First step
65                Self::copy_from(tmp.rb_mut(), result.rb());
66                mesh.apply(order, conditions.clone(), deriv, tmp.rb_mut())?;
67                Self::fused_multiply_add_assign(result, h, tmp.rb());
68
69                Ok(())
70            }
71            Method::RK4 | Method::RK4KO6(..) => {
72                self.tmp.resize(2 * dimension, 0.0);
73
74                let (tmp1, tmp2) = self.tmp.split_at_mut(dimension);
75                let mut tmp = ImageMut::from_storage(tmp1, num_channels);
76                let mut update = ImageMut::from_storage(tmp2, num_channels);
77
78                mesh.fill_boundary(order, conditions.clone(), result.rb_mut());
79
80                // K1
81                Self::copy_from(tmp.rb_mut(), result.rb());
82                deriv.preprocess(mesh, tmp.rb_mut())?;
83                mesh.apply(
84                    order,
85                    conditions.clone(),
86                    FunctionBorrowMut(&mut deriv),
87                    tmp.rb_mut(),
88                )?;
89                Self::fused_multiply_add_assign(update.rb_mut(), 1. / 6., tmp.rb());
90
91                // K2
92                Self::fused_multiply_add_dest(tmp.rb_mut(), result.rb(), h / 2.0);
93                mesh.fill_boundary(order, conditions.clone(), tmp.rb_mut());
94                deriv.preprocess(mesh, tmp.rb_mut())?;
95                mesh.apply(
96                    order,
97                    conditions.clone(),
98                    FunctionBorrowMut(&mut deriv),
99                    tmp.rb_mut(),
100                )?;
101                Self::fused_multiply_add_assign(update.rb_mut(), 1. / 3., tmp.rb());
102
103                // K3
104                Self::fused_multiply_add_dest(tmp.rb_mut(), result.rb(), h / 2.0);
105                mesh.fill_boundary(order, conditions.clone(), tmp.rb_mut());
106                deriv.preprocess(mesh, tmp.rb_mut())?;
107                mesh.apply(
108                    order,
109                    conditions.clone(),
110                    FunctionBorrowMut(&mut deriv),
111                    tmp.rb_mut(),
112                )?;
113                Self::fused_multiply_add_assign(update.rb_mut(), 1. / 3., tmp.rb());
114
115                // K4
116                Self::fused_multiply_add_dest(tmp.rb_mut(), result.rb(), h);
117                mesh.fill_boundary(order, conditions.clone(), tmp.rb_mut());
118                deriv.preprocess(mesh, tmp.rb_mut())?;
119                mesh.apply(
120                    order,
121                    conditions.clone(),
122                    FunctionBorrowMut(&mut deriv),
123                    tmp.rb_mut(),
124                )?;
125                Self::fused_multiply_add_assign(update.rb_mut(), 1. / 6., tmp.rb());
126
127                // Sum everything
128                Self::fused_multiply_add_assign(result.rb_mut(), h, update.rb());
129
130                if let Method::RK4KO6(diss) = self.method {
131                    mesh.fill_boundary_to_extent(order, 3, conditions.clone(), result.rb_mut());
132                    deriv.preprocess(mesh, result.rb_mut())?;
133                    mesh.dissipation::<6>(diss, result.rb_mut());
134                }
135
136                Ok(())
137            }
138        }
139    }
140
141    fn copy_from(dest: ImageMut, source: ImageRef) {
142        let shared: ImageShared = dest.into();
143        source.channels().par_bridge().for_each(|field| {
144            unsafe { shared.channel_mut(field) }.copy_from_slice(source.channel(field))
145        });
146    }
147
148    /// Performs operation `dest = dest + h * b`
149    fn fused_multiply_add_assign(dest: ImageMut, h: f64, b: ImageRef) {
150        let shared: ImageShared = dest.into();
151        b.channels().par_bridge().for_each(|field| {
152            let dest = unsafe { shared.channel_mut(field) };
153            let src = b.channel(field);
154
155            dest.iter_mut().zip(src).for_each(|(a, b)| *a += h * b);
156        });
157    }
158
159    /// Performs operation `dest = a + h * dest`
160    fn fused_multiply_add_dest(dest: ImageMut, a: ImageRef, h: f64) {
161        let shared: ImageShared = dest.into();
162        a.channels().par_bridge().for_each(|field| {
163            let dest = unsafe { shared.channel_mut(field) };
164            let a = a.channel(field);
165            dest.iter_mut().zip(a.iter()).for_each(|(d, a)| {
166                *d = a + h * *d;
167            });
168        });
169    }
170
171    // Allocates `len` elements using the intergrator's scratch data.
172    pub fn scratch(&mut self, len: usize) -> &mut [f64] {
173        self.tmp.clear();
174        self.tmp.resize(len, 0.0);
175        &mut self.tmp
176    }
177}