1use std::ops::{Add, Mul};
7use crate::tensor::Tensor;
8use crate::errors::{EtensorError, EtensorResult};
9use crate::autograd::tape::record;
10use crate::autograd::nodes::{
11 AddBackward, MulBackward, MatMulBackward,
12 SumAllBackward, ReluBackward, SigmoidBackward
13};
14use crate::backends::traits::Backend;
15use crate::backends::cpu::CpuBackend;
16
17pub struct Dispatcher;
19
20impl Dispatcher {
21 pub fn add(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
23 if a.shape.dims != b.shape.dims {
24 return Err(EtensorError::ShapeMismatch {
25 expected: a.shape.dims.clone(),
26 got: b.shape.dims.clone(),
27 });
28 }
29 if a.device != b.device {
30 return Err(EtensorError::DeviceMismatch {
31 expected: a.device.to_string(), got: b.device.to_string(),
32 });
33 }
34 if a.dtype != b.dtype {
35 return Err(EtensorError::DTypeMismatch {
36 expected: a.dtype.to_string(), got: b.dtype.to_string(),
37 });
38 }
39
40 let mut out_tensor = match a.device {
41 crate::device::Device::Cpu => CpuBackend::add(a, b)?,
42 #[cfg(feature = "cuda-native")]
43 crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
44 #[cfg(feature = "torch")]
45 crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
46 };
47
48 let requires_grad = a.requires_grad || b.requires_grad;
49 out_tensor.requires_grad = requires_grad;
50
51 if requires_grad {
52 record(Box::new(AddBackward {
53 output_id: out_tensor.id,
54 lhs_id: if a.requires_grad { Some(a.id) } else { None },
55 rhs_id: if b.requires_grad { Some(b.id) } else { None },
56 }));
57 }
58
59 Ok(out_tensor)
60 }
61
62 pub fn mul(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
64 if a.shape.dims != b.shape.dims {
65 return Err(EtensorError::ShapeMismatch {
66 expected: a.shape.dims.clone(), got: b.shape.dims.clone(),
67 });
68 }
69 if a.device != b.device {
70 return Err(EtensorError::DeviceMismatch {
71 expected: a.device.to_string(), got: b.device.to_string(),
72 });
73 }
74 if a.dtype != b.dtype {
75 return Err(EtensorError::DTypeMismatch {
76 expected: a.dtype.to_string(), got: b.dtype.to_string(),
77 });
78 }
79
80 let mut out_tensor = match a.device {
81 crate::device::Device::Cpu => CpuBackend::mul(a, b)?,
82 #[cfg(feature = "cuda-native")]
83 crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
84 #[cfg(feature = "torch")]
85 crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
86 };
87
88 let requires_grad = a.requires_grad || b.requires_grad;
89 out_tensor.requires_grad = requires_grad;
90
91 if requires_grad {
92 record(Box::new(MulBackward {
93 output_id: out_tensor.id,
94 lhs_id: if a.requires_grad { Some(a.id) } else { None },
95 rhs_id: if b.requires_grad { Some(b.id) } else { None },
96 lhs_data: a.data.clone(),
97 rhs_data: b.data.clone(),
98 }));
99 }
100
101 Ok(out_tensor)
102 }
103
104 pub fn matmul(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
106 if a.device != b.device {
107 return Err(EtensorError::DeviceMismatch {
108 expected: a.device.to_string(), got: b.device.to_string(),
109 });
110 }
111 if a.dtype != b.dtype {
112 return Err(EtensorError::DTypeMismatch {
113 expected: a.dtype.to_string(), got: b.dtype.to_string(),
114 });
115 }
116
117 let mut out_tensor = match a.device {
118 crate::device::Device::Cpu => CpuBackend::matmul(a, b)?,
119 #[cfg(feature = "cuda-native")]
120 crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
121 #[cfg(feature = "torch")]
122 crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
123 };
124
125 let requires_grad = a.requires_grad || b.requires_grad;
126 out_tensor.requires_grad = requires_grad;
127
128 if requires_grad {
129 record(Box::new(MatMulBackward {
130 output_id: out_tensor.id,
131 lhs_id: if a.requires_grad { Some(a.id) } else { None },
132 rhs_id: if b.requires_grad { Some(b.id) } else { None },
133 lhs_data: a.data.clone(),
134 rhs_data: b.data.clone(),
135 lhs_shape: a.shape.clone(),
136 rhs_shape: b.shape.clone(),
137 }));
138 }
139
140 Ok(out_tensor)
141 }
142
143 pub fn sum_all(a: &Tensor) -> EtensorResult<Tensor> {
145 let mut out_tensor = match a.device {
146 crate::device::Device::Cpu => CpuBackend::sum_all(a)?,
147 #[cfg(feature = "cuda-native")]
148 crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
149 #[cfg(feature = "torch")]
150 crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
151 };
152
153 out_tensor.requires_grad = a.requires_grad;
154
155 if a.requires_grad {
156 record(Box::new(SumAllBackward {
157 output_id: out_tensor.id,
158 input_id: a.id,
159 input_shape: a.shape.clone(),
160 }));
161 }
162
163 Ok(out_tensor)
164 }
165
166 pub fn relu(a: &Tensor) -> EtensorResult<Tensor> {
168 let mut out_tensor = match a.device {
169 crate::device::Device::Cpu => CpuBackend::relu(a)?,
170 #[cfg(feature = "cuda-native")]
171 crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
172 #[cfg(feature = "torch")]
173 crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
174 };
175
176 out_tensor.requires_grad = a.requires_grad;
177
178 if a.requires_grad {
179 record(Box::new(ReluBackward {
180 output_id: out_tensor.id,
181 input_id: a.id,
182 input_data: a.data.clone(), }));
184 }
185
186 Ok(out_tensor)
187 }
188
189 pub fn sigmoid(a: &Tensor) -> EtensorResult<Tensor> {
191 let mut out_tensor = match a.device {
192 crate::device::Device::Cpu => CpuBackend::sigmoid(a)?,
193 #[cfg(feature = "cuda-native")]
194 crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
195 #[cfg(feature = "torch")]
196 crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
197 };
198
199 out_tensor.requires_grad = a.requires_grad;
200
201 if a.requires_grad {
202 record(Box::new(SigmoidBackward {
203 output_id: out_tensor.id,
204 input_id: a.id,
205 output_data: out_tensor.data.clone(), }));
207 }
208
209 Ok(out_tensor)
210 }
211}
212
213impl Add for &Tensor {
218 type Output = Tensor;
219 fn add(self, rhs: Self) -> Self::Output {
220 Dispatcher::add(self, rhs).expect("Tensor addition failed!")
221 }
222}
223
224impl Mul for &Tensor {
225 type Output = Tensor;
226 fn mul(self, rhs: Self) -> Self::Output {
227 Dispatcher::mul(self, rhs).expect("Tensor multiplication failed!")
228 }
229}
230
231#[cfg(test)]
235mod tests {
236 use super::*;
239 use crate::device::Device;
240 use crate::dtypes::DType;
241 use crate::shape::Shape;
242 use crate::buffer::Buffer;
243
244 fn make_tensor(data: Vec<f32>, requires_grad: bool) -> Tensor {
245 let len = data.len();
246 Tensor::new(Buffer::from_f32_vec(data), Shape::new(vec![len]), Device::Cpu, DType::F32, requires_grad)
247 }
248
249 #[test]
250 fn test_dispatcher_add_forward_logic() {
251 let a = make_tensor(vec![1.0, 2.0, 3.0], false);
252 let b = make_tensor(vec![4.0, 5.0, 6.0], false);
253 let c = Dispatcher::add(&a, &b).unwrap();
254 assert_eq!(c.data.as_f32_slice().unwrap(), &[5.0, 7.0, 9.0]);
255 assert!(!c.requires_grad);
256 }
257
258 #[test]
259 fn test_operator_overloading() {
260 let a = make_tensor(vec![2.0, 4.0], false);
261 let b = make_tensor(vec![3.0, 5.0], false);
262 let c = &a + &b;
263 let d = &a * &b;
264 assert_eq!(c.data.as_f32_slice().unwrap(), &[5.0, 9.0]);
265 assert_eq!(d.data.as_f32_slice().unwrap(), &[6.0, 20.0]);
266 }
267}