#![feature(trait_alias)]
#![allow(clippy::needless_return)]
use rayon::prelude::*;
use std::collections::{HashMap, HashSet};
#[derive(Clone)]
pub struct Model {
irm: SparseMatrix,
_touched: HashSet<usize>,
removed: HashSet<usize>,
time_step: f64,
cutoff: f64,
error_tolerance: f64,
min_dt: f64,
}
const SUBDIVIDE: usize = 2;
impl Model {
pub fn new(time_step: f64, error_tolerance: f64, min_dt: f64, cutoff: f64) -> Model {
assert!(time_step > 0.0);
assert!(error_tolerance > 0.0);
assert!(min_dt > 0.0);
assert!(min_dt <= time_step);
assert!(cutoff >= 0.0);
Model {
time_step,
error_tolerance,
min_dt,
cutoff,
_touched: Default::default(),
removed: Default::default(),
irm: Default::default(),
}
}
pub fn len(&self) -> usize {
let mut len = self.irm.len();
for point in self.touched() {
len = len.max(*point + 1);
}
return len;
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn density(&self) -> f64 {
return self.irm.data.len() as f64 / self.irm.len().pow(2) as f64;
}
pub fn touch(&mut self, point: usize) {
self._touched.insert(point);
self.removed.remove(&point);
}
pub fn touched(&self) -> impl Iterator<Item = &usize> {
self._touched.iter()
}
pub fn delete(&mut self, point: usize) {
debug_assert!(point < self.len());
self.removed.insert(point);
self._touched.remove(&point);
}
pub fn advance(
&mut self,
current_state: &[f64],
next_state: &mut [f64],
derivative: impl Derivative,
) {
if !self._touched.is_empty() || !self.removed.is_empty() {
self.update_irm(derivative);
}
self.irm.x_vector(current_state, next_state);
}
fn update_irm(&mut self, derivative: impl Derivative) {
let mut touched = std::mem::take(&mut self._touched);
let removed = std::mem::take(&mut self.removed);
let mut touching_touched = vec![]; for &point in touched.iter().chain(&removed) {
if point < self.irm.len() {
let row_start = self.irm.row_ranges[point];
let row_end = self.irm.row_ranges[point + 1];
for tt in &self.irm.column_indices[row_start..row_end] {
if removed.contains(tt) {
continue;
}
touching_touched.push(*tt);
}
} else {
self.irm.resize(point + 1);
}
}
for &tt in &touching_touched {
touched.insert(tt);
}
let mut results: HashMap<_, _> = touched
.par_iter()
.map(|&point| {
let mut state = SparseVector::new();
state.insert(point, 1.0);
state = self.integrate(state, &derivative);
let mut state: Vec<(usize, f64)> = state.drain().collect();
debug_assert!(state.iter().all(|(idx, _val)| *idx < self.len()));
let mut sum_removed_values = 0.0;
state.retain(|(_, value)| {
if value.abs() < self.cutoff {
sum_removed_values += *value;
return false;
}
return true;
});
sum_removed_values /= state.len() as f64;
for (_, value) in &mut state {
*value += sum_removed_values;
}
return (point, state);
})
.collect();
results.reserve(removed.len());
for &column in removed.iter() {
results.insert(column, Default::default());
}
self.irm.write_columns(&results);
}
#[doc(hidden)]
pub fn integrate(&self, mut state: SparseVector, derivative: &impl Derivative) -> SparseVector {
let mut t = 0.0;
let mut low_res = None;
let min_dt = self.min_dt * SUBDIVIDE as f64;
let mut dt = min_dt;
let mut final_iteration = false;
while t < self.time_step {
if dt > self.time_step - t {
dt = self.time_step - t;
final_iteration = true;
}
if low_res.is_none() {
low_res = Some(Self::crank_nicolson(state.clone(), derivative, dt));
}
let mut first_subdivision_dt = None;
let mut first_subdivision_state = None;
let mut high_res = state.clone();
for i in 0..SUBDIVIDE {
let high_res_dt = dt / SUBDIVIDE as f64;
if !high_res_dt.is_normal() {
panic!("Failed to find time step which satisfies requested accuracy!")
}
high_res = Self::crank_nicolson(high_res, derivative, high_res_dt);
if i == 0 {
first_subdivision_dt = Some(high_res_dt);
first_subdivision_state = Some(high_res.clone());
}
}
let error = max_abs_diff(low_res.as_ref().unwrap(), &high_res);
let error_ok = error <= self.error_tolerance * dt;
if dt <= min_dt || error_ok {
std::mem::swap(&mut state, &mut high_res);
low_res.take();
if final_iteration {
break;
}
debug_assert!(t + dt > t);
t += dt;
if error_ok {
dt *= SUBDIVIDE as f64;
} else if cfg!(debug_assertions) {
eprintln!("Warning, max_integration_steps has compromised the accuracy by a factor of {}!",
error / self.error_tolerance / dt);
}
} else {
dt = first_subdivision_dt.unwrap();
if dt >= min_dt {
low_res.replace(first_subdivision_state.unwrap());
} else {
dt = min_dt;
low_res.take();
}
}
}
return state;
}
fn crank_nicolson(
mut state: SparseVector,
derivative: &impl Derivative,
dt: f64,
) -> SparseVector {
let mut deriv = SparseVector::with_capacity(state.len() + state.len() / 2);
derivative(&state, &mut deriv);
clean_sparse_vector(&mut deriv);
let iterations = 1;
for _ in 0..iterations {
let mut halfway = state.clone();
add_multiply(&mut halfway, &deriv, dt / 2.0);
deriv.clear();
derivative(&halfway, &mut deriv);
clean_sparse_vector(&mut deriv);
}
add_multiply(&mut state, &deriv, dt); clean_sparse_vector(&mut state);
return state;
}
}
pub type SparseVector = HashMap<usize, f64>;
fn clean_sparse_vector(x: &mut SparseVector) {
x.retain(|_, val| *val != 0.0);
if cfg!(debug_assertions) {
if !x.values().all(|v| v.is_finite()) {
panic!("Derivative was not finite!");
}
}
}
fn add_multiply(b: &mut SparseVector, a: &SparseVector, x: f64) {
for (idx, a_value) in a.iter() {
let b_value = *b.get(idx).unwrap_or(&0.0);
b.insert(*idx, *a_value * x + b_value);
}
}
#[doc(hidden)]
pub fn max_abs_diff(a: &SparseVector, b: &SparseVector) -> f64 {
let mut max: f64 = 0.0;
for (idx, a_value) in a.iter() {
max = max.max((a_value - b.get(idx).unwrap_or(&0.0)).abs());
}
for (idx, b_value) in b.iter() {
max = max.max((b_value - a.get(idx).unwrap_or(&0.0)).abs());
}
return max;
}
#[derive(Debug)]
struct SparseCoordinate {
row: usize,
column: usize,
value: f64,
}
#[derive(Debug, Clone)]
struct SparseMatrix {
pub data: Vec<f64>,
pub row_ranges: Vec<usize>,
pub column_indices: Vec<usize>,
}
impl Default for SparseMatrix {
fn default() -> SparseMatrix {
SparseMatrix {
data: vec![],
column_indices: vec![],
row_ranges: vec![0],
}
}
}
impl SparseMatrix {
fn len(&self) -> usize {
self.row_ranges.len() - 1
}
fn resize(&mut self, new_size: usize) {
assert!(new_size >= self.len()); self.row_ranges.resize(new_size + 1, self.data.len());
}
fn write_columns(&mut self, csr_data: &HashMap<usize, Vec<(usize, f64)>>) {
let mut coords = Vec::with_capacity(csr_data.values().map(|srv| srv.len()).sum());
for (c_idx, row) in csr_data.iter() {
for (r_idx, value) in row {
coords.push(SparseCoordinate {
row: *r_idx,
column: *c_idx,
value: *value,
});
}
}
coords.par_sort_unstable_by(|a, b| a.row.cmp(&b.row));
let mut insert_iter = coords.iter().peekable();
let mut result = SparseMatrix::default();
let max_new_len = self.data.len() + coords.len();
result.data.reserve(max_new_len);
result.column_indices.reserve(max_new_len);
result.row_ranges.reserve(self.row_ranges.len());
for (row, (row_start, row_end)) in self
.row_ranges
.iter()
.zip(self.row_ranges.iter().skip(1))
.enumerate()
{
for index in *row_start..*row_end {
let column = self.column_indices[index];
if !csr_data.contains_key(&column) {
result.data.push(self.data[index]);
result.column_indices.push(column);
}
}
while insert_iter.peek().is_some() && insert_iter.peek().unwrap().row == row {
let coord = insert_iter.next().unwrap();
result.data.push(coord.value);
result.column_indices.push(coord.column);
}
result.row_ranges.push(result.data.len());
}
std::mem::swap(self, &mut result);
}
fn x_vector(&self, src: &[f64], dst: &mut [f64]) {
assert_eq!(src.len(), self.len(), "src.len() != self.len()");
assert_eq!(dst.len(), self.len(), "dst.len() != self.len()");
dst.par_iter_mut().enumerate().for_each(|(row, dst)| {
let row_start = self.row_ranges[row];
let row_end = self.row_ranges[row + 1];
const V: usize = 4; let mut sums = [0.0; V];
let mut chunk = row_start;
if let Some(row_end_chunk) = row_end.checked_sub(V - 1) {
while chunk < row_end_chunk {
for offset in 0..V {
let index = chunk + offset;
unsafe {
sums[offset] += self.data.get_unchecked(index)
* src.get_unchecked(*self.column_indices.get_unchecked(index));
}
}
chunk += V;
}
}
for index in chunk..row_end {
unsafe {
sums[0] += self.data.get_unchecked(index)
* src.get_unchecked(*self.column_indices.get_unchecked(index));
}
}
*dst = sums.iter().sum();
});
}
}
pub trait Derivative = Fn(&SparseVector, &mut SparseVector) + std::marker::Sync;