etensor-core 0.0.1

The pure Rust tensor math and autograd engine
Documentation
//! The Execution Router and Native Operator Overloading.
//! 
//! The Dispatcher intercepts mathematical operations, verifies layout compatibility, 
//! executes the forward pass kernels, and records the backward pass history onto the Tape.

use std::ops::{Add, Mul};
use crate::tensor::Tensor;
use crate::errors::{EtensorError, EtensorResult};
use crate::autograd::tape::record;
use crate::autograd::nodes::{
    AddBackward, MulBackward, MatMulBackward, 
    SumAllBackward, ReluBackward, SigmoidBackward
};
use crate::backends::traits::Backend;
use crate::backends::cpu::CpuBackend;
    
/// The central routing engine for all mathematical operations.
pub struct Dispatcher;

impl Dispatcher {
    /// Dispatches an element-wise addition operation: out = a + b
    pub fn add(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
        if a.shape.dims != b.shape.dims {
            return Err(EtensorError::ShapeMismatch {
                expected: a.shape.dims.clone(),
                got: b.shape.dims.clone(),
            });
        }
        if a.device != b.device {
            return Err(EtensorError::DeviceMismatch {
                expected: a.device.to_string(), got: b.device.to_string(),
            });
        }
        if a.dtype != b.dtype {
            return Err(EtensorError::DTypeMismatch {
                expected: a.dtype.to_string(), got: b.dtype.to_string(),
            });
        }

        let mut out_tensor = match a.device {
            crate::device::Device::Cpu => CpuBackend::add(a, b)?,
            #[cfg(feature = "cuda-native")]
            crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
            #[cfg(feature = "torch")]
            crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
        };

        let requires_grad = a.requires_grad || b.requires_grad;
        out_tensor.requires_grad = requires_grad;

        if requires_grad {
            record(Box::new(AddBackward {
                output_id: out_tensor.id,
                lhs_id: if a.requires_grad { Some(a.id) } else { None },
                rhs_id: if b.requires_grad { Some(b.id) } else { None },
            }));
        }

        Ok(out_tensor)
    }

    /// Dispatches an element-wise multiplication operation: out = a * b
    pub fn mul(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
        if a.shape.dims != b.shape.dims {
            return Err(EtensorError::ShapeMismatch {
                expected: a.shape.dims.clone(), got: b.shape.dims.clone(),
            });
        }
        if a.device != b.device {
            return Err(EtensorError::DeviceMismatch {
                expected: a.device.to_string(), got: b.device.to_string(),
            });
        }
        if a.dtype != b.dtype {
            return Err(EtensorError::DTypeMismatch {
                expected: a.dtype.to_string(), got: b.dtype.to_string(),
            });
        }

        let mut out_tensor = match a.device {
            crate::device::Device::Cpu => CpuBackend::mul(a, b)?,
            #[cfg(feature = "cuda-native")]
            crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
            #[cfg(feature = "torch")]
            crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
        };

        let requires_grad = a.requires_grad || b.requires_grad;
        out_tensor.requires_grad = requires_grad;

        if requires_grad {
            record(Box::new(MulBackward {
                output_id: out_tensor.id,
                lhs_id: if a.requires_grad { Some(a.id) } else { None },
                rhs_id: if b.requires_grad { Some(b.id) } else { None },
                lhs_data: a.data.clone(), 
                rhs_data: b.data.clone(),
            }));
        }

        Ok(out_tensor)
    }

    /// Dispatches a matrix multiplication: out = a @ b
    pub fn matmul(a: &Tensor, b: &Tensor) -> EtensorResult<Tensor> {
        if a.device != b.device {
            return Err(EtensorError::DeviceMismatch {
                expected: a.device.to_string(), got: b.device.to_string(),
            });
        }
        if a.dtype != b.dtype {
            return Err(EtensorError::DTypeMismatch {
                expected: a.dtype.to_string(), got: b.dtype.to_string(),
            });
        }

        let mut out_tensor = match a.device {
            crate::device::Device::Cpu => CpuBackend::matmul(a, b)?,
            #[cfg(feature = "cuda-native")]
            crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
            #[cfg(feature = "torch")]
            crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
        };

        let requires_grad = a.requires_grad || b.requires_grad;
        out_tensor.requires_grad = requires_grad;

        if requires_grad {
            record(Box::new(MatMulBackward {
                output_id: out_tensor.id,
                lhs_id: if a.requires_grad { Some(a.id) } else { None },
                rhs_id: if b.requires_grad { Some(b.id) } else { None },
                lhs_data: a.data.clone(),
                rhs_data: b.data.clone(),
                lhs_shape: a.shape.clone(),
                rhs_shape: b.shape.clone(),
            }));
        }

        Ok(out_tensor)
    }

    /// Dispatches a global sum reduction: out = sum(a)
    pub fn sum_all(a: &Tensor) -> EtensorResult<Tensor> {
        let mut out_tensor = match a.device {
            crate::device::Device::Cpu => CpuBackend::sum_all(a)?,
            #[cfg(feature = "cuda-native")]
            crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
            #[cfg(feature = "torch")]
            crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
        };

        out_tensor.requires_grad = a.requires_grad;

        if a.requires_grad {
            record(Box::new(SumAllBackward {
                output_id: out_tensor.id,
                input_id: a.id,
                input_shape: a.shape.clone(),
            }));
        }

        Ok(out_tensor)
    }

    /// Dispatches a ReLU activation: out = max(0, a)
    pub fn relu(a: &Tensor) -> EtensorResult<Tensor> {
        let mut out_tensor = match a.device {
            crate::device::Device::Cpu => CpuBackend::relu(a)?,
            #[cfg(feature = "cuda-native")]
            crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
            #[cfg(feature = "torch")]
            crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
        };

        out_tensor.requires_grad = a.requires_grad;

        if a.requires_grad {
            record(Box::new(ReluBackward {
                output_id: out_tensor.id,
                input_id: a.id,
                input_data: a.data.clone(), // ReLU needs the input x to know if it was > 0
            }));
        }

        Ok(out_tensor)
    }

    /// Dispatches a Sigmoid activation: out = 1 / (1 + exp(-a))
    pub fn sigmoid(a: &Tensor) -> EtensorResult<Tensor> {
        let mut out_tensor = match a.device {
            crate::device::Device::Cpu => CpuBackend::sigmoid(a)?,
            #[cfg(feature = "cuda-native")]
            crate::device::Device::CudaNative(_) => return Err(EtensorError::InternalError("CUDA pending.".to_string())),
            #[cfg(feature = "torch")]
            crate::device::Device::CudaTorch(_) => return Err(EtensorError::InternalError("Torch pending.".to_string())),
        };

        out_tensor.requires_grad = a.requires_grad;

        if a.requires_grad {
            record(Box::new(SigmoidBackward {
                output_id: out_tensor.id,
                input_id: a.id,
                output_data: out_tensor.data.clone(), // Sigmoid cleverly uses its output y for the derivative: y * (1 - y)
            }));
        }

        Ok(out_tensor)
    }
}

// =====================================================================
// NATIVE RUST OPERATOR OVERLOADING
// =====================================================================

impl Add for &Tensor {
    type Output = Tensor;
    fn add(self, rhs: Self) -> Self::Output {
        Dispatcher::add(self, rhs).expect("Tensor addition failed!")
    }
}

impl Mul for &Tensor {
    type Output = Tensor;
    fn mul(self, rhs: Self) -> Self::Output {
        Dispatcher::mul(self, rhs).expect("Tensor multiplication failed!")
    }
}

// =====================================================================
// UNIT TESTS
// =====================================================================
#[cfg(test)]
mod tests {
    // Tests remain identical to your previous version. We rely on the backend 
    // and autograd specific tests to rigorously validate the new math logic.
    use super::*;
    use crate::device::Device;
    use crate::dtypes::DType;
    use crate::shape::Shape;
    use crate::buffer::Buffer;

    fn make_tensor(data: Vec<f32>, requires_grad: bool) -> Tensor {
        let len = data.len();
        Tensor::new(Buffer::from_f32_vec(data), Shape::new(vec![len]), Device::Cpu, DType::F32, requires_grad)
    }

    #[test]
    fn test_dispatcher_add_forward_logic() {
        let a = make_tensor(vec![1.0, 2.0, 3.0], false);
        let b = make_tensor(vec![4.0, 5.0, 6.0], false);
        let c = Dispatcher::add(&a, &b).unwrap();
        assert_eq!(c.data.as_f32_slice().unwrap(), &[5.0, 7.0, 9.0]);
        assert!(!c.requires_grad); 
    }

    #[test]
    fn test_operator_overloading() {
        let a = make_tensor(vec![2.0, 4.0], false);
        let b = make_tensor(vec![3.0, 5.0], false);
        let c = &a + &b; 
        let d = &a * &b;
        assert_eq!(c.data.as_f32_slice().unwrap(), &[5.0, 9.0]);
        assert_eq!(d.data.as_f32_slice().unwrap(), &[6.0, 20.0]);
    }
}