use crate::kernel::{ExprId, ExprPool};
use crate::poly::{ConversionError, UniPoly};
pub fn horner(expr: ExprId, var: ExprId, pool: &ExprPool) -> Result<ExprId, ConversionError> {
let poly = UniPoly::from_symbolic(expr, var, pool)?;
let coeffs = poly.coefficients_i64(); Ok(build_horner(&coeffs, var, pool))
}
fn build_horner(coeffs: &[i64], var: ExprId, pool: &ExprPool) -> ExprId {
if coeffs.is_empty() {
return pool.integer(0_i32);
}
let n = coeffs.len();
let mut result = pool.integer(coeffs[n - 1]);
for k in (0..n - 1).rev() {
let xr = pool.mul(vec![var, result]);
let ck = pool.integer(coeffs[k]);
result = pool.add(vec![ck, xr]);
}
result
}
pub fn emit_horner_c(
expr: ExprId,
var: ExprId,
var_name: &str,
fn_name: &str,
pool: &ExprPool,
) -> Result<String, ConversionError> {
let poly = UniPoly::from_symbolic(expr, var, pool)?;
let coeffs = poly.coefficients_i64();
let body = build_c_horner(&coeffs, var_name);
Ok(format!(
"double {}(double {}) {{\n return {};\n}}\n",
fn_name, var_name, body
))
}
fn build_c_horner(coeffs: &[i64], var: &str) -> String {
if coeffs.is_empty() {
return "0.0".to_string();
}
let n = coeffs.len();
let mut result = format!("{}.0", coeffs[n - 1]);
for k in (0..n - 1).rev() {
let ck = format!("{}.0", coeffs[k]);
result = format!("{} + {} * ({})", ck, var, result);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::jit::eval_interp;
use crate::kernel::{Domain, ExprPool};
use std::collections::HashMap;
fn p() -> ExprPool {
ExprPool::new()
}
#[test]
fn horner_linear() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let expr = pool.add(vec![
pool.mul(vec![pool.integer(2_i32), x]),
pool.integer(1_i32),
]);
let h = horner(expr, x, &pool).unwrap();
let mut env = HashMap::new();
env.insert(x, 3.0f64);
let val = eval_interp(h, &env, &pool).unwrap();
assert!((val - 7.0).abs() < 1e-10, "expected 7.0, got {val}");
}
#[test]
fn horner_quadratic() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let two_x = pool.mul(vec![pool.integer(2_i32), x]);
let one = pool.integer(1_i32);
let expr = pool.add(vec![x2, two_x, one]);
let h = horner(expr, x, &pool).unwrap();
let mut env = HashMap::new();
for v in [-2.0f64, -1.0, 0.0, 1.0, 2.0, 3.0] {
env.insert(x, v);
let expected = (v + 1.0).powi(2);
let actual = eval_interp(h, &env, &pool).unwrap();
assert!(
(actual - expected).abs() < 1e-9,
"v={v}: expected {expected}, got {actual}"
);
}
}
#[test]
fn horner_degree_10_op_count() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let mut expr = pool.integer(1_i32);
for k in 1_i32..=10 {
let xk = pool.pow(x, pool.integer(k));
expr = pool.add(vec![expr, xk]);
}
let h = horner(expr, x, &pool).unwrap();
let muls = count_muls(h, &pool);
assert!(
muls <= 10,
"Horner form should use ≤ 10 multiplications, got {muls}"
);
}
fn count_muls(expr: ExprId, pool: &ExprPool) -> usize {
use crate::kernel::ExprData;
match pool.get(expr) {
ExprData::Mul(args) => 1 + args.iter().map(|&a| count_muls(a, pool)).sum::<usize>(),
ExprData::Add(args) => args.iter().map(|&a| count_muls(a, pool)).sum(),
ExprData::Pow { base, exp } => count_muls(base, pool) + count_muls(exp, pool),
_ => 0,
}
}
#[test]
fn emit_horner_c_quadratic() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let x2 = pool.pow(x, pool.integer(2_i32));
let two_x = pool.mul(vec![pool.integer(2_i32), x]);
let one = pool.integer(1_i32);
let expr = pool.add(vec![x2, two_x, one]);
let code = emit_horner_c(expr, x, "x", "eval_quad", &pool).unwrap();
assert!(code.contains("eval_quad"), "function name not in output");
assert!(code.contains("double"), "return type not in output");
}
#[test]
fn horner_constant() {
let pool = p();
let x = pool.symbol("x", Domain::Real);
let five = pool.integer(5_i32);
let h = horner(five, x, &pool).unwrap();
let env = HashMap::new();
let val = eval_interp(h, &env, &pool).unwrap();
assert!((val - 5.0).abs() < 1e-10);
}
}