use crate::Float;
use crate::Tensor;
use crate::error::{CoreError, Result};
pub enum Expr<'a, T: Float> {
Input(&'a Tensor<T>),
Scalar(T),
Add(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
Sub(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
Mul(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
Div(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
Neg(Box<Expr<'a, T>>),
Sqrt(Box<Expr<'a, T>>),
Exp(Box<Expr<'a, T>>),
Ln(Box<Expr<'a, T>>),
Abs(Box<Expr<'a, T>>),
Sin(Box<Expr<'a, T>>),
Cos(Box<Expr<'a, T>>),
Pow(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
Fma(Box<Expr<'a, T>>, Box<Expr<'a, T>>, Box<Expr<'a, T>>),
Clamp(Box<Expr<'a, T>>, T, T),
}
#[allow(clippy::should_implement_trait)]
impl<'a, T: Float> Expr<'a, T> {
pub fn input(tensor: &'a Tensor<T>) -> Self {
Expr::Input(tensor)
}
pub fn scalar(val: T) -> Self {
Expr::Scalar(val)
}
pub fn add(self, other: Self) -> Self {
Expr::Add(Box::new(self), Box::new(other))
}
pub fn sub(self, other: Self) -> Self {
Expr::Sub(Box::new(self), Box::new(other))
}
pub fn mul(self, other: Self) -> Self {
Expr::Mul(Box::new(self), Box::new(other))
}
pub fn div(self, other: Self) -> Self {
Expr::Div(Box::new(self), Box::new(other))
}
pub fn neg(self) -> Self {
Expr::Neg(Box::new(self))
}
pub fn sqrt(self) -> Self {
Expr::Sqrt(Box::new(self))
}
pub fn exp(self) -> Self {
Expr::Exp(Box::new(self))
}
pub fn ln(self) -> Self {
Expr::Ln(Box::new(self))
}
pub fn abs(self) -> Self {
Expr::Abs(Box::new(self))
}
pub fn sin(self) -> Self {
Expr::Sin(Box::new(self))
}
pub fn cos(self) -> Self {
Expr::Cos(Box::new(self))
}
pub fn pow(self, other: Self) -> Self {
Expr::Pow(Box::new(self), Box::new(other))
}
pub fn fma(self, b: Self, c: Self) -> Self {
Expr::Fma(Box::new(self), Box::new(b), Box::new(c))
}
pub fn clamp(self, min: T, max: T) -> Self {
Expr::Clamp(Box::new(self), min, max)
}
pub fn eval(&self) -> Result<Tensor<T>> {
let shape = collect_shape(self)?;
let numel: usize = shape.iter().product();
let mut result = Vec::with_capacity(numel);
for i in 0..numel {
result.push(self.eval_at(i));
}
Tensor::from_vec(result, shape)
}
fn eval_at(&self, idx: usize) -> T {
match self {
Expr::Input(t) => t.as_slice()[idx],
Expr::Scalar(v) => *v,
Expr::Add(a, b) => a.eval_at(idx) + b.eval_at(idx),
Expr::Sub(a, b) => a.eval_at(idx) - b.eval_at(idx),
Expr::Mul(a, b) => a.eval_at(idx) * b.eval_at(idx),
Expr::Div(a, b) => a.eval_at(idx) / b.eval_at(idx),
Expr::Neg(a) => T::zero() - a.eval_at(idx),
Expr::Sqrt(a) => a.eval_at(idx).sqrt(),
Expr::Exp(a) => a.eval_at(idx).exp(),
Expr::Ln(a) => a.eval_at(idx).ln(),
Expr::Abs(a) => a.eval_at(idx).abs(),
Expr::Sin(a) => a.eval_at(idx).sin(),
Expr::Cos(a) => a.eval_at(idx).cos(),
Expr::Pow(a, b) => a.eval_at(idx).powf(b.eval_at(idx)),
Expr::Fma(a, b, c) => a.eval_at(idx) * b.eval_at(idx) + c.eval_at(idx),
Expr::Clamp(a, min, max) => {
let v = a.eval_at(idx);
if v < *min {
*min
} else if v > *max {
*max
} else {
v
}
}
}
}
}
fn collect_shape<T: Float>(expr: &Expr<'_, T>) -> Result<Vec<usize>> {
let mut shape: Option<Vec<usize>> = None;
collect_shape_inner(expr, &mut shape)?;
Ok(shape.unwrap_or_else(|| vec![1]))
}
fn collect_shape_inner<T: Float>(expr: &Expr<'_, T>, shape: &mut Option<Vec<usize>>) -> Result<()> {
match expr {
Expr::Input(t) => {
let s = t.shape();
match shape {
Some(existing) if existing.as_slice() != s => {
return Err(CoreError::DimensionMismatch {
expected: existing.clone(),
got: s.to_vec(),
});
}
None => {
*shape = Some(s.to_vec());
}
_ => {}
}
Ok(())
}
Expr::Scalar(_) => Ok(()),
Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b) | Expr::Div(a, b) | Expr::Pow(a, b) => {
collect_shape_inner(a, shape)?;
collect_shape_inner(b, shape)
}
Expr::Neg(a)
| Expr::Sqrt(a)
| Expr::Exp(a)
| Expr::Ln(a)
| Expr::Abs(a)
| Expr::Sin(a)
| Expr::Cos(a)
| Expr::Clamp(a, _, _) => collect_shape_inner(a, shape),
Expr::Fma(a, b, c) => {
collect_shape_inner(a, shape)?;
collect_shape_inner(b, shape)?;
collect_shape_inner(c, shape)
}
}
}
pub fn eval_expr<T: Float>(expr: &Expr<'_, T>) -> Result<Tensor<T>> {
expr.eval()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expr_basic_arithmetic() {
let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0], vec![2, 2]).unwrap();
let c = Tensor::from_vec(vec![2.0, 2.0, 2.0, 2.0], vec![2, 2]).unwrap();
let result = Expr::input(&a)
.add(Expr::input(&b))
.mul(Expr::input(&c))
.eval()
.unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result.as_slice(), &[22.0, 44.0, 66.0, 88.0]);
}
#[test]
fn test_expr_unary_ops() {
let a = Tensor::from_vec(vec![-4.0_f64, -9.0, -16.0], vec![3]).unwrap();
let result = Expr::input(&a).abs().sqrt().eval().unwrap();
assert_eq!(result.shape(), &[3]);
assert_eq!(result.as_slice(), &[2.0, 3.0, 4.0]);
}
#[test]
fn test_expr_fma() {
let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
let b = Tensor::from_vec(vec![4.0, 5.0, 6.0], vec![3]).unwrap();
let c = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
let result = Expr::input(&a)
.fma(Expr::input(&b), Expr::input(&c))
.eval()
.unwrap();
assert_eq!(result.as_slice(), &[14.0, 30.0, 48.0]);
}
#[test]
fn test_expr_scalar_broadcast() {
let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![4]).unwrap();
let result = Expr::input(&a).add(Expr::scalar(2.0)).eval().unwrap();
assert_eq!(result.as_slice(), &[3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_expr_shape_mismatch() {
let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
let b = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let err = Expr::input(&a).add(Expr::input(&b)).eval();
assert!(err.is_err());
match err.unwrap_err() {
CoreError::DimensionMismatch { expected, got } => {
assert_eq!(expected, vec![3]);
assert_eq!(got, vec![4]);
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn test_expr_complex_chain() {
let a = Tensor::from_vec(vec![0.0_f64, 2.0, 4.0], vec![3]).unwrap();
let b = Tensor::from_vec(vec![0.0, core::f64::consts::PI, 0.0], vec![3]).unwrap();
let result = Expr::input(&a)
.mul(Expr::scalar(0.5))
.exp()
.add(Expr::input(&b).cos())
.eval()
.unwrap();
let expected = [
(0.0_f64 * 0.5).exp() + 0.0_f64.cos(), (2.0_f64 * 0.5).exp() + core::f64::consts::PI.cos(), (4.0_f64 * 0.5).exp() + 0.0_f64.cos(), ];
let result_slice = result.as_slice();
for (i, (&got, &exp)) in result_slice.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-12,
"index {i}: got {got}, expected {exp}"
);
}
}
}