use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::parallel_ops::*;
use scirs2_sparse::{csr_array::CsrArray, sparray::SparseArray};
use super::coloring::determine_column_groups;
use super::finite_diff::{compute_step_sizes, SparseFiniteDiffOptions};
use crate::error::OptimizeError;
#[allow(dead_code)]
fn update_sparse_value(matrix: &mut CsrArray<f64>, row: usize, col: usize, value: f64) {
if matrix.get(row, col) != 0.0 && matrix.set(row, col, value).is_err() {
}
}
#[allow(dead_code)]
fn exists_in_sparsity(matrix: &CsrArray<f64>, row: usize, col: usize) -> bool {
matrix.get(row, col) != 0.0
}
#[allow(dead_code)]
pub fn sparse_hessian<F, G>(
func: F,
grad: Option<G>,
x: &ArrayView1<f64>,
f0: Option<f64>,
g0: Option<&Array1<f64>>,
sparsity_pattern: Option<&CsrArray<f64>>,
options: Option<SparseFiniteDiffOptions>,
) -> Result<CsrArray<f64>, OptimizeError>
where
F: Fn(&ArrayView1<f64>) -> f64 + Sync,
G: Fn(&ArrayView1<f64>) -> Array1<f64> + Sync + 'static,
{
let options = options.unwrap_or_default();
let n = x.len();
if let Some(gradient_fn) = grad {
return compute_hessian_from_gradient(gradient_fn, x, g0, sparsity_pattern, &options);
}
let sparsity_owned: CsrArray<f64>;
let sparsity = match sparsity_pattern {
Some(p) => {
if p.shape().0 != n || p.shape().1 != n {
return Err(OptimizeError::ValueError(format!(
"Sparsity _pattern shape {:?} does not match input dimension {}",
p.shape(),
n
)));
}
p
}
None => {
let mut data = Vec::with_capacity(n * n);
let mut rows = Vec::with_capacity(n * n);
let mut cols = Vec::with_capacity(n * n);
for i in 0..n {
for j in 0..n {
data.push(1.0);
rows.push(i);
cols.push(j);
}
}
sparsity_owned = CsrArray::from_triplets(&rows, &cols, &data, (n, n), false)?;
&sparsity_owned
}
};
let symmetric_sparsity = make_symmetric_sparsity(sparsity)?;
let result = match options.method.as_str() {
"2-point" => {
let f0_val = f0.unwrap_or_else(|| func(x));
compute_hessian_2point(func, x, f0_val, &symmetric_sparsity, &options)
}
"3-point" => compute_hessian_3point(func, x, &symmetric_sparsity, &options),
"cs" => compute_hessian_complex_step(func, x, &symmetric_sparsity, &options),
_ => Err(OptimizeError::ValueError(format!(
"Unknown method: {}. Valid options are '2-point', '3-point', and 'cs'",
options.method
))),
}?;
fill_symmetric_hessian(&result)
}
#[allow(dead_code)]
fn compute_hessian_from_gradient<G>(
grad_fn: G,
x: &ArrayView1<f64>,
g0: Option<&Array1<f64>>,
sparsity_pattern: Option<&CsrArray<f64>>,
options: &SparseFiniteDiffOptions,
) -> Result<CsrArray<f64>, OptimizeError>
where
G: Fn(&ArrayView1<f64>) -> Array1<f64> + Sync + 'static,
{
let _n = x.len();
let g0_owned: Array1<f64>;
let g0_ref = match g0 {
Some(g) => g,
None => {
g0_owned = grad_fn(x);
&g0_owned
}
};
let jac_options = SparseFiniteDiffOptions {
method: options.method.clone(),
rel_step: options.rel_step,
abs_step: options.abs_step,
bounds: options.bounds.clone(),
parallel: options.parallel.clone(),
seed: options.seed,
max_group_size: options.max_group_size,
};
let hessian = super::jacobian::sparse_jacobian(
grad_fn,
x,
Some(g0_ref),
sparsity_pattern,
Some(jac_options),
)?;
fill_symmetric_hessian(&hessian)
}
#[allow(dead_code)]
fn compute_hessian_2point<F>(
func: F,
x: &ArrayView1<f64>,
f0: f64,
sparsity: &CsrArray<f64>,
options: &SparseFiniteDiffOptions,
) -> Result<CsrArray<f64>, OptimizeError>
where
F: Fn(&ArrayView1<f64>) -> f64 + Sync,
{
let _n = x.len();
let groups = determine_column_groups(sparsity, None, None)?;
let h = compute_step_sizes(x, options);
let (rows, cols, _) = sparsity.find();
let (m, n) = sparsity.shape();
let zeros = vec![0.0; rows.len()];
let mut hess = CsrArray::from_triplets(&rows.to_vec(), &cols.to_vec(), &zeros, (m, n), false)?;
let mut x_perturbed = x.to_owned();
let parallel = options
.parallel
.as_ref()
.map(|p| p.num_workers.unwrap_or(1) > 1)
.unwrap_or(false);
let diag_evals: Vec<f64> = if parallel {
(0..n)
.into_par_iter()
.map(|i| {
let mut x_local = x.to_owned();
x_local[i] += h[i];
func(&x_local.view())
})
.collect()
} else {
let mut diag_vals = vec![0.0; n];
for i in 0..n {
x_perturbed[i] += h[i];
diag_vals[i] = func(&x_perturbed.view());
x_perturbed[i] = x[i];
}
diag_vals
};
for i in 0..n {
let d2f_dxi2 = (diag_evals[i] - 2.0 * f0 + diag_evals[i]) / (h[i] * h[i]);
update_sparse_value(&mut hess, i, i, d2f_dxi2);
}
if parallel {
let derivatives: Vec<(usize, usize, f64)> = groups
.par_iter()
.flat_map(|group| {
let mut derivatives = Vec::new();
let mut x_local = x.to_owned();
for &j in group {
for i in 0..j {
if exists_in_sparsity(&hess, i, j) {
x_local[i] += h[i];
x_local[j] += h[j];
let f_ij = func(&x_local.view());
x_local[j] = x[j];
let f_i = diag_evals[i];
x_local[i] = x[i];
x_local[j] += h[j];
let f_j = diag_evals[j];
let d2f_dxidxj = (f_ij - f_i - f_j + f0) / (h[i] * h[j]);
derivatives.push((i, j, d2f_dxidxj));
x_local[j] = x[j];
}
}
}
derivatives
})
.collect();
for (i, j, d2f_dxidxj) in derivatives {
if hess.set(i, j, d2f_dxidxj).is_err() {
}
}
} else {
for group in &groups {
for &j in group {
for i in 0..j {
if exists_in_sparsity(&hess, i, j) {
x_perturbed[i] += h[i];
x_perturbed[j] += h[j];
let f_ij = func(&x_perturbed.view());
let d2f_dxidxj =
(f_ij - diag_evals[i] - diag_evals[j] + f0) / (h[i] * h[j]);
update_sparse_value(&mut hess, i, j, d2f_dxidxj);
x_perturbed[i] = x[i];
x_perturbed[j] = x[j];
}
}
}
}
}
Ok(hess)
}
#[allow(dead_code)]
fn compute_hessian_3point<F>(
func: F,
x: &ArrayView1<f64>,
sparsity: &CsrArray<f64>,
options: &SparseFiniteDiffOptions,
) -> Result<CsrArray<f64>, OptimizeError>
where
F: Fn(&ArrayView1<f64>) -> f64 + Sync,
{
let n = x.len();
let groups = determine_column_groups(sparsity, None, None)?;
let h = compute_step_sizes(x, options);
let (rows, cols, _) = sparsity.find();
let (m, n_cols) = sparsity.shape();
let zeros = vec![0.0; rows.len()];
let mut hess =
CsrArray::from_triplets(&rows.to_vec(), &cols.to_vec(), &zeros, (m, n_cols), false)?;
let mut x_perturbed = x.to_owned();
let parallel = options
.parallel
.as_ref()
.map(|p| p.num_workers.unwrap_or(1) > 1)
.unwrap_or(false);
let diag_evals: Vec<(f64, f64)> = if parallel {
(0..n)
.into_par_iter()
.map(|i| {
let mut x_local = x.to_owned();
x_local[i] += h[i];
let f_plus = func(&x_local.view());
x_local[i] = x[i] - h[i];
let f_minus = func(&x_local.view());
(f_plus, f_minus)
})
.collect()
} else {
let mut diag_vals = vec![(0.0, 0.0); n];
for i in 0..n {
x_perturbed[i] += h[i];
let f_plus = func(&x_perturbed.view());
x_perturbed[i] = x[i] - h[i];
let f_minus = func(&x_perturbed.view());
diag_vals[i] = (f_plus, f_minus);
x_perturbed[i] = x[i];
}
diag_vals
};
let f0 = func(x);
for i in 0..n {
let (f_plus, f_minus) = diag_evals[i];
let d2f_dxi2 = (f_plus - 2.0 * f0 + f_minus) / (h[i] * h[i]);
update_sparse_value(&mut hess, i, i, d2f_dxi2);
}
if parallel {
let derivatives: Vec<(usize, usize, f64)> = groups
.par_iter()
.flat_map(|group| {
let mut derivatives = Vec::new();
let mut x_local = x.to_owned();
for &j in group {
for i in 0..j {
if exists_in_sparsity(&hess, i, j) {
x_local[i] += h[i];
x_local[j] += h[j];
let f_pp = func(&x_local.view());
x_local[j] = x[j] - h[j];
let f_pm = func(&x_local.view());
x_local[i] = x[i] - h[i];
x_local[j] = x[j] + h[j];
let f_mp = func(&x_local.view());
x_local[j] = x[j] - h[j];
let f_mm = func(&x_local.view());
let d2f_dxidxj = (f_pp - f_pm - f_mp + f_mm) / (4.0 * h[i] * h[j]);
derivatives.push((i, j, d2f_dxidxj));
x_local[i] = x[i];
x_local[j] = x[j];
}
}
}
derivatives
})
.collect();
for (i, j, d2f_dxidxj) in derivatives {
if hess.set(i, j, d2f_dxidxj).is_err() {
}
}
} else {
for group in &groups {
for &j in group {
for i in 0..j {
if exists_in_sparsity(&hess, i, j) {
x_perturbed[i] += h[i];
x_perturbed[j] += h[j];
let f_pp = func(&x_perturbed.view());
x_perturbed[j] = x[j] - h[j];
let f_pm = func(&x_perturbed.view());
x_perturbed[i] = x[i] - h[i];
x_perturbed[j] = x[j] + h[j];
let f_mp = func(&x_perturbed.view());
x_perturbed[j] = x[j] - h[j];
let f_mm = func(&x_perturbed.view());
let d2f_dxidxj = (f_pp - f_pm - f_mp + f_mm) / (4.0 * h[i] * h[j]);
update_sparse_value(&mut hess, i, j, d2f_dxidxj);
x_perturbed[i] = x[i];
x_perturbed[j] = x[j];
}
}
}
}
}
Ok(hess)
}
#[allow(dead_code)]
fn compute_hessian_complex_step<F>(
func: F,
x: &ArrayView1<f64>,
sparsity: &CsrArray<f64>,
options: &SparseFiniteDiffOptions,
) -> Result<CsrArray<f64>, OptimizeError>
where
F: Fn(&ArrayView1<f64>) -> f64 + Sync,
{
let n = x.len();
let h = options.abs_step.unwrap_or(1e-20);
let groups = determine_column_groups(sparsity, None, None)?;
let (rows, cols, _) = sparsity.find();
let zeros = vec![0.0; rows.len()];
let mut hess = CsrArray::from_triplets(&rows.to_vec(), &cols.to_vec(), &zeros, (n, n), false)?;
let parallel = options
.parallel
.as_ref()
.map(|p| p.num_workers.unwrap_or(1) > 1)
.unwrap_or(false);
let _f0 = func(x);
if parallel {
let derivatives: Vec<(usize, usize, f64)> = groups
.par_iter()
.flat_map(|group| {
let mut derivatives = Vec::new();
for &j in group {
if exists_in_sparsity(&hess, j, j) {
let d2f_dxj2 = compute_hessian_diagonal_complex_step(&func, x, j, h);
derivatives.push((j, j, d2f_dxj2));
}
for i in 0..j {
if exists_in_sparsity(&hess, i, j) {
let d2f_dxidxj = compute_hessian_mixed_complex_step(&func, x, i, j, h);
derivatives.push((i, j, d2f_dxidxj));
}
}
}
derivatives
})
.collect();
for (i, j, derivative) in derivatives {
if hess.set(i, j, derivative).is_err() {
}
}
} else {
for group in &groups {
for &j in group {
if exists_in_sparsity(&hess, j, j) {
let d2f_dxj2 = compute_hessian_diagonal_complex_step(&func, x, j, h);
update_sparse_value(&mut hess, j, j, d2f_dxj2);
}
for i in 0..j {
if exists_in_sparsity(&hess, i, j) {
let d2f_dxidxj = compute_hessian_mixed_complex_step(&func, x, i, j, h);
update_sparse_value(&mut hess, i, j, d2f_dxidxj);
}
}
}
}
}
Ok(hess)
}
#[allow(dead_code)]
fn compute_hessian_diagonal_complex_step<F>(func: &F, x: &ArrayView1<f64>, i: usize, h: f64) -> f64
where
F: Fn(&ArrayView1<f64>) -> f64,
{
let mut x_plus = x.to_owned();
let mut x_minus = x.to_owned();
let mut x_plus2 = x.to_owned();
let mut x_minus2 = x.to_owned();
x_plus[i] += h;
x_minus[i] -= h;
x_plus2[i] += 2.0 * h;
x_minus2[i] -= 2.0 * h;
let f_plus = func(&x_plus.view());
let f_minus = func(&x_minus.view());
let f_plus2 = func(&x_plus2.view());
let f_minus2 = func(&x_minus2.view());
let f0 = func(x);
(-f_plus2 + 16.0 * f_plus - 30.0 * f0 + 16.0 * f_minus - f_minus2) / (12.0 * h * h)
}
#[allow(dead_code)]
fn compute_hessian_mixed_complex_step<F>(
func: &F,
x: &ArrayView1<f64>,
i: usize,
j: usize,
h: f64,
) -> f64
where
F: Fn(&ArrayView1<f64>) -> f64,
{
let mut x_pp = x.to_owned();
x_pp[i] += h;
x_pp[j] += h;
let f_pp = func(&x_pp.view());
let mut x_pm = x.to_owned();
x_pm[i] += h;
x_pm[j] -= h;
let f_pm = func(&x_pm.view());
let mut x_mp = x.to_owned();
x_mp[i] -= h;
x_mp[j] += h;
let f_mp = func(&x_mp.view());
let mut x_mm = x.to_owned();
x_mm[i] -= h;
x_mm[j] -= h;
let f_mm = func(&x_mm.view());
(f_pp - f_pm - f_mp + f_mm) / (4.0 * h * h)
}
#[allow(dead_code)]
fn make_symmetric_sparsity(sparsity: &CsrArray<f64>) -> Result<CsrArray<f64>, OptimizeError> {
let (m, n) = sparsity.shape();
if m != n {
return Err(OptimizeError::ValueError(
"Sparsity pattern must be square for Hessian computation".to_string(),
));
}
let dense = sparsity.to_array();
let dense_transposed = dense.t().to_owned();
let mut data = Vec::new();
let mut rows = Vec::new();
let mut cols = Vec::new();
for i in 0..n {
for j in 0..n {
if dense[[i, j]] > 0.0 || dense_transposed[[i, j]] > 0.0 {
rows.push(i);
cols.push(j);
data.push(1.0); }
}
}
Ok(CsrArray::from_triplets(&rows, &cols, &data, (n, n), false)?)
}
#[allow(dead_code)]
fn fill_symmetric_hessian(upper: &CsrArray<f64>) -> Result<CsrArray<f64>, OptimizeError> {
let (n, _) = upper.shape();
if n != upper.shape().1 {
return Err(OptimizeError::ValueError(
"Hessian matrix must be square".to_string(),
));
}
let upper_dense = upper.to_array();
let mut data = Vec::new();
let mut rows = Vec::new();
let mut cols = Vec::new();
for i in 0..n {
for j in 0..n {
let value = upper_dense[[i, j]];
if value != 0.0 {
rows.push(i);
cols.push(j);
data.push(value);
if i != j {
rows.push(j);
cols.push(i);
data.push(value);
}
}
}
}
let full = CsrArray::from_triplets(&rows, &cols, &data, (n, n), false)?;
Ok(full)
}