1use std::ops::{Index, IndexMut};
2
3#[derive(Debug, Clone, PartialEq)]
4pub struct DenseMatrix {
5 nrows: usize,
6 ncols: usize,
7 data: Vec<f64>,
8}
9
10impl DenseMatrix {
11 pub fn zeros(nrows: usize, ncols: usize) -> Self {
12 Self {
13 nrows,
14 ncols,
15 data: vec![0.0; nrows * ncols],
16 }
17 }
18
19 pub fn identity(nrows: usize, ncols: usize) -> Self {
20 let mut out = Self::zeros(nrows, ncols);
21 let diag = nrows.min(ncols);
22 for i in 0..diag {
23 out[(i, i)] = 1.0;
24 }
25 out
26 }
27
28 pub fn from_diagonal_element(nrows: usize, ncols: usize, value: f64) -> Self {
29 let mut out = Self::zeros(nrows, ncols);
30 let diag = nrows.min(ncols);
31 for i in 0..diag {
32 out[(i, i)] = value;
33 }
34 out
35 }
36
37 pub fn from_column_slice(nrows: usize, ncols: usize, data: &[f64]) -> Self {
38 assert_eq!(data.len(), nrows * ncols);
39 Self {
40 nrows,
41 ncols,
42 data: data.to_vec(),
43 }
44 }
45
46 pub fn from_row_slice(nrows: usize, ncols: usize, data: &[f64]) -> Self {
47 assert_eq!(data.len(), nrows * ncols);
48 let mut out = Self::zeros(nrows, ncols);
49 for row in 0..nrows {
50 for col in 0..ncols {
51 out[(row, col)] = data[row * ncols + col];
52 }
53 }
54 out
55 }
56
57 pub fn from_fn(nrows: usize, ncols: usize, mut f: impl FnMut(usize, usize) -> f64) -> Self {
58 let mut out = Self::zeros(nrows, ncols);
59 for col in 0..ncols {
60 for row in 0..nrows {
61 out[(row, col)] = f(row, col);
62 }
63 }
64 out
65 }
66
67 pub fn nrows(&self) -> usize {
68 self.nrows
69 }
70
71 pub fn ncols(&self) -> usize {
72 self.ncols
73 }
74
75 pub fn as_slice(&self) -> &[f64] {
76 &self.data
77 }
78
79 pub(crate) fn as_mut_slice(&mut self) -> &mut [f64] {
80 &mut self.data
81 }
82
83 pub fn into_vec(self) -> Vec<f64> {
84 self.data
85 }
86
87 pub(crate) fn column(&self, col: usize) -> &[f64] {
88 let start = col * self.nrows;
89 &self.data[start..start + self.nrows]
90 }
91
92 pub(crate) fn column_mut(&mut self, col: usize) -> &mut [f64] {
93 let start = col * self.nrows;
94 &mut self.data[start..start + self.nrows]
95 }
96
97 pub(crate) fn set_column(&mut self, col: usize, values: &[f64]) {
98 assert_eq!(values.len(), self.nrows);
99 self.column_mut(col).copy_from_slice(values);
100 }
101
102 pub(crate) fn get(&self, row: usize, col: usize) -> f64 {
103 self[(row, col)]
104 }
105
106 pub(crate) fn set(&mut self, row: usize, col: usize, value: f64) {
107 self[(row, col)] = value;
108 }
109
110 pub(crate) fn copy_columns_from(
111 &mut self,
112 dst_start: usize,
113 src: &Self,
114 src_start: usize,
115 count: usize,
116 ) {
117 assert_eq!(self.nrows, src.nrows);
118 for offset in 0..count {
119 self.column_mut(dst_start + offset)
120 .copy_from_slice(src.column(src_start + offset));
121 }
122 }
123
124 pub(crate) fn transpose(&self) -> Self {
125 let mut out = Self::zeros(self.ncols, self.nrows);
126 for row in 0..self.nrows {
127 for col in 0..self.ncols {
128 out[(col, row)] = self[(row, col)];
129 }
130 }
131 out
132 }
133
134 pub(crate) fn mul(&self, rhs: &Self) -> Self {
135 assert_eq!(self.ncols, rhs.nrows);
136 let mut out = Self::zeros(self.nrows, rhs.ncols);
137 for out_col in 0..rhs.ncols {
138 for k in 0..self.ncols {
139 let rhs_value = rhs[(k, out_col)];
140 if rhs_value.abs() <= 1e-30 {
141 continue;
142 }
143 for row in 0..self.nrows {
144 out[(row, out_col)] += self[(row, k)] * rhs_value;
145 }
146 }
147 }
148 out
149 }
150
151 pub(crate) fn transpose_mul(&self, rhs: &Self) -> Self {
152 assert_eq!(self.nrows, rhs.nrows);
153 let mut out = Self::zeros(self.ncols, rhs.ncols);
154 for out_col in 0..rhs.ncols {
155 for left_col in 0..self.ncols {
156 out[(left_col, out_col)] = self
157 .column(left_col)
158 .iter()
159 .zip(rhs.column(out_col))
160 .map(|(a, b)| a * b)
161 .sum();
162 }
163 }
164 out
165 }
166
167 pub(crate) fn mul_vector(&self, rhs: &[f64]) -> Vec<f64> {
168 assert_eq!(self.ncols, rhs.len());
169 let mut out = vec![0.0; self.nrows];
170 for col in 0..self.ncols {
171 let rhs_value = rhs[col];
172 if rhs_value.abs() <= 1e-30 {
173 continue;
174 }
175 for row in 0..self.nrows {
176 out[row] += self[(row, col)] * rhs_value;
177 }
178 }
179 out
180 }
181
182 pub(crate) fn select_columns(&self, indices: &[usize]) -> Self {
183 let mut out = Self::zeros(self.nrows, indices.len());
184 for (dst_col, &src_col) in indices.iter().enumerate() {
185 out.column_mut(dst_col)
186 .copy_from_slice(self.column(src_col));
187 }
188 out
189 }
190
191 pub(crate) fn to_row_major(&self) -> Vec<f64> {
192 let mut out = vec![0.0; self.nrows * self.ncols];
193 for row in 0..self.nrows {
194 for col in 0..self.ncols {
195 out[row * self.ncols + col] = self[(row, col)];
196 }
197 }
198 out
199 }
200
201 pub(crate) fn from_row_major(nrows: usize, ncols: usize, data: &[f64]) -> Self {
202 Self::from_row_slice(nrows, ncols, data)
203 }
204}
205
206impl Index<(usize, usize)> for DenseMatrix {
207 type Output = f64;
208
209 fn index(&self, index: (usize, usize)) -> &Self::Output {
210 &self.data[index.1 * self.nrows + index.0]
211 }
212}
213
214impl IndexMut<(usize, usize)> for DenseMatrix {
215 fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
216 &mut self.data[index.1 * self.nrows + index.0]
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn dense_matrix_roundtrips_column_major() {
226 let mat = DenseMatrix::from_column_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
227 assert_eq!(mat.nrows(), 3);
228 assert_eq!(mat.ncols(), 2);
229 assert_eq!(mat[(2, 1)], 6.0);
230 assert_eq!(mat.as_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
231 }
232
233 #[test]
234 fn transpose_mul_matches_manual_result() {
235 let a = DenseMatrix::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
236 let b = DenseMatrix::from_row_slice(3, 2, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
237 let gram = a.transpose_mul(&b);
238 assert_eq!(
239 gram,
240 DenseMatrix::from_row_slice(2, 2, &[89.0, 98.0, 116.0, 128.0])
241 );
242 }
243}