use crate::common::IntegrateFloat;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
#[allow(dead_code)]
pub fn compute_banded_jacobian<F, Func>(
f: &Func,
t: F,
y: &Array1<F>,
f_current: &Array1<F>,
lower: usize,
upper: usize,
) -> Array2<F>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let n = y.len();
let mut jac = Array2::<F>::zeros((n, n));
let eps = F::from_f64(1e-8).expect("Operation failed");
for j in 0..n {
let row_start = j.saturating_sub(lower);
let row_end = (j + upper + 1).min(n);
if row_start < row_end {
let mut y_perturbed = y.clone();
let perturbation = eps * (F::one() + y[j].abs()).max(F::one());
y_perturbed[j] += perturbation;
let f_perturbed = f(t, y_perturbed.view());
for i in row_start..row_end {
jac[[i, j]] = (f_perturbed[i] - f_current[i]) / perturbation;
}
}
}
jac
}
#[allow(dead_code)]
pub fn compute_diagonal_jacobian<F, Func>(
f: &Func,
t: F,
y: &Array1<F>,
f_current: &Array1<F>,
block_size: usize,
) -> Array2<F>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let n = y.len();
let mut jac = Array2::<F>::zeros((n, n));
let eps = F::from_f64(1e-8).expect("Operation failed");
for j in 0..n {
let block_idx = j / block_size;
let block_start = block_idx * block_size;
let block_end = (block_start + block_size).min(n);
let mut y_perturbed = y.clone();
let perturbation = eps * (F::one() + y[j].abs()).max(F::one());
y_perturbed[j] += perturbation;
let f_perturbed = f(t, y_perturbed.view());
for i in block_start..block_end {
jac[[i, j]] = (f_perturbed[i] - f_current[i]) / perturbation;
}
}
jac
}
#[allow(dead_code)]
pub fn compute_colored_jacobian<F, Func>(
f: &Func,
t: F,
y: &Array1<F>,
f_current: &Array1<F>,
coloring: &[usize], ) -> Array2<F>
where
F: IntegrateFloat,
Func: Fn(F, ArrayView1<F>) -> Array1<F>,
{
let n = y.len();
let mut jac = Array2::<F>::zeros((n, n));
let eps = F::from_f64(1e-8).expect("Operation failed");
let max_color = coloring.iter().max().map_or(0, |&x| x) + 1;
for color in 0..max_color {
let mut y_perturbed = y.clone();
let mut perturbations = vec![F::zero(); n];
for j in 0..n {
if coloring[j] == color {
let perturbation = eps * (F::one() + y[j].abs()).max(F::one());
y_perturbed[j] += perturbation;
perturbations[j] = perturbation;
}
}
let f_perturbed = f(t, y_perturbed.view());
for j in 0..n {
if coloring[j] == color && perturbations[j] > F::zero() {
for i in 0..n {
jac[[i, j]] = (f_perturbed[i] - f_current[i]) / perturbations[j];
}
}
}
}
jac
}
#[allow(dead_code)]
pub fn generate_banded_coloring(n: usize, lower: usize, upper: usize) -> Vec<usize> {
let bandwidth = lower + upper + 1;
let mut coloring = vec![0; n];
for (i, color) in coloring.iter_mut().enumerate().take(n) {
*color = i % bandwidth;
}
coloring
}
#[allow(dead_code)]
pub fn broyden_update<F>(_jac: &mut Array2<F>, delta_y: &Array1<F>, deltaf: &Array1<F>)
where
F: IntegrateFloat,
{
let n = delta_y.len();
let mut jac_dy = Array1::zeros(n);
for i in 0..n {
for j in 0..n {
jac_dy[i] += _jac[[i, j]] * delta_y[j];
}
}
let correction = deltaf - &jac_dy;
let dy_norm_squared = delta_y.iter().map(|&x| x * x).sum::<F>();
if dy_norm_squared > F::from_f64(1e-14).expect("Operation failed") {
for i in 0..n {
for j in 0..n {
_jac[[i, j]] += correction[i] * delta_y[j] / dy_norm_squared;
}
}
}
}
#[allow(dead_code)]
pub fn block_update<F>(
jac: &mut Array2<F>,
delta_y: &Array1<F>,
delta_f: &Array1<F>,
block_size: usize,
) where
F: IntegrateFloat,
{
let n = delta_y.len();
let n_blocks = n.div_ceil(block_size);
for block in 0..n_blocks {
let start = block * block_size;
let end = (start + block_size).min(n);
let mut block_dy = Array1::zeros(end - start);
let mut block_df = Array1::zeros(end - start);
for i in start..end {
block_dy[i - start] = delta_y[i];
block_df[i - start] = delta_f[i];
}
let mut block_jac_dy = Array1::zeros(end - start);
for i in 0..(end - start) {
for j in 0..(end - start) {
block_jac_dy[i] += jac[[i + start, j + start]] * block_dy[j];
}
}
let correction = &block_df - &block_jac_dy;
let dy_norm_squared = block_dy.iter().map(|&x| x * x).sum::<F>();
if dy_norm_squared > F::from_f64(1e-14).expect("Operation failed") {
for i in 0..(end - start) {
for j in 0..(end - start) {
jac[[i + start, j + start]] += correction[i] * block_dy[j] / dy_norm_squared;
}
}
}
}
}