use crate::OdeSystem;
use faer::{ComplexField, Conjugate, SimpleEntity};
use numra_core::Scalar;
use numra_linalg::{DenseMatrix, LUFactorization};
pub fn compute_consistent_initial<S, Sys>(system: &Sys, t0: S, y0: &[S]) -> Result<Vec<S>, String>
where
S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
Sys: OdeSystem<S>,
{
compute_consistent_initial_tol(system, t0, y0, S::from_f64(1e-10), 50)
}
pub fn compute_consistent_initial_tol<S, Sys>(
system: &Sys,
t0: S,
y0: &[S],
tol: S,
max_iter: usize,
) -> Result<Vec<S>, String>
where
S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
Sys: OdeSystem<S>,
{
if !system.is_singular_mass() {
return Ok(y0.to_vec());
}
let alg_indices = system.algebraic_indices();
if alg_indices.is_empty() {
return Ok(y0.to_vec());
}
let dim = y0.len();
let n_alg = alg_indices.len();
let mut y = y0.to_vec();
let mut f = vec![S::ZERO; dim];
let mut f_pert = vec![S::ZERO; dim];
let h_factor = S::EPSILON.sqrt();
for iter in 0..max_iter {
system.rhs(t0, &y, &mut f);
let mut max_residual = S::ZERO;
for &i in &alg_indices {
max_residual = max_residual.max(f[i].abs());
}
if max_residual < tol {
return Ok(y);
}
let mut jac_data = vec![S::ZERO; n_alg * n_alg];
for (col, &j) in alg_indices.iter().enumerate() {
let y_orig = y[j];
let h = h_factor * (S::ONE + y_orig.abs());
y[j] = y_orig + h;
system.rhs(t0, &y, &mut f_pert);
y[j] = y_orig;
for (row, &i) in alg_indices.iter().enumerate() {
jac_data[row * n_alg + col] = (f_pert[i] - f[i]) / h;
}
}
let mut rhs_alg = vec![S::ZERO; n_alg];
for (row, &i) in alg_indices.iter().enumerate() {
rhs_alg[row] = -f[i];
}
let jac_mat = DenseMatrix::from_row_major(n_alg, n_alg, &jac_data);
let lu = LUFactorization::new(&jac_mat).map_err(|e| {
format!(
"DAE Jacobian factorization failed at iteration {}: {}. \
System may be index-2 or higher.",
iter, e
)
})?;
let delta = lu.solve(&rhs_alg).map_err(|e| {
format!(
"DAE Jacobian solve failed at iteration {}: {}. \
Jacobian may be singular; system may be index-2 or higher.",
iter, e
)
})?;
for (idx, &j) in alg_indices.iter().enumerate() {
y[j] = y[j] + delta[idx];
}
}
Err(format!(
"Failed to find consistent initial conditions after {} iterations \
(residual still > {})",
max_iter,
tol.to_f64()
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DaeProblem;
#[test]
fn test_dae_consistent_init() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
dydt[1] = y[1] - y[0] * y[0];
},
|mass: &mut [f64]| {
mass[0] = 1.0;
mass[1] = 0.0;
mass[2] = 0.0;
mass[3] = 0.0;
},
0.0,
1.0,
vec![2.0, 0.0],
vec![1],
);
let y0 = compute_consistent_initial(&dae, 0.0, &[2.0, 0.0]).unwrap();
let constraint = y0[1] - y0[0].powi(2);
assert!(
constraint.abs() < 1e-10,
"Constraint not satisfied: {}",
constraint
);
assert!((y0[0] - 2.0).abs() < 1e-10, "y1 should be unchanged");
assert!(
(y0[1] - 4.0).abs() < 1e-10,
"y2 should be 4.0, got {}",
y0[1]
);
}
#[test]
fn test_dae_consistent_init_already_consistent() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
dydt[1] = y[1] - y[0] * y[0];
},
|mass: &mut [f64]| {
mass[0] = 1.0;
mass[1] = 0.0;
mass[2] = 0.0;
mass[3] = 0.0;
},
0.0,
1.0,
vec![2.0, 4.0],
vec![1],
);
let y0 = compute_consistent_initial(&dae, 0.0, &[2.0, 4.0]).unwrap();
assert!((y0[0] - 2.0).abs() < 1e-10);
assert!((y0[1] - 4.0).abs() < 1e-10);
}
#[test]
fn test_dae_consistent_init_ode_passthrough() {
use crate::OdeProblem;
let ode = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
},
0.0,
1.0,
vec![1.0],
);
let y0 = compute_consistent_initial(&ode, 0.0, &[5.0]).unwrap();
assert!((y0[0] - 5.0).abs() < 1e-10);
}
#[test]
fn test_dae_consistent_init_multiple_algebraic() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
dydt[1] = y[1] - y[0] * y[0];
dydt[2] = y[2] - y[0].powi(3);
},
|mass: &mut [f64]| {
mass[0] = 1.0;
mass[1] = 0.0;
mass[2] = 0.0;
mass[3] = 0.0;
mass[4] = 0.0;
mass[5] = 0.0;
mass[6] = 0.0;
mass[7] = 0.0;
mass[8] = 0.0;
},
0.0,
1.0,
vec![3.0, 0.0, 0.0],
vec![1, 2],
);
let y0 = compute_consistent_initial(&dae, 0.0, &[3.0, 0.0, 0.0]).unwrap();
assert!((y0[0] - 3.0).abs() < 1e-10, "y1 unchanged");
assert!((y0[1] - 9.0).abs() < 1e-10, "y2 = y1^2 = 9");
assert!((y0[2] - 27.0).abs() < 1e-10, "y3 = y1^3 = 27");
}
#[test]
fn test_dae_consistent_init_nonlinear_constraint() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = 1.0;
dydt[1] = y[1] - y[0].sin();
},
|mass: &mut [f64]| {
mass[0] = 1.0;
mass[1] = 0.0;
mass[2] = 0.0;
mass[3] = 0.0;
},
0.0,
1.0,
vec![1.0, 0.0],
vec![1],
);
let y0 = compute_consistent_initial(&dae, 0.0, &[1.0, 0.0]).unwrap();
let expected_y2 = 1.0_f64.sin();
assert!((y0[0] - 1.0).abs() < 1e-10, "y1 unchanged");
assert!(
(y0[1] - expected_y2).abs() < 1e-10,
"y2 = sin(y1), got {} expected {}",
y0[1],
expected_y2
);
}
#[test]
fn test_dae_consistent_init_with_tolerance() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
dydt[1] = y[1] - y[0] * y[0];
},
|mass: &mut [f64]| {
mass[0] = 1.0;
mass[1] = 0.0;
mass[2] = 0.0;
mass[3] = 0.0;
},
0.0,
1.0,
vec![2.0, 0.0],
vec![1],
);
let y0 = compute_consistent_initial_tol(&dae, 0.0, &[2.0, 0.0], 1e-12, 100).unwrap();
let constraint = y0[1] - y0[0].powi(2);
assert!(constraint.abs() < 1e-12, "Should satisfy tighter tolerance");
}
#[test]
fn test_dae_init_coupled_constraints() {
let dae = DaeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
dydt[1] = y[1] + y[2] - 1.0; dydt[2] = y[1] - y[2]; },
|mass: &mut [f64]| {
mass[0] = 1.0;
mass[1] = 0.0;
mass[2] = 0.0;
mass[3] = 0.0;
mass[4] = 0.0;
mass[5] = 0.0;
mass[6] = 0.0;
mass[7] = 0.0;
mass[8] = 0.0;
},
0.0,
1.0,
vec![1.0, 0.7, 0.2],
vec![1, 2],
);
let y0 = compute_consistent_initial(&dae, 0.0, &[1.0, 0.7, 0.2]).unwrap();
assert!((y0[0] - 1.0).abs() < 1e-10, "y1 unchanged");
assert!(
(y0[1] - 0.5).abs() < 1e-8,
"y2 should be 0.5, got {}",
y0[1]
);
assert!(
(y0[2] - 0.5).abs() < 1e-8,
"y3 should be 0.5, got {}",
y0[2]
);
let g1 = y0[1] + y0[2] - 1.0;
let g2 = y0[1] - y0[2];
assert!(g1.abs() < 1e-10, "Constraint 1 violated: {}", g1);
assert!(g2.abs() < 1e-10, "Constraint 2 violated: {}", g2);
}
}