multrix/
lib.rs

1pub mod multrix {
2    use rand::Rng;
3    use rayon::prelude::*;
4    use std::fs;
5
6    /// A matrix struct that stores the values in a one-dimensional vector, where each row is stored
7    /// contiguously.
8    pub struct Matrix {
9        data: Vec<f64>,
10        rows: usize,
11        cols: usize,
12    }
13
14    impl Matrix {
15        /// Creates a new (square) identity matrix with the given dimension.
16        pub fn new_identity(dimension: usize) -> Matrix {
17            let mut data: Vec<f64> = vec![0.0; dimension * dimension];
18
19            for i in 0..dimension {
20                for j in 0..dimension {
21                    data[i * dimension + j] = {
22                        if i == j {
23                            1.0
24                        } else {
25                            0.0
26                        }
27                    };
28                }
29            }
30            Matrix { data, rows: dimension, cols: dimension }
31        }
32
33        /// Creates a new matrix with the given dimensions and random values.
34        pub fn new_rand(rows: usize, cols: usize) -> Matrix {
35            let mut data: Vec<f64> = vec![0.0; cols * rows];
36
37            let mut rng = rand::thread_rng();
38            for i in 0..rows {
39                for j in 0..cols {
40                    data[i * cols + j] = rng.gen_range(0..10) as f64;
41                }
42            }
43            Matrix { data, rows, cols }
44        }
45
46        /// Creates a new matrix with the given dimensions reading the values from a file.
47        /// The file must contain a comma-separated list of numbers, with each row on a new line.
48        /// The last element on each row may or may not be followed by a comma.
49        ///
50        /// # Panics
51        /// The function panics if the dimensions are incorrect, if it fails to read from the file,
52        /// or the file contains invalid data and numbers cannot be parsed.
53        pub fn new_from_file(filename: &str) -> Matrix {
54            let contents = match fs::read_to_string(filename) {
55                Ok(v) => v,
56                Err(e) => {
57                    eprintln!(
58                        "{} failed to read from file '{}': {:?}",
59                        "Error:", filename, e
60                    );
61                    panic!();
62                }
63            };
64            let mut data = Vec::new();
65            let rows = contents.lines().count();
66            let mut cols = 0;
67            for line in contents.lines() {
68                for num_str in line.split(',') {
69                    data.push(num_str.parse().unwrap());
70                }
71                if cols == 0 {
72                    cols = line.split(',').count();
73                }
74            }
75            Matrix { data, rows, cols }
76        }
77
78        /// Creates a new matrix with the given dimensions from a one-dimensional vector containing
79        /// the values, where each row is stored contiguously.
80        ///
81        /// # Panics
82        /// The function panics if the provided dimensions are different than the vector length.
83        pub fn new_from_vec(data: Vec<f64>, rows: usize, cols: usize) -> Matrix {
84            assert_eq!(data.len(), rows * cols, "Invalid matrix dimensions");
85            Matrix { data, rows, cols }
86        }
87
88        /// Creates a new matrix with the given dimensions from a two-dimensional vector containing
89        /// the values.
90        pub fn new_from_vec_vec(data: Vec<Vec<f64>>) -> Matrix {
91            let rows = data.len();
92            let cols = data[0].len();
93            let mut data_vec = Vec::with_capacity(rows * cols);
94            for row in data {
95                for element in row {
96                    data_vec.push(element);
97                }
98            }
99            Matrix { data: data_vec, rows, cols }
100        }
101
102        /// Gets the value at the given row and column indices.
103        pub fn get(&self, row: usize, col: usize) -> f64 {
104            self.data[row * self.cols + col]
105        }
106
107        /// Sets the value at the given row and column indices.
108        pub fn set(&mut self, row: usize, col: usize, value: f64) {
109            self.data[row * self.cols + col] = value;
110        }
111
112        /// Returns the current matrix transposed (rows and columns swapped).
113        pub fn transpose(&self) -> Matrix {
114            let mut data = vec![0.0; self.rows * self.cols];
115            for i in 0..self.rows {
116                for j in 0..self.cols {
117                    data[j * self.rows + i] = self.data[i * self.cols + j];
118                }
119            }
120            Matrix { data, rows: self.cols, cols: self.rows }
121        }
122
123        /// Returns whether the two matrices are conformable for multiplication.
124        pub fn is_conformable(&self, other: &Matrix) -> bool {
125            self.cols == other.rows
126        }
127
128        /// Adds the given matrix to the current one and returns the result.
129        ///
130        /// # Panics
131        /// The function panics if the matrices cannot be added: they must have the same dimensions.
132        fn addition(self, other: Matrix) -> Matrix {
133            assert_eq!(self.rows, other.rows, "Matrices cannot be added");
134            assert_eq!(self.cols, other.cols, "Matrices cannot be added");
135            let mut data = vec![0.0; self.rows * self.cols];
136            for i in 0..self.rows * self.cols {
137                data[i] = self.data[i] + other.data[i];
138            }
139            Matrix { data, rows: self.rows, cols: self.cols }
140        }
141
142        /// Negates the sign of the current matrix and returns the result.
143        fn negation(self) -> Matrix {
144            let mut data = vec![0.0; self.rows * self.cols];
145            for i in 0..self.rows * self.cols {
146                data[i] = -self.data[i];
147            }
148            Matrix { data, rows: self.rows, cols: self.cols }
149        }
150
151        /// Returns the product between the current matrix and the given one, and uses only one thread.
152        ///
153        /// # Panics
154        /// The function panics if the matrices cannot be multiplied: the number of columns of the
155        /// first matrix must be equal to the number of rows of the second matrix.
156        pub fn product(self, other: Matrix) -> Matrix {
157            if self.cols != other.rows {
158                panic!("Matrices cannot be multiplied");
159            }
160
161            let rows = self.rows;
162            let cols = other.cols;
163            let mut data = vec![0.0; cols * rows];
164            for i in 0..cols * rows {
165                let mut c = 0.0;
166                let row = i / cols;
167                let col = i % cols;
168                for k in 0..self.cols {
169                    c += self.data[row * self.cols + k] * other.data[k * other.cols + col];
170                }
171                data[i] = c;
172            }
173            Matrix { data, rows, cols }
174        }
175
176        /// Returns the product between the current matrix and the given one, and uses multiple threads.
177        ///
178        /// # Panics
179        /// The function panics if the matrices cannot be multiplied: the number of columns of the
180        /// first matrix must be equal to the number of rows of the second matrix.
181        pub fn parallel_product(self, other: Matrix) -> Matrix {
182            if self.cols != other.rows {
183                panic!("Matrices cannot be multiplied");
184            }
185
186            let rows = self.rows;
187            let cols = other.cols;
188            let mut data = vec![0.0; cols * rows];
189
190            data.par_iter_mut().enumerate().for_each(|(i, c)| {
191                let row = i / cols;
192                let col = i % cols;
193                for k in 0..self.cols {
194                    *c += self.data[row * self.cols + k] * other.data[k * other.cols + col];
195                }
196            });
197
198            Matrix { data, rows, cols }
199        }
200
201        /// Writes the matrix to the given file in the same comma-separated format as the input.
202        ///
203        /// # Panics
204        /// The function panics if it fails to write to the file.
205        pub fn write_to_file(&self, filename: &str) {
206            let mut matrix_str = String::with_capacity(self.rows * self.cols * 2);
207            for (i, element) in self.data.iter().enumerate() {
208                matrix_str.push_str(&format!("{},", element));
209                if (i + 1) % self.cols == 0 {
210                    matrix_str.push('\n');
211                }
212            }
213
214            match fs::write(filename, matrix_str) {
215                Ok(_) => {}
216                Err(e) => {
217                    eprintln!("Error: failed to write to file '{}': {:?}", filename, e);
218                    panic!();
219                }
220            }
221        }
222
223    }
224    use std::ops::Add;
225    impl Add for Matrix {
226        type Output = Matrix;
227        fn add(self, other: Matrix) -> Matrix {
228            self.addition(other)
229        }
230    }
231
232    use std::ops::Neg;
233    impl Neg for Matrix {
234        type Output = Matrix;
235        fn neg(self) -> Matrix {
236            self.negation()
237        }
238    }
239
240    use std::ops::Mul;
241    impl Mul for Matrix {
242        type Output = Matrix;
243        fn mul(self, other: Matrix) -> Matrix {
244            self.parallel_product(other)
245        }
246    }
247
248    use std::ops::Sub;
249    impl Sub for Matrix {
250        type Output = Matrix;
251        fn sub(self, other: Matrix) -> Matrix {
252            self + (-other)
253        }
254    }
255
256    use std::fmt;
257    impl fmt::Display for Matrix {
258        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
259            for (i, element) in self.data.iter().enumerate() {
260                write!(f, "{},", element)?;
261                if (i + 1) % self.cols == 0 {
262                    writeln!(f)?;
263                }
264            }
265            Ok(())
266        }
267    }
268}