elemental/
matrix.rs

1//! Provides an abstraction over matrix behaviors.
2
3use std::ops::{
4    Add,
5    Sub,
6    Mul,
7    Index,
8    IndexMut,
9};
10
11use crate::error::*;
12
13/// Abstracts over matrices.
14#[derive(Clone, Debug)]
15pub struct Matrix {
16    rows: usize,
17    cols: usize,
18    vals: Vec<f64>,
19}
20
21impl Matrix {
22    /// Constructs a new matrix.
23    pub fn new(rows: usize, cols: usize, vals: Vec<f64>) -> Self {
24        Self {
25            rows,
26            cols,
27            vals,
28        }
29    }
30
31    /// Multiplies the given matrix by the given scalar, returning a new matrix.
32    pub fn scalar_multiply(&self, scalar: f64) -> Self {
33        let vals = self.vals.iter().map(|x| scalar*x).collect::<Vec<f64>>();
34
35        Self {
36            vals,
37            ..*self
38        }
39    }
40
41    /// Gets the number of rows of the matrix.
42    pub fn rows(&self) -> usize {
43        self.rows
44    }
45
46    /// Gets the number of columns of the matrix.
47    pub fn cols(&self) -> usize {
48        self.cols
49    }
50
51    /// Copies the values of the matrix.
52    pub fn copy_vals(&self) -> Vec<f64> {
53        self.vals.to_owned()
54    }
55
56    /// Gets the values of the matrix.
57    pub fn vals(&self) -> &Vec<f64> {
58        &self.vals
59    }
60    
61    /// Gets the values of the matrix, with mutable permission.
62    pub fn vals_mut(&mut self) -> &mut Vec<f64> {
63        &mut self.vals
64    }
65
66    /// Generates an empty matrix.
67    pub fn empty() -> Self {
68        Self {
69            rows: 0,
70            cols: 0,
71            vals: Vec::new(),
72        }
73    }
74}
75
76/// Defines matrix addition.
77impl Add for Matrix {
78    type Output = Self;
79
80    fn add(self, other: Self) -> Self {
81        if self.rows() != other.rows() || self.cols() != other.cols() {
82            throw(ImproperDimensions);
83            Self::new(0, 0, Vec::new());
84        }
85
86        let mut output_vals = Vec::new();
87
88        for (i, j) in self.vals().iter().zip(other.vals().iter()) {
89            output_vals.push(i + j);
90        }
91
92        Self {
93            vals: output_vals,
94            ..self
95        }
96    }
97}
98
99/// Defines matrix subtraction.
100impl Sub for Matrix {
101    type Output = Self;
102
103    fn sub(self, other: Self) -> Self {
104        if self.rows() != other.rows() || self.cols() != other.cols() {
105            throw(ImproperDimensions);
106            Self::new(0, 0, Vec::new());
107        }
108
109        let mut output_vals = Vec::new();
110
111        for (i, j) in self.vals().iter().zip(other.vals().iter()) {
112            output_vals.push(i - j);
113        }
114
115        Self {
116            vals: output_vals,
117            ..self
118        }
119    }
120}
121
122/// Defines matrix multiplication.
123impl Mul for Matrix {
124    type Output = Self;
125
126    #[allow(unused_variables)]
127    fn mul(self, other: Self) -> Self {
128        Self::new(0, 0, Vec::new())
129    }
130}
131
132/// Defines matrix indexing.
133impl Index<[usize; 2]> for Matrix {
134    type Output = f64;
135
136    fn index(&self, index: [usize; 2]) -> &Self::Output {
137        let i = index[0];
138        let j = index[1];
139        &self.vals()[i*self.cols() + j]
140    }
141}
142
143/// Defines matrix indexing with mutable permission.
144impl IndexMut<[usize; 2]> for Matrix {
145    fn index_mut(&mut self, index: [usize; 2]) -> &mut Self::Output {
146        let i = index[0];
147        let j = index[1];
148        let cols = self.cols();
149        &mut self.vals_mut()[i*cols + j]
150    }
151}