pub fn cfl_timestep(max_velocity: f64, grid_spacing: f64, cfl_number: f64) -> f64 {
let v = max_velocity.abs();
if v < 1e-30 {
return f64::MAX;
}
cfl_number * grid_spacing / v
}
pub fn courant_dt(dx: f64, c: f64, cfl: f64) -> f64 {
let c_abs = c.abs();
if c_abs < 1e-30 {
return f64::MAX;
}
cfl * dx / c_abs
}
pub fn diffusive_dt(dx: f64, nu: f64, safety: f64) -> f64 {
let nu_abs = nu.abs();
if nu_abs < 1e-30 {
return f64::MAX;
}
safety * dx * dx / (2.0 * nu_abs)
}
pub fn richardson_dt(dt_old: f64, error: f64, tolerance: f64, order: u32) -> f64 {
const MIN_FACTOR: f64 = 0.1;
const MAX_FACTOR: f64 = 5.0;
if error < 1e-30 {
return dt_old * MAX_FACTOR;
}
let exponent = 1.0 / (order as f64 + 1.0);
let ratio = (tolerance / error).powf(exponent);
let factor = ratio.clamp(MIN_FACTOR, MAX_FACTOR);
dt_old * factor
}
#[derive(Debug, Clone)]
pub struct TimeDomainConfig {
pub min_dt: f64,
pub max_dt: f64,
pub cfl_factor: f64,
pub safety_factor: f64,
}
impl TimeDomainConfig {
pub fn new(min_dt: f64, max_dt: f64, cfl_factor: f64, safety_factor: f64) -> Self {
Self {
min_dt,
max_dt,
cfl_factor,
safety_factor,
}
}
}
#[derive(Debug, Clone)]
pub struct TimeDomainState {
pub current_dt: f64,
pub error_estimate: f64,
pub step_count: u64,
pub config: TimeDomainConfig,
}
impl TimeDomainState {
pub fn from_config(config: TimeDomainConfig) -> Self {
let initial_dt = config.max_dt * config.safety_factor;
Self {
current_dt: initial_dt,
error_estimate: 0.0,
step_count: 0,
config,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct StepSchedule {
pub domain_idx: usize,
pub n_substeps: usize,
pub substep_dt: f64,
}
#[derive(Debug, Clone)]
pub struct UnifiedTimeStepper {
pub domains: Vec<TimeDomainState>,
pub global_dt: f64,
pub global_time: f64,
pub subcycle_ratios: Vec<usize>,
}
impl UnifiedTimeStepper {
pub fn new(configs: Vec<TimeDomainConfig>) -> Self {
let domains: Vec<TimeDomainState> = configs
.into_iter()
.map(TimeDomainState::from_config)
.collect();
let n = domains.len();
let mut stepper = Self {
domains,
global_dt: 0.0,
global_time: 0.0,
subcycle_ratios: vec![1; n],
};
stepper.global_dt = stepper.raw_global_dt();
stepper.recompute_subcycle_ratios();
stepper
}
fn raw_global_dt(&self) -> f64 {
self.domains
.iter()
.map(|d| d.current_dt * d.config.safety_factor)
.fold(f64::MAX, f64::min)
}
fn recompute_subcycle_ratios(&mut self) {
if self.global_dt <= 0.0 {
for r in &mut self.subcycle_ratios {
*r = 1;
}
return;
}
for (i, domain) in self.domains.iter().enumerate() {
let effective_dt = domain.current_dt * domain.config.safety_factor;
if effective_dt >= self.global_dt {
self.subcycle_ratios[i] = 1;
} else {
let ratio = (self.global_dt / effective_dt).ceil() as usize;
self.subcycle_ratios[i] = ratio.max(1);
}
}
}
pub fn compute_global_dt(&mut self) -> f64 {
self.global_dt = self.raw_global_dt();
let global_min = self
.domains
.iter()
.map(|d| d.config.min_dt)
.fold(0.0_f64, f64::max);
let global_max = self
.domains
.iter()
.map(|d| d.config.max_dt)
.fold(f64::MAX, f64::min);
self.global_dt = self.global_dt.clamp(global_min, global_max);
self.recompute_subcycle_ratios();
self.global_dt
}
pub fn compute_subcycle_ratios(&mut self) {
self.recompute_subcycle_ratios();
}
pub fn advance_global_time(&mut self) {
self.global_time += self.global_dt;
for (i, domain) in self.domains.iter_mut().enumerate() {
domain.step_count += self.subcycle_ratios[i] as u64;
}
}
pub fn update_domain_dt(&mut self, domain_idx: usize, new_dt: f64) -> bool {
if let Some(domain) = self.domains.get_mut(domain_idx) {
domain.current_dt = new_dt.clamp(domain.config.min_dt, domain.config.max_dt);
true
} else {
false
}
}
pub fn update_domain_error(&mut self, domain_idx: usize, error: f64) -> bool {
if let Some(domain) = self.domains.get_mut(domain_idx) {
domain.error_estimate = error;
true
} else {
false
}
}
pub fn step_schedule(&self) -> Vec<StepSchedule> {
self.domains
.iter()
.enumerate()
.map(|(i, _domain)| {
let n = self.subcycle_ratios[i];
let sub_dt = if n > 0 {
self.global_dt / n as f64
} else {
self.global_dt
};
StepSchedule {
domain_idx: i,
n_substeps: n,
substep_dt: sub_dt,
}
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cfl_timestep_basic() {
let dt = cfl_timestep(10.0, 1.0, 0.5);
assert!((dt - 0.05).abs() < 1e-12, "dt = {dt}");
}
#[test]
fn test_cfl_timestep_zero_velocity_returns_max() {
let dt = cfl_timestep(0.0, 1.0, 0.5);
assert_eq!(dt, f64::MAX);
}
#[test]
fn test_courant_dt_basic() {
let dt = courant_dt(0.1, 340.0, 0.8);
let expected = 0.8 * 0.1 / 340.0;
assert!((dt - expected).abs() < 1e-14, "dt = {dt}");
}
#[test]
fn test_courant_dt_zero_speed() {
assert_eq!(courant_dt(0.1, 0.0, 0.8), f64::MAX);
}
#[test]
fn test_diffusive_dt_basic() {
let dt = diffusive_dt(0.01, 1e-3, 0.9);
let expected = 0.9 * 0.01 * 0.01 / (2.0 * 1e-3);
assert!((dt - expected).abs() < 1e-14, "dt = {dt}");
}
#[test]
fn test_diffusive_dt_zero_nu() {
assert_eq!(diffusive_dt(0.01, 0.0, 0.9), f64::MAX);
}
#[test]
fn test_richardson_increases_dt_when_error_below_tolerance() {
let dt_old = 0.01;
let error = 1e-6;
let tolerance = 1e-4;
let dt_new = richardson_dt(dt_old, error, tolerance, 2);
assert!(
dt_new > dt_old,
"dt_new={dt_new} should be > dt_old={dt_old}"
);
}
#[test]
fn test_richardson_decreases_dt_when_error_above_tolerance() {
let dt_old = 0.01;
let error = 1e-2;
let tolerance = 1e-4;
let dt_new = richardson_dt(dt_old, error, tolerance, 2);
assert!(
dt_new < dt_old,
"dt_new={dt_new} should be < dt_old={dt_old}"
);
}
#[test]
fn test_richardson_zero_error_gives_max_growth() {
let dt_old = 0.01;
let dt_new = richardson_dt(dt_old, 0.0, 1e-4, 2);
assert!((dt_new - dt_old * 5.0).abs() < 1e-14);
}
#[test]
fn test_richardson_clamped_growth() {
let dt_new = richardson_dt(0.01, 1e-30, 1.0, 1);
assert!((dt_new - 0.05).abs() < 1e-10);
}
#[test]
fn test_richardson_clamped_shrink() {
let dt_new = richardson_dt(0.01, 1e10, 1e-4, 2);
assert!((dt_new - 0.001).abs() < 1e-10);
}
#[test]
fn test_single_domain_dt_converges() {
let cfg = TimeDomainConfig::new(1e-6, 1e-2, 0.5, 0.9);
let mut stepper = UnifiedTimeStepper::new(vec![cfg]);
for _ in 0..20 {
let target_dt = 5e-4;
stepper.update_domain_dt(0, target_dt);
stepper.compute_global_dt();
}
let effective = stepper.global_dt;
let expected = 5e-4 * 0.9;
assert!(
(effective - expected).abs() < 1e-10,
"effective={effective}, expected={expected}"
);
}
#[test]
fn test_two_domains_subcycling() {
let fast = TimeDomainConfig::new(1e-6, 1e-2, 0.25, 0.9);
let slow = TimeDomainConfig::new(1e-5, 1e-1, 0.5, 0.9);
let mut stepper = UnifiedTimeStepper::new(vec![fast, slow]);
stepper.update_domain_dt(0, 1e-4);
stepper.update_domain_dt(1, 1e-3);
stepper.compute_global_dt();
let g = stepper.global_dt;
let fast_eff = 1e-4 * 0.9;
assert!(
(g - fast_eff).abs() < 1e-12,
"global_dt={g}, expected={fast_eff}"
);
assert_eq!(stepper.subcycle_ratios[0], 1);
assert_eq!(stepper.subcycle_ratios[1], 1);
}
#[test]
fn test_global_dt_is_minimum_across_domains() {
let d1 = TimeDomainConfig::new(1e-6, 1.0, 0.5, 1.0);
let d2 = TimeDomainConfig::new(1e-6, 1.0, 0.5, 1.0);
let d3 = TimeDomainConfig::new(1e-6, 1.0, 0.5, 1.0);
let mut stepper = UnifiedTimeStepper::new(vec![d1, d2, d3]);
stepper.update_domain_dt(0, 0.1);
stepper.update_domain_dt(1, 0.05);
stepper.update_domain_dt(2, 0.2);
let g = stepper.compute_global_dt();
assert!((g - 0.05).abs() < 1e-14, "global_dt={g}");
}
#[test]
fn test_schedule_single_domain() {
let cfg = TimeDomainConfig::new(1e-6, 1.0, 0.5, 1.0);
let mut stepper = UnifiedTimeStepper::new(vec![cfg]);
stepper.update_domain_dt(0, 0.01);
stepper.compute_global_dt();
let sched = stepper.step_schedule();
assert_eq!(sched.len(), 1);
assert_eq!(sched[0].domain_idx, 0);
assert_eq!(sched[0].n_substeps, 1);
assert!((sched[0].substep_dt - 0.01).abs() < 1e-14);
}
#[test]
fn test_schedule_multi_domain() {
let d1 = TimeDomainConfig::new(1e-6, 1.0, 0.5, 1.0);
let d2 = TimeDomainConfig::new(1e-6, 1.0, 0.5, 1.0);
let mut stepper = UnifiedTimeStepper::new(vec![d1, d2]);
stepper.update_domain_dt(0, 0.01);
stepper.update_domain_dt(1, 0.1);
stepper.compute_global_dt();
let sched = stepper.step_schedule();
assert_eq!(sched.len(), 2);
assert_eq!(sched[0].n_substeps, 1);
assert!((sched[0].substep_dt - 0.01).abs() < 1e-14);
assert_eq!(sched[1].n_substeps, 1);
assert!((sched[1].substep_dt - 0.01).abs() < 1e-14);
}
#[test]
fn test_advance_global_time() {
let cfg = TimeDomainConfig::new(1e-6, 1.0, 0.5, 1.0);
let mut stepper = UnifiedTimeStepper::new(vec![cfg]);
stepper.update_domain_dt(0, 0.01);
stepper.compute_global_dt();
let dt = stepper.global_dt;
stepper.advance_global_time();
assert!((stepper.global_time - dt).abs() < 1e-14);
assert_eq!(stepper.domains[0].step_count, 1); }
#[test]
fn test_advance_accumulates() {
let cfg = TimeDomainConfig::new(1e-6, 1.0, 0.5, 1.0);
let mut stepper = UnifiedTimeStepper::new(vec![cfg]);
stepper.update_domain_dt(0, 0.01);
stepper.compute_global_dt();
for _ in 0..100 {
stepper.advance_global_time();
}
let expected_time = 100.0 * stepper.global_dt;
assert!(
(stepper.global_time - expected_time).abs() < 1e-10,
"time={}, expected={}",
stepper.global_time,
expected_time
);
}
#[test]
fn test_update_domain_error() {
let cfg = TimeDomainConfig::new(1e-6, 1.0, 0.5, 1.0);
let mut stepper = UnifiedTimeStepper::new(vec![cfg]);
assert!(stepper.update_domain_error(0, 1e-5));
assert!((stepper.domains[0].error_estimate - 1e-5).abs() < 1e-20);
}
#[test]
fn test_update_domain_error_invalid_index() {
let cfg = TimeDomainConfig::new(1e-6, 1.0, 0.5, 1.0);
let mut stepper = UnifiedTimeStepper::new(vec![cfg]);
assert!(!stepper.update_domain_error(5, 1e-5));
}
#[test]
fn test_update_domain_dt_clamps() {
let cfg = TimeDomainConfig::new(1e-6, 1e-2, 0.5, 1.0);
let mut stepper = UnifiedTimeStepper::new(vec![cfg]);
stepper.update_domain_dt(0, 1.0);
assert!((stepper.domains[0].current_dt - 1e-2).abs() < 1e-14);
stepper.update_domain_dt(0, 1e-10);
assert!((stepper.domains[0].current_dt - 1e-6).abs() < 1e-14);
}
#[test]
fn test_update_domain_dt_invalid_index() {
let cfg = TimeDomainConfig::new(1e-6, 1.0, 0.5, 1.0);
let mut stepper = UnifiedTimeStepper::new(vec![cfg]);
assert!(!stepper.update_domain_dt(99, 0.01));
}
#[test]
fn test_richardson_driven_adaptation() {
let cfg = TimeDomainConfig::new(1e-8, 1.0, 0.5, 0.9);
let mut stepper = UnifiedTimeStepper::new(vec![cfg]);
stepper.update_domain_dt(0, 0.01);
let tolerance = 1e-4;
let order = 2_u32;
let error = 1e-2;
let new_dt = richardson_dt(stepper.domains[0].current_dt, error, tolerance, order);
stepper.update_domain_dt(0, new_dt);
stepper.compute_global_dt();
assert!(stepper.global_dt < 0.01 * 0.9, "should have shrunk");
let error2 = 1e-10;
let new_dt2 = richardson_dt(stepper.domains[0].current_dt, error2, tolerance, order);
stepper.update_domain_dt(0, new_dt2);
stepper.compute_global_dt();
assert!(
stepper.domains[0].current_dt > stepper.domains[0].config.min_dt,
"should have grown"
);
}
#[test]
fn test_subcycling_fast_domain() {
let d0 = TimeDomainConfig::new(1e-6, 1.0, 0.5, 1.0);
let d1 = TimeDomainConfig::new(1e-6, 1.0, 0.5, 1.0);
let mut stepper = UnifiedTimeStepper::new(vec![d0, d1]);
stepper.update_domain_dt(0, 0.1);
stepper.update_domain_dt(1, 0.025);
stepper.compute_global_dt();
assert_eq!(stepper.subcycle_ratios[0], 1);
assert_eq!(stepper.subcycle_ratios[1], 1);
assert!((stepper.global_dt - 0.025).abs() < 1e-14);
}
#[test]
fn test_empty_domains() {
let stepper = UnifiedTimeStepper::new(vec![]);
assert_eq!(stepper.domains.len(), 0);
assert_eq!(stepper.subcycle_ratios.len(), 0);
assert!(stepper.step_schedule().is_empty());
}
#[test]
fn test_identical_domains() {
let c1 = TimeDomainConfig::new(1e-6, 1.0, 0.5, 0.9);
let c2 = TimeDomainConfig::new(1e-6, 1.0, 0.5, 0.9);
let mut stepper = UnifiedTimeStepper::new(vec![c1, c2]);
stepper.update_domain_dt(0, 0.01);
stepper.update_domain_dt(1, 0.01);
stepper.compute_global_dt();
assert_eq!(stepper.subcycle_ratios[0], 1);
assert_eq!(stepper.subcycle_ratios[1], 1);
assert!((stepper.global_dt - 0.01 * 0.9).abs() < 1e-14);
}
#[test]
fn test_global_dt_respects_min_dt_bound() {
let c1 = TimeDomainConfig::new(0.001, 1.0, 0.5, 0.01);
let mut stepper = UnifiedTimeStepper::new(vec![c1]);
stepper.update_domain_dt(0, 0.001); let g = stepper.compute_global_dt();
assert!(g >= 0.001, "global_dt={g} should be >= min_dt=0.001");
}
#[test]
fn test_negative_velocity_cfl() {
let dt = cfl_timestep(-100.0, 1.0, 0.5);
let expected = 0.5 * 1.0 / 100.0;
assert!((dt - expected).abs() < 1e-14);
}
#[test]
fn test_diffusive_dt_negative_nu() {
let dt = diffusive_dt(0.1, -0.01, 0.9);
let expected = 0.9 * 0.01 / 0.02;
assert!((dt - expected).abs() < 1e-14);
}
}