1use nalgebra::DMatrix;
8
9#[derive(Debug, Clone, PartialEq)]
39pub struct FdMatrix {
40 data: Vec<f64>,
41 nrows: usize,
42 ncols: usize,
43}
44
45impl FdMatrix {
46 pub fn from_column_major(data: Vec<f64>, nrows: usize, ncols: usize) -> Option<Self> {
50 if data.len() != nrows * ncols {
51 return None;
52 }
53 Some(Self { data, nrows, ncols })
54 }
55
56 pub fn from_slice(data: &[f64], nrows: usize, ncols: usize) -> Option<Self> {
60 if data.len() != nrows * ncols {
61 return None;
62 }
63 Some(Self {
64 data: data.to_vec(),
65 nrows,
66 ncols,
67 })
68 }
69
70 pub fn zeros(nrows: usize, ncols: usize) -> Self {
72 Self {
73 data: vec![0.0; nrows * ncols],
74 nrows,
75 ncols,
76 }
77 }
78
79 #[inline]
81 pub fn nrows(&self) -> usize {
82 self.nrows
83 }
84
85 #[inline]
87 pub fn ncols(&self) -> usize {
88 self.ncols
89 }
90
91 #[inline]
93 pub fn shape(&self) -> (usize, usize) {
94 (self.nrows, self.ncols)
95 }
96
97 #[inline]
99 pub fn len(&self) -> usize {
100 self.data.len()
101 }
102
103 #[inline]
105 pub fn is_empty(&self) -> bool {
106 self.data.is_empty()
107 }
108
109 #[inline]
114 pub fn column(&self, col: usize) -> &[f64] {
115 let start = col * self.nrows;
116 &self.data[start..start + self.nrows]
117 }
118
119 #[inline]
124 pub fn column_mut(&mut self, col: usize) -> &mut [f64] {
125 let start = col * self.nrows;
126 &mut self.data[start..start + self.nrows]
127 }
128
129 pub fn row(&self, row: usize) -> Vec<f64> {
134 (0..self.ncols)
135 .map(|j| self.data[row + j * self.nrows])
136 .collect()
137 }
138
139 pub fn rows(&self) -> Vec<Vec<f64>> {
143 (0..self.nrows).map(|i| self.row(i)).collect()
144 }
145
146 pub fn to_row_major(&self) -> Vec<f64> {
152 let mut buf = vec![0.0; self.nrows * self.ncols];
153 for i in 0..self.nrows {
154 for j in 0..self.ncols {
155 buf[i * self.ncols + j] = self.data[i + j * self.nrows];
156 }
157 }
158 buf
159 }
160
161 #[inline]
163 pub fn as_slice(&self) -> &[f64] {
164 &self.data
165 }
166
167 #[inline]
169 pub fn as_mut_slice(&mut self) -> &mut [f64] {
170 &mut self.data
171 }
172
173 pub fn into_vec(self) -> Vec<f64> {
175 self.data
176 }
177
178 pub fn to_dmatrix(&self) -> DMatrix<f64> {
183 DMatrix::from_column_slice(self.nrows, self.ncols, &self.data)
184 }
185
186 pub fn from_dmatrix(mat: &DMatrix<f64>) -> Self {
190 let (nrows, ncols) = mat.shape();
191 Self {
192 data: mat.as_slice().to_vec(),
193 nrows,
194 ncols,
195 }
196 }
197
198 #[inline]
200 pub fn get(&self, row: usize, col: usize) -> Option<f64> {
201 if row < self.nrows && col < self.ncols {
202 Some(self.data[row + col * self.nrows])
203 } else {
204 None
205 }
206 }
207
208 #[inline]
210 pub fn set(&mut self, row: usize, col: usize, value: f64) -> bool {
211 if row < self.nrows && col < self.ncols {
212 self.data[row + col * self.nrows] = value;
213 true
214 } else {
215 false
216 }
217 }
218}
219
220impl std::ops::Index<(usize, usize)> for FdMatrix {
221 type Output = f64;
222
223 #[inline]
224 fn index(&self, (row, col): (usize, usize)) -> &f64 {
225 debug_assert!(
226 row < self.nrows && col < self.ncols,
227 "FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
228 row,
229 col,
230 self.nrows,
231 self.ncols
232 );
233 &self.data[row + col * self.nrows]
234 }
235}
236
237impl std::ops::IndexMut<(usize, usize)> for FdMatrix {
238 #[inline]
239 fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut f64 {
240 debug_assert!(
241 row < self.nrows && col < self.ncols,
242 "FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
243 row,
244 col,
245 self.nrows,
246 self.ncols
247 );
248 &mut self.data[row + col * self.nrows]
249 }
250}
251
252impl std::fmt::Display for FdMatrix {
253 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 write!(f, "FdMatrix({}x{})", self.nrows, self.ncols)
255 }
256}
257
258#[derive(Debug, Clone, PartialEq)]
263pub struct FdCurveSet {
264 pub dims: Vec<FdMatrix>,
266}
267
268impl FdCurveSet {
269 pub fn ndim(&self) -> usize {
271 self.dims.len()
272 }
273
274 pub fn ncurves(&self) -> usize {
276 if self.dims.is_empty() {
277 0
278 } else {
279 self.dims[0].nrows()
280 }
281 }
282
283 pub fn npoints(&self) -> usize {
285 if self.dims.is_empty() {
286 0
287 } else {
288 self.dims[0].ncols()
289 }
290 }
291
292 pub fn from_1d(data: FdMatrix) -> Self {
294 Self { dims: vec![data] }
295 }
296
297 pub fn from_dims(dims: Vec<FdMatrix>) -> Option<Self> {
301 if dims.is_empty() {
302 return None;
303 }
304 let (n, m) = dims[0].shape();
305 if dims.iter().any(|d| d.shape() != (n, m)) {
306 return None;
307 }
308 Some(Self { dims })
309 }
310
311 pub fn point(&self, curve: usize, time_idx: usize) -> Vec<f64> {
313 self.dims.iter().map(|d| d[(curve, time_idx)]).collect()
314 }
315}
316
317impl std::fmt::Display for FdCurveSet {
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 write!(
320 f,
321 "FdCurveSet(d={}, n={}, m={})",
322 self.ndim(),
323 self.ncurves(),
324 self.npoints()
325 )
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 fn sample_3x4() -> FdMatrix {
334 let data = vec![
336 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ];
341 FdMatrix::from_column_major(data, 3, 4).unwrap()
342 }
343
344 #[test]
345 fn test_from_column_major_valid() {
346 let mat = sample_3x4();
347 assert_eq!(mat.nrows(), 3);
348 assert_eq!(mat.ncols(), 4);
349 assert_eq!(mat.shape(), (3, 4));
350 assert_eq!(mat.len(), 12);
351 assert!(!mat.is_empty());
352 }
353
354 #[test]
355 fn test_from_column_major_invalid() {
356 assert!(FdMatrix::from_column_major(vec![1.0, 2.0], 3, 4).is_none());
357 }
358
359 #[test]
360 fn test_from_slice() {
361 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
362 let mat = FdMatrix::from_slice(&data, 2, 3).unwrap();
363 assert_eq!(mat[(0, 0)], 1.0);
364 assert_eq!(mat[(1, 0)], 2.0);
365 assert_eq!(mat[(0, 1)], 3.0);
366 }
367
368 #[test]
369 fn test_from_slice_invalid() {
370 assert!(FdMatrix::from_slice(&[1.0, 2.0], 3, 3).is_none());
371 }
372
373 #[test]
374 fn test_zeros() {
375 let mat = FdMatrix::zeros(2, 3);
376 assert_eq!(mat.nrows(), 2);
377 assert_eq!(mat.ncols(), 3);
378 for j in 0..3 {
379 for i in 0..2 {
380 assert_eq!(mat[(i, j)], 0.0);
381 }
382 }
383 }
384
385 #[test]
386 fn test_index() {
387 let mat = sample_3x4();
388 assert_eq!(mat[(0, 0)], 1.0);
389 assert_eq!(mat[(1, 0)], 2.0);
390 assert_eq!(mat[(2, 0)], 3.0);
391 assert_eq!(mat[(0, 1)], 4.0);
392 assert_eq!(mat[(1, 1)], 5.0);
393 assert_eq!(mat[(2, 3)], 12.0);
394 }
395
396 #[test]
397 fn test_index_mut() {
398 let mut mat = sample_3x4();
399 mat[(1, 2)] = 99.0;
400 assert_eq!(mat[(1, 2)], 99.0);
401 }
402
403 #[test]
404 fn test_column() {
405 let mat = sample_3x4();
406 assert_eq!(mat.column(0), &[1.0, 2.0, 3.0]);
407 assert_eq!(mat.column(1), &[4.0, 5.0, 6.0]);
408 assert_eq!(mat.column(3), &[10.0, 11.0, 12.0]);
409 }
410
411 #[test]
412 fn test_column_mut() {
413 let mut mat = sample_3x4();
414 mat.column_mut(1)[0] = 99.0;
415 assert_eq!(mat[(0, 1)], 99.0);
416 }
417
418 #[test]
419 fn test_row() {
420 let mat = sample_3x4();
421 assert_eq!(mat.row(0), vec![1.0, 4.0, 7.0, 10.0]);
422 assert_eq!(mat.row(1), vec![2.0, 5.0, 8.0, 11.0]);
423 assert_eq!(mat.row(2), vec![3.0, 6.0, 9.0, 12.0]);
424 }
425
426 #[test]
427 fn test_rows() {
428 let mat = sample_3x4();
429 let rows = mat.rows();
430 assert_eq!(rows.len(), 3);
431 assert_eq!(rows[0], vec![1.0, 4.0, 7.0, 10.0]);
432 assert_eq!(rows[2], vec![3.0, 6.0, 9.0, 12.0]);
433 }
434
435 #[test]
436 fn test_as_slice() {
437 let mat = sample_3x4();
438 let expected = vec![
439 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
440 ];
441 assert_eq!(mat.as_slice(), expected.as_slice());
442 }
443
444 #[test]
445 fn test_into_vec() {
446 let mat = sample_3x4();
447 let v = mat.into_vec();
448 assert_eq!(v.len(), 12);
449 assert_eq!(v[0], 1.0);
450 }
451
452 #[test]
453 fn test_get_bounds_check() {
454 let mat = sample_3x4();
455 assert_eq!(mat.get(0, 0), Some(1.0));
456 assert_eq!(mat.get(2, 3), Some(12.0));
457 assert_eq!(mat.get(3, 0), None); assert_eq!(mat.get(0, 4), None); }
460
461 #[test]
462 fn test_set_bounds_check() {
463 let mut mat = sample_3x4();
464 assert!(mat.set(1, 1, 99.0));
465 assert_eq!(mat[(1, 1)], 99.0);
466 assert!(!mat.set(5, 0, 99.0)); }
468
469 #[test]
470 fn test_nalgebra_roundtrip() {
471 let mat = sample_3x4();
472 let dmat = mat.to_dmatrix();
473 assert_eq!(dmat.nrows(), 3);
474 assert_eq!(dmat.ncols(), 4);
475 assert_eq!(dmat[(0, 0)], 1.0);
476 assert_eq!(dmat[(1, 2)], 8.0);
477
478 let back = FdMatrix::from_dmatrix(&dmat);
479 assert_eq!(mat, back);
480 }
481
482 #[test]
483 fn test_empty() {
484 let mat = FdMatrix::zeros(0, 0);
485 assert!(mat.is_empty());
486 assert_eq!(mat.len(), 0);
487 }
488
489 #[test]
490 fn test_single_element() {
491 let mat = FdMatrix::from_column_major(vec![42.0], 1, 1).unwrap();
492 assert_eq!(mat[(0, 0)], 42.0);
493 assert_eq!(mat.column(0), &[42.0]);
494 assert_eq!(mat.row(0), vec![42.0]);
495 }
496
497 #[test]
498 fn test_display() {
499 let mat = sample_3x4();
500 assert_eq!(format!("{}", mat), "FdMatrix(3x4)");
501 }
502
503 #[test]
504 fn test_clone() {
505 let mat = sample_3x4();
506 let cloned = mat.clone();
507 assert_eq!(mat, cloned);
508 }
509
510 #[test]
511 fn test_as_mut_slice() {
512 let mut mat = FdMatrix::zeros(2, 2);
513 let s = mat.as_mut_slice();
514 s[0] = 1.0;
515 s[1] = 2.0;
516 s[2] = 3.0;
517 s[3] = 4.0;
518 assert_eq!(mat[(0, 0)], 1.0);
519 assert_eq!(mat[(1, 0)], 2.0);
520 assert_eq!(mat[(0, 1)], 3.0);
521 assert_eq!(mat[(1, 1)], 4.0);
522 }
523
524 #[test]
525 fn test_fd_curve_set_empty() {
526 assert!(FdCurveSet::from_dims(vec![]).is_none());
527 let cs = FdCurveSet::from_dims(vec![]).unwrap_or(FdCurveSet { dims: vec![] });
528 assert_eq!(cs.ndim(), 0);
529 assert_eq!(cs.ncurves(), 0);
530 assert_eq!(cs.npoints(), 0);
531 assert_eq!(format!("{}", cs), "FdCurveSet(d=0, n=0, m=0)");
532 }
533
534 #[test]
535 fn test_fd_curve_set_from_1d() {
536 let mat = sample_3x4();
537 let cs = FdCurveSet::from_1d(mat.clone());
538 assert_eq!(cs.ndim(), 1);
539 assert_eq!(cs.ncurves(), 3);
540 assert_eq!(cs.npoints(), 4);
541 assert_eq!(cs.point(0, 0), vec![1.0]);
542 assert_eq!(cs.point(1, 2), vec![8.0]);
543 }
544
545 #[test]
546 fn test_fd_curve_set_from_dims_consistent() {
547 let m1 = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
548 let m2 = FdMatrix::from_column_major(vec![5.0, 6.0, 7.0, 8.0], 2, 2).unwrap();
549 let cs = FdCurveSet::from_dims(vec![m1, m2]).unwrap();
550 assert_eq!(cs.ndim(), 2);
551 assert_eq!(cs.point(0, 0), vec![1.0, 5.0]);
552 assert_eq!(cs.point(1, 1), vec![4.0, 8.0]);
553 assert_eq!(format!("{}", cs), "FdCurveSet(d=2, n=2, m=2)");
554 }
555
556 #[test]
557 fn test_fd_curve_set_from_dims_inconsistent() {
558 let m1 = FdMatrix::from_column_major(vec![1.0, 2.0], 2, 1).unwrap();
559 let m2 = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
560 assert!(FdCurveSet::from_dims(vec![m1, m2]).is_none());
561 }
562
563 #[test]
564 fn test_to_row_major() {
565 let mat = sample_3x4();
566 let rm = mat.to_row_major();
567 assert_eq!(
569 rm,
570 vec![1.0, 4.0, 7.0, 10.0, 2.0, 5.0, 8.0, 11.0, 3.0, 6.0, 9.0, 12.0]
571 );
572 }
573
574 #[test]
575 fn test_column_major_layout_matches_manual() {
576 let n = 5;
578 let m = 7;
579 let data: Vec<f64> = (0..n * m).map(|x| x as f64).collect();
580 let mat = FdMatrix::from_column_major(data.clone(), n, m).unwrap();
581
582 for j in 0..m {
583 for i in 0..n {
584 assert_eq!(mat[(i, j)], data[i + j * n]);
585 }
586 }
587 }
588}