1use crate::{
2 image::{ImageMut, ImageRef, ImageShared},
3 kernel::SystemBoundaryConds,
4 mesh::{Function, FunctionBorrowMut, Mesh},
5};
6use datasize::DataSize;
7use rayon::iter::{ParallelBridge, ParallelIterator};
8use reborrow::{Reborrow, ReborrowMut};
9
10#[derive(Clone, Copy, Debug, Default, DataSize)]
12pub enum Method {
13 #[default]
15 ForwardEuler,
16 RK4,
17 RK4KO6(f64),
18}
19
20#[derive(Clone, Debug, DataSize)]
21pub struct Integrator {
22 pub method: Method,
24 tmp: Vec<f64>,
26}
27
28impl Integrator {
29 pub fn new(method: Method) -> Self {
31 Self {
32 method,
33 tmp: Vec::new(),
34 }
35 }
36
37 pub fn step<const N: usize, C: SystemBoundaryConds<N> + Sync, F: Function<N> + Sync>(
39 &mut self,
40 mesh: &mut Mesh<N>,
41 order: usize,
42 conditions: C,
43 mut deriv: F,
44 h: f64,
45 mut result: ImageMut,
46 ) -> Result<(), F::Error>
47 where
48 F::Error: Send,
49 {
50 assert!(mesh.num_nodes() == result.num_nodes());
51 let num_channels = result.num_channels();
52
53 let dimension = num_channels * result.num_nodes();
55 self.tmp.clear();
56
57 match self.method {
58 Method::ForwardEuler => {
59 self.tmp.resize(dimension, 0.0);
61 let mut tmp = ImageMut::from_storage(&mut self.tmp, num_channels);
63
64 Self::copy_from(tmp.rb_mut(), result.rb());
66 mesh.apply(order, conditions.clone(), deriv, tmp.rb_mut())?;
67 Self::fused_multiply_add_assign(result, h, tmp.rb());
68
69 Ok(())
70 }
71 Method::RK4 | Method::RK4KO6(..) => {
72 self.tmp.resize(2 * dimension, 0.0);
73
74 let (tmp1, tmp2) = self.tmp.split_at_mut(dimension);
75 let mut tmp = ImageMut::from_storage(tmp1, num_channels);
76 let mut update = ImageMut::from_storage(tmp2, num_channels);
77
78 mesh.fill_boundary(order, conditions.clone(), result.rb_mut());
79
80 Self::copy_from(tmp.rb_mut(), result.rb());
82 deriv.preprocess(mesh, tmp.rb_mut())?;
83 mesh.apply(
84 order,
85 conditions.clone(),
86 FunctionBorrowMut(&mut deriv),
87 tmp.rb_mut(),
88 )?;
89 Self::fused_multiply_add_assign(update.rb_mut(), 1. / 6., tmp.rb());
90
91 Self::fused_multiply_add_dest(tmp.rb_mut(), result.rb(), h / 2.0);
93 mesh.fill_boundary(order, conditions.clone(), tmp.rb_mut());
94 deriv.preprocess(mesh, tmp.rb_mut())?;
95 mesh.apply(
96 order,
97 conditions.clone(),
98 FunctionBorrowMut(&mut deriv),
99 tmp.rb_mut(),
100 )?;
101 Self::fused_multiply_add_assign(update.rb_mut(), 1. / 3., tmp.rb());
102
103 Self::fused_multiply_add_dest(tmp.rb_mut(), result.rb(), h / 2.0);
105 mesh.fill_boundary(order, conditions.clone(), tmp.rb_mut());
106 deriv.preprocess(mesh, tmp.rb_mut())?;
107 mesh.apply(
108 order,
109 conditions.clone(),
110 FunctionBorrowMut(&mut deriv),
111 tmp.rb_mut(),
112 )?;
113 Self::fused_multiply_add_assign(update.rb_mut(), 1. / 3., tmp.rb());
114
115 Self::fused_multiply_add_dest(tmp.rb_mut(), result.rb(), h);
117 mesh.fill_boundary(order, conditions.clone(), tmp.rb_mut());
118 deriv.preprocess(mesh, tmp.rb_mut())?;
119 mesh.apply(
120 order,
121 conditions.clone(),
122 FunctionBorrowMut(&mut deriv),
123 tmp.rb_mut(),
124 )?;
125 Self::fused_multiply_add_assign(update.rb_mut(), 1. / 6., tmp.rb());
126
127 Self::fused_multiply_add_assign(result.rb_mut(), h, update.rb());
129
130 if let Method::RK4KO6(diss) = self.method {
131 mesh.fill_boundary_to_extent(order, 3, conditions.clone(), result.rb_mut());
132 deriv.preprocess(mesh, result.rb_mut())?;
133 mesh.dissipation::<6>(diss, result.rb_mut());
134 }
135
136 Ok(())
137 }
138 }
139 }
140
141 fn copy_from(dest: ImageMut, source: ImageRef) {
142 let shared: ImageShared = dest.into();
143 source.channels().par_bridge().for_each(|field| {
144 unsafe { shared.channel_mut(field) }.copy_from_slice(source.channel(field))
145 });
146 }
147
148 fn fused_multiply_add_assign(dest: ImageMut, h: f64, b: ImageRef) {
150 let shared: ImageShared = dest.into();
151 b.channels().par_bridge().for_each(|field| {
152 let dest = unsafe { shared.channel_mut(field) };
153 let src = b.channel(field);
154
155 dest.iter_mut().zip(src).for_each(|(a, b)| *a += h * b);
156 });
157 }
158
159 fn fused_multiply_add_dest(dest: ImageMut, a: ImageRef, h: f64) {
161 let shared: ImageShared = dest.into();
162 a.channels().par_bridge().for_each(|field| {
163 let dest = unsafe { shared.channel_mut(field) };
164 let a = a.channel(field);
165 dest.iter_mut().zip(a.iter()).for_each(|(d, a)| {
166 *d = a + h * *d;
167 });
168 });
169 }
170
171 pub fn scratch(&mut self, len: usize) -> &mut [f64] {
173 self.tmp.clear();
174 self.tmp.resize(len, 0.0);
175 &mut self.tmp
176 }
177}