use crate::ir::eval::fold_fma_literal;
use crate::ir::Expr;
pub(super) fn simplify_fma(a: &Expr, b: &Expr, c: &Expr) -> Option<Expr> {
if let Some(folded) = fold_fma_literal(a, b, c) {
return Some(folded);
}
if matches!(a, Expr::LitF32(v) if *v == 1.0) {
return Some(Expr::add(b.clone(), c.clone()));
}
if matches!(b, Expr::LitF32(v) if *v == 1.0) {
return Some(Expr::add(a.clone(), c.clone()));
}
if matches!((a, b), (Expr::LitF32(v), Expr::LitF32(other)) if *v == 0.0 && other.is_finite()) {
return Some(c.clone());
}
if matches!((a, b), (Expr::LitF32(other), Expr::LitF32(v)) if other.is_finite() && *v == 0.0) {
return Some(c.clone());
}
None
}