use crate::tensor::Tensor;
use crate::buffer::Buffer;
use crate::shape::Shape;
use crate::dtypes::DType;
use crate::device::Device;
use crate::errors::{EtensorError, EtensorResult};
pub fn add_relu_forward(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
if a.shape.dims != b.shape.dims {
return Err(EtensorError::ShapeMismatch {
expected: a.shape.dims.clone(),
got: b.shape.dims.clone(),
});
}
let slice_a = a.data.as_f32_slice()?;
let slice_b = b.data.as_f32_slice()?;
let out_vec: Vec<f32> = slice_a.iter().zip(slice_b).map(|(x, y)| (x + y).max(0.0)).collect();
Ok(Tensor::new(
Buffer::from_f32_vec(out_vec),
a.shape.clone(),
Device::Cpu,
a.dtype,
false, ))
}
pub fn linear_forward(x: &Tensor, w: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
if x.shape.rank() != 2 || w.shape.rank() != 2 || b.shape.rank() != 1 {
return Err(EtensorError::InternalError(
"Fused Linear requires 2D Input, 2D Weight, and 1D Bias.".to_string(),
));
}
let m = x.shape.dims[0];
let k_x = x.shape.dims[1];
let k_w = w.shape.dims[0];
let n = w.shape.dims[1];
if k_x != k_w {
return Err(EtensorError::ShapeMismatch {
expected: vec![m, k_x],
got: vec![k_w, n],
});
}
if b.shape.dims[0] != n {
return Err(EtensorError::ShapeMismatch {
expected: vec![n],
got: b.shape.dims.clone(),
});
}
let slice_x = x.data.as_f32_slice()?;
let slice_w = w.data.as_f32_slice()?;
let slice_b = b.data.as_f32_slice()?;
let mut out_vec = Vec::with_capacity(m * n);
for _ in 0..m {
out_vec.extend_from_slice(slice_b);
}
let stride_x0 = x.shape.strides[0] as isize;
let stride_x1 = x.shape.strides[1] as isize;
let stride_w0 = w.shape.strides[0] as isize;
let stride_w1 = w.shape.strides[1] as isize;
unsafe {
matrixmultiply::sgemm(
m, k_x, n,
1.0, slice_x.as_ptr(),
stride_x0, stride_x1, slice_w.as_ptr(),
stride_w0, stride_w1, 1.0, out_vec.as_mut_ptr(),
n as isize, 1, );
}
Ok(Tensor::new(
Buffer::from_f32_vec(out_vec),
Shape::new(vec![m, n]),
Device::Cpu,
DType::F32,
false,
))
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_tensor(data: Vec<f32>, dims: Vec<usize>) -> Tensor {
Tensor::new(
Buffer::from_f32_vec(data),
Shape::new(dims),
Device::Cpu,
DType::F32,
false,
)
}
#[test]
fn test_cpu_fusion_add_relu() {
let a = make_test_tensor(vec![-2.0, 1.0, 3.0], vec![3]);
let b = make_test_tensor(vec![1.0, -5.0, 2.0], vec![3]);
let c = add_relu_forward(&a, &b).unwrap();
let slice = c.data.as_f32_slice().unwrap();
assert_eq!(slice, &[0.0, 0.0, 5.0]);
}
#[test]
fn test_cpu_fusion_linear() {
let x = make_test_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let w = make_test_tensor(vec![2.0, 0.0, 0.0, 2.0], vec![2, 2]);
let b = make_test_tensor(vec![10.0, 20.0], vec![2]);
let y = linear_forward(&x, &w, &b).unwrap();
let slice = y.data.as_f32_slice().unwrap();
assert_eq!(y.shape.dims, vec![2, 2]);
assert_eq!(slice, &[12.0, 24.0, 16.0, 28.0]);
}
}