Skip to main content

fdars_core/
matrix.rs

1//! Column-major matrix type for functional data analysis.
2//!
3//! [`FdMatrix`] provides safe, dimension-tracked access to the flat column-major
4//! data layout used throughout this crate. It eliminates manual `data[i + j * n]`
5//! index arithmetic and carries dimensions alongside the data.
6
7use nalgebra::DMatrix;
8
9/// Column-major matrix for functional data.
10///
11/// Stores data in a flat `Vec<f64>` with column-major (Fortran) layout:
12/// element `(row, col)` is at index `row + col * nrows`.
13///
14/// # Conventions
15///
16/// For functional data, rows typically represent observations and columns
17/// represent evaluation points. For 2D surfaces with `m1 x m2` grids,
18/// the surface is flattened into `m1 * m2` columns.
19///
20/// # Examples
21///
22/// ```
23/// use fdars_core::matrix::FdMatrix;
24///
25/// // 3 observations, 4 evaluation points
26/// let data = vec![
27///     1.0, 2.0, 3.0,  // column 0 (all obs at point 0)
28///     4.0, 5.0, 6.0,  // column 1
29///     7.0, 8.0, 9.0,  // column 2
30///     10.0, 11.0, 12.0, // column 3
31/// ];
32/// let mat = FdMatrix::from_column_major(data, 3, 4).unwrap();
33///
34/// assert_eq!(mat[(0, 0)], 1.0);  // obs 0 at point 0
35/// assert_eq!(mat[(1, 2)], 8.0);  // obs 1 at point 2
36/// assert_eq!(mat.column(0), &[1.0, 2.0, 3.0]);
37/// ```
38#[derive(Debug, Clone, PartialEq)]
39#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
40pub struct FdMatrix {
41    data: Vec<f64>,
42    nrows: usize,
43    ncols: usize,
44}
45
46impl FdMatrix {
47    /// Create from flat column-major data with dimension validation.
48    ///
49    /// Returns `Err` if `data.len() != nrows * ncols`.
50    pub fn from_column_major(
51        data: Vec<f64>,
52        nrows: usize,
53        ncols: usize,
54    ) -> Result<Self, crate::FdarError> {
55        if data.len() != nrows * ncols {
56            return Err(crate::FdarError::InvalidDimension {
57                parameter: "data",
58                expected: format!("{}", nrows * ncols),
59                actual: format!("{}", data.len()),
60            });
61        }
62        Ok(Self { data, nrows, ncols })
63    }
64
65    /// Create from a borrowed slice (copies the data).
66    ///
67    /// Returns `Err` if `data.len() != nrows * ncols`.
68    pub fn from_slice(data: &[f64], nrows: usize, ncols: usize) -> Result<Self, crate::FdarError> {
69        if data.len() != nrows * ncols {
70            return Err(crate::FdarError::InvalidDimension {
71                parameter: "data",
72                expected: format!("{}", nrows * ncols),
73                actual: format!("{}", data.len()),
74            });
75        }
76        Ok(Self {
77            data: data.to_vec(),
78            nrows,
79            ncols,
80        })
81    }
82
83    /// Create a zero-filled matrix.
84    pub fn zeros(nrows: usize, ncols: usize) -> Self {
85        Self {
86            data: vec![0.0; nrows * ncols],
87            nrows,
88            ncols,
89        }
90    }
91
92    /// Number of rows.
93    #[inline]
94    pub fn nrows(&self) -> usize {
95        self.nrows
96    }
97
98    /// Number of columns.
99    #[inline]
100    pub fn ncols(&self) -> usize {
101        self.ncols
102    }
103
104    /// Dimensions as `(nrows, ncols)`.
105    #[inline]
106    pub fn shape(&self) -> (usize, usize) {
107        (self.nrows, self.ncols)
108    }
109
110    /// Total number of elements.
111    #[inline]
112    pub fn len(&self) -> usize {
113        self.data.len()
114    }
115
116    /// Whether the matrix is empty.
117    #[inline]
118    pub fn is_empty(&self) -> bool {
119        self.data.is_empty()
120    }
121
122    /// Get a contiguous column slice (zero-copy).
123    ///
124    /// # Panics
125    /// Panics if `col >= ncols`.
126    #[inline]
127    pub fn column(&self, col: usize) -> &[f64] {
128        let start = col * self.nrows;
129        &self.data[start..start + self.nrows]
130    }
131
132    /// Get a mutable contiguous column slice (zero-copy).
133    ///
134    /// # Panics
135    /// Panics if `col >= ncols`.
136    #[inline]
137    pub fn column_mut(&mut self, col: usize) -> &mut [f64] {
138        let start = col * self.nrows;
139        &mut self.data[start..start + self.nrows]
140    }
141
142    /// Extract a single row as a new `Vec<f64>`.
143    ///
144    /// This is an O(ncols) operation because rows are not contiguous
145    /// in column-major layout.
146    pub fn row(&self, row: usize) -> Vec<f64> {
147        (0..self.ncols)
148            .map(|j| self.data[row + j * self.nrows])
149            .collect()
150    }
151
152    /// Copy a single row into a pre-allocated buffer (zero allocation).
153    ///
154    /// # Panics
155    /// Panics if `buf.len() < ncols` or `row >= nrows`.
156    #[inline]
157    pub fn row_to_buf(&self, row: usize, buf: &mut [f64]) {
158        debug_assert!(
159            row < self.nrows,
160            "row {row} out of bounds for {} rows",
161            self.nrows
162        );
163        debug_assert!(
164            buf.len() >= self.ncols,
165            "buffer len {} < ncols {}",
166            buf.len(),
167            self.ncols
168        );
169        let n = self.nrows;
170        for j in 0..self.ncols {
171            buf[j] = self.data[row + j * n];
172        }
173    }
174
175    /// Compute the dot product of two rows without materializing either one.
176    ///
177    /// The rows may come from different matrices (which must have the same `ncols`).
178    ///
179    /// # Panics
180    /// Panics (in debug) if `row_a >= self.nrows`, `row_b >= other.nrows`,
181    /// or `self.ncols != other.ncols`.
182    #[inline]
183    pub fn row_dot(&self, row_a: usize, other: &FdMatrix, row_b: usize) -> f64 {
184        debug_assert_eq!(self.ncols, other.ncols, "ncols mismatch in row_dot");
185        let na = self.nrows;
186        let nb = other.nrows;
187        let mut sum = 0.0;
188        for j in 0..self.ncols {
189            sum += self.data[row_a + j * na] * other.data[row_b + j * nb];
190        }
191        sum
192    }
193
194    /// Compute the squared L2 distance between two rows without allocation.
195    ///
196    /// Equivalent to `||self.row(row_a) - other.row(row_b)||^2` but without
197    /// materializing either row vector.
198    ///
199    /// # Panics
200    /// Panics (in debug) if `row_a >= self.nrows`, `row_b >= other.nrows`,
201    /// or `self.ncols != other.ncols`.
202    #[inline]
203    pub fn row_l2_sq(&self, row_a: usize, other: &FdMatrix, row_b: usize) -> f64 {
204        debug_assert_eq!(self.ncols, other.ncols, "ncols mismatch in row_l2_sq");
205        let na = self.nrows;
206        let nb = other.nrows;
207        let mut sum = 0.0;
208        for j in 0..self.ncols {
209            let d = self.data[row_a + j * na] - other.data[row_b + j * nb];
210            sum += d * d;
211        }
212        sum
213    }
214
215    /// Iterate over rows, yielding each row as a `Vec<f64>`.
216    ///
217    /// More efficient than [`to_row_major()`](Self::to_row_major) when only a
218    /// subset of rows are needed or when processing rows one at a time, because
219    /// it materializes only one row at a time instead of allocating the entire
220    /// transposed matrix up front.
221    ///
222    /// Because `FdMatrix` uses column-major storage, row elements are not
223    /// contiguous and a zero-copy row slice is not possible. Each yielded
224    /// `Vec<f64>` is an O(ncols) allocation.
225    ///
226    /// # Examples
227    ///
228    /// ```
229    /// use fdars_core::matrix::FdMatrix;
230    ///
231    /// let mat = FdMatrix::from_column_major(vec![
232    ///     1.0, 2.0,   // col 0
233    ///     3.0, 4.0,   // col 1
234    ///     5.0, 6.0,   // col 2
235    /// ], 2, 3).unwrap();
236    ///
237    /// let rows: Vec<Vec<f64>> = mat.iter_rows().collect();
238    /// assert_eq!(rows, vec![vec![1.0, 3.0, 5.0], vec![2.0, 4.0, 6.0]]);
239    /// ```
240    pub fn iter_rows(&self) -> impl Iterator<Item = Vec<f64>> + '_ {
241        (0..self.nrows).map(move |i| self.row(i))
242    }
243
244    /// Iterate over columns, yielding each column as a slice `&[f64]`.
245    ///
246    /// Zero-copy because `FdMatrix` uses column-major storage and each column
247    /// is a contiguous block in memory.
248    ///
249    /// # Examples
250    ///
251    /// ```
252    /// use fdars_core::matrix::FdMatrix;
253    ///
254    /// let mat = FdMatrix::from_column_major(vec![
255    ///     1.0, 2.0,   // col 0
256    ///     3.0, 4.0,   // col 1
257    ///     5.0, 6.0,   // col 2
258    /// ], 2, 3).unwrap();
259    ///
260    /// let cols: Vec<&[f64]> = mat.iter_columns().collect();
261    /// assert_eq!(cols, vec![&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0]]);
262    /// ```
263    pub fn iter_columns(&self) -> impl Iterator<Item = &[f64]> {
264        (0..self.ncols).map(move |j| self.column(j))
265    }
266
267    /// Extract all rows as `Vec<Vec<f64>>`.
268    ///
269    /// Equivalent to the former `extract_curves` function.
270    pub fn rows(&self) -> Vec<Vec<f64>> {
271        (0..self.nrows).map(|i| self.row(i)).collect()
272    }
273
274    /// Produce a single contiguous flat buffer in row-major order.
275    ///
276    /// Row `i` occupies `buf[i * ncols..(i + 1) * ncols]`. This is a single
277    /// allocation versus `nrows` allocations from `rows()`, and gives better
278    /// cache locality for row-oriented iteration.
279    pub fn to_row_major(&self) -> Vec<f64> {
280        let mut buf = vec![0.0; self.nrows * self.ncols];
281        for i in 0..self.nrows {
282            for j in 0..self.ncols {
283                buf[i * self.ncols + j] = self.data[i + j * self.nrows];
284            }
285        }
286        buf
287    }
288
289    /// Flat slice of the underlying column-major data (zero-copy).
290    #[inline]
291    pub fn as_slice(&self) -> &[f64] {
292        &self.data
293    }
294
295    /// Mutable flat slice of the underlying column-major data.
296    #[inline]
297    pub fn as_mut_slice(&mut self) -> &mut [f64] {
298        &mut self.data
299    }
300
301    /// Consume and return the underlying `Vec<f64>`.
302    pub fn into_vec(self) -> Vec<f64> {
303        self.data
304    }
305
306    /// Convert to a nalgebra `DMatrix<f64>`.
307    ///
308    /// This copies the data into nalgebra's storage. Both use column-major
309    /// layout, so the copy is a simple memcpy.
310    pub fn to_dmatrix(&self) -> DMatrix<f64> {
311        DMatrix::from_column_slice(self.nrows, self.ncols, &self.data)
312    }
313
314    /// Create from a nalgebra `DMatrix<f64>`.
315    ///
316    /// Both use column-major layout so this is a direct copy.
317    pub fn from_dmatrix(mat: &DMatrix<f64>) -> Self {
318        let (nrows, ncols) = mat.shape();
319        Self {
320            data: mat.as_slice().to_vec(),
321            nrows,
322            ncols,
323        }
324    }
325
326    /// Get element at (row, col) with bounds checking.
327    #[inline]
328    pub fn get(&self, row: usize, col: usize) -> Option<f64> {
329        if row < self.nrows && col < self.ncols {
330            Some(self.data[row + col * self.nrows])
331        } else {
332            None
333        }
334    }
335
336    /// Extract a submatrix by row and column indices.
337    ///
338    /// Returns a new `FdMatrix` containing only the specified rows and columns.
339    ///
340    /// # Examples
341    ///
342    /// ```
343    /// use fdars_core::matrix::FdMatrix;
344    ///
345    /// let mat = FdMatrix::from_column_major(
346    ///     vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
347    ///     3, 3,
348    /// ).unwrap();
349    ///
350    /// // Extract rows [0, 2] and columns [1, 2]
351    /// let sub = mat.submatrix(&[0, 2], &[1, 2]);
352    /// assert_eq!(sub.shape(), (2, 2));
353    /// assert_eq!(sub[(0, 0)], 4.0); // row 0, col 1 of original
354    /// assert_eq!(sub[(1, 1)], 9.0); // row 2, col 2 of original
355    /// ```
356    pub fn submatrix(&self, rows: &[usize], cols: &[usize]) -> Self {
357        let nr = rows.len();
358        let nc = cols.len();
359        let mut data = vec![0.0; nr * nc];
360        for (jj, &col) in cols.iter().enumerate() {
361            for (ii, &row) in rows.iter().enumerate() {
362                data[ii + jj * nr] = self.data[row + col * self.nrows];
363            }
364        }
365        Self {
366            data,
367            nrows: nr,
368            ncols: nc,
369        }
370    }
371
372    /// Extract a submatrix selecting only specific rows (all columns kept).
373    ///
374    /// Equivalent to `submatrix(rows, &(0..ncols).collect::<Vec<_>>())` but
375    /// more efficient.
376    pub fn select_rows(&self, rows: &[usize]) -> Self {
377        let nr = rows.len();
378        let nc = self.ncols;
379        let mut data = vec![0.0; nr * nc];
380        for j in 0..nc {
381            for (ii, &row) in rows.iter().enumerate() {
382                data[ii + j * nr] = self.data[row + j * self.nrows];
383            }
384        }
385        Self {
386            data,
387            nrows: nr,
388            ncols: nc,
389        }
390    }
391
392    /// Extract a submatrix selecting only specific columns (all rows kept).
393    ///
394    /// Efficient for column-major layout — each column is a contiguous copy.
395    pub fn select_columns(&self, cols: &[usize]) -> Self {
396        let nr = self.nrows;
397        let nc = cols.len();
398        let mut data = vec![0.0; nr * nc];
399        for (jj, &col) in cols.iter().enumerate() {
400            let src = &self.data[col * nr..(col + 1) * nr];
401            data[jj * nr..(jj + 1) * nr].copy_from_slice(src);
402        }
403        Self {
404            data,
405            nrows: nr,
406            ncols: nc,
407        }
408    }
409
410    /// Set element at (row, col) with bounds checking.
411    #[inline]
412    pub fn set(&mut self, row: usize, col: usize, value: f64) -> bool {
413        if row < self.nrows && col < self.ncols {
414            self.data[row + col * self.nrows] = value;
415            true
416        } else {
417            false
418        }
419    }
420}
421
422impl std::ops::Index<(usize, usize)> for FdMatrix {
423    type Output = f64;
424
425    #[inline]
426    fn index(&self, (row, col): (usize, usize)) -> &f64 {
427        debug_assert!(
428            row < self.nrows && col < self.ncols,
429            "FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
430            row,
431            col,
432            self.nrows,
433            self.ncols
434        );
435        &self.data[row + col * self.nrows]
436    }
437}
438
439impl std::ops::IndexMut<(usize, usize)> for FdMatrix {
440    #[inline]
441    fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut f64 {
442        debug_assert!(
443            row < self.nrows && col < self.ncols,
444            "FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
445            row,
446            col,
447            self.nrows,
448            self.ncols
449        );
450        &mut self.data[row + col * self.nrows]
451    }
452}
453
454impl std::fmt::Display for FdMatrix {
455    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456        write!(f, "FdMatrix({}x{})", self.nrows, self.ncols)
457    }
458}
459
460/// A set of multidimensional functional curves in R^d.
461///
462/// Each dimension is stored as a separate [`FdMatrix`] (n curves × m points).
463/// For d=1 this is equivalent to a single `FdMatrix`.
464#[derive(Debug, Clone, PartialEq)]
465pub struct FdCurveSet {
466    /// One matrix per coordinate dimension, each n × m.
467    pub dims: Vec<FdMatrix>,
468}
469
470impl FdCurveSet {
471    /// Number of coordinate dimensions (d).
472    pub fn ndim(&self) -> usize {
473        self.dims.len()
474    }
475
476    /// Number of curves (n).
477    pub fn ncurves(&self) -> usize {
478        if self.dims.is_empty() {
479            0
480        } else {
481            self.dims[0].nrows()
482        }
483    }
484
485    /// Number of evaluation points (m).
486    pub fn npoints(&self) -> usize {
487        if self.dims.is_empty() {
488            0
489        } else {
490            self.dims[0].ncols()
491        }
492    }
493
494    /// Wrap a single 1D `FdMatrix` as a `FdCurveSet`.
495    pub fn from_1d(data: FdMatrix) -> Self {
496        Self { dims: vec![data] }
497    }
498
499    /// Build from multiple dimension matrices.
500    ///
501    /// Returns `Err` if `dims` is empty or if dimensions are inconsistent.
502    pub fn from_dims(dims: Vec<FdMatrix>) -> Result<Self, crate::FdarError> {
503        if dims.is_empty() {
504            return Err(crate::FdarError::InvalidDimension {
505                parameter: "dims",
506                expected: "non-empty".to_string(),
507                actual: "empty".to_string(),
508            });
509        }
510        let (n, m) = dims[0].shape();
511        if dims.iter().any(|d| d.shape() != (n, m)) {
512            return Err(crate::FdarError::InvalidDimension {
513                parameter: "dims",
514                expected: format!("all ({n}, {m})"),
515                actual: "inconsistent shapes".to_string(),
516            });
517        }
518        Ok(Self { dims })
519    }
520
521    /// Extract the R^d point for a given curve and time index.
522    pub fn point(&self, curve: usize, time_idx: usize) -> Vec<f64> {
523        self.dims.iter().map(|d| d[(curve, time_idx)]).collect()
524    }
525}
526
527impl std::fmt::Display for FdCurveSet {
528    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
529        write!(
530            f,
531            "FdCurveSet(d={}, n={}, m={})",
532            self.ndim(),
533            self.ncurves(),
534            self.npoints()
535        )
536    }
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542
543    fn sample_3x4() -> FdMatrix {
544        // 3 rows, 4 columns, column-major
545        let data = vec![
546            1.0, 2.0, 3.0, // col 0
547            4.0, 5.0, 6.0, // col 1
548            7.0, 8.0, 9.0, // col 2
549            10.0, 11.0, 12.0, // col 3
550        ];
551        FdMatrix::from_column_major(data, 3, 4).unwrap()
552    }
553
554    #[test]
555    fn test_from_column_major_valid() {
556        let mat = sample_3x4();
557        assert_eq!(mat.nrows(), 3);
558        assert_eq!(mat.ncols(), 4);
559        assert_eq!(mat.shape(), (3, 4));
560        assert_eq!(mat.len(), 12);
561        assert!(!mat.is_empty());
562    }
563
564    #[test]
565    fn test_from_column_major_invalid() {
566        assert!(FdMatrix::from_column_major(vec![1.0, 2.0], 3, 4).is_err());
567    }
568
569    #[test]
570    fn test_from_slice() {
571        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
572        let mat = FdMatrix::from_slice(&data, 2, 3).unwrap();
573        assert_eq!(mat[(0, 0)], 1.0);
574        assert_eq!(mat[(1, 0)], 2.0);
575        assert_eq!(mat[(0, 1)], 3.0);
576    }
577
578    #[test]
579    fn test_from_slice_invalid() {
580        assert!(FdMatrix::from_slice(&[1.0, 2.0], 3, 3).is_err());
581    }
582
583    #[test]
584    fn test_zeros() {
585        let mat = FdMatrix::zeros(2, 3);
586        assert_eq!(mat.nrows(), 2);
587        assert_eq!(mat.ncols(), 3);
588        for j in 0..3 {
589            for i in 0..2 {
590                assert_eq!(mat[(i, j)], 0.0);
591            }
592        }
593    }
594
595    #[test]
596    fn test_index() {
597        let mat = sample_3x4();
598        assert_eq!(mat[(0, 0)], 1.0);
599        assert_eq!(mat[(1, 0)], 2.0);
600        assert_eq!(mat[(2, 0)], 3.0);
601        assert_eq!(mat[(0, 1)], 4.0);
602        assert_eq!(mat[(1, 1)], 5.0);
603        assert_eq!(mat[(2, 3)], 12.0);
604    }
605
606    #[test]
607    fn test_index_mut() {
608        let mut mat = sample_3x4();
609        mat[(1, 2)] = 99.0;
610        assert_eq!(mat[(1, 2)], 99.0);
611    }
612
613    #[test]
614    fn test_column() {
615        let mat = sample_3x4();
616        assert_eq!(mat.column(0), &[1.0, 2.0, 3.0]);
617        assert_eq!(mat.column(1), &[4.0, 5.0, 6.0]);
618        assert_eq!(mat.column(3), &[10.0, 11.0, 12.0]);
619    }
620
621    #[test]
622    fn test_column_mut() {
623        let mut mat = sample_3x4();
624        mat.column_mut(1)[0] = 99.0;
625        assert_eq!(mat[(0, 1)], 99.0);
626    }
627
628    #[test]
629    fn test_row() {
630        let mat = sample_3x4();
631        assert_eq!(mat.row(0), vec![1.0, 4.0, 7.0, 10.0]);
632        assert_eq!(mat.row(1), vec![2.0, 5.0, 8.0, 11.0]);
633        assert_eq!(mat.row(2), vec![3.0, 6.0, 9.0, 12.0]);
634    }
635
636    #[test]
637    fn test_rows() {
638        let mat = sample_3x4();
639        let rows = mat.rows();
640        assert_eq!(rows.len(), 3);
641        assert_eq!(rows[0], vec![1.0, 4.0, 7.0, 10.0]);
642        assert_eq!(rows[2], vec![3.0, 6.0, 9.0, 12.0]);
643    }
644
645    #[test]
646    fn test_as_slice() {
647        let mat = sample_3x4();
648        let expected = vec![
649            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
650        ];
651        assert_eq!(mat.as_slice(), expected.as_slice());
652    }
653
654    #[test]
655    fn test_into_vec() {
656        let mat = sample_3x4();
657        let v = mat.into_vec();
658        assert_eq!(v.len(), 12);
659        assert_eq!(v[0], 1.0);
660    }
661
662    #[test]
663    fn test_get_bounds_check() {
664        let mat = sample_3x4();
665        assert_eq!(mat.get(0, 0), Some(1.0));
666        assert_eq!(mat.get(2, 3), Some(12.0));
667        assert_eq!(mat.get(3, 0), None); // row out of bounds
668        assert_eq!(mat.get(0, 4), None); // col out of bounds
669    }
670
671    #[test]
672    fn test_set_bounds_check() {
673        let mut mat = sample_3x4();
674        assert!(mat.set(1, 1, 99.0));
675        assert_eq!(mat[(1, 1)], 99.0);
676        assert!(!mat.set(5, 0, 99.0)); // out of bounds
677    }
678
679    #[test]
680    fn test_nalgebra_roundtrip() {
681        let mat = sample_3x4();
682        let dmat = mat.to_dmatrix();
683        assert_eq!(dmat.nrows(), 3);
684        assert_eq!(dmat.ncols(), 4);
685        assert_eq!(dmat[(0, 0)], 1.0);
686        assert_eq!(dmat[(1, 2)], 8.0);
687
688        let back = FdMatrix::from_dmatrix(&dmat);
689        assert_eq!(mat, back);
690    }
691
692    #[test]
693    fn test_empty() {
694        let mat = FdMatrix::zeros(0, 0);
695        assert!(mat.is_empty());
696        assert_eq!(mat.len(), 0);
697    }
698
699    #[test]
700    fn test_single_element() {
701        let mat = FdMatrix::from_column_major(vec![42.0], 1, 1).unwrap();
702        assert_eq!(mat[(0, 0)], 42.0);
703        assert_eq!(mat.column(0), &[42.0]);
704        assert_eq!(mat.row(0), vec![42.0]);
705    }
706
707    #[test]
708    fn test_display() {
709        let mat = sample_3x4();
710        assert_eq!(format!("{}", mat), "FdMatrix(3x4)");
711    }
712
713    #[test]
714    fn test_clone() {
715        let mat = sample_3x4();
716        let cloned = mat.clone();
717        assert_eq!(mat, cloned);
718    }
719
720    #[test]
721    fn test_as_mut_slice() {
722        let mut mat = FdMatrix::zeros(2, 2);
723        let s = mat.as_mut_slice();
724        s[0] = 1.0;
725        s[1] = 2.0;
726        s[2] = 3.0;
727        s[3] = 4.0;
728        assert_eq!(mat[(0, 0)], 1.0);
729        assert_eq!(mat[(1, 0)], 2.0);
730        assert_eq!(mat[(0, 1)], 3.0);
731        assert_eq!(mat[(1, 1)], 4.0);
732    }
733
734    #[test]
735    fn test_fd_curve_set_empty() {
736        assert!(FdCurveSet::from_dims(vec![]).is_err());
737        let cs = FdCurveSet::from_dims(vec![]).unwrap_or(FdCurveSet { dims: vec![] });
738        assert_eq!(cs.ndim(), 0);
739        assert_eq!(cs.ncurves(), 0);
740        assert_eq!(cs.npoints(), 0);
741        assert_eq!(format!("{}", cs), "FdCurveSet(d=0, n=0, m=0)");
742    }
743
744    #[test]
745    fn test_fd_curve_set_from_1d() {
746        let mat = sample_3x4();
747        let cs = FdCurveSet::from_1d(mat.clone());
748        assert_eq!(cs.ndim(), 1);
749        assert_eq!(cs.ncurves(), 3);
750        assert_eq!(cs.npoints(), 4);
751        assert_eq!(cs.point(0, 0), vec![1.0]);
752        assert_eq!(cs.point(1, 2), vec![8.0]);
753    }
754
755    #[test]
756    fn test_fd_curve_set_from_dims_consistent() {
757        let m1 = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
758        let m2 = FdMatrix::from_column_major(vec![5.0, 6.0, 7.0, 8.0], 2, 2).unwrap();
759        let cs = FdCurveSet::from_dims(vec![m1, m2]).unwrap();
760        assert_eq!(cs.ndim(), 2);
761        assert_eq!(cs.point(0, 0), vec![1.0, 5.0]);
762        assert_eq!(cs.point(1, 1), vec![4.0, 8.0]);
763        assert_eq!(format!("{}", cs), "FdCurveSet(d=2, n=2, m=2)");
764    }
765
766    #[test]
767    fn test_fd_curve_set_from_dims_inconsistent() {
768        let m1 = FdMatrix::from_column_major(vec![1.0, 2.0], 2, 1).unwrap();
769        let m2 = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
770        assert!(FdCurveSet::from_dims(vec![m1, m2]).is_err());
771    }
772
773    #[test]
774    fn test_to_row_major() {
775        let mat = sample_3x4();
776        let rm = mat.to_row_major();
777        // Row 0: [1,4,7,10], Row 1: [2,5,8,11], Row 2: [3,6,9,12]
778        assert_eq!(
779            rm,
780            vec![1.0, 4.0, 7.0, 10.0, 2.0, 5.0, 8.0, 11.0, 3.0, 6.0, 9.0, 12.0]
781        );
782    }
783
784    #[test]
785    fn test_row_to_buf() {
786        let mat = sample_3x4();
787        let mut buf = vec![0.0; 4];
788        mat.row_to_buf(0, &mut buf);
789        assert_eq!(buf, vec![1.0, 4.0, 7.0, 10.0]);
790        mat.row_to_buf(1, &mut buf);
791        assert_eq!(buf, vec![2.0, 5.0, 8.0, 11.0]);
792        mat.row_to_buf(2, &mut buf);
793        assert_eq!(buf, vec![3.0, 6.0, 9.0, 12.0]);
794    }
795
796    #[test]
797    fn test_row_to_buf_larger_buffer() {
798        let mat = sample_3x4();
799        let mut buf = vec![99.0; 6]; // bigger than ncols
800        mat.row_to_buf(0, &mut buf);
801        assert_eq!(&buf[..4], &[1.0, 4.0, 7.0, 10.0]);
802        // Remaining elements unchanged
803        assert_eq!(buf[4], 99.0);
804    }
805
806    #[test]
807    fn test_row_dot_same_matrix() {
808        let mat = sample_3x4();
809        // row0 = [1, 4, 7, 10], row1 = [2, 5, 8, 11]
810        // dot = 1*2 + 4*5 + 7*8 + 10*11 = 2 + 20 + 56 + 110 = 188
811        assert_eq!(mat.row_dot(0, &mat, 1), 188.0);
812        // self dot: row0 . row0 = 1+16+49+100 = 166
813        assert_eq!(mat.row_dot(0, &mat, 0), 166.0);
814    }
815
816    #[test]
817    fn test_row_dot_different_matrices() {
818        let mat1 = sample_3x4();
819        let data2 = vec![
820            10.0, 20.0, 30.0, // col 0
821            40.0, 50.0, 60.0, // col 1
822            70.0, 80.0, 90.0, // col 2
823            100.0, 110.0, 120.0, // col 3
824        ];
825        let mat2 = FdMatrix::from_column_major(data2, 3, 4).unwrap();
826        // mat1 row0 = [1, 4, 7, 10], mat2 row0 = [10, 40, 70, 100]
827        // dot = 10 + 160 + 490 + 1000 = 1660
828        assert_eq!(mat1.row_dot(0, &mat2, 0), 1660.0);
829    }
830
831    #[test]
832    fn test_row_l2_sq_identical() {
833        let mat = sample_3x4();
834        assert_eq!(mat.row_l2_sq(0, &mat, 0), 0.0);
835        assert_eq!(mat.row_l2_sq(1, &mat, 1), 0.0);
836    }
837
838    #[test]
839    fn test_row_l2_sq_different() {
840        let mat = sample_3x4();
841        // row0 = [1,4,7,10], row1 = [2,5,8,11]
842        // diff = [-1,-1,-1,-1], sq sum = 4
843        assert_eq!(mat.row_l2_sq(0, &mat, 1), 4.0);
844    }
845
846    #[test]
847    fn test_row_l2_sq_cross_matrix() {
848        let mat1 = FdMatrix::from_column_major(vec![0.0, 0.0], 1, 2).unwrap();
849        let mat2 = FdMatrix::from_column_major(vec![3.0, 4.0], 1, 2).unwrap();
850        // row0 = [0, 0], row0 = [3, 4], sq dist = 9 + 16 = 25
851        assert_eq!(mat1.row_l2_sq(0, &mat2, 0), 25.0);
852    }
853
854    #[test]
855    fn test_column_major_layout_matches_manual() {
856        // Verify that FdMatrix[(i, j)] == data[i + j * n] for all i, j
857        let n = 5;
858        let m = 7;
859        let data: Vec<f64> = (0..n * m).map(|x| x as f64).collect();
860        let mat = FdMatrix::from_column_major(data.clone(), n, m).unwrap();
861
862        for j in 0..m {
863            for i in 0..n {
864                assert_eq!(mat[(i, j)], data[i + j * n]);
865            }
866        }
867    }
868
869    #[test]
870    fn test_iter_rows() {
871        let mat = sample_3x4();
872        let rows: Vec<Vec<f64>> = mat.iter_rows().collect();
873        assert_eq!(rows.len(), 3);
874        assert_eq!(rows[0], vec![1.0, 4.0, 7.0, 10.0]);
875        assert_eq!(rows[1], vec![2.0, 5.0, 8.0, 11.0]);
876        assert_eq!(rows[2], vec![3.0, 6.0, 9.0, 12.0]);
877    }
878
879    #[test]
880    fn test_iter_rows_matches_rows() {
881        let mat = sample_3x4();
882        let from_iter: Vec<Vec<f64>> = mat.iter_rows().collect();
883        let from_rows = mat.rows();
884        assert_eq!(from_iter, from_rows);
885    }
886
887    #[test]
888    fn test_iter_rows_partial() {
889        // Verify that taking only a subset avoids full materialization
890        let mat = sample_3x4();
891        let first_two: Vec<Vec<f64>> = mat.iter_rows().take(2).collect();
892        assert_eq!(first_two.len(), 2);
893        assert_eq!(first_two[0], vec![1.0, 4.0, 7.0, 10.0]);
894        assert_eq!(first_two[1], vec![2.0, 5.0, 8.0, 11.0]);
895    }
896
897    #[test]
898    fn test_iter_rows_empty() {
899        let mat = FdMatrix::zeros(0, 0);
900        let rows: Vec<Vec<f64>> = mat.iter_rows().collect();
901        assert!(rows.is_empty());
902    }
903
904    #[test]
905    fn test_iter_rows_single_row() {
906        let mat = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0], 1, 3).unwrap();
907        let rows: Vec<Vec<f64>> = mat.iter_rows().collect();
908        assert_eq!(rows, vec![vec![1.0, 2.0, 3.0]]);
909    }
910
911    #[test]
912    fn test_iter_rows_single_column() {
913        let mat = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
914        let rows: Vec<Vec<f64>> = mat.iter_rows().collect();
915        assert_eq!(rows, vec![vec![1.0], vec![2.0], vec![3.0]]);
916    }
917
918    #[test]
919    fn test_iter_columns() {
920        let mat = sample_3x4();
921        let cols: Vec<&[f64]> = mat.iter_columns().collect();
922        assert_eq!(cols.len(), 4);
923        assert_eq!(cols[0], &[1.0, 2.0, 3.0]);
924        assert_eq!(cols[1], &[4.0, 5.0, 6.0]);
925        assert_eq!(cols[2], &[7.0, 8.0, 9.0]);
926        assert_eq!(cols[3], &[10.0, 11.0, 12.0]);
927    }
928
929    #[test]
930    fn test_iter_columns_partial() {
931        let mat = sample_3x4();
932        let first_two: Vec<&[f64]> = mat.iter_columns().take(2).collect();
933        assert_eq!(first_two.len(), 2);
934        assert_eq!(first_two[0], &[1.0, 2.0, 3.0]);
935        assert_eq!(first_two[1], &[4.0, 5.0, 6.0]);
936    }
937
938    #[test]
939    fn test_iter_columns_empty() {
940        let mat = FdMatrix::zeros(0, 0);
941        let cols: Vec<&[f64]> = mat.iter_columns().collect();
942        assert!(cols.is_empty());
943    }
944
945    #[test]
946    fn test_iter_columns_single_column() {
947        let mat = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
948        let cols: Vec<&[f64]> = mat.iter_columns().collect();
949        assert_eq!(cols, vec![&[1.0, 2.0, 3.0]]);
950    }
951
952    #[test]
953    fn test_iter_columns_single_row() {
954        let mat = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0], 1, 3).unwrap();
955        let cols: Vec<&[f64]> = mat.iter_columns().collect();
956        assert_eq!(cols, vec![&[1.0_f64] as &[f64], &[2.0], &[3.0]]);
957    }
958
959    #[test]
960    fn test_iter_rows_enumerate() {
961        let mat = sample_3x4();
962        for (i, row) in mat.iter_rows().enumerate() {
963            assert_eq!(row, mat.row(i));
964        }
965    }
966
967    #[test]
968    fn test_iter_columns_enumerate() {
969        let mat = sample_3x4();
970        for (j, col) in mat.iter_columns().enumerate() {
971            assert_eq!(col, mat.column(j));
972        }
973    }
974}