use crate::filtration::ScenarioFiltration;
use crate::proc::{Process, ProcessUniverse};
use crate::rng::BaseRng;
pub fn runge_kutta_iteration(
filtration: &mut ScenarioFiltration,
process_universe: &ProcessUniverse,
t_idx: usize,
rng: &mut dyn BaseRng,
) {
let num_processes = process_universe.processes.len();
let current_time = filtration.times[t_idx];
let next_time = filtration.times[t_idx + 1];
let dt = (next_time - current_time).into_inner();
let sqrt_dt = dt.sqrt();
let sk = if rng.sample(t_idx, 0) > 0.5 {
1.0
} else {
-1.0
};
let mut step_increments = Vec::with_capacity(num_processes);
for p_idx in 0..num_processes {
let mut incs = Vec::new();
if let Process::Levy(levy) = &process_universe.processes[p_idx] {
for incr in &levy.incrementors {
incs.push(incr.sample(t_idx, filtration, rng));
}
}
step_increments.push(incs);
}
let mut x_t = vec![0.0; num_processes];
#[allow(clippy::needless_range_loop)]
for p_idx in 0..num_processes {
x_t[p_idx] = filtration.get(t_idx, p_idx);
}
let mut k1 = vec![0.0; num_processes];
for p_idx in 0..num_processes {
if let Process::Levy(levy) = &process_universe.processes[p_idx] {
for (inc_idx, &d) in step_increments[p_idx].iter().enumerate() {
let c = levy.coefficients[inc_idx]
.eval(current_time, filtration)
.unwrap();
k1[p_idx] += c * d;
}
}
}
let mut k2 = vec![0.0; num_processes];
for p_idx in 0..num_processes {
if let Process::Levy(levy) = &process_universe.processes[p_idx] {
let mut perturbation = 0.0;
for (inc_idx, incr) in levy.incrementors.iter().enumerate() {
if incr.is_wiener() {
perturbation += levy.coefficients[inc_idx]
.eval(current_time, filtration)
.unwrap()
* sk
* sqrt_dt;
}
}
filtration.set(t_idx + 1, p_idx, x_t[p_idx] + k1[p_idx] + perturbation);
}
}
for p_idx in 0..num_processes {
if let Process::Levy(levy) = &process_universe.processes[p_idx] {
for (inc_idx, &d) in step_increments[p_idx].iter().enumerate() {
let c = levy.coefficients[inc_idx]
.eval(next_time, filtration)
.unwrap();
k2[p_idx] += c * d;
}
}
}
for p_idx in &process_universe.levy_process_indices {
let final_val = x_t[*p_idx] + 0.5 * (k1[*p_idx] + k2[*p_idx]);
filtration.set(t_idx + 1, *p_idx, final_val);
}
for p_idx in &process_universe.algebraic_process_indices {
if let Process::Algebraic(alg) = &process_universe.processes[*p_idx] {
let val = alg.coefficients[0].eval(next_time, filtration).unwrap();
filtration.set(t_idx + 1, *p_idx, val);
}
}
}