mathrc 0.2.4

Rust Mathematics Library
Documentation
use crate::err::vector_err::VectorErr;
use std::fmt;
use std::ops::{Add, Index, IndexMut, Mul, Neg, Sub};

pub trait VectorOps {
    fn dot(&self, other: &Self) -> Result<f64, VectorErr>
    where
        Self: Sized;
    fn len(&self) -> f64;
    fn normalize(&self) -> Result<Self, VectorErr>
    where
        Self: Sized;
}

#[derive(Debug, Clone, PartialEq)]
pub struct Vector {
    vec: Vec<f64>,
}

impl Vector {
    pub fn new(vec: Vec<f64>) -> Self {
        Self { vec }
    }

    pub fn len(&self) -> f64 {
        self.vec.iter().map(|x| x * x).sum::<f64>().sqrt()
    }

    pub fn cross(&self, other: &Self) -> Result<Self, VectorErr> {
        if self.vec.len() != 3 || other.vec.len() != 3 {
            return Err(VectorErr::DimensionMismatch);
        }
        Ok(Self {
            vec: vec![
                self.vec[1] * other.vec[2] - self.vec[2] * other.vec[1],
                self.vec[2] * other.vec[0] - self.vec[0] * other.vec[2],
                self.vec[0] * other.vec[1] - self.vec[1] * other.vec[0],
            ],
        })
    }
}

impl VectorOps for Vector {
    fn dot(&self, other: &Self) -> Result<f64, VectorErr> {
        if self.vec.len() != other.vec.len() {
            return Err(VectorErr::DimensionMismatch);
        }
        Ok(self.vec.iter().zip(&other.vec).map(|(a, b)| a * b).sum())
    }

    fn len(&self) -> f64 {
        self.vec.iter().map(|x| x * x).sum::<f64>().sqrt()
    }

    fn normalize(&self) -> Result<Self, VectorErr> {
        let len = self.len();
        if len == 0.0 {
            return Err(VectorErr::ZeroVector);
        }
        Ok(Self {
            vec: self.vec.iter().map(|x| x / len).collect(),
        })
    }
}

impl Add for Vector {
    type Output = Result<Self, VectorErr>;
    fn add(self, other: Self) -> Result<Self, VectorErr> {
        if self.vec.len() != other.vec.len() {
            return Err(VectorErr::DimensionMismatch);
        }
        Ok(Self {
            vec: self
                .vec
                .iter()
                .zip(&other.vec)
                .map(|(a, b)| a + b)
                .collect(),
        })
    }
}

impl Sub for Vector {
    type Output = Result<Self, VectorErr>;
    fn sub(self, other: Self) -> Result<Self, VectorErr> {
        if self.vec.len() != other.vec.len() {
            return Err(VectorErr::DimensionMismatch);
        }
        Ok(Self {
            vec: self
                .vec
                .iter()
                .zip(&other.vec)
                .map(|(a, b)| a - b)
                .collect(),
        })
    }
}

impl Mul<f64> for Vector {
    type Output = Self;
    fn mul(self, scalar: f64) -> Self {
        Self {
            vec: self.vec.iter().map(|x| x * scalar).collect(),
        }
    }
}

impl Neg for Vector {
    type Output = Self;
    fn neg(self) -> Self {
        Self {
            vec: self.vec.iter().map(|x| -x).collect(),
        }
    }
}

impl Index<usize> for Vector {
    type Output = f64;
    fn index(&self, index: usize) -> &f64 {
        &self.vec[index]
    }
}

impl IndexMut<usize> for Vector {
    fn index_mut(&mut self, index: usize) -> &mut f64 {
        &mut self.vec[index]
    }
}

impl fmt::Display for Vector {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let inner = self
            .vec
            .iter()
            .map(|x| x.to_string())
            .collect::<Vec<_>>()
            .join(", ");
        write!(f, "[{}]", inner)
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_add() {
        let a = Vector::new(vec![1.0, 2.0, 3.0]);
        let b = Vector::new(vec![4.0, 5.0, 6.0]);
        assert_eq!((a + b).unwrap(), Vector::new(vec![5.0, 7.0, 9.0]));
    }

    #[test]
    fn test_dot() {
        let a = Vector::new(vec![1.0, 2.0, 3.0]);
        let b = Vector::new(vec![4.0, 5.0, 6.0]);
        assert_eq!(a.dot(&b).unwrap(), 32.0);
    }

    #[test]
    fn test_cross() {
        let a = Vector::new(vec![1.0, 0.0, 0.0]);
        let b = Vector::new(vec![0.0, 1.0, 0.0]);
        assert_eq!(a.cross(&b).unwrap(), Vector::new(vec![0.0, 0.0, 1.0]));
    }

    #[test]
    fn test_normalize() {
        let a = Vector::new(vec![3.0, 4.0]);
        let n = a.normalize().unwrap();
        assert!((n.len() - 1.0).abs() < 1e-10);
    }

    #[test]
    fn test_dimension_mismatch() {
        let a = Vector::new(vec![1.0, 2.0]);
        let b = Vector::new(vec![1.0, 2.0, 3.0]);
        assert_eq!((a + b).unwrap_err(), VectorErr::DimensionMismatch);
    }
}