1use nalgebra::DMatrix;
8
9#[derive(Debug, Clone, PartialEq)]
39pub struct FdMatrix {
40 data: Vec<f64>,
41 nrows: usize,
42 ncols: usize,
43}
44
45impl FdMatrix {
46 pub fn from_column_major(
50 data: Vec<f64>,
51 nrows: usize,
52 ncols: usize,
53 ) -> Result<Self, crate::FdarError> {
54 if data.len() != nrows * ncols {
55 return Err(crate::FdarError::InvalidDimension {
56 parameter: "data",
57 expected: format!("{}", nrows * ncols),
58 actual: format!("{}", data.len()),
59 });
60 }
61 Ok(Self { data, nrows, ncols })
62 }
63
64 pub fn from_slice(data: &[f64], nrows: usize, ncols: usize) -> Result<Self, crate::FdarError> {
68 if data.len() != nrows * ncols {
69 return Err(crate::FdarError::InvalidDimension {
70 parameter: "data",
71 expected: format!("{}", nrows * ncols),
72 actual: format!("{}", data.len()),
73 });
74 }
75 Ok(Self {
76 data: data.to_vec(),
77 nrows,
78 ncols,
79 })
80 }
81
82 pub fn zeros(nrows: usize, ncols: usize) -> Self {
84 Self {
85 data: vec![0.0; nrows * ncols],
86 nrows,
87 ncols,
88 }
89 }
90
91 #[inline]
93 pub fn nrows(&self) -> usize {
94 self.nrows
95 }
96
97 #[inline]
99 pub fn ncols(&self) -> usize {
100 self.ncols
101 }
102
103 #[inline]
105 pub fn shape(&self) -> (usize, usize) {
106 (self.nrows, self.ncols)
107 }
108
109 #[inline]
111 pub fn len(&self) -> usize {
112 self.data.len()
113 }
114
115 #[inline]
117 pub fn is_empty(&self) -> bool {
118 self.data.is_empty()
119 }
120
121 #[inline]
126 pub fn column(&self, col: usize) -> &[f64] {
127 let start = col * self.nrows;
128 &self.data[start..start + self.nrows]
129 }
130
131 #[inline]
136 pub fn column_mut(&mut self, col: usize) -> &mut [f64] {
137 let start = col * self.nrows;
138 &mut self.data[start..start + self.nrows]
139 }
140
141 pub fn row(&self, row: usize) -> Vec<f64> {
146 (0..self.ncols)
147 .map(|j| self.data[row + j * self.nrows])
148 .collect()
149 }
150
151 #[inline]
156 pub fn row_to_buf(&self, row: usize, buf: &mut [f64]) {
157 debug_assert!(
158 row < self.nrows,
159 "row {row} out of bounds for {} rows",
160 self.nrows
161 );
162 debug_assert!(
163 buf.len() >= self.ncols,
164 "buffer len {} < ncols {}",
165 buf.len(),
166 self.ncols
167 );
168 let n = self.nrows;
169 for j in 0..self.ncols {
170 buf[j] = self.data[row + j * n];
171 }
172 }
173
174 #[inline]
182 pub fn row_dot(&self, row_a: usize, other: &FdMatrix, row_b: usize) -> f64 {
183 debug_assert_eq!(self.ncols, other.ncols, "ncols mismatch in row_dot");
184 let na = self.nrows;
185 let nb = other.nrows;
186 let mut sum = 0.0;
187 for j in 0..self.ncols {
188 sum += self.data[row_a + j * na] * other.data[row_b + j * nb];
189 }
190 sum
191 }
192
193 #[inline]
202 pub fn row_l2_sq(&self, row_a: usize, other: &FdMatrix, row_b: usize) -> f64 {
203 debug_assert_eq!(self.ncols, other.ncols, "ncols mismatch in row_l2_sq");
204 let na = self.nrows;
205 let nb = other.nrows;
206 let mut sum = 0.0;
207 for j in 0..self.ncols {
208 let d = self.data[row_a + j * na] - other.data[row_b + j * nb];
209 sum += d * d;
210 }
211 sum
212 }
213
214 pub fn rows(&self) -> Vec<Vec<f64>> {
218 (0..self.nrows).map(|i| self.row(i)).collect()
219 }
220
221 pub fn to_row_major(&self) -> Vec<f64> {
227 let mut buf = vec![0.0; self.nrows * self.ncols];
228 for i in 0..self.nrows {
229 for j in 0..self.ncols {
230 buf[i * self.ncols + j] = self.data[i + j * self.nrows];
231 }
232 }
233 buf
234 }
235
236 #[inline]
238 pub fn as_slice(&self) -> &[f64] {
239 &self.data
240 }
241
242 #[inline]
244 pub fn as_mut_slice(&mut self) -> &mut [f64] {
245 &mut self.data
246 }
247
248 pub fn into_vec(self) -> Vec<f64> {
250 self.data
251 }
252
253 pub fn to_dmatrix(&self) -> DMatrix<f64> {
258 DMatrix::from_column_slice(self.nrows, self.ncols, &self.data)
259 }
260
261 pub fn from_dmatrix(mat: &DMatrix<f64>) -> Self {
265 let (nrows, ncols) = mat.shape();
266 Self {
267 data: mat.as_slice().to_vec(),
268 nrows,
269 ncols,
270 }
271 }
272
273 #[inline]
275 pub fn get(&self, row: usize, col: usize) -> Option<f64> {
276 if row < self.nrows && col < self.ncols {
277 Some(self.data[row + col * self.nrows])
278 } else {
279 None
280 }
281 }
282
283 #[inline]
285 pub fn set(&mut self, row: usize, col: usize, value: f64) -> bool {
286 if row < self.nrows && col < self.ncols {
287 self.data[row + col * self.nrows] = value;
288 true
289 } else {
290 false
291 }
292 }
293}
294
295impl std::ops::Index<(usize, usize)> for FdMatrix {
296 type Output = f64;
297
298 #[inline]
299 fn index(&self, (row, col): (usize, usize)) -> &f64 {
300 debug_assert!(
301 row < self.nrows && col < self.ncols,
302 "FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
303 row,
304 col,
305 self.nrows,
306 self.ncols
307 );
308 &self.data[row + col * self.nrows]
309 }
310}
311
312impl std::ops::IndexMut<(usize, usize)> for FdMatrix {
313 #[inline]
314 fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut f64 {
315 debug_assert!(
316 row < self.nrows && col < self.ncols,
317 "FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
318 row,
319 col,
320 self.nrows,
321 self.ncols
322 );
323 &mut self.data[row + col * self.nrows]
324 }
325}
326
327impl std::fmt::Display for FdMatrix {
328 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329 write!(f, "FdMatrix({}x{})", self.nrows, self.ncols)
330 }
331}
332
333#[derive(Debug, Clone, PartialEq)]
338pub struct FdCurveSet {
339 pub dims: Vec<FdMatrix>,
341}
342
343impl FdCurveSet {
344 pub fn ndim(&self) -> usize {
346 self.dims.len()
347 }
348
349 pub fn ncurves(&self) -> usize {
351 if self.dims.is_empty() {
352 0
353 } else {
354 self.dims[0].nrows()
355 }
356 }
357
358 pub fn npoints(&self) -> usize {
360 if self.dims.is_empty() {
361 0
362 } else {
363 self.dims[0].ncols()
364 }
365 }
366
367 pub fn from_1d(data: FdMatrix) -> Self {
369 Self { dims: vec![data] }
370 }
371
372 pub fn from_dims(dims: Vec<FdMatrix>) -> Result<Self, crate::FdarError> {
376 if dims.is_empty() {
377 return Err(crate::FdarError::InvalidDimension {
378 parameter: "dims",
379 expected: "non-empty".to_string(),
380 actual: "empty".to_string(),
381 });
382 }
383 let (n, m) = dims[0].shape();
384 if dims.iter().any(|d| d.shape() != (n, m)) {
385 return Err(crate::FdarError::InvalidDimension {
386 parameter: "dims",
387 expected: format!("all ({n}, {m})"),
388 actual: "inconsistent shapes".to_string(),
389 });
390 }
391 Ok(Self { dims })
392 }
393
394 pub fn point(&self, curve: usize, time_idx: usize) -> Vec<f64> {
396 self.dims.iter().map(|d| d[(curve, time_idx)]).collect()
397 }
398}
399
400impl std::fmt::Display for FdCurveSet {
401 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402 write!(
403 f,
404 "FdCurveSet(d={}, n={}, m={})",
405 self.ndim(),
406 self.ncurves(),
407 self.npoints()
408 )
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 fn sample_3x4() -> FdMatrix {
417 let data = vec![
419 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ];
424 FdMatrix::from_column_major(data, 3, 4).unwrap()
425 }
426
427 #[test]
428 fn test_from_column_major_valid() {
429 let mat = sample_3x4();
430 assert_eq!(mat.nrows(), 3);
431 assert_eq!(mat.ncols(), 4);
432 assert_eq!(mat.shape(), (3, 4));
433 assert_eq!(mat.len(), 12);
434 assert!(!mat.is_empty());
435 }
436
437 #[test]
438 fn test_from_column_major_invalid() {
439 assert!(FdMatrix::from_column_major(vec![1.0, 2.0], 3, 4).is_err());
440 }
441
442 #[test]
443 fn test_from_slice() {
444 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
445 let mat = FdMatrix::from_slice(&data, 2, 3).unwrap();
446 assert_eq!(mat[(0, 0)], 1.0);
447 assert_eq!(mat[(1, 0)], 2.0);
448 assert_eq!(mat[(0, 1)], 3.0);
449 }
450
451 #[test]
452 fn test_from_slice_invalid() {
453 assert!(FdMatrix::from_slice(&[1.0, 2.0], 3, 3).is_err());
454 }
455
456 #[test]
457 fn test_zeros() {
458 let mat = FdMatrix::zeros(2, 3);
459 assert_eq!(mat.nrows(), 2);
460 assert_eq!(mat.ncols(), 3);
461 for j in 0..3 {
462 for i in 0..2 {
463 assert_eq!(mat[(i, j)], 0.0);
464 }
465 }
466 }
467
468 #[test]
469 fn test_index() {
470 let mat = sample_3x4();
471 assert_eq!(mat[(0, 0)], 1.0);
472 assert_eq!(mat[(1, 0)], 2.0);
473 assert_eq!(mat[(2, 0)], 3.0);
474 assert_eq!(mat[(0, 1)], 4.0);
475 assert_eq!(mat[(1, 1)], 5.0);
476 assert_eq!(mat[(2, 3)], 12.0);
477 }
478
479 #[test]
480 fn test_index_mut() {
481 let mut mat = sample_3x4();
482 mat[(1, 2)] = 99.0;
483 assert_eq!(mat[(1, 2)], 99.0);
484 }
485
486 #[test]
487 fn test_column() {
488 let mat = sample_3x4();
489 assert_eq!(mat.column(0), &[1.0, 2.0, 3.0]);
490 assert_eq!(mat.column(1), &[4.0, 5.0, 6.0]);
491 assert_eq!(mat.column(3), &[10.0, 11.0, 12.0]);
492 }
493
494 #[test]
495 fn test_column_mut() {
496 let mut mat = sample_3x4();
497 mat.column_mut(1)[0] = 99.0;
498 assert_eq!(mat[(0, 1)], 99.0);
499 }
500
501 #[test]
502 fn test_row() {
503 let mat = sample_3x4();
504 assert_eq!(mat.row(0), vec![1.0, 4.0, 7.0, 10.0]);
505 assert_eq!(mat.row(1), vec![2.0, 5.0, 8.0, 11.0]);
506 assert_eq!(mat.row(2), vec![3.0, 6.0, 9.0, 12.0]);
507 }
508
509 #[test]
510 fn test_rows() {
511 let mat = sample_3x4();
512 let rows = mat.rows();
513 assert_eq!(rows.len(), 3);
514 assert_eq!(rows[0], vec![1.0, 4.0, 7.0, 10.0]);
515 assert_eq!(rows[2], vec![3.0, 6.0, 9.0, 12.0]);
516 }
517
518 #[test]
519 fn test_as_slice() {
520 let mat = sample_3x4();
521 let expected = vec![
522 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
523 ];
524 assert_eq!(mat.as_slice(), expected.as_slice());
525 }
526
527 #[test]
528 fn test_into_vec() {
529 let mat = sample_3x4();
530 let v = mat.into_vec();
531 assert_eq!(v.len(), 12);
532 assert_eq!(v[0], 1.0);
533 }
534
535 #[test]
536 fn test_get_bounds_check() {
537 let mat = sample_3x4();
538 assert_eq!(mat.get(0, 0), Some(1.0));
539 assert_eq!(mat.get(2, 3), Some(12.0));
540 assert_eq!(mat.get(3, 0), None); assert_eq!(mat.get(0, 4), None); }
543
544 #[test]
545 fn test_set_bounds_check() {
546 let mut mat = sample_3x4();
547 assert!(mat.set(1, 1, 99.0));
548 assert_eq!(mat[(1, 1)], 99.0);
549 assert!(!mat.set(5, 0, 99.0)); }
551
552 #[test]
553 fn test_nalgebra_roundtrip() {
554 let mat = sample_3x4();
555 let dmat = mat.to_dmatrix();
556 assert_eq!(dmat.nrows(), 3);
557 assert_eq!(dmat.ncols(), 4);
558 assert_eq!(dmat[(0, 0)], 1.0);
559 assert_eq!(dmat[(1, 2)], 8.0);
560
561 let back = FdMatrix::from_dmatrix(&dmat);
562 assert_eq!(mat, back);
563 }
564
565 #[test]
566 fn test_empty() {
567 let mat = FdMatrix::zeros(0, 0);
568 assert!(mat.is_empty());
569 assert_eq!(mat.len(), 0);
570 }
571
572 #[test]
573 fn test_single_element() {
574 let mat = FdMatrix::from_column_major(vec![42.0], 1, 1).unwrap();
575 assert_eq!(mat[(0, 0)], 42.0);
576 assert_eq!(mat.column(0), &[42.0]);
577 assert_eq!(mat.row(0), vec![42.0]);
578 }
579
580 #[test]
581 fn test_display() {
582 let mat = sample_3x4();
583 assert_eq!(format!("{}", mat), "FdMatrix(3x4)");
584 }
585
586 #[test]
587 fn test_clone() {
588 let mat = sample_3x4();
589 let cloned = mat.clone();
590 assert_eq!(mat, cloned);
591 }
592
593 #[test]
594 fn test_as_mut_slice() {
595 let mut mat = FdMatrix::zeros(2, 2);
596 let s = mat.as_mut_slice();
597 s[0] = 1.0;
598 s[1] = 2.0;
599 s[2] = 3.0;
600 s[3] = 4.0;
601 assert_eq!(mat[(0, 0)], 1.0);
602 assert_eq!(mat[(1, 0)], 2.0);
603 assert_eq!(mat[(0, 1)], 3.0);
604 assert_eq!(mat[(1, 1)], 4.0);
605 }
606
607 #[test]
608 fn test_fd_curve_set_empty() {
609 assert!(FdCurveSet::from_dims(vec![]).is_err());
610 let cs = FdCurveSet::from_dims(vec![]).unwrap_or(FdCurveSet { dims: vec![] });
611 assert_eq!(cs.ndim(), 0);
612 assert_eq!(cs.ncurves(), 0);
613 assert_eq!(cs.npoints(), 0);
614 assert_eq!(format!("{}", cs), "FdCurveSet(d=0, n=0, m=0)");
615 }
616
617 #[test]
618 fn test_fd_curve_set_from_1d() {
619 let mat = sample_3x4();
620 let cs = FdCurveSet::from_1d(mat.clone());
621 assert_eq!(cs.ndim(), 1);
622 assert_eq!(cs.ncurves(), 3);
623 assert_eq!(cs.npoints(), 4);
624 assert_eq!(cs.point(0, 0), vec![1.0]);
625 assert_eq!(cs.point(1, 2), vec![8.0]);
626 }
627
628 #[test]
629 fn test_fd_curve_set_from_dims_consistent() {
630 let m1 = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
631 let m2 = FdMatrix::from_column_major(vec![5.0, 6.0, 7.0, 8.0], 2, 2).unwrap();
632 let cs = FdCurveSet::from_dims(vec![m1, m2]).unwrap();
633 assert_eq!(cs.ndim(), 2);
634 assert_eq!(cs.point(0, 0), vec![1.0, 5.0]);
635 assert_eq!(cs.point(1, 1), vec![4.0, 8.0]);
636 assert_eq!(format!("{}", cs), "FdCurveSet(d=2, n=2, m=2)");
637 }
638
639 #[test]
640 fn test_fd_curve_set_from_dims_inconsistent() {
641 let m1 = FdMatrix::from_column_major(vec![1.0, 2.0], 2, 1).unwrap();
642 let m2 = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
643 assert!(FdCurveSet::from_dims(vec![m1, m2]).is_err());
644 }
645
646 #[test]
647 fn test_to_row_major() {
648 let mat = sample_3x4();
649 let rm = mat.to_row_major();
650 assert_eq!(
652 rm,
653 vec![1.0, 4.0, 7.0, 10.0, 2.0, 5.0, 8.0, 11.0, 3.0, 6.0, 9.0, 12.0]
654 );
655 }
656
657 #[test]
658 fn test_row_to_buf() {
659 let mat = sample_3x4();
660 let mut buf = vec![0.0; 4];
661 mat.row_to_buf(0, &mut buf);
662 assert_eq!(buf, vec![1.0, 4.0, 7.0, 10.0]);
663 mat.row_to_buf(1, &mut buf);
664 assert_eq!(buf, vec![2.0, 5.0, 8.0, 11.0]);
665 mat.row_to_buf(2, &mut buf);
666 assert_eq!(buf, vec![3.0, 6.0, 9.0, 12.0]);
667 }
668
669 #[test]
670 fn test_row_to_buf_larger_buffer() {
671 let mat = sample_3x4();
672 let mut buf = vec![99.0; 6]; mat.row_to_buf(0, &mut buf);
674 assert_eq!(&buf[..4], &[1.0, 4.0, 7.0, 10.0]);
675 assert_eq!(buf[4], 99.0);
677 }
678
679 #[test]
680 fn test_row_dot_same_matrix() {
681 let mat = sample_3x4();
682 assert_eq!(mat.row_dot(0, &mat, 1), 188.0);
685 assert_eq!(mat.row_dot(0, &mat, 0), 166.0);
687 }
688
689 #[test]
690 fn test_row_dot_different_matrices() {
691 let mat1 = sample_3x4();
692 let data2 = vec![
693 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, ];
698 let mat2 = FdMatrix::from_column_major(data2, 3, 4).unwrap();
699 assert_eq!(mat1.row_dot(0, &mat2, 0), 1660.0);
702 }
703
704 #[test]
705 fn test_row_l2_sq_identical() {
706 let mat = sample_3x4();
707 assert_eq!(mat.row_l2_sq(0, &mat, 0), 0.0);
708 assert_eq!(mat.row_l2_sq(1, &mat, 1), 0.0);
709 }
710
711 #[test]
712 fn test_row_l2_sq_different() {
713 let mat = sample_3x4();
714 assert_eq!(mat.row_l2_sq(0, &mat, 1), 4.0);
717 }
718
719 #[test]
720 fn test_row_l2_sq_cross_matrix() {
721 let mat1 = FdMatrix::from_column_major(vec![0.0, 0.0], 1, 2).unwrap();
722 let mat2 = FdMatrix::from_column_major(vec![3.0, 4.0], 1, 2).unwrap();
723 assert_eq!(mat1.row_l2_sq(0, &mat2, 0), 25.0);
725 }
726
727 #[test]
728 fn test_column_major_layout_matches_manual() {
729 let n = 5;
731 let m = 7;
732 let data: Vec<f64> = (0..n * m).map(|x| x as f64).collect();
733 let mat = FdMatrix::from_column_major(data.clone(), n, m).unwrap();
734
735 for j in 0..m {
736 for i in 0..n {
737 assert_eq!(mat[(i, j)], data[i + j * n]);
738 }
739 }
740 }
741}