Skip to main content

neco_eigensolve/
dense_block.rs

1use std::ops::{Index, IndexMut};
2
3#[derive(Debug, Clone, PartialEq)]
4pub struct DenseMatrix {
5    nrows: usize,
6    ncols: usize,
7    data: Vec<f64>,
8}
9
10impl DenseMatrix {
11    pub fn zeros(nrows: usize, ncols: usize) -> Self {
12        Self {
13            nrows,
14            ncols,
15            data: vec![0.0; nrows * ncols],
16        }
17    }
18
19    pub fn identity(nrows: usize, ncols: usize) -> Self {
20        let mut out = Self::zeros(nrows, ncols);
21        let diag = nrows.min(ncols);
22        for i in 0..diag {
23            out[(i, i)] = 1.0;
24        }
25        out
26    }
27
28    pub fn from_diagonal_element(nrows: usize, ncols: usize, value: f64) -> Self {
29        let mut out = Self::zeros(nrows, ncols);
30        let diag = nrows.min(ncols);
31        for i in 0..diag {
32            out[(i, i)] = value;
33        }
34        out
35    }
36
37    pub fn from_column_slice(nrows: usize, ncols: usize, data: &[f64]) -> Self {
38        assert_eq!(data.len(), nrows * ncols);
39        Self {
40            nrows,
41            ncols,
42            data: data.to_vec(),
43        }
44    }
45
46    pub fn from_row_slice(nrows: usize, ncols: usize, data: &[f64]) -> Self {
47        assert_eq!(data.len(), nrows * ncols);
48        let mut out = Self::zeros(nrows, ncols);
49        for row in 0..nrows {
50            for col in 0..ncols {
51                out[(row, col)] = data[row * ncols + col];
52            }
53        }
54        out
55    }
56
57    pub fn from_fn(nrows: usize, ncols: usize, mut f: impl FnMut(usize, usize) -> f64) -> Self {
58        let mut out = Self::zeros(nrows, ncols);
59        for col in 0..ncols {
60            for row in 0..nrows {
61                out[(row, col)] = f(row, col);
62            }
63        }
64        out
65    }
66
67    pub fn nrows(&self) -> usize {
68        self.nrows
69    }
70
71    pub fn ncols(&self) -> usize {
72        self.ncols
73    }
74
75    pub fn as_slice(&self) -> &[f64] {
76        &self.data
77    }
78
79    pub(crate) fn as_mut_slice(&mut self) -> &mut [f64] {
80        &mut self.data
81    }
82
83    pub fn into_vec(self) -> Vec<f64> {
84        self.data
85    }
86
87    pub(crate) fn column(&self, col: usize) -> &[f64] {
88        let start = col * self.nrows;
89        &self.data[start..start + self.nrows]
90    }
91
92    pub(crate) fn column_mut(&mut self, col: usize) -> &mut [f64] {
93        let start = col * self.nrows;
94        &mut self.data[start..start + self.nrows]
95    }
96
97    pub(crate) fn set_column(&mut self, col: usize, values: &[f64]) {
98        assert_eq!(values.len(), self.nrows);
99        self.column_mut(col).copy_from_slice(values);
100    }
101
102    pub(crate) fn get(&self, row: usize, col: usize) -> f64 {
103        self[(row, col)]
104    }
105
106    pub(crate) fn set(&mut self, row: usize, col: usize, value: f64) {
107        self[(row, col)] = value;
108    }
109
110    pub(crate) fn copy_columns_from(
111        &mut self,
112        dst_start: usize,
113        src: &Self,
114        src_start: usize,
115        count: usize,
116    ) {
117        assert_eq!(self.nrows, src.nrows);
118        for offset in 0..count {
119            self.column_mut(dst_start + offset)
120                .copy_from_slice(src.column(src_start + offset));
121        }
122    }
123
124    pub(crate) fn transpose(&self) -> Self {
125        let mut out = Self::zeros(self.ncols, self.nrows);
126        for row in 0..self.nrows {
127            for col in 0..self.ncols {
128                out[(col, row)] = self[(row, col)];
129            }
130        }
131        out
132    }
133
134    pub(crate) fn mul(&self, rhs: &Self) -> Self {
135        assert_eq!(self.ncols, rhs.nrows);
136        let mut out = Self::zeros(self.nrows, rhs.ncols);
137        for out_col in 0..rhs.ncols {
138            for k in 0..self.ncols {
139                let rhs_value = rhs[(k, out_col)];
140                if rhs_value.abs() <= 1e-30 {
141                    continue;
142                }
143                for row in 0..self.nrows {
144                    out[(row, out_col)] += self[(row, k)] * rhs_value;
145                }
146            }
147        }
148        out
149    }
150
151    pub(crate) fn transpose_mul(&self, rhs: &Self) -> Self {
152        assert_eq!(self.nrows, rhs.nrows);
153        let mut out = Self::zeros(self.ncols, rhs.ncols);
154        for out_col in 0..rhs.ncols {
155            for left_col in 0..self.ncols {
156                out[(left_col, out_col)] = self
157                    .column(left_col)
158                    .iter()
159                    .zip(rhs.column(out_col))
160                    .map(|(a, b)| a * b)
161                    .sum();
162            }
163        }
164        out
165    }
166
167    pub(crate) fn mul_vector(&self, rhs: &[f64]) -> Vec<f64> {
168        assert_eq!(self.ncols, rhs.len());
169        let mut out = vec![0.0; self.nrows];
170        for col in 0..self.ncols {
171            let rhs_value = rhs[col];
172            if rhs_value.abs() <= 1e-30 {
173                continue;
174            }
175            for row in 0..self.nrows {
176                out[row] += self[(row, col)] * rhs_value;
177            }
178        }
179        out
180    }
181
182    pub(crate) fn select_columns(&self, indices: &[usize]) -> Self {
183        let mut out = Self::zeros(self.nrows, indices.len());
184        for (dst_col, &src_col) in indices.iter().enumerate() {
185            out.column_mut(dst_col)
186                .copy_from_slice(self.column(src_col));
187        }
188        out
189    }
190
191    pub(crate) fn to_row_major(&self) -> Vec<f64> {
192        let mut out = vec![0.0; self.nrows * self.ncols];
193        for row in 0..self.nrows {
194            for col in 0..self.ncols {
195                out[row * self.ncols + col] = self[(row, col)];
196            }
197        }
198        out
199    }
200
201    pub(crate) fn from_row_major(nrows: usize, ncols: usize, data: &[f64]) -> Self {
202        Self::from_row_slice(nrows, ncols, data)
203    }
204}
205
206impl Index<(usize, usize)> for DenseMatrix {
207    type Output = f64;
208
209    fn index(&self, index: (usize, usize)) -> &Self::Output {
210        &self.data[index.1 * self.nrows + index.0]
211    }
212}
213
214impl IndexMut<(usize, usize)> for DenseMatrix {
215    fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
216        &mut self.data[index.1 * self.nrows + index.0]
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn dense_matrix_roundtrips_column_major() {
226        let mat = DenseMatrix::from_column_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
227        assert_eq!(mat.nrows(), 3);
228        assert_eq!(mat.ncols(), 2);
229        assert_eq!(mat[(2, 1)], 6.0);
230        assert_eq!(mat.as_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
231    }
232
233    #[test]
234    fn transpose_mul_matches_manual_result() {
235        let a = DenseMatrix::from_row_slice(3, 2, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
236        let b = DenseMatrix::from_row_slice(3, 2, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
237        let gram = a.transpose_mul(&b);
238        assert_eq!(
239            gram,
240            DenseMatrix::from_row_slice(2, 2, &[89.0, 98.0, 116.0, 128.0])
241        );
242    }
243}