use crate::{
image::{ImageMut, ImageRef, ImageShared},
kernel::SystemBoundaryConds,
mesh::{Function, FunctionBorrowMut, Mesh},
};
use datasize::DataSize;
use rayon::iter::{ParallelBridge, ParallelIterator};
use reborrow::{Reborrow, ReborrowMut};
#[derive(Clone, Copy, Debug, Default, DataSize)]
pub enum Method {
#[default]
ForwardEuler,
RK4,
RK4KO6(f64),
}
#[derive(Clone, Debug, DataSize)]
pub struct Integrator {
pub method: Method,
tmp: Vec<f64>,
}
impl Integrator {
pub fn new(method: Method) -> Self {
Self {
method,
tmp: Vec::new(),
}
}
pub fn step<const N: usize, C: SystemBoundaryConds<N> + Sync, F: Function<N> + Sync>(
&mut self,
mesh: &mut Mesh<N>,
order: usize,
conditions: C,
mut deriv: F,
h: f64,
mut result: ImageMut,
) -> Result<(), F::Error>
where
F::Error: Send,
{
assert!(mesh.num_nodes() == result.num_nodes());
let num_channels = result.num_channels();
let dimension = num_channels * result.num_nodes();
self.tmp.clear();
match self.method {
Method::ForwardEuler => {
self.tmp.resize(dimension, 0.0);
let mut tmp = ImageMut::from_storage(&mut self.tmp, num_channels);
Self::copy_from(tmp.rb_mut(), result.rb());
mesh.apply(order, conditions.clone(), deriv, tmp.rb_mut())?;
Self::fused_multiply_add_assign(result, h, tmp.rb());
Ok(())
}
Method::RK4 | Method::RK4KO6(..) => {
self.tmp.resize(2 * dimension, 0.0);
let (tmp1, tmp2) = self.tmp.split_at_mut(dimension);
let mut tmp = ImageMut::from_storage(tmp1, num_channels);
let mut update = ImageMut::from_storage(tmp2, num_channels);
mesh.fill_boundary(order, conditions.clone(), result.rb_mut());
Self::copy_from(tmp.rb_mut(), result.rb());
deriv.preprocess(mesh, tmp.rb_mut())?;
mesh.apply(
order,
conditions.clone(),
FunctionBorrowMut(&mut deriv),
tmp.rb_mut(),
)?;
Self::fused_multiply_add_assign(update.rb_mut(), 1. / 6., tmp.rb());
Self::fused_multiply_add_dest(tmp.rb_mut(), result.rb(), h / 2.0);
mesh.fill_boundary(order, conditions.clone(), tmp.rb_mut());
deriv.preprocess(mesh, tmp.rb_mut())?;
mesh.apply(
order,
conditions.clone(),
FunctionBorrowMut(&mut deriv),
tmp.rb_mut(),
)?;
Self::fused_multiply_add_assign(update.rb_mut(), 1. / 3., tmp.rb());
Self::fused_multiply_add_dest(tmp.rb_mut(), result.rb(), h / 2.0);
mesh.fill_boundary(order, conditions.clone(), tmp.rb_mut());
deriv.preprocess(mesh, tmp.rb_mut())?;
mesh.apply(
order,
conditions.clone(),
FunctionBorrowMut(&mut deriv),
tmp.rb_mut(),
)?;
Self::fused_multiply_add_assign(update.rb_mut(), 1. / 3., tmp.rb());
Self::fused_multiply_add_dest(tmp.rb_mut(), result.rb(), h);
mesh.fill_boundary(order, conditions.clone(), tmp.rb_mut());
deriv.preprocess(mesh, tmp.rb_mut())?;
mesh.apply(
order,
conditions.clone(),
FunctionBorrowMut(&mut deriv),
tmp.rb_mut(),
)?;
Self::fused_multiply_add_assign(update.rb_mut(), 1. / 6., tmp.rb());
Self::fused_multiply_add_assign(result.rb_mut(), h, update.rb());
if let Method::RK4KO6(diss) = self.method {
mesh.fill_boundary_to_extent(order, 3, conditions.clone(), result.rb_mut());
deriv.preprocess(mesh, result.rb_mut())?;
mesh.dissipation::<6>(diss, result.rb_mut());
}
Ok(())
}
}
}
fn copy_from(dest: ImageMut, source: ImageRef) {
let shared: ImageShared = dest.into();
source.channels().par_bridge().for_each(|field| {
unsafe { shared.channel_mut(field) }.copy_from_slice(source.channel(field))
});
}
fn fused_multiply_add_assign(dest: ImageMut, h: f64, b: ImageRef) {
let shared: ImageShared = dest.into();
b.channels().par_bridge().for_each(|field| {
let dest = unsafe { shared.channel_mut(field) };
let src = b.channel(field);
dest.iter_mut().zip(src).for_each(|(a, b)| *a += h * b);
});
}
fn fused_multiply_add_dest(dest: ImageMut, a: ImageRef, h: f64) {
let shared: ImageShared = dest.into();
a.channels().par_bridge().for_each(|field| {
let dest = unsafe { shared.channel_mut(field) };
let a = a.channel(field);
dest.iter_mut().zip(a.iter()).for_each(|(d, a)| {
*d = a + h * *d;
});
});
}
pub fn scratch(&mut self, len: usize) -> &mut [f64] {
self.tmp.clear();
self.tmp.resize(len, 0.0);
&mut self.tmp
}
}