Skip to main content

etensor_core/
dispatch.rs

1//! The Execution Router and Native Operator Overloading.
2//! 
3//! The Dispatcher intercepts mathematical operations, verifies layout compatibility, 
4//! executes the forward pass kernels, and records the backward pass history onto the Tape.
5
6use std::ops::{Add, Mul};
7use crate::tensor::Tensor;
8use crate::errors::{EtensorError, EtensorResult};
9use crate::autograd::tape::record;
10use crate::autograd::nodes::{
11    AddBackward, MulBackward, MatMulBackward, 
12    SumAllBackward, ReluBackward, SigmoidBackward
13};
14use crate::backends::traits::Backend;
15use crate::backends::cpu::CpuBackend;
16    
17/// The central routing engine for all mathematical operations.
18pub struct Dispatcher;
19
20impl Dispatcher {
21    /// Dispatches an element-wise addition operation: out = a + b
22    pub fn add(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
23        if a.shape.dims != b.shape.dims {
24            return Err(EtensorError::ShapeMismatch {
25                expected: a.shape.dims.clone(),
26                got: b.shape.dims.clone(),
27            });
28        }
29        if a.device != b.device {
30            return Err(EtensorError::DeviceMismatch {
31                expected: a.device.to_string(), got: b.device.to_string(),
32            });
33        }
34        if a.dtype != b.dtype {
35            return Err(EtensorError::DTypeMismatch {
36                expected: a.dtype.to_string(), got: b.dtype.to_string(),
37            });
38        }
39
40        let mut out_tensor = match a.device {
41            crate::device::Device::Cpu => CpuBackend::add(a, b)?,
42            #[cfg(feature = "cuda-native")]
43            crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
44            #[cfg(feature = "torch")]
45            crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
46        };
47
48        let requires_grad = a.requires_grad || b.requires_grad;
49        out_tensor.requires_grad = requires_grad;
50
51        if requires_grad {
52            record(Box::new(AddBackward {
53                output_id: out_tensor.id,
54                lhs_id: if a.requires_grad { Some(a.id) } else { None },
55                rhs_id: if b.requires_grad { Some(b.id) } else { None },
56            }));
57        }
58
59        Ok(out_tensor)
60    }
61
62    /// Dispatches an element-wise multiplication operation: out = a * b
63    pub fn mul(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
64        if a.shape.dims != b.shape.dims {
65            return Err(EtensorError::ShapeMismatch {
66                expected: a.shape.dims.clone(), got: b.shape.dims.clone(),
67            });
68        }
69        if a.device != b.device {
70            return Err(EtensorError::DeviceMismatch {
71                expected: a.device.to_string(), got: b.device.to_string(),
72            });
73        }
74        if a.dtype != b.dtype {
75            return Err(EtensorError::DTypeMismatch {
76                expected: a.dtype.to_string(), got: b.dtype.to_string(),
77            });
78        }
79
80        let mut out_tensor = match a.device {
81            crate::device::Device::Cpu => CpuBackend::mul(a, b)?,
82            #[cfg(feature = "cuda-native")]
83            crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
84            #[cfg(feature = "torch")]
85            crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
86        };
87
88        let requires_grad = a.requires_grad || b.requires_grad;
89        out_tensor.requires_grad = requires_grad;
90
91        if requires_grad {
92            record(Box::new(MulBackward {
93                output_id: out_tensor.id,
94                lhs_id: if a.requires_grad { Some(a.id) } else { None },
95                rhs_id: if b.requires_grad { Some(b.id) } else { None },
96                lhs_data: a.data.clone(), 
97                rhs_data: b.data.clone(),
98            }));
99        }
100
101        Ok(out_tensor)
102    }
103
104    /// Dispatches a matrix multiplication: out = a @ b
105    pub fn matmul(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
106        if a.device != b.device {
107            return Err(EtensorError::DeviceMismatch {
108                expected: a.device.to_string(), got: b.device.to_string(),
109            });
110        }
111        if a.dtype != b.dtype {
112            return Err(EtensorError::DTypeMismatch {
113                expected: a.dtype.to_string(), got: b.dtype.to_string(),
114            });
115        }
116
117        let mut out_tensor = match a.device {
118            crate::device::Device::Cpu => CpuBackend::matmul(a, b)?,
119            #[cfg(feature = "cuda-native")]
120            crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
121            #[cfg(feature = "torch")]
122            crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
123        };
124
125        let requires_grad = a.requires_grad || b.requires_grad;
126        out_tensor.requires_grad = requires_grad;
127
128        if requires_grad {
129            record(Box::new(MatMulBackward {
130                output_id: out_tensor.id,
131                lhs_id: if a.requires_grad { Some(a.id) } else { None },
132                rhs_id: if b.requires_grad { Some(b.id) } else { None },
133                lhs_data: a.data.clone(),
134                rhs_data: b.data.clone(),
135                lhs_shape: a.shape.clone(),
136                rhs_shape: b.shape.clone(),
137            }));
138        }
139
140        Ok(out_tensor)
141    }
142
143    /// Dispatches a global sum reduction: out = sum(a)
144    pub fn sum_all(a: &Tensor) -> EtensorResult<Tensor> {
145        let mut out_tensor = match a.device {
146            crate::device::Device::Cpu => CpuBackend::sum_all(a)?,
147            #[cfg(feature = "cuda-native")]
148            crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
149            #[cfg(feature = "torch")]
150            crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
151        };
152
153        out_tensor.requires_grad = a.requires_grad;
154
155        if a.requires_grad {
156            record(Box::new(SumAllBackward {
157                output_id: out_tensor.id,
158                input_id: a.id,
159                input_shape: a.shape.clone(),
160            }));
161        }
162
163        Ok(out_tensor)
164    }
165
166    /// Dispatches a ReLU activation: out = max(0, a)
167    pub fn relu(a: &Tensor) -> EtensorResult<Tensor> {
168        let mut out_tensor = match a.device {
169            crate::device::Device::Cpu => CpuBackend::relu(a)?,
170            #[cfg(feature = "cuda-native")]
171            crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
172            #[cfg(feature = "torch")]
173            crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
174        };
175
176        out_tensor.requires_grad = a.requires_grad;
177
178        if a.requires_grad {
179            record(Box::new(ReluBackward {
180                output_id: out_tensor.id,
181                input_id: a.id,
182                input_data: a.data.clone(), // ReLU needs the input x to know if it was > 0
183            }));
184        }
185
186        Ok(out_tensor)
187    }
188
189    /// Dispatches a Sigmoid activation: out = 1 / (1 + exp(-a))
190    pub fn sigmoid(a: &Tensor) -> EtensorResult<Tensor> {
191        let mut out_tensor = match a.device {
192            crate::device::Device::Cpu => CpuBackend::sigmoid(a)?,
193            #[cfg(feature = "cuda-native")]
194            crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
195            #[cfg(feature = "torch")]
196            crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
197        };
198
199        out_tensor.requires_grad = a.requires_grad;
200
201        if a.requires_grad {
202            record(Box::new(SigmoidBackward {
203                output_id: out_tensor.id,
204                input_id: a.id,
205                output_data: out_tensor.data.clone(), // Sigmoid cleverly uses its output y for the derivative: y * (1 - y)
206            }));
207        }
208
209        Ok(out_tensor)
210    }
211}
212
213// =====================================================================
214// NATIVE RUST OPERATOR OVERLOADING
215// =====================================================================
216
217impl Add for &Tensor {
218    type Output = Tensor;
219    fn add(self, rhs: Self) -> Self::Output {
220        Dispatcher::add(self, rhs).expect("Tensor addition failed!")
221    }
222}
223
224impl Mul for &Tensor {
225    type Output = Tensor;
226    fn mul(self, rhs: Self) -> Self::Output {
227        Dispatcher::mul(self, rhs).expect("Tensor multiplication failed!")
228    }
229}
230
231// =====================================================================
232// UNIT TESTS
233// =====================================================================
234#[cfg(test)]
235mod tests {
236    // Tests remain identical to your previous version. We rely on the backend 
237    // and autograd specific tests to rigorously validate the new math logic.
238    use super::*;
239    use crate::device::Device;
240    use crate::dtypes::DType;
241    use crate::shape::Shape;
242    use crate::buffer::Buffer;
243
244    fn make_tensor(data: Vec<f32>, requires_grad: bool) -> Tensor {
245        let len = data.len();
246        Tensor::new(Buffer::from_f32_vec(data), Shape::new(vec![len]), Device::Cpu, DType::F32, requires_grad)
247    }
248
249    #[test]
250    fn test_dispatcher_add_forward_logic() {
251        let a = make_tensor(vec![1.0, 2.0, 3.0], false);
252        let b = make_tensor(vec![4.0, 5.0, 6.0], false);
253        let c = Dispatcher::add(&a, &b).unwrap();
254        assert_eq!(c.data.as_f32_slice().unwrap(), &[5.0, 7.0, 9.0]);
255        assert!(!c.requires_grad); 
256    }
257
258    #[test]
259    fn test_operator_overloading() {
260        let a = make_tensor(vec![2.0, 4.0], false);
261        let b = make_tensor(vec![3.0, 5.0], false);
262        let c = &a + &b; 
263        let d = &a * &b;
264        assert_eq!(c.data.as_f32_slice().unwrap(), &[5.0, 9.0]);
265        assert_eq!(d.data.as_f32_slice().unwrap(), &[6.0, 20.0]);
266    }
267}