1use nalgebra::DMatrix;
8
9#[derive(Debug, Clone, PartialEq)]
39pub struct FdMatrix {
40 data: Vec<f64>,
41 nrows: usize,
42 ncols: usize,
43}
44
45impl FdMatrix {
46 pub fn from_column_major(data: Vec<f64>, nrows: usize, ncols: usize) -> Option<Self> {
50 if data.len() != nrows * ncols {
51 return None;
52 }
53 Some(Self { data, nrows, ncols })
54 }
55
56 pub fn from_slice(data: &[f64], nrows: usize, ncols: usize) -> Option<Self> {
60 if data.len() != nrows * ncols {
61 return None;
62 }
63 Some(Self {
64 data: data.to_vec(),
65 nrows,
66 ncols,
67 })
68 }
69
70 pub fn zeros(nrows: usize, ncols: usize) -> Self {
72 Self {
73 data: vec![0.0; nrows * ncols],
74 nrows,
75 ncols,
76 }
77 }
78
79 #[inline]
81 pub fn nrows(&self) -> usize {
82 self.nrows
83 }
84
85 #[inline]
87 pub fn ncols(&self) -> usize {
88 self.ncols
89 }
90
91 #[inline]
93 pub fn shape(&self) -> (usize, usize) {
94 (self.nrows, self.ncols)
95 }
96
97 #[inline]
99 pub fn len(&self) -> usize {
100 self.data.len()
101 }
102
103 #[inline]
105 pub fn is_empty(&self) -> bool {
106 self.data.is_empty()
107 }
108
109 #[inline]
114 pub fn column(&self, col: usize) -> &[f64] {
115 let start = col * self.nrows;
116 &self.data[start..start + self.nrows]
117 }
118
119 #[inline]
124 pub fn column_mut(&mut self, col: usize) -> &mut [f64] {
125 let start = col * self.nrows;
126 &mut self.data[start..start + self.nrows]
127 }
128
129 pub fn row(&self, row: usize) -> Vec<f64> {
134 (0..self.ncols)
135 .map(|j| self.data[row + j * self.nrows])
136 .collect()
137 }
138
139 pub fn rows(&self) -> Vec<Vec<f64>> {
143 (0..self.nrows).map(|i| self.row(i)).collect()
144 }
145
146 pub fn to_row_major(&self) -> Vec<f64> {
152 let mut buf = vec![0.0; self.nrows * self.ncols];
153 for i in 0..self.nrows {
154 for j in 0..self.ncols {
155 buf[i * self.ncols + j] = self.data[i + j * self.nrows];
156 }
157 }
158 buf
159 }
160
161 #[inline]
163 pub fn as_slice(&self) -> &[f64] {
164 &self.data
165 }
166
167 #[inline]
169 pub fn as_mut_slice(&mut self) -> &mut [f64] {
170 &mut self.data
171 }
172
173 pub fn into_vec(self) -> Vec<f64> {
175 self.data
176 }
177
178 pub fn to_dmatrix(&self) -> DMatrix<f64> {
183 DMatrix::from_column_slice(self.nrows, self.ncols, &self.data)
184 }
185
186 pub fn from_dmatrix(mat: &DMatrix<f64>) -> Self {
190 let (nrows, ncols) = mat.shape();
191 Self {
192 data: mat.as_slice().to_vec(),
193 nrows,
194 ncols,
195 }
196 }
197
198 #[inline]
200 pub fn get(&self, row: usize, col: usize) -> Option<f64> {
201 if row < self.nrows && col < self.ncols {
202 Some(self.data[row + col * self.nrows])
203 } else {
204 None
205 }
206 }
207
208 #[inline]
210 pub fn set(&mut self, row: usize, col: usize, value: f64) -> bool {
211 if row < self.nrows && col < self.ncols {
212 self.data[row + col * self.nrows] = value;
213 true
214 } else {
215 false
216 }
217 }
218}
219
220impl std::ops::Index<(usize, usize)> for FdMatrix {
221 type Output = f64;
222
223 #[inline]
224 fn index(&self, (row, col): (usize, usize)) -> &f64 {
225 debug_assert!(
226 row < self.nrows && col < self.ncols,
227 "FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
228 row,
229 col,
230 self.nrows,
231 self.ncols
232 );
233 &self.data[row + col * self.nrows]
234 }
235}
236
237impl std::ops::IndexMut<(usize, usize)> for FdMatrix {
238 #[inline]
239 fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut f64 {
240 debug_assert!(
241 row < self.nrows && col < self.ncols,
242 "FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
243 row,
244 col,
245 self.nrows,
246 self.ncols
247 );
248 &mut self.data[row + col * self.nrows]
249 }
250}
251
252impl std::fmt::Display for FdMatrix {
253 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 write!(f, "FdMatrix({}x{})", self.nrows, self.ncols)
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 fn sample_3x4() -> FdMatrix {
263 let data = vec![
265 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ];
270 FdMatrix::from_column_major(data, 3, 4).unwrap()
271 }
272
273 #[test]
274 fn test_from_column_major_valid() {
275 let mat = sample_3x4();
276 assert_eq!(mat.nrows(), 3);
277 assert_eq!(mat.ncols(), 4);
278 assert_eq!(mat.shape(), (3, 4));
279 assert_eq!(mat.len(), 12);
280 assert!(!mat.is_empty());
281 }
282
283 #[test]
284 fn test_from_column_major_invalid() {
285 assert!(FdMatrix::from_column_major(vec![1.0, 2.0], 3, 4).is_none());
286 }
287
288 #[test]
289 fn test_from_slice() {
290 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
291 let mat = FdMatrix::from_slice(&data, 2, 3).unwrap();
292 assert_eq!(mat[(0, 0)], 1.0);
293 assert_eq!(mat[(1, 0)], 2.0);
294 assert_eq!(mat[(0, 1)], 3.0);
295 }
296
297 #[test]
298 fn test_from_slice_invalid() {
299 assert!(FdMatrix::from_slice(&[1.0, 2.0], 3, 3).is_none());
300 }
301
302 #[test]
303 fn test_zeros() {
304 let mat = FdMatrix::zeros(2, 3);
305 assert_eq!(mat.nrows(), 2);
306 assert_eq!(mat.ncols(), 3);
307 for j in 0..3 {
308 for i in 0..2 {
309 assert_eq!(mat[(i, j)], 0.0);
310 }
311 }
312 }
313
314 #[test]
315 fn test_index() {
316 let mat = sample_3x4();
317 assert_eq!(mat[(0, 0)], 1.0);
318 assert_eq!(mat[(1, 0)], 2.0);
319 assert_eq!(mat[(2, 0)], 3.0);
320 assert_eq!(mat[(0, 1)], 4.0);
321 assert_eq!(mat[(1, 1)], 5.0);
322 assert_eq!(mat[(2, 3)], 12.0);
323 }
324
325 #[test]
326 fn test_index_mut() {
327 let mut mat = sample_3x4();
328 mat[(1, 2)] = 99.0;
329 assert_eq!(mat[(1, 2)], 99.0);
330 }
331
332 #[test]
333 fn test_column() {
334 let mat = sample_3x4();
335 assert_eq!(mat.column(0), &[1.0, 2.0, 3.0]);
336 assert_eq!(mat.column(1), &[4.0, 5.0, 6.0]);
337 assert_eq!(mat.column(3), &[10.0, 11.0, 12.0]);
338 }
339
340 #[test]
341 fn test_column_mut() {
342 let mut mat = sample_3x4();
343 mat.column_mut(1)[0] = 99.0;
344 assert_eq!(mat[(0, 1)], 99.0);
345 }
346
347 #[test]
348 fn test_row() {
349 let mat = sample_3x4();
350 assert_eq!(mat.row(0), vec![1.0, 4.0, 7.0, 10.0]);
351 assert_eq!(mat.row(1), vec![2.0, 5.0, 8.0, 11.0]);
352 assert_eq!(mat.row(2), vec![3.0, 6.0, 9.0, 12.0]);
353 }
354
355 #[test]
356 fn test_rows() {
357 let mat = sample_3x4();
358 let rows = mat.rows();
359 assert_eq!(rows.len(), 3);
360 assert_eq!(rows[0], vec![1.0, 4.0, 7.0, 10.0]);
361 assert_eq!(rows[2], vec![3.0, 6.0, 9.0, 12.0]);
362 }
363
364 #[test]
365 fn test_as_slice() {
366 let mat = sample_3x4();
367 let expected = vec![
368 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
369 ];
370 assert_eq!(mat.as_slice(), expected.as_slice());
371 }
372
373 #[test]
374 fn test_into_vec() {
375 let mat = sample_3x4();
376 let v = mat.into_vec();
377 assert_eq!(v.len(), 12);
378 assert_eq!(v[0], 1.0);
379 }
380
381 #[test]
382 fn test_get_bounds_check() {
383 let mat = sample_3x4();
384 assert_eq!(mat.get(0, 0), Some(1.0));
385 assert_eq!(mat.get(2, 3), Some(12.0));
386 assert_eq!(mat.get(3, 0), None); assert_eq!(mat.get(0, 4), None); }
389
390 #[test]
391 fn test_set_bounds_check() {
392 let mut mat = sample_3x4();
393 assert!(mat.set(1, 1, 99.0));
394 assert_eq!(mat[(1, 1)], 99.0);
395 assert!(!mat.set(5, 0, 99.0)); }
397
398 #[test]
399 fn test_nalgebra_roundtrip() {
400 let mat = sample_3x4();
401 let dmat = mat.to_dmatrix();
402 assert_eq!(dmat.nrows(), 3);
403 assert_eq!(dmat.ncols(), 4);
404 assert_eq!(dmat[(0, 0)], 1.0);
405 assert_eq!(dmat[(1, 2)], 8.0);
406
407 let back = FdMatrix::from_dmatrix(&dmat);
408 assert_eq!(mat, back);
409 }
410
411 #[test]
412 fn test_empty() {
413 let mat = FdMatrix::zeros(0, 0);
414 assert!(mat.is_empty());
415 assert_eq!(mat.len(), 0);
416 }
417
418 #[test]
419 fn test_single_element() {
420 let mat = FdMatrix::from_column_major(vec![42.0], 1, 1).unwrap();
421 assert_eq!(mat[(0, 0)], 42.0);
422 assert_eq!(mat.column(0), &[42.0]);
423 assert_eq!(mat.row(0), vec![42.0]);
424 }
425
426 #[test]
427 fn test_display() {
428 let mat = sample_3x4();
429 assert_eq!(format!("{}", mat), "FdMatrix(3x4)");
430 }
431
432 #[test]
433 fn test_clone() {
434 let mat = sample_3x4();
435 let cloned = mat.clone();
436 assert_eq!(mat, cloned);
437 }
438
439 #[test]
440 fn test_as_mut_slice() {
441 let mut mat = FdMatrix::zeros(2, 2);
442 let s = mat.as_mut_slice();
443 s[0] = 1.0;
444 s[1] = 2.0;
445 s[2] = 3.0;
446 s[3] = 4.0;
447 assert_eq!(mat[(0, 0)], 1.0);
448 assert_eq!(mat[(1, 0)], 2.0);
449 assert_eq!(mat[(0, 1)], 3.0);
450 assert_eq!(mat[(1, 1)], 4.0);
451 }
452
453 #[test]
454 fn test_column_major_layout_matches_manual() {
455 let n = 5;
457 let m = 7;
458 let data: Vec<f64> = (0..n * m).map(|x| x as f64).collect();
459 let mat = FdMatrix::from_column_major(data.clone(), n, m).unwrap();
460
461 for j in 0..m {
462 for i in 0..n {
463 assert_eq!(mat[(i, j)], data[i + j * n]);
464 }
465 }
466 }
467}