use scirs2_symbolic::cas::{closed_form_step, LineSearchError as SymLineSearchError};
use scirs2_symbolic::eml::eval::{eval_real, EvalCtx};
use scirs2_symbolic::eml::LoweredOp;
#[derive(Debug)]
pub enum OptLineSearchError {
Symbolic(SymLineSearchError),
EvalError(String),
DimensionMismatch {
expected: usize,
got: usize,
},
}
impl std::fmt::Display for OptLineSearchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Symbolic(e) => write!(f, "symbolic line-search error: {e}"),
Self::EvalError(s) => write!(f, "line-search eval error: {s}"),
Self::DimensionMismatch { expected, got } => write!(
f,
"dimension mismatch: α* expression expects {expected} variables, x has {got}"
),
}
}
}
impl std::error::Error for OptLineSearchError {}
#[derive(Debug)]
pub struct SymbolicLineSearch {
alpha_expr: LoweredOp,
n_vars: usize,
}
impl SymbolicLineSearch {
pub fn new(
f: &LoweredOp,
x_vars: &[usize],
direction: &[LoweredOp],
) -> Result<Self, OptLineSearchError> {
let alpha_expr =
closed_form_step(f, x_vars, direction).map_err(OptLineSearchError::Symbolic)?;
Ok(Self {
alpha_expr,
n_vars: x_vars.len(),
})
}
pub fn eval(&self, x: &[f64]) -> Result<f64, OptLineSearchError> {
if x.len() < self.n_vars {
return Err(OptLineSearchError::DimensionMismatch {
expected: self.n_vars,
got: x.len(),
});
}
let ctx = EvalCtx::new(x);
eval_real(&self.alpha_expr, &ctx).map_err(|e| OptLineSearchError::EvalError(e.to_string()))
}
pub fn alpha_expr(&self) -> &LoweredOp {
&self.alpha_expr
}
}
#[cfg(test)]
mod tests {
use super::*;
fn var(i: usize) -> LoweredOp {
LoweredOp::Var(i)
}
fn c(v: f64) -> LoweredOp {
LoweredOp::Const(v)
}
#[test]
fn test_shifted_quadratic_step_from_origin() {
let inner = LoweredOp::Sub(Box::new(var(0)), Box::new(c(5.0)));
let f = LoweredOp::Mul(Box::new(inner.clone()), Box::new(inner));
let ls = SymbolicLineSearch::new(&f, &[0], &[c(1.0)]).expect("build");
let alpha = ls.eval(&[0.0]).expect("eval");
assert!((alpha - 5.0).abs() < 1e-10, "expected 5.0, got {alpha}");
let x_new = 0.0 + alpha * 1.0;
assert!(
(x_new - 5.0).abs() < 1e-10,
"step should land at x=5 (minimum), got x={x_new}"
);
}
#[test]
fn test_degenerate_direction_propagates() {
let f = var(0);
let result = SymbolicLineSearch::new(&f, &[0], &[c(0.0)]);
assert!(
matches!(
result,
Err(OptLineSearchError::Symbolic(
SymLineSearchError::DegenerateDirection
))
),
"expected DegenerateDirection, got: {result:?}"
);
}
}