MA_rs 0.1.0

A Machine Learning framework written and used in Rust
Documentation
#[derive(Debug)]
pub enum TensorData {
    F64(Vec<f64>),
    I32(Vec<i32>),
}

#[derive(Debug, Clone)]
pub enum DType {
    F64,
    I32,
}

#[derive(Debug)]
pub enum TensorError {
    TypeMismatch,
    ShapeMismatch,
    EmptyTensor,
}

#[derive(Debug)]
pub struct Tensor {
    pub data: TensorData,
    pub shape: Vec<usize>,
    strides: Option<Vec<usize>>,
    dtype: DType,
}

impl Tensor {
    pub fn new(data: Vec<f64>, shape: Vec<usize>, dtype: DType) -> Self {
        Self {
            data: TensorData::F64(data),
            shape,
            strides: None,
            dtype,
        }
    }

    pub fn zeroes(shape: Vec<usize>, dtype: DType) -> Self {
        let size = shape.iter().product();
        match dtype {
            DType::F64 => {
                let data = vec![0.0; size];
                Self {
                    data: TensorData::F64(data),
                    shape,
                    strides: None,
                    dtype: DType::F64,
                }
            }
            DType::I32 => {
                let data = vec![0; size];
                Self {
                    data: TensorData::I32(data),
                    shape,
                    strides: None,
                    dtype: DType::I32,
                }
            }
        }
    }

    pub fn ones(shape: Vec<usize>, dtype: DType) -> Self {
        let size = shape.iter().product();
        match dtype {
            DType::F64 => {
                let data = vec![1.0; size];
                Self {
                    data: TensorData::F64(data),
                    shape,
                    strides: None,
                    dtype: DType::F64,
                }
            }
            DType::I32 => {
                let data = vec![1; size];
                Self {
                    data: TensorData::I32(data),
                    shape,
                    strides: None,
                    dtype: DType::I32,
                }
            }
        }
    }

    pub fn add(&self, tensor: &Self) -> Result<Self, TensorError> {
        if self.shape != tensor.shape {
            return Err(TensorError::ShapeMismatch);
        }

        match (&self.data, &tensor.data) {
            (TensorData::F64(a), TensorData::F64(b)) => {
                let result: Vec<f64> = a.iter().zip(b).map(|(&x, &y)| x+y).collect();
                Ok(Tensor {
                    data: TensorData::F64(result),
                    shape: self.shape.clone(),
                    strides: None,
                    dtype: DType::F64
                })
            },
            (TensorData::I32(a), TensorData::I32(b)) => {
                let result: Vec<i32> = a.iter().zip(b).map(|(&x, &y)| x+y).collect();
                Ok(Tensor {
                    data: TensorData::I32(result),
                    shape: self.shape.clone(),
                    strides: None,
                    dtype: DType::I32
                })
            },
            _ => Err(TensorError::TypeMismatch)
        }
    }

    pub fn mul(&self, tensor: &Self) -> Result<Self, TensorError> {
        if self.shape != tensor.shape {
            return Err(TensorError::ShapeMismatch);
        }

        match (&self.data, &tensor.data) {
            (TensorData::F64(a), TensorData::F64(b)) => {
                let result: Vec<f64> = a.iter().zip(b).map(|(&x, &y)| x*y).collect();
                Ok(Tensor {
                    data: TensorData::F64(result),
                    shape: self.shape.clone(),
                    strides: None,
                    dtype: DType::F64
                })
            },
            (TensorData::I32(a), TensorData::I32(b)) => {
                let result: Vec<i32> = a.iter().zip(b).map(|(&x, &y)| x*y).collect();
                Ok(Tensor {
                    data: TensorData::I32(result),
                    shape: self.shape.clone(),
                    strides: None,
                    dtype: DType::I32
                })
            },
            _ => Err(TensorError::TypeMismatch)
        }
    }

    pub fn sub(&self, tensor: &Self) -> Result<Self, TensorError> {
        if self.shape != tensor.shape {
            return Err(TensorError::ShapeMismatch);
        }

        match (&self.data, &tensor.data) {
            (TensorData::F64(a), TensorData::F64(b)) => {
                let result: Vec<f64> = a.iter().zip(b).map(|(&x, &y)| x-y).collect();
                Ok(Tensor {
                    data: TensorData::F64(result),
                    shape: self.shape.clone(),
                    strides: None,
                    dtype: DType::F64
                })
            },
            (TensorData::I32(a), TensorData::I32(b)) => {
                let result: Vec<i32> = a.iter().zip(b).map(|(&x, &y)| x-y).collect();
                Ok(Tensor {
                    data: TensorData::I32(result),
                    shape: self.shape.clone(),
                    strides: None,
                    dtype: DType::I32
                })
            },
            _ => Err(TensorError::TypeMismatch)
        }
    }


    pub fn div(&self, tensor: &Self) -> Result<Self, TensorError> {
        if self.shape != tensor.shape {
            return Err(TensorError::ShapeMismatch);
        }

        match (&self.data, &tensor.data) {
            (TensorData::F64(a), TensorData::F64(b)) => {
                let result: Vec<f64> = a.iter().zip(b).map(|(&x, &y)| x/y).collect();
                Ok(Tensor {
                    data: TensorData::F64(result),
                    shape: self.shape.clone(),
                    strides: None,
                    dtype: DType::F64
                })
            },
            (TensorData::I32(a), TensorData::I32(b)) => {
                let result: Vec<i32> = a.iter().zip(b).map(|(&x, &y)| x/y).collect();
                Ok(Tensor {
                    data: TensorData::I32(result),
                    shape: self.shape.clone(),
                    strides: None,
                    dtype: DType::I32
                })
            },
            _ => Err(TensorError::TypeMismatch)
        }
    }

    pub fn sum(&self) -> Result<f64, TensorError> {
        if self.shape.is_empty() {
            return Err(TensorError::EmptyTensor);
        }

        match &self.data {
            TensorData::F64(values) => Ok(values.iter().copied().sum()),
            TensorData::I32(values) => Ok(values.iter().map(|&x| x as f64).sum()),
        }
    }

    pub fn mean(&self) -> Result<f64, TensorError> {
        if self.shape.is_empty() {
            return Err(TensorError::EmptyTensor);
        }

        let sum = &self.sum().unwrap();
        let len = &self.shape.iter().map(|&x| x as f64).product();
        
        Ok(sum/len)
    }
}