use std::sync::Arc;
use scirs2_symbolic::eml::eval::{eval_real, EvalCtx};
use scirs2_symbolic::eml::LoweredOp;
use crate::quadrature::gaussian::gauss_legendre as gauss_legendre_nodes;
#[derive(Debug)]
pub enum SymbolicQuadError {
EvalError(String),
InvalidInterval(f64, f64),
InvalidNodeCount(usize),
NodeComputationError(String),
}
impl std::fmt::Display for SymbolicQuadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SymbolicQuadError::EvalError(msg) => write!(f, "symbolic eval error: {msg}"),
SymbolicQuadError::InvalidInterval(a, b) => {
write!(f, "invalid interval [{a}, {b}]: must have a < b")
}
SymbolicQuadError::InvalidNodeCount(n) => {
write!(f, "invalid node count {n}: must be >= 1")
}
SymbolicQuadError::NodeComputationError(msg) => {
write!(f, "Gauss-Legendre node computation failed: {msg}")
}
}
}
}
impl std::error::Error for SymbolicQuadError {}
pub fn quad_gauss_legendre_symbolic(
integrand: &Arc<LoweredOp>,
a: f64,
b: f64,
n: usize,
) -> Result<f64, SymbolicQuadError> {
if a >= b {
return Err(SymbolicQuadError::InvalidInterval(a, b));
}
if n == 0 {
return Err(SymbolicQuadError::InvalidNodeCount(n));
}
let (xi_vec, wi_vec) = gauss_legendre_nodes(n)
.map_err(|e| SymbolicQuadError::NodeComputationError(e.to_string()))?;
let half_len = 0.5 * (b - a);
let mid = 0.5 * (a + b);
let mut integral = 0.0_f64;
let mut binding = [0.0_f64; 1];
for (xi, wi) in xi_vec.iter().zip(wi_vec.iter()) {
let x_i = half_len * xi + mid;
let w_i = half_len * wi;
binding[0] = x_i;
let ctx = EvalCtx::new(&binding);
let val = eval_real(integrand.as_ref(), &ctx)
.map_err(|e| SymbolicQuadError::EvalError(e.to_string()))?;
integral += w_i * val;
}
Ok(integral)
}
#[cfg(test)]
mod tests {
use super::*;
fn const_op(v: f64) -> Arc<LoweredOp> {
Arc::new(LoweredOp::Const(v))
}
fn var0() -> Arc<LoweredOp> {
Arc::new(LoweredOp::Var(0))
}
#[test]
fn quad_constant() {
let result = quad_gauss_legendre_symbolic(&const_op(3.0), 0.0, 1.0, 5).unwrap();
assert!((result - 3.0).abs() < 1e-12, "result = {result}");
}
#[test]
fn quad_linear_exact() {
let result = quad_gauss_legendre_symbolic(&var0(), 0.0, 1.0, 1).unwrap();
assert!((result - 0.5).abs() < 1e-12, "result = {result}");
}
#[test]
fn quad_invalid_interval() {
let err = quad_gauss_legendre_symbolic(&var0(), 1.0, 0.0, 5);
assert!(matches!(
err,
Err(SymbolicQuadError::InvalidInterval(1.0, 0.0))
));
}
#[test]
fn quad_invalid_node_count() {
let err = quad_gauss_legendre_symbolic(&var0(), 0.0, 1.0, 0);
assert!(matches!(err, Err(SymbolicQuadError::InvalidNodeCount(0))));
}
}