Skip to main content

fdars_core/
matrix.rs

1//! Column-major matrix type for functional data analysis.
2//!
3//! [`FdMatrix`] provides safe, dimension-tracked access to the flat column-major
4//! data layout used throughout this crate. It eliminates manual `data[i + j * n]`
5//! index arithmetic and carries dimensions alongside the data.
6
7use nalgebra::DMatrix;
8
9/// Column-major matrix for functional data.
10///
11/// Stores data in a flat `Vec<f64>` with column-major (Fortran) layout:
12/// element `(row, col)` is at index `row + col * nrows`.
13///
14/// # Conventions
15///
16/// For functional data, rows typically represent observations and columns
17/// represent evaluation points. For 2D surfaces with `m1 x m2` grids,
18/// the surface is flattened into `m1 * m2` columns.
19///
20/// # Examples
21///
22/// ```
23/// use fdars_core::matrix::FdMatrix;
24///
25/// // 3 observations, 4 evaluation points
26/// let data = vec![
27///     1.0, 2.0, 3.0,  // column 0 (all obs at point 0)
28///     4.0, 5.0, 6.0,  // column 1
29///     7.0, 8.0, 9.0,  // column 2
30///     10.0, 11.0, 12.0, // column 3
31/// ];
32/// let mat = FdMatrix::from_column_major(data, 3, 4).unwrap();
33///
34/// assert_eq!(mat[(0, 0)], 1.0);  // obs 0 at point 0
35/// assert_eq!(mat[(1, 2)], 8.0);  // obs 1 at point 2
36/// assert_eq!(mat.column(0), &[1.0, 2.0, 3.0]);
37/// ```
38#[derive(Debug, Clone, PartialEq)]
39pub struct FdMatrix {
40    data: Vec<f64>,
41    nrows: usize,
42    ncols: usize,
43}
44
45impl FdMatrix {
46    /// Create from flat column-major data with dimension validation.
47    ///
48    /// Returns `None` if `data.len() != nrows * ncols`.
49    pub fn from_column_major(data: Vec<f64>, nrows: usize, ncols: usize) -> Option<Self> {
50        if data.len() != nrows * ncols {
51            return None;
52        }
53        Some(Self { data, nrows, ncols })
54    }
55
56    /// Create from a borrowed slice (copies the data).
57    ///
58    /// Returns `None` if `data.len() != nrows * ncols`.
59    pub fn from_slice(data: &[f64], nrows: usize, ncols: usize) -> Option<Self> {
60        if data.len() != nrows * ncols {
61            return None;
62        }
63        Some(Self {
64            data: data.to_vec(),
65            nrows,
66            ncols,
67        })
68    }
69
70    /// Create a zero-filled matrix.
71    pub fn zeros(nrows: usize, ncols: usize) -> Self {
72        Self {
73            data: vec![0.0; nrows * ncols],
74            nrows,
75            ncols,
76        }
77    }
78
79    /// Number of rows.
80    #[inline]
81    pub fn nrows(&self) -> usize {
82        self.nrows
83    }
84
85    /// Number of columns.
86    #[inline]
87    pub fn ncols(&self) -> usize {
88        self.ncols
89    }
90
91    /// Dimensions as `(nrows, ncols)`.
92    #[inline]
93    pub fn shape(&self) -> (usize, usize) {
94        (self.nrows, self.ncols)
95    }
96
97    /// Total number of elements.
98    #[inline]
99    pub fn len(&self) -> usize {
100        self.data.len()
101    }
102
103    /// Whether the matrix is empty.
104    #[inline]
105    pub fn is_empty(&self) -> bool {
106        self.data.is_empty()
107    }
108
109    /// Get a contiguous column slice (zero-copy).
110    ///
111    /// # Panics
112    /// Panics if `col >= ncols`.
113    #[inline]
114    pub fn column(&self, col: usize) -> &[f64] {
115        let start = col * self.nrows;
116        &self.data[start..start + self.nrows]
117    }
118
119    /// Get a mutable contiguous column slice (zero-copy).
120    ///
121    /// # Panics
122    /// Panics if `col >= ncols`.
123    #[inline]
124    pub fn column_mut(&mut self, col: usize) -> &mut [f64] {
125        let start = col * self.nrows;
126        &mut self.data[start..start + self.nrows]
127    }
128
129    /// Extract a single row as a new `Vec<f64>`.
130    ///
131    /// This is an O(ncols) operation because rows are not contiguous
132    /// in column-major layout.
133    pub fn row(&self, row: usize) -> Vec<f64> {
134        (0..self.ncols)
135            .map(|j| self.data[row + j * self.nrows])
136            .collect()
137    }
138
139    /// Extract all rows as `Vec<Vec<f64>>`.
140    ///
141    /// Equivalent to the former `extract_curves` function.
142    pub fn rows(&self) -> Vec<Vec<f64>> {
143        (0..self.nrows).map(|i| self.row(i)).collect()
144    }
145
146    /// Produce a single contiguous flat buffer in row-major order.
147    ///
148    /// Row `i` occupies `buf[i * ncols..(i + 1) * ncols]`. This is a single
149    /// allocation versus `nrows` allocations from `rows()`, and gives better
150    /// cache locality for row-oriented iteration.
151    pub fn to_row_major(&self) -> Vec<f64> {
152        let mut buf = vec![0.0; self.nrows * self.ncols];
153        for i in 0..self.nrows {
154            for j in 0..self.ncols {
155                buf[i * self.ncols + j] = self.data[i + j * self.nrows];
156            }
157        }
158        buf
159    }
160
161    /// Flat slice of the underlying column-major data (zero-copy).
162    #[inline]
163    pub fn as_slice(&self) -> &[f64] {
164        &self.data
165    }
166
167    /// Mutable flat slice of the underlying column-major data.
168    #[inline]
169    pub fn as_mut_slice(&mut self) -> &mut [f64] {
170        &mut self.data
171    }
172
173    /// Consume and return the underlying `Vec<f64>`.
174    pub fn into_vec(self) -> Vec<f64> {
175        self.data
176    }
177
178    /// Convert to a nalgebra `DMatrix<f64>`.
179    ///
180    /// This copies the data into nalgebra's storage. Both use column-major
181    /// layout, so the copy is a simple memcpy.
182    pub fn to_dmatrix(&self) -> DMatrix<f64> {
183        DMatrix::from_column_slice(self.nrows, self.ncols, &self.data)
184    }
185
186    /// Create from a nalgebra `DMatrix<f64>`.
187    ///
188    /// Both use column-major layout so this is a direct copy.
189    pub fn from_dmatrix(mat: &DMatrix<f64>) -> Self {
190        let (nrows, ncols) = mat.shape();
191        Self {
192            data: mat.as_slice().to_vec(),
193            nrows,
194            ncols,
195        }
196    }
197
198    /// Get element at (row, col) with bounds checking.
199    #[inline]
200    pub fn get(&self, row: usize, col: usize) -> Option<f64> {
201        if row < self.nrows && col < self.ncols {
202            Some(self.data[row + col * self.nrows])
203        } else {
204            None
205        }
206    }
207
208    /// Set element at (row, col) with bounds checking.
209    #[inline]
210    pub fn set(&mut self, row: usize, col: usize, value: f64) -> bool {
211        if row < self.nrows && col < self.ncols {
212            self.data[row + col * self.nrows] = value;
213            true
214        } else {
215            false
216        }
217    }
218}
219
220impl std::ops::Index<(usize, usize)> for FdMatrix {
221    type Output = f64;
222
223    #[inline]
224    fn index(&self, (row, col): (usize, usize)) -> &f64 {
225        debug_assert!(
226            row < self.nrows && col < self.ncols,
227            "FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
228            row,
229            col,
230            self.nrows,
231            self.ncols
232        );
233        &self.data[row + col * self.nrows]
234    }
235}
236
237impl std::ops::IndexMut<(usize, usize)> for FdMatrix {
238    #[inline]
239    fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut f64 {
240        debug_assert!(
241            row < self.nrows && col < self.ncols,
242            "FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
243            row,
244            col,
245            self.nrows,
246            self.ncols
247        );
248        &mut self.data[row + col * self.nrows]
249    }
250}
251
252impl std::fmt::Display for FdMatrix {
253    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254        write!(f, "FdMatrix({}x{})", self.nrows, self.ncols)
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    fn sample_3x4() -> FdMatrix {
263        // 3 rows, 4 columns, column-major
264        let data = vec![
265            1.0, 2.0, 3.0, // col 0
266            4.0, 5.0, 6.0, // col 1
267            7.0, 8.0, 9.0, // col 2
268            10.0, 11.0, 12.0, // col 3
269        ];
270        FdMatrix::from_column_major(data, 3, 4).unwrap()
271    }
272
273    #[test]
274    fn test_from_column_major_valid() {
275        let mat = sample_3x4();
276        assert_eq!(mat.nrows(), 3);
277        assert_eq!(mat.ncols(), 4);
278        assert_eq!(mat.shape(), (3, 4));
279        assert_eq!(mat.len(), 12);
280        assert!(!mat.is_empty());
281    }
282
283    #[test]
284    fn test_from_column_major_invalid() {
285        assert!(FdMatrix::from_column_major(vec![1.0, 2.0], 3, 4).is_none());
286    }
287
288    #[test]
289    fn test_from_slice() {
290        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
291        let mat = FdMatrix::from_slice(&data, 2, 3).unwrap();
292        assert_eq!(mat[(0, 0)], 1.0);
293        assert_eq!(mat[(1, 0)], 2.0);
294        assert_eq!(mat[(0, 1)], 3.0);
295    }
296
297    #[test]
298    fn test_from_slice_invalid() {
299        assert!(FdMatrix::from_slice(&[1.0, 2.0], 3, 3).is_none());
300    }
301
302    #[test]
303    fn test_zeros() {
304        let mat = FdMatrix::zeros(2, 3);
305        assert_eq!(mat.nrows(), 2);
306        assert_eq!(mat.ncols(), 3);
307        for j in 0..3 {
308            for i in 0..2 {
309                assert_eq!(mat[(i, j)], 0.0);
310            }
311        }
312    }
313
314    #[test]
315    fn test_index() {
316        let mat = sample_3x4();
317        assert_eq!(mat[(0, 0)], 1.0);
318        assert_eq!(mat[(1, 0)], 2.0);
319        assert_eq!(mat[(2, 0)], 3.0);
320        assert_eq!(mat[(0, 1)], 4.0);
321        assert_eq!(mat[(1, 1)], 5.0);
322        assert_eq!(mat[(2, 3)], 12.0);
323    }
324
325    #[test]
326    fn test_index_mut() {
327        let mut mat = sample_3x4();
328        mat[(1, 2)] = 99.0;
329        assert_eq!(mat[(1, 2)], 99.0);
330    }
331
332    #[test]
333    fn test_column() {
334        let mat = sample_3x4();
335        assert_eq!(mat.column(0), &[1.0, 2.0, 3.0]);
336        assert_eq!(mat.column(1), &[4.0, 5.0, 6.0]);
337        assert_eq!(mat.column(3), &[10.0, 11.0, 12.0]);
338    }
339
340    #[test]
341    fn test_column_mut() {
342        let mut mat = sample_3x4();
343        mat.column_mut(1)[0] = 99.0;
344        assert_eq!(mat[(0, 1)], 99.0);
345    }
346
347    #[test]
348    fn test_row() {
349        let mat = sample_3x4();
350        assert_eq!(mat.row(0), vec![1.0, 4.0, 7.0, 10.0]);
351        assert_eq!(mat.row(1), vec![2.0, 5.0, 8.0, 11.0]);
352        assert_eq!(mat.row(2), vec![3.0, 6.0, 9.0, 12.0]);
353    }
354
355    #[test]
356    fn test_rows() {
357        let mat = sample_3x4();
358        let rows = mat.rows();
359        assert_eq!(rows.len(), 3);
360        assert_eq!(rows[0], vec![1.0, 4.0, 7.0, 10.0]);
361        assert_eq!(rows[2], vec![3.0, 6.0, 9.0, 12.0]);
362    }
363
364    #[test]
365    fn test_as_slice() {
366        let mat = sample_3x4();
367        let expected = vec![
368            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
369        ];
370        assert_eq!(mat.as_slice(), expected.as_slice());
371    }
372
373    #[test]
374    fn test_into_vec() {
375        let mat = sample_3x4();
376        let v = mat.into_vec();
377        assert_eq!(v.len(), 12);
378        assert_eq!(v[0], 1.0);
379    }
380
381    #[test]
382    fn test_get_bounds_check() {
383        let mat = sample_3x4();
384        assert_eq!(mat.get(0, 0), Some(1.0));
385        assert_eq!(mat.get(2, 3), Some(12.0));
386        assert_eq!(mat.get(3, 0), None); // row out of bounds
387        assert_eq!(mat.get(0, 4), None); // col out of bounds
388    }
389
390    #[test]
391    fn test_set_bounds_check() {
392        let mut mat = sample_3x4();
393        assert!(mat.set(1, 1, 99.0));
394        assert_eq!(mat[(1, 1)], 99.0);
395        assert!(!mat.set(5, 0, 99.0)); // out of bounds
396    }
397
398    #[test]
399    fn test_nalgebra_roundtrip() {
400        let mat = sample_3x4();
401        let dmat = mat.to_dmatrix();
402        assert_eq!(dmat.nrows(), 3);
403        assert_eq!(dmat.ncols(), 4);
404        assert_eq!(dmat[(0, 0)], 1.0);
405        assert_eq!(dmat[(1, 2)], 8.0);
406
407        let back = FdMatrix::from_dmatrix(&dmat);
408        assert_eq!(mat, back);
409    }
410
411    #[test]
412    fn test_empty() {
413        let mat = FdMatrix::zeros(0, 0);
414        assert!(mat.is_empty());
415        assert_eq!(mat.len(), 0);
416    }
417
418    #[test]
419    fn test_single_element() {
420        let mat = FdMatrix::from_column_major(vec![42.0], 1, 1).unwrap();
421        assert_eq!(mat[(0, 0)], 42.0);
422        assert_eq!(mat.column(0), &[42.0]);
423        assert_eq!(mat.row(0), vec![42.0]);
424    }
425
426    #[test]
427    fn test_display() {
428        let mat = sample_3x4();
429        assert_eq!(format!("{}", mat), "FdMatrix(3x4)");
430    }
431
432    #[test]
433    fn test_clone() {
434        let mat = sample_3x4();
435        let cloned = mat.clone();
436        assert_eq!(mat, cloned);
437    }
438
439    #[test]
440    fn test_as_mut_slice() {
441        let mut mat = FdMatrix::zeros(2, 2);
442        let s = mat.as_mut_slice();
443        s[0] = 1.0;
444        s[1] = 2.0;
445        s[2] = 3.0;
446        s[3] = 4.0;
447        assert_eq!(mat[(0, 0)], 1.0);
448        assert_eq!(mat[(1, 0)], 2.0);
449        assert_eq!(mat[(0, 1)], 3.0);
450        assert_eq!(mat[(1, 1)], 4.0);
451    }
452
453    #[test]
454    fn test_column_major_layout_matches_manual() {
455        // Verify that FdMatrix[(i, j)] == data[i + j * n] for all i, j
456        let n = 5;
457        let m = 7;
458        let data: Vec<f64> = (0..n * m).map(|x| x as f64).collect();
459        let mat = FdMatrix::from_column_major(data.clone(), n, m).unwrap();
460
461        for j in 0..m {
462            for i in 0..n {
463                assert_eq!(mat[(i, j)], data[i + j * n]);
464            }
465        }
466    }
467}