use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use crate::simd::SimdOps;
use scirs2_core::ndarray::Array1;
use scirs2_core::simd_ops::SimdUnifiedOps;
use std::marker::PhantomData;
use super::core::{ArrayExpr, BinaryExpr, Expr, ScalarExpr, UnaryExpr};
use super::enhanced::ReductionExpr;
pub trait FusedOp<T: Clone> {
fn eval_fused(&self) -> Array<T>;
fn fused_size(&self) -> usize;
fn fused_shape(&self) -> &[usize];
}
pub struct FusedElementWiseChain<T, A, B, C, F1, F2>
where
T: Clone,
A: Expr<T>,
B: Expr<T>,
C: Expr<T>,
F1: Fn(T, T) -> T,
F2: Fn(T, T) -> T,
{
a: A,
b: B,
c: C,
op1: F1,
op2: F2,
shape: Vec<usize>,
_phantom: PhantomData<T>,
}
impl<T, A, B, C, F1, F2> FusedElementWiseChain<T, A, B, C, F1, F2>
where
T: Clone,
A: Expr<T>,
B: Expr<T>,
C: Expr<T>,
F1: Fn(T, T) -> T,
F2: Fn(T, T) -> T,
{
pub fn new(a: A, b: B, c: C, op1: F1, op2: F2) -> Result<Self> {
if a.shape() != b.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape().to_vec(),
actual: b.shape().to_vec(),
});
}
if a.shape() != c.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape().to_vec(),
actual: c.shape().to_vec(),
});
}
Ok(Self {
shape: a.shape().to_vec(),
a,
b,
c,
op1,
op2,
_phantom: PhantomData,
})
}
}
impl<T, A, B, C, F1, F2> Expr<T> for FusedElementWiseChain<T, A, B, C, F1, F2>
where
T: Clone,
A: Expr<T>,
B: Expr<T>,
C: Expr<T>,
F1: Fn(T, T) -> T,
F2: Fn(T, T) -> T,
{
#[inline(always)]
fn eval_at(&self, index: usize) -> T {
let av = self.a.eval_at(index);
let bv = self.b.eval_at(index);
let cv = self.c.eval_at(index);
(self.op2)((self.op1)(av, bv), cv)
}
#[inline]
fn size(&self) -> usize {
self.a.size()
}
#[inline]
fn shape(&self) -> &[usize] {
&self.shape
}
}
impl<T, A, B, C, F1, F2> FusedOp<T> for FusedElementWiseChain<T, A, B, C, F1, F2>
where
T: Clone,
A: Expr<T>,
B: Expr<T>,
C: Expr<T>,
F1: Fn(T, T) -> T,
F2: Fn(T, T) -> T,
{
fn eval_fused(&self) -> Array<T> {
let size = self.a.size();
let mut data = Vec::with_capacity(size);
for i in 0..size {
data.push(self.eval_at(i));
}
Array::from_vec(data).reshape(&self.shape)
}
fn fused_size(&self) -> usize {
self.a.size()
}
fn fused_shape(&self) -> &[usize] {
&self.shape
}
}
pub struct FusedScalarBroadcast<T, E>
where
T: Clone,
E: Expr<T>,
{
expr: E,
mul_scalar: T,
add_scalar: T,
fma_fn: fn(T, T, T) -> T,
}
impl<T, E> FusedScalarBroadcast<T, E>
where
T: Clone,
E: Expr<T>,
{
pub fn new(expr: E, mul_scalar: T, add_scalar: T, fma_fn: fn(T, T, T) -> T) -> Self {
Self {
expr,
mul_scalar,
add_scalar,
fma_fn,
}
}
}
impl<T, E> Expr<T> for FusedScalarBroadcast<T, E>
where
T: Clone,
E: Expr<T>,
{
#[inline(always)]
fn eval_at(&self, index: usize) -> T {
let val = self.expr.eval_at(index);
(self.fma_fn)(val, self.mul_scalar.clone(), self.add_scalar.clone())
}
#[inline]
fn size(&self) -> usize {
self.expr.size()
}
#[inline]
fn shape(&self) -> &[usize] {
self.expr.shape()
}
}
impl<T, E> FusedOp<T> for FusedScalarBroadcast<T, E>
where
T: Clone,
E: Expr<T>,
{
fn eval_fused(&self) -> Array<T> {
let size = self.expr.size();
let mut data = Vec::with_capacity(size);
for i in 0..size {
data.push(self.eval_at(i));
}
Array::from_vec(data).reshape(self.expr.shape())
}
fn fused_size(&self) -> usize {
self.expr.size()
}
fn fused_shape(&self) -> &[usize] {
self.expr.shape()
}
}
pub fn fused_scalar_broadcast_f64<E: Expr<f64>>(
expr: E,
mul_scalar: f64,
add_scalar: f64,
) -> FusedScalarBroadcast<f64, E> {
FusedScalarBroadcast::new(expr, mul_scalar, add_scalar, |a, b, c| a.mul_add(b, c))
}
pub fn fused_scalar_broadcast_f32<E: Expr<f32>>(
expr: E,
mul_scalar: f32,
add_scalar: f32,
) -> FusedScalarBroadcast<f32, E> {
FusedScalarBroadcast::new(expr, mul_scalar, add_scalar, |a, b, c| a.mul_add(b, c))
}
pub struct FusedMultiplyAdd<A, B, C>
where
A: Expr<f64>,
B: Expr<f64>,
C: Expr<f64>,
{
a: A,
b: B,
c: C,
shape: Vec<usize>,
}
impl<A, B, C> FusedMultiplyAdd<A, B, C>
where
A: Expr<f64>,
B: Expr<f64>,
C: Expr<f64>,
{
pub fn new(a: A, b: B, c: C) -> Result<Self> {
if a.shape() != b.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape().to_vec(),
actual: b.shape().to_vec(),
});
}
if a.shape() != c.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape().to_vec(),
actual: c.shape().to_vec(),
});
}
Ok(Self {
shape: a.shape().to_vec(),
a,
b,
c,
})
}
}
impl<A, B, C> Expr<f64> for FusedMultiplyAdd<A, B, C>
where
A: Expr<f64>,
B: Expr<f64>,
C: Expr<f64>,
{
#[inline(always)]
fn eval_at(&self, index: usize) -> f64 {
let av = self.a.eval_at(index);
let bv = self.b.eval_at(index);
let cv = self.c.eval_at(index);
av.mul_add(bv, cv)
}
#[inline]
fn size(&self) -> usize {
self.a.size()
}
#[inline]
fn shape(&self) -> &[usize] {
&self.shape
}
fn eval(&self) -> Array<f64> {
let size = self.a.size();
const SIMD_THRESHOLD: usize = 32;
if size >= SIMD_THRESHOLD {
self.eval_simd_fma()
} else {
self.eval_scalar_fma()
}
}
}
impl<A, B, C> FusedMultiplyAdd<A, B, C>
where
A: Expr<f64>,
B: Expr<f64>,
C: Expr<f64>,
{
fn eval_simd_fma(&self) -> Array<f64> {
let a_arr = self.a.eval();
let b_arr = self.b.eval();
let c_arr = self.c.eval();
let a_data = a_arr.to_vec();
let b_data = b_arr.to_vec();
let c_data = c_arr.to_vec();
let a_nd = Array1::from_vec(a_data);
let b_nd = Array1::from_vec(b_data);
let c_nd = Array1::from_vec(c_data);
let result = f64::simd_fma(&a_nd.view(), &b_nd.view(), &c_nd.view());
Array::from_vec(result.to_vec()).reshape(&self.shape)
}
fn eval_scalar_fma(&self) -> Array<f64> {
let size = self.a.size();
let mut data = Vec::with_capacity(size);
for i in 0..size {
let av = self.a.eval_at(i);
let bv = self.b.eval_at(i);
let cv = self.c.eval_at(i);
data.push(av.mul_add(bv, cv));
}
Array::from_vec(data).reshape(&self.shape)
}
}
impl<A, B, C> FusedOp<f64> for FusedMultiplyAdd<A, B, C>
where
A: Expr<f64>,
B: Expr<f64>,
C: Expr<f64>,
{
fn eval_fused(&self) -> Array<f64> {
self.eval()
}
fn fused_size(&self) -> usize {
self.a.size()
}
fn fused_shape(&self) -> &[usize] {
&self.shape
}
}
pub struct FusedReduction<T, A, B, ElemOp, RedOp, Identity>
where
T: Clone,
A: Expr<T>,
B: Expr<T>,
ElemOp: Fn(T, T) -> T,
RedOp: Fn(T, T) -> T,
Identity: Fn() -> T,
{
a: A,
b: B,
elem_op: ElemOp,
reduce_op: RedOp,
identity: Identity,
_phantom: PhantomData<T>,
}
impl<T, A, B, ElemOp, RedOp, Identity> FusedReduction<T, A, B, ElemOp, RedOp, Identity>
where
T: Clone,
A: Expr<T>,
B: Expr<T>,
ElemOp: Fn(T, T) -> T,
RedOp: Fn(T, T) -> T,
Identity: Fn() -> T,
{
pub fn new(a: A, b: B, elem_op: ElemOp, reduce_op: RedOp, identity: Identity) -> Result<Self> {
if a.shape() != b.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape().to_vec(),
actual: b.shape().to_vec(),
});
}
Ok(Self {
a,
b,
elem_op,
reduce_op,
identity,
_phantom: PhantomData,
})
}
pub fn reduce(&self) -> T {
let size = self.a.size();
if size == 0 {
return (self.identity)();
}
let first_a = self.a.eval_at(0);
let first_b = self.b.eval_at(0);
let mut acc = (self.elem_op)(first_a, first_b);
for i in 1..size {
let av = self.a.eval_at(i);
let bv = self.b.eval_at(i);
let elem_result = (self.elem_op)(av, bv);
acc = (self.reduce_op)(acc, elem_result);
}
acc
}
}
pub fn fused_dot_product<'a>(a: ArrayExpr<'a, f64>, b: ArrayExpr<'a, f64>) -> Result<f64> {
let fused = FusedReduction::new(a, b, |x, y| x * y, |acc, v| acc + v, || 0.0)?;
Ok(fused.reduce())
}
pub fn fused_sum_of_squares(a: &Array<f64>) -> f64 {
let size = a.size();
if size == 0 {
return 0.0;
}
let expr = ArrayExpr::new(a);
let fused = FusedReduction::new(
ArrayExpr::new(a),
expr,
|x, y| x * y,
|acc, v| acc + v,
|| 0.0,
);
match fused {
Ok(f) => f.reduce(),
Err(_) => 0.0,
}
}
pub fn fused_sum_abs_diff<'a>(a: ArrayExpr<'a, f64>, b: ArrayExpr<'a, f64>) -> Result<f64> {
let fused = FusedReduction::new(a, b, |x, y| (x - y).abs(), |acc, v| acc + v, || 0.0)?;
Ok(fused.reduce())
}
pub fn simd_fma_arrays(a: &Array<f64>, b: &Array<f64>, c: &Array<f64>) -> Result<Array<f64>> {
if a.shape() != b.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
});
}
if a.shape() != c.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: c.shape(),
});
}
let a_data = a.to_vec();
let b_data = b.to_vec();
let c_data = c.to_vec();
let a_nd = Array1::from_vec(a_data);
let b_nd = Array1::from_vec(b_data);
let c_nd = Array1::from_vec(c_data);
let result = f64::simd_fma(&a_nd.view(), &b_nd.view(), &c_nd.view());
Ok(Array::from_vec(result.to_vec()).reshape(&a.shape()))
}
pub fn simd_fused_scalar_broadcast(a: &Array<f64>, scalar_mul: f64, scalar_add: f64) -> Array<f64> {
let data = a.to_vec();
let size = data.len();
let mut result = Vec::with_capacity(size);
const CHUNK: usize = 8;
let full_chunks = size / CHUNK;
for chunk_idx in 0..full_chunks {
let base = chunk_idx * CHUNK;
for j in 0..CHUNK {
result.push(data[base + j].mul_add(scalar_mul, scalar_add));
}
}
for i in (full_chunks * CHUNK)..size {
result.push(data[i].mul_add(scalar_mul, scalar_add));
}
Array::from_vec(result).reshape(&a.shape())
}
pub fn simd_fused_dot_product(a: &Array<f64>, b: &Array<f64>) -> Result<f64> {
if a.shape() != b.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
});
}
let a_data = a.to_vec();
let b_data = b.to_vec();
let a_nd = Array1::from_vec(a_data);
let b_nd = Array1::from_vec(b_data);
Ok(f64::simd_dot(&a_nd.view(), &b_nd.view()))
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FusionPattern {
ElementWiseChain,
ScalarBroadcast,
FusedMultiplyAdd,
ReductionFusion,
None,
}
#[derive(Debug, Clone)]
pub struct FusionAnalysis {
pub pattern: FusionPattern,
pub estimated_speedup: f64,
pub allocations_eliminated: usize,
pub passes_eliminated: usize,
}
impl FusionAnalysis {
pub fn none() -> Self {
Self {
pattern: FusionPattern::None,
estimated_speedup: 1.0,
allocations_eliminated: 0,
passes_eliminated: 0,
}
}
pub fn element_wise_chain() -> Self {
Self {
pattern: FusionPattern::ElementWiseChain,
estimated_speedup: 1.5,
allocations_eliminated: 1,
passes_eliminated: 1,
}
}
pub fn scalar_broadcast() -> Self {
Self {
pattern: FusionPattern::ScalarBroadcast,
estimated_speedup: 1.8,
allocations_eliminated: 1,
passes_eliminated: 1,
}
}
pub fn fma() -> Self {
Self {
pattern: FusionPattern::FusedMultiplyAdd,
estimated_speedup: 2.0,
allocations_eliminated: 1,
passes_eliminated: 1,
}
}
pub fn reduction() -> Self {
Self {
pattern: FusionPattern::ReductionFusion,
estimated_speedup: 2.5,
allocations_eliminated: 1,
passes_eliminated: 1,
}
}
}
pub struct FusionDetector;
impl FusionDetector {
pub fn detect_element_wise_chain(size_outer: usize, size_inner: usize) -> FusionAnalysis {
if size_outer == size_inner && size_inner > 0 {
FusionAnalysis::element_wise_chain()
} else {
FusionAnalysis::none()
}
}
pub fn detect_scalar_broadcast(has_mul: bool, has_add: bool) -> FusionAnalysis {
if has_mul && has_add {
FusionAnalysis::scalar_broadcast()
} else {
FusionAnalysis::none()
}
}
pub fn detect_reduction_fusion(is_element_wise: bool, is_reduction: bool) -> FusionAnalysis {
if is_element_wise && is_reduction {
FusionAnalysis::reduction()
} else {
FusionAnalysis::none()
}
}
pub fn detect_fma(has_multiply: bool, has_add: bool, sizes_match: bool) -> FusionAnalysis {
if has_multiply && has_add && sizes_match {
FusionAnalysis::fma()
} else {
FusionAnalysis::none()
}
}
pub fn analyze_ternary(a_size: usize, b_size: usize, c_size: usize) -> FusionAnalysis {
if a_size == b_size && b_size == c_size && a_size > 0 {
FusionAnalysis::fma()
} else if a_size == b_size && a_size > 0 {
FusionAnalysis::element_wise_chain()
} else {
FusionAnalysis::none()
}
}
}
pub struct FusedQuadOp<T, A, B, C, D, F1, F2, F3>
where
T: Clone,
A: Expr<T>,
B: Expr<T>,
C: Expr<T>,
D: Expr<T>,
F1: Fn(T, T) -> T,
F2: Fn(T, T) -> T,
F3: Fn(T, T) -> T,
{
a: A,
b: B,
c: C,
d: D,
op1: F1,
op2: F2,
op3: F3,
shape: Vec<usize>,
_phantom: PhantomData<T>,
}
impl<T, A, B, C, D, F1, F2, F3> FusedQuadOp<T, A, B, C, D, F1, F2, F3>
where
T: Clone,
A: Expr<T>,
B: Expr<T>,
C: Expr<T>,
D: Expr<T>,
F1: Fn(T, T) -> T,
F2: Fn(T, T) -> T,
F3: Fn(T, T) -> T,
{
pub fn new(a: A, b: B, c: C, d: D, op1: F1, op2: F2, op3: F3) -> Result<Self> {
let shape = a.shape().to_vec();
if shape != b.shape() || shape != c.shape() || shape != d.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: shape,
actual: b.shape().to_vec(),
});
}
Ok(Self {
shape: a.shape().to_vec(),
a,
b,
c,
d,
op1,
op2,
op3,
_phantom: PhantomData,
})
}
}
impl<T, A, B, C, D, F1, F2, F3> Expr<T> for FusedQuadOp<T, A, B, C, D, F1, F2, F3>
where
T: Clone,
A: Expr<T>,
B: Expr<T>,
C: Expr<T>,
D: Expr<T>,
F1: Fn(T, T) -> T,
F2: Fn(T, T) -> T,
F3: Fn(T, T) -> T,
{
#[inline(always)]
fn eval_at(&self, index: usize) -> T {
let lhs = (self.op1)(self.a.eval_at(index), self.b.eval_at(index));
let rhs = (self.op3)(self.c.eval_at(index), self.d.eval_at(index));
(self.op2)(lhs, rhs)
}
#[inline]
fn size(&self) -> usize {
self.a.size()
}
#[inline]
fn shape(&self) -> &[usize] {
&self.shape
}
}
impl<T, A, B, C, D, F1, F2, F3> FusedOp<T> for FusedQuadOp<T, A, B, C, D, F1, F2, F3>
where
T: Clone,
A: Expr<T>,
B: Expr<T>,
C: Expr<T>,
D: Expr<T>,
F1: Fn(T, T) -> T,
F2: Fn(T, T) -> T,
F3: Fn(T, T) -> T,
{
fn eval_fused(&self) -> Array<T> {
self.eval()
}
fn fused_size(&self) -> usize {
self.a.size()
}
fn fused_shape(&self) -> &[usize] {
&self.shape
}
}
pub struct FusedUnaryChain<T, E, F1, F2>
where
T: Clone,
E: Expr<T>,
F1: Fn(T) -> T,
F2: Fn(T) -> T,
{
expr: E,
op1: F1,
op2: F2,
_phantom: PhantomData<T>,
}
impl<T, E, F1, F2> FusedUnaryChain<T, E, F1, F2>
where
T: Clone,
E: Expr<T>,
F1: Fn(T) -> T,
F2: Fn(T) -> T,
{
pub fn new(expr: E, op1: F1, op2: F2) -> Self {
Self {
expr,
op1,
op2,
_phantom: PhantomData,
}
}
}
impl<T, E, F1, F2> Expr<T> for FusedUnaryChain<T, E, F1, F2>
where
T: Clone,
E: Expr<T>,
F1: Fn(T) -> T,
F2: Fn(T) -> T,
{
#[inline(always)]
fn eval_at(&self, index: usize) -> T {
(self.op2)((self.op1)(self.expr.eval_at(index)))
}
#[inline]
fn size(&self) -> usize {
self.expr.size()
}
#[inline]
fn shape(&self) -> &[usize] {
self.expr.shape()
}
}
pub struct FusionBuilder<'a> {
data: Vec<f64>,
shape: Vec<usize>,
fusions_applied: usize,
_lifetime: PhantomData<&'a ()>,
}
impl<'a> FusionBuilder<'a> {
pub fn from_array(array: &'a Array<f64>) -> Self {
Self {
data: array.to_vec(),
shape: array.shape(),
fusions_applied: 0,
_lifetime: PhantomData,
}
}
pub fn add_expr(mut self, other: &Array<f64>) -> Self {
let other_data = other.to_vec();
let len = self.data.len().min(other_data.len());
for i in 0..len {
self.data[i] += other_data[i];
}
self.fusions_applied += 1;
self
}
pub fn mul_expr(mut self, other: &Array<f64>) -> Self {
let other_data = other.to_vec();
let len = self.data.len().min(other_data.len());
for i in 0..len {
self.data[i] *= other_data[i];
}
self.fusions_applied += 1;
self
}
pub fn sub_expr(mut self, other: &Array<f64>) -> Self {
let other_data = other.to_vec();
let len = self.data.len().min(other_data.len());
for i in 0..len {
self.data[i] -= other_data[i];
}
self.fusions_applied += 1;
self
}
pub fn add_scalar(mut self, scalar: f64) -> Self {
for v in &mut self.data {
*v += scalar;
}
self.fusions_applied += 1;
self
}
pub fn mul_scalar(mut self, scalar: f64) -> Self {
for v in &mut self.data {
*v *= scalar;
}
self.fusions_applied += 1;
self
}
pub fn fma_scalar(mut self, mul: f64, add: f64) -> Self {
for v in &mut self.data {
*v = v.mul_add(mul, add);
}
self.fusions_applied += 1;
self
}
pub fn map<F: Fn(f64) -> f64>(mut self, op: F) -> Self {
for v in &mut self.data {
*v = op(*v);
}
self.fusions_applied += 1;
self
}
pub fn sum(self) -> f64 {
self.data.iter().sum()
}
pub fn product(self) -> f64 {
self.data.iter().product()
}
pub fn eval_fused(self) -> Array<f64> {
Array::from_vec(self.data).reshape(&self.shape)
}
pub fn fusions_applied(&self) -> usize {
self.fusions_applied
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array::Array;
use crate::expr::core::{ArrayExpr, Expr, LazyEval};
use approx::assert_relative_eq;
#[test]
fn test_fused_element_wise_chain_add_mul() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0]);
let c = Array::from_vec(vec![2.0, 3.0, 4.0, 5.0]);
let fused = FusedElementWiseChain::new(
ArrayExpr::new(&a),
ArrayExpr::new(&b),
ArrayExpr::new(&c),
|x, y| x + y,
|r, z| r * z,
)
.expect("Fused chain creation should succeed");
let result = fused.eval_fused();
assert_eq!(result.to_vec(), vec![22.0, 66.0, 132.0, 220.0]);
}
#[test]
fn test_fused_element_wise_chain_sub_mul() {
let a = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0]);
let b = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let c = Array::from_vec(vec![3.0, 3.0, 3.0, 3.0]);
let fused = FusedElementWiseChain::new(
ArrayExpr::new(&a),
ArrayExpr::new(&b),
ArrayExpr::new(&c),
|x, y| x - y,
|r, z| r * z,
)
.expect("Fused chain creation should succeed");
let result = fused.eval_fused();
assert_eq!(result.to_vec(), vec![27.0, 54.0, 81.0, 108.0]);
}
#[test]
fn test_fused_scalar_broadcast_mul_add() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let fused = fused_scalar_broadcast_f64(ArrayExpr::new(&a), 2.0, 3.0);
let result = fused.eval_fused();
assert_eq!(result.to_vec(), vec![5.0, 7.0, 9.0, 11.0]);
}
#[test]
fn test_fused_scalar_broadcast_f32() {
let a = Array::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
let fused = fused_scalar_broadcast_f32(ArrayExpr::new(&a), 3.0, 1.0);
let result = fused.eval_fused();
assert_eq!(result.to_vec(), vec![4.0f32, 7.0, 10.0, 13.0]);
}
#[test]
fn test_fused_multiply_add_basic() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![2.0, 3.0, 4.0, 5.0]);
let c = Array::from_vec(vec![10.0, 10.0, 10.0, 10.0]);
let fma = FusedMultiplyAdd::new(ArrayExpr::new(&a), ArrayExpr::new(&b), ArrayExpr::new(&c))
.expect("FMA creation should succeed");
let result = fma.eval_fused();
assert_eq!(result.to_vec(), vec![12.0, 16.0, 22.0, 30.0]);
}
#[test]
fn test_fused_multiply_add_simd_path() {
let size = 64;
let a_data: Vec<f64> = (0..size).map(|i| i as f64).collect();
let b_data: Vec<f64> = (0..size).map(|i| (i + 1) as f64).collect();
let c_data: Vec<f64> = vec![100.0; size];
let a = Array::from_vec(a_data.clone());
let b = Array::from_vec(b_data.clone());
let c = Array::from_vec(c_data.clone());
let fma = FusedMultiplyAdd::new(ArrayExpr::new(&a), ArrayExpr::new(&b), ArrayExpr::new(&c))
.expect("FMA creation should succeed");
let result = fma.eval();
let result_data = result.to_vec();
for i in 0..size {
let expected = a_data[i].mul_add(b_data[i], c_data[i]);
assert_relative_eq!(result_data[i], expected, epsilon = 1e-10);
}
}
#[test]
fn test_fma_shape_mismatch() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0]);
let b = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let c = Array::from_vec(vec![1.0, 2.0, 3.0]);
let result =
FusedMultiplyAdd::new(ArrayExpr::new(&a), ArrayExpr::new(&b), ArrayExpr::new(&c));
assert!(result.is_err());
}
#[test]
fn test_fused_dot_product() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![2.0, 3.0, 4.0, 5.0]);
let dot = fused_dot_product(ArrayExpr::new(&a), ArrayExpr::new(&b))
.expect("Fused dot product should succeed");
assert_relative_eq!(dot, 40.0, epsilon = 1e-10);
}
#[test]
fn test_fused_sum_of_squares() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let sos = fused_sum_of_squares(&a);
assert_relative_eq!(sos, 30.0, epsilon = 1e-10);
}
#[test]
fn test_fused_sum_abs_diff() {
let a = Array::from_vec(vec![1.0, 5.0, 3.0, 10.0]);
let b = Array::from_vec(vec![2.0, 3.0, 7.0, 8.0]);
let sad = fused_sum_abs_diff(ArrayExpr::new(&a), ArrayExpr::new(&b))
.expect("Fused sum-abs-diff should succeed");
assert_relative_eq!(sad, 9.0, epsilon = 1e-10);
}
#[test]
fn test_shape_mismatch_chain() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0]);
let b = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let c = Array::from_vec(vec![1.0, 2.0, 3.0]);
let result = FusedElementWiseChain::new(
ArrayExpr::new(&a),
ArrayExpr::new(&b),
ArrayExpr::new(&c),
|x, y| x + y,
|r, z| r * z,
);
assert!(result.is_err());
}
#[test]
fn test_non_fusible_preserves_correctness() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0]);
let add_expr = BinaryExpr::new(ArrayExpr::new(&a), ArrayExpr::new(&b), |x: f64, y: f64| {
x + y
})
.expect("Binary expr should succeed");
let result = add_expr.eval();
assert_eq!(result.to_vec(), vec![11.0, 22.0, 33.0, 44.0]);
}
#[test]
fn test_simd_fma_arrays_alignment() {
let sizes = [1, 3, 7, 15, 17, 31, 33, 63, 65, 127, 129, 255, 257];
for &size in &sizes {
let a_data: Vec<f64> = (0..size).map(|i| i as f64).collect();
let b_data: Vec<f64> = (0..size).map(|i| (i + 1) as f64).collect();
let c_data: Vec<f64> = vec![1.0; size];
let a = Array::from_vec(a_data.clone());
let b = Array::from_vec(b_data.clone());
let c = Array::from_vec(c_data.clone());
let result =
simd_fma_arrays(&a, &b, &c).expect("SIMD FMA should succeed for all sizes");
let result_data = result.to_vec();
for i in 0..size {
let expected = a_data[i].mul_add(b_data[i], c_data[i]);
assert_relative_eq!(result_data[i], expected, epsilon = 1e-10,);
}
}
}
#[test]
fn test_simd_fused_scalar_broadcast_alignment() {
for size in [1, 5, 7, 8, 9, 15, 16, 17, 63, 64, 65] {
let data: Vec<f64> = (0..size).map(|i| i as f64 + 1.0).collect();
let a = Array::from_vec(data.clone());
let result = simd_fused_scalar_broadcast(&a, 2.0, 3.0);
let result_data = result.to_vec();
for i in 0..size {
let expected = data[i].mul_add(2.0, 3.0);
assert_relative_eq!(result_data[i], expected, epsilon = 1e-10);
}
}
}
#[test]
fn test_fusion_vs_unfused_correctness() {
let n = 1000;
let a_data: Vec<f64> = (0..n).map(|i| (i as f64) * 0.1).collect();
let b_data: Vec<f64> = (0..n).map(|i| (i as f64) * 0.2 + 1.0).collect();
let c_data: Vec<f64> = (0..n).map(|i| (i as f64) * 0.05).collect();
let a = Array::from_vec(a_data.clone());
let b = Array::from_vec(b_data.clone());
let c = Array::from_vec(c_data.clone());
let unfused: Vec<f64> = (0..n)
.map(|i| (a_data[i] + b_data[i]) * c_data[i])
.collect();
let fused = FusedElementWiseChain::new(
ArrayExpr::new(&a),
ArrayExpr::new(&b),
ArrayExpr::new(&c),
|x, y| x + y,
|r, z| r * z,
)
.expect("Fused creation should succeed");
let fused_result = fused.eval_fused().to_vec();
for i in 0..n {
assert_relative_eq!(fused_result[i], unfused[i], epsilon = 1e-10);
}
}
#[test]
fn test_fma_vs_manual_correctness() {
let n = 500;
let a_data: Vec<f64> = (0..n).map(|i| (i as f64) * 0.3).collect();
let b_data: Vec<f64> = (0..n).map(|i| (i as f64) * 0.7 + 0.5).collect();
let c_data: Vec<f64> = (0..n).map(|i| (i as f64) * 0.1 + 2.0).collect();
let a = Array::from_vec(a_data.clone());
let b = Array::from_vec(b_data.clone());
let c = Array::from_vec(c_data.clone());
let manual: Vec<f64> = (0..n).map(|i| a_data[i] * b_data[i] + c_data[i]).collect();
let fma = FusedMultiplyAdd::new(ArrayExpr::new(&a), ArrayExpr::new(&b), ArrayExpr::new(&c))
.expect("FMA creation should succeed");
let fma_result = fma.eval().to_vec();
for i in 0..n {
assert_relative_eq!(fma_result[i], manual[i], epsilon = 1e-8);
}
}
#[test]
fn test_mixed_fused_ops() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![5.0, 6.0, 7.0, 8.0]);
let c = Array::from_vec(vec![10.0, 10.0, 10.0, 10.0]);
let d = Array::from_vec(vec![2.0, 2.0, 2.0, 2.0]);
let quad = FusedQuadOp::new(
ArrayExpr::new(&a),
ArrayExpr::new(&b),
ArrayExpr::new(&c),
ArrayExpr::new(&d),
|x, y| x + y, |l, r| l * r, |x, y| x - y, )
.expect("Quad op creation should succeed");
let result = quad.eval_fused();
assert_eq!(result.to_vec(), vec![48.0, 64.0, 80.0, 96.0]);
}
#[test]
fn test_fused_unary_chain() {
let a = Array::from_vec(vec![1.0, 4.0, 9.0, 16.0]);
let chain = FusedUnaryChain::new(ArrayExpr::new(&a), |x: f64| x.sqrt(), |x: f64| x * 2.0);
let result = chain.eval();
assert_eq!(result.to_vec(), vec![2.0, 4.0, 6.0, 8.0]);
}
#[test]
fn test_fusion_detector_element_wise() {
let analysis = FusionDetector::detect_element_wise_chain(100, 100);
assert_eq!(analysis.pattern, FusionPattern::ElementWiseChain);
assert!(analysis.estimated_speedup > 1.0);
assert_eq!(analysis.allocations_eliminated, 1);
}
#[test]
fn test_fusion_detector_scalar_broadcast() {
let analysis = FusionDetector::detect_scalar_broadcast(true, true);
assert_eq!(analysis.pattern, FusionPattern::ScalarBroadcast);
let no_fusion = FusionDetector::detect_scalar_broadcast(true, false);
assert_eq!(no_fusion.pattern, FusionPattern::None);
}
#[test]
fn test_fusion_detector_reduction() {
let analysis = FusionDetector::detect_reduction_fusion(true, true);
assert_eq!(analysis.pattern, FusionPattern::ReductionFusion);
let no_fusion = FusionDetector::detect_reduction_fusion(false, true);
assert_eq!(no_fusion.pattern, FusionPattern::None);
}
#[test]
fn test_fusion_detector_fma() {
let analysis = FusionDetector::detect_fma(true, true, true);
assert_eq!(analysis.pattern, FusionPattern::FusedMultiplyAdd);
let no_fusion = FusionDetector::detect_fma(true, false, true);
assert_eq!(no_fusion.pattern, FusionPattern::None);
}
#[test]
fn test_fusion_detector_ternary_analysis() {
let fma = FusionDetector::analyze_ternary(100, 100, 100);
assert_eq!(fma.pattern, FusionPattern::FusedMultiplyAdd);
let chain = FusionDetector::analyze_ternary(100, 100, 50);
assert_eq!(chain.pattern, FusionPattern::ElementWiseChain);
let none = FusionDetector::analyze_ternary(100, 50, 50);
assert_eq!(none.pattern, FusionPattern::None);
}
#[test]
fn test_fusion_builder_add_mul() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![10.0, 20.0, 30.0, 40.0]);
let c = Array::from_vec(vec![2.0, 2.0, 2.0, 2.0]);
let result = FusionBuilder::from_array(&a)
.add_expr(&b)
.mul_expr(&c)
.eval_fused();
assert_eq!(result.to_vec(), vec![22.0, 44.0, 66.0, 88.0]);
}
#[test]
fn test_fusion_builder_scalar_fma() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let result = FusionBuilder::from_array(&a)
.fma_scalar(2.0, 3.0)
.eval_fused();
assert_eq!(result.to_vec(), vec![5.0, 7.0, 9.0, 11.0]);
}
#[test]
fn test_fusion_builder_complex_chain() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![2.0, 2.0, 2.0, 2.0]);
let result = FusionBuilder::from_array(&a)
.mul_scalar(3.0) .add_expr(&b) .add_scalar(1.0) .eval_fused();
assert_eq!(result.to_vec(), vec![6.0, 9.0, 12.0, 15.0]);
}
#[test]
fn test_fusion_builder_fusions_count() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0]);
let builder = FusionBuilder::from_array(&a)
.mul_scalar(2.0)
.add_scalar(1.0)
.map(|x| x * x);
assert_eq!(builder.fusions_applied(), 3);
}
#[test]
fn test_fusion_builder_sum_reduction() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![2.0, 3.0, 4.0, 5.0]);
let dot = FusionBuilder::from_array(&a).mul_expr(&b).sum();
assert_relative_eq!(dot, 40.0, epsilon = 1e-10);
}
#[test]
fn test_simd_fma_arrays_basic() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![2.0, 3.0, 4.0, 5.0]);
let c = Array::from_vec(vec![10.0, 10.0, 10.0, 10.0]);
let result = simd_fma_arrays(&a, &b, &c).expect("SIMD FMA should succeed");
assert_eq!(result.to_vec(), vec![12.0, 16.0, 22.0, 30.0]);
}
#[test]
fn test_simd_fused_dot_product() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![2.0, 3.0, 4.0, 5.0]);
let dot = simd_fused_dot_product(&a, &b).expect("SIMD dot product should succeed");
assert_relative_eq!(dot, 40.0, epsilon = 1e-10);
}
#[test]
fn test_simd_fused_dot_product_shape_mismatch() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0]);
let b = Array::from_vec(vec![1.0, 2.0]);
let result = simd_fused_dot_product(&a, &b);
assert!(result.is_err());
}
#[test]
fn test_fused_reduction_empty() {
let a = Array::<f64>::from_vec(vec![]);
let b = Array::<f64>::from_vec(vec![]);
let fused = FusedReduction::new(
ArrayExpr::new(&a),
ArrayExpr::new(&b),
|x, y| x * y,
|acc, v| acc + v,
|| 0.0,
)
.expect("Fused reduction on empty arrays should succeed");
assert_relative_eq!(fused.reduce(), 0.0, epsilon = 1e-10);
}
#[test]
fn test_fused_scalar_broadcast_single_element() {
let a = Array::from_vec(vec![5.0]);
let fused = fused_scalar_broadcast_f64(ArrayExpr::new(&a), 3.0, 7.0);
let result = fused.eval_fused();
assert_eq!(result.to_vec(), vec![22.0]);
}
#[test]
fn test_fused_sum_of_squares_empty() {
let a = Array::<f64>::from_vec(vec![]);
assert_relative_eq!(fused_sum_of_squares(&a), 0.0, epsilon = 1e-10);
}
#[test]
fn test_fusion_analysis_properties() {
let none = FusionAnalysis::none();
assert_eq!(none.pattern, FusionPattern::None);
assert_relative_eq!(none.estimated_speedup, 1.0, epsilon = 1e-10);
assert_eq!(none.allocations_eliminated, 0);
let fma = FusionAnalysis::fma();
assert_eq!(fma.pattern, FusionPattern::FusedMultiplyAdd);
assert!(fma.estimated_speedup > 1.0);
assert!(fma.allocations_eliminated > 0);
}
}