use crate::kernel::{ExprId, ExprPool};
use crate::poly::{ConversionError, UniPoly};
pub fn horner(expr: ExprId, var: ExprId, pool: &ExprPool) -> Result<ExprId, ConversionError> {
let poly = UniPoly::from_symbolic(expr, var, pool)?;
let coeffs = poly.coefficients_i64(); Ok(build_horner(&coeffs, var, pool))
}
fn build_horner(coeffs: &[i64], var: ExprId, pool: &ExprPool) -> ExprId {
if coeffs.is_empty() {
return pool.integer(0_i32);
}
let n = coeffs.len();
let mut result = pool.integer(coeffs[n - 1]);
for k in (0..n - 1).rev() {
let xr = pool.mul(vec![var, result]);
let ck = pool.integer(coeffs[k]);
result = pool.add(vec![ck, xr]);
}
result
}
pub fn emit_horner_c(
expr: ExprId,
var: ExprId,
var_name: &str,
fn_name: &str,
pool: &ExprPool,
) -> Result<String, ConversionError> {
let poly = UniPoly::from_symbolic(expr, var, pool)?;
let coeffs = poly.coefficients_i64();
let body = build_c_horner(&coeffs, var_name);
Ok(format!(
"double {}(double {}) {{\n return {};\n}}\n",
fn_name, var_name, body
))
}
#[inline]
pub fn eval_horner_f64(coeffs: &[f64], x: f64) -> f64 {
if coeffs.is_empty() {
return 0.0;
}
let mut acc = coeffs[coeffs.len() - 1];
for &c in coeffs[..coeffs.len() - 1].iter().rev() {
acc = c + x * acc;
}
acc
}
pub fn eval_horner_f64_batch(coeffs: &[f64], xs: &[f64], out: &mut [f64]) {
assert_eq!(xs.len(), out.len());
let mut i = 0;
while i + 4 <= xs.len() {
let chunk = wide::f64x4::new([xs[i], xs[i + 1], xs[i + 2], xs[i + 3]]);
let vals = eval_horner_f64x4(coeffs, chunk).to_array();
out[i..i + 4].copy_from_slice(&vals);
i += 4;
}
for (x, o) in xs[i..].iter().zip(out[i..].iter_mut()) {
*o = eval_horner_f64(coeffs, *x);
}
}
#[inline]
fn eval_horner_f64x4(coeffs: &[f64], x: wide::f64x4) -> wide::f64x4 {
if coeffs.is_empty() {
return wide::f64x4::splat(0.0);
}
let mut acc = wide::f64x4::splat(coeffs[coeffs.len() - 1]);
for &c in coeffs[..coeffs.len() - 1].iter().rev() {
acc = wide::f64x4::splat(c) + x * acc;
}
acc
}
fn build_c_horner(coeffs: &[i64], var: &str) -> String {
if coeffs.is_empty() {
return "0.0".to_string();
}
let n = coeffs.len();
let mut result = format!("{}.0", coeffs[n - 1]);
for k in (0..n - 1).rev() {
let ck = format!("{}.0", coeffs[k]);
result = format!("{} + {} * ({})", ck, var, result);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::jit::eval_interp;
use crate::kernel::{Domain, ExprPool};
use std::collections::HashMap;
fn p() -> ExprPool {
ExprPool::new()
}
#[test]
fn horner_linear() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.add(vec![
pool.mul(vec![pool.integer(2_i32), x]),
pool.integer(1_i32),
]);
let h = horner(expr, x, &pool).unwrap();
let mut env = HashMap::new();
env.insert(x, 3.0f64);
let val = eval_interp(h, &env, &pool).unwrap();
assert!((val - 7.0).abs() < 1e-10, "expected 7.0, got {val}");
}
#[test]
fn horner_quadratic() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let two_x = pool.mul(vec![pool.integer(2_i32), x]);
let one = pool.integer(1_i32);
let expr = pool.add(vec![x2, two_x, one]);
let h = horner(expr, x, &pool).unwrap();
let mut env = HashMap::new();
for v in [-2.0f64, -1.0, 0.0, 1.0, 2.0, 3.0] {
env.insert(x, v);
let expected = (v + 1.0).powi(2);
let actual = eval_interp(h, &env, &pool).unwrap();
assert!(
(actual - expected).abs() < 1e-9,
"v={v}: expected {expected}, got {actual}"
);
}
}
#[test]
fn horner_degree_10_op_count() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let mut expr = pool.integer(1_i32);
for k in 1_i32..=10 {
let xk = pool.pow(x, pool.integer(k));
expr = pool.add(vec![expr, xk]);
}
let h = horner(expr, x, &pool).unwrap();
let muls = count_muls(h, &pool);
assert!(
muls <= 10,
"Horner form should use ≤ 10 multiplications, got {muls}"
);
}
fn count_muls(expr: ExprId, pool: &ExprPool) -> usize {
use crate::kernel::ExprData;
match pool.get(expr) {
ExprData::Mul(args) => 1 + args.iter().map(|&a| count_muls(a, pool)).sum::<usize>(),
ExprData::Add(args) => args.iter().map(|&a| count_muls(a, pool)).sum(),
ExprData::Pow { base, exp } => count_muls(base, pool) + count_muls(exp, pool),
_ => 0,
}
}
#[test]
fn emit_horner_c_quadratic() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let two_x = pool.mul(vec![pool.integer(2_i32), x]);
let one = pool.integer(1_i32);
let expr = pool.add(vec![x2, two_x, one]);
let code = emit_horner_c(expr, x, "x", "eval_quad", &pool).unwrap();
assert!(code.contains("eval_quad"), "function name not in output");
assert!(code.contains("double"), "return type not in output");
}
#[test]
fn horner_constant() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let five = pool.integer(5_i32);
let h = horner(five, x, &pool).unwrap();
let env = HashMap::new();
let val = eval_interp(h, &env, &pool).unwrap();
assert!((val - 5.0).abs() < 1e-10);
}
#[test]
fn eval_horner_f64_matches_interp() {
let coeffs = [1.0, 2.0, 3.0]; let xs = [-1.0, 0.0, 0.5, 2.0, 10.0];
for &x in &xs {
let scalar = eval_horner_f64(&coeffs, x);
let expected = 1.0 + x * (2.0 + x * 3.0);
assert!((scalar - expected).abs() < 1e-12, "x={x}");
}
}
#[test]
fn eval_horner_f64_batch_matches_scalar() {
let coeffs = [1.0, 2.0, 3.0];
let xs = [-1.0, 0.0, 0.5, 2.0, 10.0, 3.0, 7.0];
let mut out = vec![0.0; xs.len()];
eval_horner_f64_batch(&coeffs, &xs, &mut out);
for (i, &x) in xs.iter().enumerate() {
assert!((out[i] - eval_horner_f64(&coeffs, x)).abs() < 1e-12);
}
}
}