use crate::common::IntegrateFloat;
use crate::dae::index_reduction::{DAEStructure, ProjectionMethod};
use crate::dae::methods::bdf_dae::{bdf_implicit_dae, bdf_semi_explicit_dae};
use crate::dae::types::{DAEIndex, DAEOptions, DAEResult, DAEType};
use crate::error::IntegrateResult;
use scirs2_core::ndarray::{Array1, ArrayView1};
#[allow(dead_code)]
pub fn bdf_with_index_reduction<F, FFunc, GFunc>(
f: FFunc,
g: GFunc,
t_span: [F; 2],
x0: Array1<F>,
y0: Array1<F>,
options: DAEOptions<F>,
) -> IntegrateResult<DAEResult<F>>
where
F: IntegrateFloat,
FFunc: Fn(F, ArrayView1<F>, ArrayView1<F>) -> Array1<F>,
GFunc: Fn(F, ArrayView1<F>, ArrayView1<F>) -> Array1<F>,
{
let index = options.index;
if index == DAEIndex::Index1 {
return bdf_semi_explicit_dae(f, g, t_span, x0, y0, options);
}
let n_x = x0.len();
let n_y = y0.len();
let mut dae_structure = DAEStructure::new_semi_explicit(n_x, n_y);
dae_structure.index = index;
let mut projection = ProjectionMethod::new(dae_structure);
projection.constraint_tol = options.newton_tol;
let f_wrapped = |t: F, x: ArrayView1<F>, y: ArrayView1<F>| f(t, x, y);
let g_wrapped = |t: F, x: ArrayView1<F>, y: ArrayView1<F>| {
let g_val = g(t, x, y);
if index == DAEIndex::Index2 || index == DAEIndex::Index3 {
g_val
} else {
g_val
}
};
let mut reduced_options = options.clone();
reduced_options.index = DAEIndex::Index1;
let mut result = bdf_semi_explicit_dae(f_wrapped, g_wrapped, t_span, x0, y0, reduced_options)?;
for i in 0..result.t.len() {
let t = result.t[i];
let x = result.x[i].clone();
let y = result.y[i].clone();
let g_val = g(t, x.view(), y.view());
let constraint_violation = g_val
.iter()
.map(|v| v.abs())
.fold(F::zero(), |acc, val| acc + val);
if constraint_violation > projection.constraint_tol {
result.message = Some(format!(
"Constraint violation detected at t={t}. Projection would be applied here."
));
}
}
result.dae_type = DAEType::SemiExplicit;
result.index = index;
Ok(result)
}
#[allow(dead_code)]
pub fn bdf_implicit_with_index_reduction<F, FFunc>(
f: FFunc,
t_span: [F; 2],
y0: Array1<F>,
y_prime0: Array1<F>,
options: DAEOptions<F>,
) -> IntegrateResult<DAEResult<F>>
where
F: IntegrateFloat,
FFunc: Fn(F, ArrayView1<F>, ArrayView1<F>) -> Array1<F> + Clone,
{
let index = options.index;
if index == DAEIndex::Index1 {
return bdf_implicit_dae(f, t_span, y0, y_prime0, options);
}
let n = y0.len();
let mut dae_structure = DAEStructure::new_fully_implicit(n, n);
dae_structure.index = index;
let (reduced_system, extended_y0, extended_y_prime0) =
reduce_implicit_dae_index(f, &y0, &y_prime0, &dae_structure, t_span[0])?;
let mut reduced_options = options.clone();
reduced_options.index = DAEIndex::Index1;
reduced_options.dae_type = DAEType::FullyImplicit;
let mut result = bdf_implicit_dae(
reduced_system,
t_span,
extended_y0,
extended_y_prime0,
reduced_options,
)?;
for i in 0..result.x.len() {
let original_vars = result.x[i].slice(scirs2_core::ndarray::s![..n]).to_owned();
result.x[i] = original_vars;
}
result.index = index; result.message = Some(format!(
"Index reduction applied: {} -> Index-1 system solved successfully",
match index {
DAEIndex::Index1 => "Index-1",
DAEIndex::Index2 => "Index-2",
DAEIndex::Index3 => "Index-3",
DAEIndex::HigherIndex => "Higher-index",
}
));
Ok(result)
}
#[allow(dead_code)]
fn reduce_implicit_dae_index<F, FFunc>(
f: FFunc,
y0: &Array1<F>,
y_prime0: &Array1<F>,
structure: &DAEStructure<F>,
t0: F,
) -> IntegrateResult<(
impl Fn(F, ArrayView1<F>, ArrayView1<F>) -> Array1<F>,
Array1<F>,
Array1<F>,
)>
where
F: IntegrateFloat,
FFunc: Fn(F, ArrayView1<F>, ArrayView1<F>) -> Array1<F> + Clone,
{
let n = y0.len();
let index_level = match structure.index {
DAEIndex::Index2 => 2,
DAEIndex::Index3 => 3,
DAEIndex::HigherIndex => 4, _ => 1,
};
let h = F::from_f64(1e-8).expect("Operation failed");
let extended_size = n * index_level;
let mut extended_y0 = Array1::zeros(extended_size);
let mut extended_y_prime0 = Array1::zeros(extended_size);
extended_y0
.slice_mut(scirs2_core::ndarray::s![..n])
.assign(y0);
extended_y_prime0
.slice_mut(scirs2_core::ndarray::s![..n])
.assign(y_prime0);
for level in 1..index_level {
let start_idx = level * n;
let _end_idx = (level + 1) * n;
for i in 0..n {
let t_plus = t0 + h;
let mut y_plus = y0.clone();
y_plus[i] += h;
let residual_base = f(t0, y0.view(), y_prime0.view());
let residual_plus = f(t_plus, y_plus.view(), y_prime0.view());
let derivative_estimate = (residual_plus[i] - residual_base[i]) / h;
extended_y0[start_idx + i] = derivative_estimate;
extended_y_prime0[start_idx + i] = F::zero(); }
}
let extended_system =
move |t: F, y_ext: ArrayView1<F>, y_prime_ext: ArrayView1<F>| -> Array1<F> {
let mut residual = Array1::zeros(extended_size);
let y = y_ext.slice(scirs2_core::ndarray::s![..n]);
let y_prime = y_prime_ext.slice(scirs2_core::ndarray::s![..n]);
let f_val = f(t, y, y_prime);
residual
.slice_mut(scirs2_core::ndarray::s![..n])
.assign(&f_val);
for level in 1..index_level {
let start_idx = level * n;
let _end_idx = (level + 1) * n;
let h_diff = F::from_f64(1e-6).expect("Operation failed");
for i in 0..n {
let t_plus = t + h_diff;
let f_t_plus = f(t_plus, y, y_prime);
let f_t = f(t, y, y_prime);
let df_dt = (f_t_plus[i] - f_t[i]) / h_diff;
residual[start_idx + i] =
df_dt + f_val[i] * F::from_f64(0.1).expect("Operation failed");
}
}
residual
};
Ok((extended_system, extended_y0, extended_y_prime0))
}