1use nalgebra::DMatrix;
8
9#[derive(Debug, Clone, PartialEq)]
39pub struct FdMatrix {
40 data: Vec<f64>,
41 nrows: usize,
42 ncols: usize,
43}
44
45impl FdMatrix {
46 pub fn from_column_major(data: Vec<f64>, nrows: usize, ncols: usize) -> Option<Self> {
50 if data.len() != nrows * ncols {
51 return None;
52 }
53 Some(Self { data, nrows, ncols })
54 }
55
56 pub fn from_slice(data: &[f64], nrows: usize, ncols: usize) -> Option<Self> {
60 if data.len() != nrows * ncols {
61 return None;
62 }
63 Some(Self {
64 data: data.to_vec(),
65 nrows,
66 ncols,
67 })
68 }
69
70 pub fn zeros(nrows: usize, ncols: usize) -> Self {
72 Self {
73 data: vec![0.0; nrows * ncols],
74 nrows,
75 ncols,
76 }
77 }
78
79 #[inline]
81 pub fn nrows(&self) -> usize {
82 self.nrows
83 }
84
85 #[inline]
87 pub fn ncols(&self) -> usize {
88 self.ncols
89 }
90
91 #[inline]
93 pub fn shape(&self) -> (usize, usize) {
94 (self.nrows, self.ncols)
95 }
96
97 #[inline]
99 pub fn len(&self) -> usize {
100 self.data.len()
101 }
102
103 #[inline]
105 pub fn is_empty(&self) -> bool {
106 self.data.is_empty()
107 }
108
109 #[inline]
114 pub fn column(&self, col: usize) -> &[f64] {
115 let start = col * self.nrows;
116 &self.data[start..start + self.nrows]
117 }
118
119 #[inline]
124 pub fn column_mut(&mut self, col: usize) -> &mut [f64] {
125 let start = col * self.nrows;
126 &mut self.data[start..start + self.nrows]
127 }
128
129 pub fn row(&self, row: usize) -> Vec<f64> {
134 (0..self.ncols)
135 .map(|j| self.data[row + j * self.nrows])
136 .collect()
137 }
138
139 pub fn rows(&self) -> Vec<Vec<f64>> {
143 (0..self.nrows).map(|i| self.row(i)).collect()
144 }
145
146 #[inline]
148 pub fn as_slice(&self) -> &[f64] {
149 &self.data
150 }
151
152 #[inline]
154 pub fn as_mut_slice(&mut self) -> &mut [f64] {
155 &mut self.data
156 }
157
158 pub fn into_vec(self) -> Vec<f64> {
160 self.data
161 }
162
163 pub fn to_dmatrix(&self) -> DMatrix<f64> {
168 DMatrix::from_column_slice(self.nrows, self.ncols, &self.data)
169 }
170
171 pub fn from_dmatrix(mat: &DMatrix<f64>) -> Self {
175 let (nrows, ncols) = mat.shape();
176 Self {
177 data: mat.as_slice().to_vec(),
178 nrows,
179 ncols,
180 }
181 }
182
183 #[inline]
185 pub fn get(&self, row: usize, col: usize) -> Option<f64> {
186 if row < self.nrows && col < self.ncols {
187 Some(self.data[row + col * self.nrows])
188 } else {
189 None
190 }
191 }
192
193 #[inline]
195 pub fn set(&mut self, row: usize, col: usize, value: f64) -> bool {
196 if row < self.nrows && col < self.ncols {
197 self.data[row + col * self.nrows] = value;
198 true
199 } else {
200 false
201 }
202 }
203}
204
205impl std::ops::Index<(usize, usize)> for FdMatrix {
206 type Output = f64;
207
208 #[inline]
209 fn index(&self, (row, col): (usize, usize)) -> &f64 {
210 debug_assert!(
211 row < self.nrows && col < self.ncols,
212 "FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
213 row,
214 col,
215 self.nrows,
216 self.ncols
217 );
218 &self.data[row + col * self.nrows]
219 }
220}
221
222impl std::ops::IndexMut<(usize, usize)> for FdMatrix {
223 #[inline]
224 fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut f64 {
225 debug_assert!(
226 row < self.nrows && col < self.ncols,
227 "FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
228 row,
229 col,
230 self.nrows,
231 self.ncols
232 );
233 &mut self.data[row + col * self.nrows]
234 }
235}
236
237impl std::fmt::Display for FdMatrix {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 write!(f, "FdMatrix({}x{})", self.nrows, self.ncols)
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 fn sample_3x4() -> FdMatrix {
248 let data = vec![
250 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ];
255 FdMatrix::from_column_major(data, 3, 4).unwrap()
256 }
257
258 #[test]
259 fn test_from_column_major_valid() {
260 let mat = sample_3x4();
261 assert_eq!(mat.nrows(), 3);
262 assert_eq!(mat.ncols(), 4);
263 assert_eq!(mat.shape(), (3, 4));
264 assert_eq!(mat.len(), 12);
265 assert!(!mat.is_empty());
266 }
267
268 #[test]
269 fn test_from_column_major_invalid() {
270 assert!(FdMatrix::from_column_major(vec![1.0, 2.0], 3, 4).is_none());
271 }
272
273 #[test]
274 fn test_from_slice() {
275 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
276 let mat = FdMatrix::from_slice(&data, 2, 3).unwrap();
277 assert_eq!(mat[(0, 0)], 1.0);
278 assert_eq!(mat[(1, 0)], 2.0);
279 assert_eq!(mat[(0, 1)], 3.0);
280 }
281
282 #[test]
283 fn test_from_slice_invalid() {
284 assert!(FdMatrix::from_slice(&[1.0, 2.0], 3, 3).is_none());
285 }
286
287 #[test]
288 fn test_zeros() {
289 let mat = FdMatrix::zeros(2, 3);
290 assert_eq!(mat.nrows(), 2);
291 assert_eq!(mat.ncols(), 3);
292 for j in 0..3 {
293 for i in 0..2 {
294 assert_eq!(mat[(i, j)], 0.0);
295 }
296 }
297 }
298
299 #[test]
300 fn test_index() {
301 let mat = sample_3x4();
302 assert_eq!(mat[(0, 0)], 1.0);
303 assert_eq!(mat[(1, 0)], 2.0);
304 assert_eq!(mat[(2, 0)], 3.0);
305 assert_eq!(mat[(0, 1)], 4.0);
306 assert_eq!(mat[(1, 1)], 5.0);
307 assert_eq!(mat[(2, 3)], 12.0);
308 }
309
310 #[test]
311 fn test_index_mut() {
312 let mut mat = sample_3x4();
313 mat[(1, 2)] = 99.0;
314 assert_eq!(mat[(1, 2)], 99.0);
315 }
316
317 #[test]
318 fn test_column() {
319 let mat = sample_3x4();
320 assert_eq!(mat.column(0), &[1.0, 2.0, 3.0]);
321 assert_eq!(mat.column(1), &[4.0, 5.0, 6.0]);
322 assert_eq!(mat.column(3), &[10.0, 11.0, 12.0]);
323 }
324
325 #[test]
326 fn test_column_mut() {
327 let mut mat = sample_3x4();
328 mat.column_mut(1)[0] = 99.0;
329 assert_eq!(mat[(0, 1)], 99.0);
330 }
331
332 #[test]
333 fn test_row() {
334 let mat = sample_3x4();
335 assert_eq!(mat.row(0), vec![1.0, 4.0, 7.0, 10.0]);
336 assert_eq!(mat.row(1), vec![2.0, 5.0, 8.0, 11.0]);
337 assert_eq!(mat.row(2), vec![3.0, 6.0, 9.0, 12.0]);
338 }
339
340 #[test]
341 fn test_rows() {
342 let mat = sample_3x4();
343 let rows = mat.rows();
344 assert_eq!(rows.len(), 3);
345 assert_eq!(rows[0], vec![1.0, 4.0, 7.0, 10.0]);
346 assert_eq!(rows[2], vec![3.0, 6.0, 9.0, 12.0]);
347 }
348
349 #[test]
350 fn test_as_slice() {
351 let mat = sample_3x4();
352 let expected = vec![
353 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
354 ];
355 assert_eq!(mat.as_slice(), expected.as_slice());
356 }
357
358 #[test]
359 fn test_into_vec() {
360 let mat = sample_3x4();
361 let v = mat.into_vec();
362 assert_eq!(v.len(), 12);
363 assert_eq!(v[0], 1.0);
364 }
365
366 #[test]
367 fn test_get_bounds_check() {
368 let mat = sample_3x4();
369 assert_eq!(mat.get(0, 0), Some(1.0));
370 assert_eq!(mat.get(2, 3), Some(12.0));
371 assert_eq!(mat.get(3, 0), None); assert_eq!(mat.get(0, 4), None); }
374
375 #[test]
376 fn test_set_bounds_check() {
377 let mut mat = sample_3x4();
378 assert!(mat.set(1, 1, 99.0));
379 assert_eq!(mat[(1, 1)], 99.0);
380 assert!(!mat.set(5, 0, 99.0)); }
382
383 #[test]
384 fn test_nalgebra_roundtrip() {
385 let mat = sample_3x4();
386 let dmat = mat.to_dmatrix();
387 assert_eq!(dmat.nrows(), 3);
388 assert_eq!(dmat.ncols(), 4);
389 assert_eq!(dmat[(0, 0)], 1.0);
390 assert_eq!(dmat[(1, 2)], 8.0);
391
392 let back = FdMatrix::from_dmatrix(&dmat);
393 assert_eq!(mat, back);
394 }
395
396 #[test]
397 fn test_empty() {
398 let mat = FdMatrix::zeros(0, 0);
399 assert!(mat.is_empty());
400 assert_eq!(mat.len(), 0);
401 }
402
403 #[test]
404 fn test_single_element() {
405 let mat = FdMatrix::from_column_major(vec![42.0], 1, 1).unwrap();
406 assert_eq!(mat[(0, 0)], 42.0);
407 assert_eq!(mat.column(0), &[42.0]);
408 assert_eq!(mat.row(0), vec![42.0]);
409 }
410
411 #[test]
412 fn test_display() {
413 let mat = sample_3x4();
414 assert_eq!(format!("{}", mat), "FdMatrix(3x4)");
415 }
416
417 #[test]
418 fn test_clone() {
419 let mat = sample_3x4();
420 let cloned = mat.clone();
421 assert_eq!(mat, cloned);
422 }
423
424 #[test]
425 fn test_as_mut_slice() {
426 let mut mat = FdMatrix::zeros(2, 2);
427 let s = mat.as_mut_slice();
428 s[0] = 1.0;
429 s[1] = 2.0;
430 s[2] = 3.0;
431 s[3] = 4.0;
432 assert_eq!(mat[(0, 0)], 1.0);
433 assert_eq!(mat[(1, 0)], 2.0);
434 assert_eq!(mat[(0, 1)], 3.0);
435 assert_eq!(mat[(1, 1)], 4.0);
436 }
437
438 #[test]
439 fn test_column_major_layout_matches_manual() {
440 let n = 5;
442 let m = 7;
443 let data: Vec<f64> = (0..n * m).map(|x| x as f64).collect();
444 let mat = FdMatrix::from_column_major(data.clone(), n, m).unwrap();
445
446 for j in 0..m {
447 for i in 0..n {
448 assert_eq!(mat[(i, j)], data[i + j * n]);
449 }
450 }
451 }
452}