use num_traits::Float;
use std::iter::Sum;
use std::{
fmt,
ops::{Add, Mul, Neg, Sub},
};
use crate::{
err::VectorErr,
vector::{Vector, VectorOps},
};
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct Vector3d<T: Float> {
pub x: T,
pub y: T,
pub z: T,
}
impl<T: Float> Vector3d<T> {
pub fn new(x: T, y: T, z: T) -> Self {
Self { x, y, z }
}
pub fn to_vec(&self) -> Vector<T> {
Vector::new(vec![self.x, self.y, self.z])
}
pub fn cross(&self, other: &Self) -> Self {
Self {
x: self.y * other.z - self.z * other.y,
y: self.z * other.x - self.x * other.z,
z: self.x * other.y - self.y * other.x,
}
}
}
impl<T> VectorOps<T> for Vector3d<T>
where
T: Float + Sum<T>,
{
fn dot(&self, other: &Self) -> Result<T, VectorErr> {
Ok(self.x * other.x + self.y * other.y + self.z * other.z)
}
fn magnitude(&self) -> T {
(self.x * self.x + self.y * self.y + self.z * self.z).sqrt()
}
fn normalize(&self) -> Result<Self, VectorErr> {
let mag = self.magnitude();
if mag == T::zero() {
return Err(VectorErr::ZeroVector);
}
Ok(Self {
x: self.x / mag,
y: self.y / mag,
z: self.z / mag,
})
}
}
impl<T: Float> Add for Vector3d<T> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self {
x: self.x + rhs.x,
y: self.y + rhs.y,
z: self.z + rhs.z,
}
}
}
impl<T: Float> Sub for Vector3d<T> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self {
x: self.x - rhs.x,
y: self.y - rhs.y,
z: self.z - rhs.z,
}
}
}
impl<T: Float> Mul<T> for Vector3d<T> {
type Output = Self;
fn mul(self, scalar: T) -> Self::Output {
Self {
x: self.x * scalar,
y: self.y * scalar,
z: self.z * scalar,
}
}
}
impl<T: Float> Neg for Vector3d<T> {
type Output = Self;
fn neg(self) -> Self::Output {
Self {
x: -self.x,
y: -self.y,
z: -self.z,
}
}
}
impl<T: Float + fmt::Display> fmt::Display for Vector3d<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[{}, {}, {}]", self.x, self.y, self.z)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_add() {
let a = Vector3d::new(1.0, 2.0, 3.0);
let b = Vector3d::new(3.0, 4.0, 5.0);
assert_eq!(a + b, Vector3d::new(4.0, 6.0, 8.0));
}
#[test]
fn test_dot() {
let a = Vector3d::new(1.0, 2.0, 3.0);
let b = Vector3d::new(4.0, 5.0, 6.0);
assert_eq!(a.dot(&b).unwrap(), 32.0);
}
#[test]
fn test_cross() {
let a = Vector3d::new(1.0, 0.0, 0.0);
let b = Vector3d::new(0.0, 1.0, 0.0);
assert_eq!(a.cross(&b), Vector3d::new(0.0, 0.0, 1.0));
}
#[test]
fn test_normalize() {
let a = Vector3d::new(3.0, 4.0, 0.0);
let n = a.normalize().unwrap();
assert!((n.magnitude() - 1.0).abs() < 1e-10);
}
#[test]
fn test_zero_normalize() {
let a = Vector3d::new(0.0, 0.0, 0.0);
assert_eq!(a.normalize().unwrap_err(), VectorErr::ZeroVector);
}
}