type Time = f64;
#[derive(Clone)]
pub struct ScaleFactor {
params: CosmologicalParameters,
state: ScaleFactorState,
config: IntegratorConfig,
}
pub const LITTLE_H_TO_BIG_H: f64 = 1.022e-4;
#[derive(Clone)]
pub struct CosmologicalParameters {
pub omega_m0: f64,
pub omega_de0: f64,
pub omega_r0: f64,
pub omega_k0: f64,
#[allow(dead_code)]
pub w: f64,
pub h: f64,
}
#[derive(Clone)]
struct ScaleFactorState {
a: f64,
dadt: f64,
t: f64,
expanding: bool,
}
#[derive(Clone)]
struct IntegratorConfig {
max_dloga: f64,
}
impl ScaleFactor {
pub fn new(params: CosmologicalParameters, z0: f64, max_dloga: f64, t0: Option<f64>) -> Self {
let state = ScaleFactorState {
a: 1.0 / (1.0 + z0),
dadt: Self::derivative_associated(1.0 / (1.0 + z0), ¶ms, true),
t: t0.unwrap_or(1.0),
expanding: true,
};
let config = IntegratorConfig { max_dloga };
Self {
params,
state,
config,
}
}
pub fn update_dadt(&mut self) {
self.state.dadt =
Self::derivative_associated(self.state.a, &self.params, self.state.expanding);
}
pub fn derivative(&self) -> f64 {
Self::derivative_associated(self.state.a, &self.params, self.state.expanding)
}
fn derivative_associated(a: f64, params: &CosmologicalParameters, expanding: bool) -> f64 {
params.h
* LITTLE_H_TO_BIG_H
* (params.omega_r0 / a.powi(2)
+ params.omega_m0 / a.powi(1)
+ params.omega_de0 * a.powi(2)
+ params.omega_k0)
.sqrt()
* if expanding { 1.0 } else { -1.0 }
}
pub fn step_forward(&mut self, dt: f64) {
self.update_dadt();
let mut steps: usize = 1;
let mut step_dt: f64 = dt;
while (self.state.dadt / self.state.a * step_dt) > self.config.max_dloga {
steps = steps.checked_add(1).unwrap();
step_dt = dt / steps as f64;
}
let final_time = self.state.t + dt;
while steps > 0 {
steps -= 1;
self.update_dadt();
let f = |_tn, an| Self::derivative_associated(an, &self.params, self.state.expanding);
if steps == 0 {
let step_dt = final_time - self.state.t;
self.state.a = rk4(
f,
self.state.t,
self.state.a,
step_dt,
Some(self.state.dadt),
);
self.state.t = final_time;
} else {
self.state.a = rk4(
f,
self.state.t,
self.state.a,
step_dt,
Some(self.state.dadt),
);
self.state.t += step_dt;
}
}
self.update_dadt();
}
pub fn get_a(&self) -> f64 {
self.state.a
}
pub fn get_dadt(&self) -> f64 {
self.state.dadt
}
pub fn get_time(&self) -> f64 {
self.state.t
}
pub fn get_time_series(&mut self, dt: f64, terminate: Terminate) -> TimeSeries {
let mut a = vec![self.get_a()];
let mut dadt = vec![self.get_dadt()];
let mut t = vec![self.get_time()];
let condition = |a: &Vec<f64>, t: &Vec<Time>| match terminate {
Terminate::ScaleFactor(ref final_a) => a.last().unwrap() < final_a,
Terminate::Time(ref final_time) => t.last().unwrap() < final_time,
};
while condition(&a, &t) {
self.step_forward(dt);
t.push(self.get_time());
a.push(self.get_a());
dadt.push(self.get_dadt());
}
TimeSeries {
t: t.into_boxed_slice(),
a: a.into_boxed_slice(),
dadt: dadt.into_boxed_slice(),
}
}
}
pub struct TimeSeries {
pub t: Box<[Time]>,
pub a: Box<[f64]>,
pub dadt: Box<[f64]>,
}
pub enum Terminate {
ScaleFactor(f64),
Time(f64),
}
pub(crate) fn rk4<F>(
f: F,
tn: Time,
yn: f64,
h: f64,
derivative: Option<f64>,
) -> f64
where
F: Fn(Time, f64) -> f64,
{
let k1: f64 = derivative.unwrap_or_else(|| f(tn, yn));
let k2: f64 = f(tn + h / 2.0, yn + h * k1 / 2.0);
let k3: f64 = f(tn + h / 2.0, yn + h * k2 / 2.0);
let k4: f64 = f(tn + h, yn + h * k3);
yn + h * (k1 + 2.0 * k2 + 2.0 * k3 + k4) / 6.0
}
pub(crate) fn rk4_multi<F, const N: usize>(
f: F,
tn: Time,
yn: [f64; N],
h: f64,
) -> [f64; N]
where
F: Fn(Time, [f64; N]) -> [f64; N],
{
let mut scratch: [f64; N] = [0.0; N];
let k1: [f64; N] = f(tn, yn);
for i in 0..N {
scratch[i] = yn[i] + h * k1[i] / 2.0;
}
let k2: [f64; N] = f(tn + h / 2.0, scratch);
for i in 0..N {
scratch[i] = yn[i] + h * k2[i] / 2.0;
}
let k3: [f64; N] = f(tn + h / 2.0, scratch);
for i in 0..N {
scratch[i] = yn[i] + h * k3[i];
}
let k4: [f64; N] = f(tn + h, scratch);
for i in 0..N {
scratch[i] = yn[i] + h * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]) / 6.0;
}
scratch
}
#[test]
fn test_rk4_impls_are_consistent() {
let t0 = 0.0;
let y0 = 0.0;
let f = |t: f64, _y0: f64| 2.0 * t;
let y0_multi = [0.0];
let f_multi = |t: f64, _y0: [f64; 1]| [2.0 * t];
let dt = 1.0;
let first_result = rk4(f, t0, y0, dt, None);
let second_result = rk4_multi(f_multi, t0, y0_multi, dt);
assert_eq!(first_result, second_result[0]);
}