use crate::bspline::{BSpline, ExtrapolateMode};
use crate::error::{InterpolateError, InterpolateResult};
#[cfg(feature = "linalg")]
use crate::numerical_stability::{
assess_matrix_condition, solve_with_enhanced_monitoring, solve_with_stability_monitoring,
StabilityLevel,
};
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::{Debug, Display};
use super::types::{Constraint, ConstraintType};
#[allow(dead_code)]
fn solve_constrained_interpolation<T>(
design_matrix: &ArrayView2<T>,
y: &ArrayView1<T>,
constraint_matrix: &ArrayView2<T>,
constraint_rhs: &ArrayView1<T>,
) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ 'static
+ std::fmt::LowerExp,
{
let weight = T::from_f64(1e6).expect("Operation failed");
let weighted_design = design_matrix.map(|&x| x * weight);
let weighted_y = y.map(|&x| x * weight);
solve_constrained_least_squares(
&weighted_design.view(),
&weighted_y.view(),
constraint_matrix,
constraint_rhs,
)
}
#[allow(dead_code)]
fn solve_constrained_least_squares<T>(
design_matrix: &ArrayView2<T>,
y: &ArrayView1<T>,
constraint_matrix: &ArrayView2<T>,
constraint_rhs: &ArrayView1<T>,
) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ 'static
+ std::fmt::LowerExp,
{
let a_transpose = design_matrix.t();
#[cfg(feature = "linalg")]
let ata = a_transpose.dot(design_matrix);
#[cfg(not(feature = "linalg"))]
let _ata = a_transpose.dot(design_matrix);
#[cfg(feature = "linalg")]
let aty = a_transpose.dot(y);
#[cfg(not(feature = "linalg"))]
let _aty = a_transpose.dot(y);
if constraint_matrix.shape()[0] == 0 {
#[cfg(feature = "linalg")]
{
let condition_report = assess_matrix_condition(&ata.view());
if let Ok(report) = condition_report {
match report.stability_level {
StabilityLevel::Poor => {
eprintln!(
"Warning: Normal equations _matrix is poorly conditioned \
(condition number: {:.2e}). Results may be unreliable.",
report.condition_number
);
}
StabilityLevel::Marginal => {
eprintln!(
"Info: Normal equations _matrix has marginal conditioning \
(condition number: {:.2e}). Monitoring solution quality.",
report.condition_number
);
}
_ => {}
}
}
match solve_with_stability_monitoring(&ata.view(), &aty.view()) {
Ok(solution) => return Ok(solution),
Err(_) => {
return Err(InterpolateError::ComputationError(
"Failed to solve the unconstrained least squares problem with stability monitoring".to_string(),
))
}
}
}
#[cfg(not(feature = "linalg"))]
return Err(InterpolateError::NotImplemented(
"Linear algebra operations require the 'linalg' feature".to_string(),
));
}
#[cfg(feature = "linalg")]
let mut c = {
match solve_with_enhanced_monitoring(&ata.view(), &aty.view()) {
Ok((solution, solve_report)) => {
if !solve_report.condition_report.is_well_conditioned {
eprintln!(
"Warning: Initial solution for constrained problem computed with \
poorly conditioned _matrix (condition number: {:.2e})",
solve_report.condition_report.condition_number
);
}
solution
}
Err(_) => {
eprintln!(
"Warning: Stability-monitored solve failed for initial solution. \
Using zero initialization."
);
let n = design_matrix.shape()[1];
Array1::zeros(n)
}
}
};
#[cfg(not(feature = "linalg"))]
let mut c = {
let n = design_matrix.shape()[1];
Array1::zeros(n)
};
#[cfg(feature = "linalg")]
let mut constraint_values = constraint_matrix.dot(&c) - constraint_rhs;
#[cfg(not(feature = "linalg"))]
let mut constraint_values = constraint_matrix.dot(&c) - constraint_rhs;
let mut all_satisfied = true;
for &val in constraint_values.iter() {
if val < T::zero() {
all_satisfied = false;
break;
}
}
if all_satisfied {
return Ok(c);
}
let max_iterations = 100;
let mut iterations = 0;
while iterations < max_iterations {
iterations += 1;
let mut worst_idx = 0;
let mut worst_violation = T::zero();
for (i, &val) in constraint_values.iter().enumerate() {
if val < worst_violation {
worst_idx = i;
worst_violation = val;
}
}
if worst_violation >= -T::epsilon() {
break;
}
let constraint_vector = constraint_matrix.row(worst_idx).to_owned();
#[cfg(feature = "linalg")]
let constraint_norm_squared = constraint_vector.dot(&constraint_vector);
#[cfg(not(feature = "linalg"))]
let constraint_norm_squared = constraint_vector.dot(&constraint_vector);
if constraint_norm_squared < T::epsilon() {
continue;
}
let step_size = -worst_violation / constraint_norm_squared;
for i in 0..c.len() {
c[i] += step_size * constraint_vector[i];
}
#[cfg(feature = "linalg")]
{
constraint_values = constraint_matrix.dot(&c) - constraint_rhs;
}
#[cfg(not(feature = "linalg"))]
{
constraint_values = constraint_matrix.dot(&c) - constraint_rhs;
}
}
#[cfg(feature = "linalg")]
{
constraint_values = constraint_matrix.dot(&c) - constraint_rhs;
}
#[cfg(not(feature = "linalg"))]
{
constraint_values = constraint_matrix.dot(&c) - constraint_rhs;
}
all_satisfied = true;
for &val in constraint_values.iter() {
if val < -T::from_f64(1e-6).expect("Operation failed") {
all_satisfied = false;
break;
}
}
if !all_satisfied {
return Err(InterpolateError::ComputationError(
"Failed to find a solution that satisfies all constraints".to_string(),
));
}
Ok(c)
}
#[allow(dead_code)]
fn solve_constrained_penalized<T>(
design_matrix: &ArrayView2<T>,
y: &ArrayView1<T>,
penalty_matrix: &ArrayView2<T>,
lambda: T,
constraint_matrix: &ArrayView2<T>,
constraint_rhs: &ArrayView1<T>,
) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ 'static
+ std::fmt::LowerExp,
{
let a_transpose = design_matrix.t();
let mut ata = a_transpose.dot(design_matrix);
#[cfg(feature = "linalg")]
let aty = a_transpose.dot(y);
#[cfg(not(feature = "linalg"))]
let _aty = a_transpose.dot(y);
for i in 0..ata.shape()[0] {
for j in 0..ata.shape()[1] {
ata[[i, j]] += lambda * penalty_matrix[[i, j]];
}
}
if constraint_matrix.shape()[0] == 0 {
#[cfg(feature = "linalg")]
{
use scirs2_linalg::solve;
let ata_f64 = ata.mapv(|x| x.to_f64().expect("Operation failed"));
let aty_f64 = aty.mapv(|x| x.to_f64().expect("Operation failed"));
match solve(&ata_f64.view(), &aty_f64.view(), None) {
Ok(solution) => {
return Ok(solution.mapv(|x| T::from_f64(x).expect("Operation failed")))
}
Err(_) => {
return Err(InterpolateError::ComputationError(
"Failed to solve the unconstrained penalized problem".to_string(),
))
}
}
}
#[cfg(not(feature = "linalg"))]
return Err(InterpolateError::NotImplemented(
"Linear algebra operations require the 'linalg' feature".to_string(),
));
}
#[cfg(feature = "linalg")]
let mut c = {
use scirs2_linalg::solve;
let ata_f64 = ata.mapv(|x| x.to_f64().expect("Operation failed"));
let aty_f64 = aty.mapv(|x| x.to_f64().expect("Operation failed"));
match solve(&ata_f64.view(), &aty_f64.view(), None) {
Ok(solution) => solution.mapv(|x| T::from_f64(x).expect("Operation failed")),
Err(_) => {
let n = design_matrix.shape()[1];
Array1::zeros(n)
}
}
};
#[cfg(not(feature = "linalg"))]
let mut c = {
let n = design_matrix.shape()[1];
Array1::zeros(n)
};
let max_iterations = 100;
let mut iterations = 0;
while iterations < max_iterations {
iterations += 1;
#[cfg(feature = "linalg")]
let constraint_values = constraint_matrix.dot(&c) - constraint_rhs;
#[cfg(not(feature = "linalg"))]
let constraint_values = constraint_matrix.dot(&c) - constraint_rhs;
let mut worst_idx = 0;
let mut worst_violation = T::zero();
for (i, &val) in constraint_values.iter().enumerate() {
if val < worst_violation {
worst_idx = i;
worst_violation = val;
}
}
if worst_violation >= -T::epsilon() {
break;
}
let constraint_vector = constraint_matrix.row(worst_idx).to_owned();
#[cfg(feature = "linalg")]
let constraint_norm_squared = constraint_vector.dot(&constraint_vector);
#[cfg(not(feature = "linalg"))]
let constraint_norm_squared = constraint_vector.dot(&constraint_vector);
if constraint_norm_squared < T::epsilon() {
continue;
}
let step_size = -worst_violation / constraint_norm_squared;
for i in 0..c.len() {
c[i] += step_size * constraint_vector[i];
}
}
Ok(c)
}
#[allow(dead_code)]
pub fn solve_constrained_system<T>(
x: &ArrayView1<T>,
y: &ArrayView1<T>,
knots: &ArrayView1<T>,
degree: usize,
constraints: &[Constraint<T>],
) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static
+ std::fmt::LowerExp,
{
let n_coeffs = knots.len() - degree - 1;
let mut design_matrix = Array2::zeros((x.len(), n_coeffs));
for (i, &x_val) in x.iter().enumerate() {
for j in 0..n_coeffs {
let basis = BSpline::basis_element(degree, j, knots, ExtrapolateMode::Extrapolate)?;
design_matrix[[i, j]] = basis.evaluate(x_val)?;
}
}
let (constraint_matrix, constraint_rhs) =
generate_constraint_matrices(x, knots, degree, constraints)?;
solve_constrained_interpolation(
&design_matrix.view(),
y,
&constraint_matrix.view(),
&constraint_rhs.view(),
)
}
#[allow(dead_code)]
pub fn solve_penalized_system<T>(
x: &ArrayView1<T>,
y: &ArrayView1<T>,
knots: &ArrayView1<T>,
degree: usize,
constraints: &[Constraint<T>],
lambda: T,
) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static
+ std::fmt::LowerExp,
{
let n_coeffs = knots.len() - degree - 1;
let mut design_matrix = Array2::zeros((x.len(), n_coeffs));
for (i, &x_val) in x.iter().enumerate() {
for j in 0..n_coeffs {
let basis = BSpline::basis_element(degree, j, knots, ExtrapolateMode::Extrapolate)?;
design_matrix[[i, j]] = basis.evaluate(x_val)?;
}
}
let penalty_matrix = create_penalty_matrix(n_coeffs, degree)?;
let (constraint_matrix, constraint_rhs) =
generate_constraint_matrices(x, knots, degree, constraints)?;
solve_constrained_penalized(
&design_matrix.view(),
y,
&penalty_matrix.view(),
lambda,
&constraint_matrix.view(),
&constraint_rhs.view(),
)
}
#[allow(dead_code)]
fn create_penalty_matrix<T>(n: usize, degree: usize) -> InterpolateResult<Array2<T>>
where
T: Float + FromPrimitive + std::ops::AddAssign + std::ops::SubAssign,
{
let mut penalty = Array2::zeros((n, n));
if degree < 2 {
return Ok(penalty);
}
let one = T::one();
let two = T::from_f64(2.0).expect("Operation failed");
for i in 0..n - 2 {
penalty[[i, i]] += one;
penalty[[i + 1, i + 1]] += two * two;
penalty[[i + 2, i + 2]] += one;
penalty[[i, i + 1]] -= two;
penalty[[i + 1, i]] -= two;
penalty[[i, i + 2]] += one;
penalty[[i + 2, i]] += one;
penalty[[i + 1, i + 2]] -= two;
penalty[[i + 2, i + 1]] -= two;
}
Ok(penalty)
}
#[allow(dead_code)]
pub fn generate_constraint_matrices<T>(
x: &ArrayView1<T>,
knots: &ArrayView1<T>,
degree: usize,
constraints: &[Constraint<T>],
) -> InterpolateResult<(Array2<T>, Array1<T>)>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::RemAssign
+ 'static
+ std::fmt::LowerExp,
{
let n_coeffs = knots.len() - degree - 1;
let x_min = x[0];
let x_max = x[x.len() - 1];
let mut total_constraints = 0;
for constraint in constraints {
let _constraint_x_min = constraint.x_min.unwrap_or(x_min);
let _constraint_x_max = constraint.x_max.unwrap_or(x_max);
let n_eval = 10;
match constraint.constraint_type {
ConstraintType::MonotoneIncreasing | ConstraintType::MonotoneDecreasing => {
total_constraints += n_eval - 1;
}
ConstraintType::Convex | ConstraintType::Concave => {
total_constraints += n_eval;
}
_ => {
total_constraints += n_eval;
}
}
}
let mut constraint_matrix = Array2::zeros((total_constraints, n_coeffs));
let mut constraint_rhs = Array1::zeros(total_constraints);
let mut constraint_idx = 0;
let extrapolate = ExtrapolateMode::Extrapolate;
for constraint in constraints {
let constraint_x_min = constraint.x_min.unwrap_or(x_min);
let constraint_x_max = constraint.x_max.unwrap_or(x_max);
let n_eval = 10;
let mut eval_points = Vec::new();
for i in 0..n_eval {
let t = i as f64 / (n_eval - 1) as f64;
let x_val = constraint_x_min
+ T::from_f64(t).expect("Operation failed") * (constraint_x_max - constraint_x_min);
eval_points.push(x_val);
}
match constraint.constraint_type {
ConstraintType::MonotoneIncreasing => {
for i in 0..eval_points.len() - 1 {
let x_val = (eval_points[i] + eval_points[i + 1])
/ T::from_f64(2.0).expect("Operation failed");
for j in 0..n_coeffs {
let basis = BSpline::basis_element(degree, j, knots, extrapolate)?;
constraint_matrix[[constraint_idx, j]] = basis.derivative(x_val, 1)?;
}
constraint_rhs[constraint_idx] = T::zero();
constraint_idx += 1;
}
}
ConstraintType::MonotoneDecreasing => {
for i in 0..eval_points.len() - 1 {
let x_val = (eval_points[i] + eval_points[i + 1])
/ T::from_f64(2.0).expect("Operation failed");
for j in 0..n_coeffs {
let basis = BSpline::basis_element(degree, j, knots, extrapolate)?;
constraint_matrix[[constraint_idx, j]] = -basis.derivative(x_val, 1)?;
}
constraint_rhs[constraint_idx] = T::zero();
constraint_idx += 1;
}
}
ConstraintType::Convex => {
for &x_val in eval_points.iter() {
for j in 0..n_coeffs {
let basis = BSpline::basis_element(degree, j, knots, extrapolate)?;
constraint_matrix[[constraint_idx, j]] = basis.derivative(x_val, 2)?;
}
constraint_rhs[constraint_idx] = T::zero();
constraint_idx += 1;
}
}
ConstraintType::Concave => {
for &x_val in eval_points.iter() {
for j in 0..n_coeffs {
let basis = BSpline::basis_element(degree, j, knots, extrapolate)?;
constraint_matrix[[constraint_idx, j]] = -basis.derivative(x_val, 2)?;
}
constraint_rhs[constraint_idx] = T::zero();
constraint_idx += 1;
}
}
ConstraintType::Positive => {
for &x_val in eval_points.iter() {
for j in 0..n_coeffs {
let basis = BSpline::basis_element(degree, j, knots, extrapolate)?;
constraint_matrix[[constraint_idx, j]] = basis.evaluate(x_val)?;
}
constraint_rhs[constraint_idx] = T::zero();
constraint_idx += 1;
}
}
ConstraintType::UpperBound => {
let upper_bound = constraint.parameter.unwrap_or(T::one());
for &x_val in eval_points.iter() {
for j in 0..n_coeffs {
let basis = BSpline::basis_element(degree, j, knots, extrapolate)?;
constraint_matrix[[constraint_idx, j]] = basis.evaluate(x_val)?;
}
constraint_rhs[constraint_idx] = upper_bound;
constraint_idx += 1;
}
}
ConstraintType::LowerBound => {
let lower_bound = constraint.parameter.unwrap_or(T::zero());
for &x_val in eval_points.iter() {
for j in 0..n_coeffs {
let basis = BSpline::basis_element(degree, j, knots, extrapolate)?;
constraint_matrix[[constraint_idx, j]] = -basis.evaluate(x_val)?;
}
constraint_rhs[constraint_idx] = -lower_bound;
constraint_idx += 1;
}
}
}
}
if constraint_idx < total_constraints {
constraint_matrix = constraint_matrix
.slice(s![0..constraint_idx, ..])
.to_owned();
constraint_rhs = constraint_rhs.slice(s![0..constraint_idx]).to_owned();
}
Ok((constraint_matrix, constraint_rhs))
}