math_concept/lalg/matrix/
mod.rs1use std::default::Default;
2use std::ops::{Add, AddAssign, Div, Mul};
3
4#[cfg(test)]
5mod tests;
6
7pub struct Matrix<T> {
28 r: usize,
29 c: usize,
30 buf: Vec<T>,
31}
32
33impl<T> Matrix<T>
34where
35 T: Copy + Add<Output = T> + Div<Output = T>,
36{
37 pub fn new(r: usize, c: usize, v: Vec<T>) -> Matrix<T> {
40 assert_eq!(r * c, v.len(), "Matrix dimensions and data size differ");
41 Matrix { r, c, buf: v }
42 }
43
44 pub fn from_vec(r: usize, c: usize, v: &Vec<T>) -> Matrix<T> {
47 assert_eq!(r * c, v.len(), "Matrix dimensions and data size differ");
48 Matrix::new(r, c, v.clone())
49 }
50
51 pub fn rows(&self) -> usize {
53 self.r
54 }
55
56 pub fn cols(&self) -> usize {
58 self.c
59 }
60
61 pub fn at(&self, r: usize, c: usize) -> T {
64 assert!(r < self.r && c < self.c, "Out of bounds");
65 self.buf[r * self.c + c]
66 }
67
68 pub fn at_mut(&mut self, r: usize, c: usize) -> &mut T {
71 assert!(r < self.r && c < self.c, "Out of bounds");
72 &mut self.buf[r * self.c + c]
73 }
74
75 pub fn data(&self) -> &Vec<T> {
80 &self.buf
81 }
82
83 pub fn map<U, F>(&self, f: F) -> Matrix<U>
95 where
96 U: Copy + Add<Output = U> + Div<Output = U>,
97 F: Fn(&T) -> U,
98 {
99 let new_v = self.buf.iter().map(f).collect();
100
101 Matrix::new(self.r, self.c, new_v)
102 }
103
104 pub fn diagonal(&self) -> Vec<T> {
117 self.assert_square();
118
119 self.buf
120 .iter()
121 .zip(0..self.buf.len())
122 .filter(|(_, i)| i % self.c == i / self.r)
123 .map(|(&elem, _)| elem)
124 .collect()
125 }
126
127 fn assert_square(&self) {
128 assert_eq!(self.r, self.c, "Matrix is not square");
129 }
130}
131
132impl<'a, T> Add<&'a Matrix<T>> for &'a Matrix<T>
133where
134 T: Copy + Add<Output = T> + Div<Output = T>,
135{
136 type Output = Matrix<T>;
137
138 fn add(self, other: Self) -> Self::Output {
139 assert!(
140 self.r == other.rows() && self.c == other.cols(),
141 "Matrices are not of the same size"
142 );
143
144 let new_v = self
145 .buf
146 .iter()
147 .zip(other.data())
148 .map(|(&x, &y)| x + y)
149 .collect();
150
151 Matrix::new(self.r, self.c, new_v)
152 }
153}
154
155impl<'a, T> Mul<&'a Matrix<T>> for &'a Matrix<T>
156where
157 T: Copy + Add<Output = T> + AddAssign<T> + Mul<Output = T> + Div<Output = T> + Default,
158{
159 type Output = Matrix<T>;
160
161 fn mul(self, other: Self) -> Self::Output {
162 assert_eq!(self.c, other.r);
163
164 let mut v: Vec<T> = vec![Default::default(); self.r * other.c];
165
166 for i in 0..self.r {
167 for j in 0..other.c {
168 for k in 0..self.c {
169 v[i * self.r + j] += self.buf[i * self.c + k] * other.buf[k * other.c + j];
170 }
171 }
172 }
173
174 Matrix::new(self.r, other.c, v)
175 }
176}