etensor_core/backends/cpu/
fusion.rs1use crate::tensor::Tensor;
8use crate::buffer::Buffer;
9use crate::shape::Shape;
10use crate::dtypes::DType;
11use crate::device::Device;
12use crate::errors::{EtensorError, EtensorResult};
13
14pub fn add_relu_forward(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
19 if a.shape.dims != b.shape.dims {
20 return Err(EtensorError::ShapeMismatch {
21 expected: a.shape.dims.clone(),
22 got: b.shape.dims.clone(),
23 });
24 }
25
26 let slice_a = a.data.as_f32_slice()?;
27 let slice_b = b.data.as_f32_slice()?;
28
29 let out_vec: Vec<f32> = slice_a.iter().zip(slice_b).map(|(x, y)| (x + y).max(0.0)).collect();
30
31 Ok(Tensor::new(
32 Buffer::from_f32_vec(out_vec),
33 a.shape.clone(),
34 Device::Cpu,
35 a.dtype,
36 false, ))
38}
39
40pub fn linear_forward(x: &Tensor, w: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
43 if x.shape.rank() != 2 || w.shape.rank() != 2 || b.shape.rank() != 1 {
44 return Err(EtensorError::InternalError(
45 "Fused Linear requires 2D Input, 2D Weight, and 1D Bias.".to_string(),
46 ));
47 }
48
49 let m = x.shape.dims[0];
50 let k_x = x.shape.dims[1];
51 let k_w = w.shape.dims[0];
52 let n = w.shape.dims[1];
53
54 if k_x != k_w {
55 return Err(EtensorError::ShapeMismatch {
56 expected: vec![m, k_x],
57 got: vec![k_w, n],
58 });
59 }
60
61 if b.shape.dims[0] != n {
63 return Err(EtensorError::ShapeMismatch {
64 expected: vec![n],
65 got: b.shape.dims.clone(),
66 });
67 }
68
69 let slice_x = x.data.as_f32_slice()?;
70 let slice_w = w.data.as_f32_slice()?;
71 let slice_b = b.data.as_f32_slice()?;
72
73 let mut out_vec = Vec::with_capacity(m * n);
76 for _ in 0..m {
77 out_vec.extend_from_slice(slice_b);
78 }
79
80 let stride_x0 = x.shape.strides[0] as isize;
81 let stride_x1 = x.shape.strides[1] as isize;
82 let stride_w0 = w.shape.strides[0] as isize;
83 let stride_w1 = w.shape.strides[1] as isize;
84
85 unsafe {
87 matrixmultiply::sgemm(
88 m, k_x, n,
89 1.0, slice_x.as_ptr(),
91 stride_x0, stride_x1, slice_w.as_ptr(),
93 stride_w0, stride_w1, 1.0, out_vec.as_mut_ptr(),
96 n as isize, 1, );
98 }
99
100 Ok(Tensor::new(
101 Buffer::from_f32_vec(out_vec),
102 Shape::new(vec![m, n]),
103 Device::Cpu,
104 DType::F32,
105 false,
106 ))
107}
108
109#[cfg(test)]
113mod tests {
114 use super::*;
115
116 fn make_test_tensor(data: Vec<f32>, dims: Vec<usize>) -> Tensor {
117 Tensor::new(
118 Buffer::from_f32_vec(data),
119 Shape::new(dims),
120 Device::Cpu,
121 DType::F32,
122 false,
123 )
124 }
125
126 #[test]
127 fn test_cpu_fusion_add_relu() {
128 let a = make_test_tensor(vec![-2.0, 1.0, 3.0], vec![3]);
129 let b = make_test_tensor(vec![1.0, -5.0, 2.0], vec![3]);
130
131 let c = add_relu_forward(&a, &b).unwrap();
134 let slice = c.data.as_f32_slice().unwrap();
135
136 assert_eq!(slice, &[0.0, 0.0, 5.0]);
137 }
138
139 #[test]
140 fn test_cpu_fusion_linear() {
141 let x = make_test_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
145
146 let w = make_test_tensor(vec![2.0, 0.0, 0.0, 2.0], vec![2, 2]);
150
151 let b = make_test_tensor(vec![10.0, 20.0], vec![2]);
154
155 let y = linear_forward(&x, &w, &b).unwrap();
164 let slice = y.data.as_f32_slice().unwrap();
165
166 assert_eq!(y.shape.dims, vec![2, 2]);
167 assert_eq!(slice, &[12.0, 24.0, 16.0, 28.0]);
168 }
169}