1#[cfg(not(feature = "std"))]
22use alloc::vec::Vec;
23
24use oxiblas_core::memory::AlignedVec;
25use oxiblas_core::scalar::Scalar;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum TriangularKind {
30 Upper,
32 Lower,
34}
35
36#[derive(Clone)]
62pub struct PackedMat<T: Scalar> {
63 data: AlignedVec<T>,
65 n: usize,
67 kind: TriangularKind,
69}
70
71impl<T: Scalar> PackedMat<T> {
72 pub fn zeros(n: usize, kind: TriangularKind) -> Self
74 where
75 T: bytemuck::Zeroable,
76 {
77 let len = Self::packed_len(n);
78 PackedMat {
79 data: AlignedVec::zeros(len),
80 n,
81 kind,
82 }
83 }
84
85 pub fn filled(n: usize, kind: TriangularKind, value: T) -> Self {
87 let len = Self::packed_len(n);
88 PackedMat {
89 data: AlignedVec::filled(len, value),
90 n,
91 kind,
92 }
93 }
94
95 pub fn from_slice(n: usize, kind: TriangularKind, data: &[T]) -> Self {
100 let len = Self::packed_len(n);
101 assert_eq!(
102 data.len(),
103 len,
104 "Slice length must equal n*(n+1)/2 = {}",
105 len
106 );
107
108 PackedMat {
109 data: AlignedVec::from_slice(data),
110 n,
111 kind,
112 }
113 }
114
115 #[inline]
117 pub const fn packed_len(n: usize) -> usize {
118 n * (n + 1) / 2
119 }
120
121 #[inline]
123 pub fn dim(&self) -> usize {
124 self.n
125 }
126
127 #[inline]
129 pub fn kind(&self) -> TriangularKind {
130 self.kind
131 }
132
133 #[inline]
135 pub fn len(&self) -> usize {
136 self.data.len()
137 }
138
139 #[inline]
141 pub fn is_empty(&self) -> bool {
142 self.n == 0
143 }
144
145 #[inline]
149 pub fn packed_index(&self, row: usize, col: usize) -> Option<usize> {
150 if row >= self.n || col >= self.n {
151 return None;
152 }
153
154 match self.kind {
155 TriangularKind::Upper => {
156 if row <= col {
157 Some(col * (col + 1) / 2 + row)
159 } else {
160 None
161 }
162 }
163 TriangularKind::Lower => {
164 if row >= col {
165 let offset = self.n * col - col * (col.saturating_sub(1)) / 2;
169 Some(offset + (row - col))
170 } else {
171 None
172 }
173 }
174 }
175 }
176
177 #[inline]
181 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
182 self.packed_index(row, col).map(|idx| &self.data[idx])
183 }
184
185 #[inline]
189 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
190 self.packed_index(row, col).map(|idx| &mut self.data[idx])
191 }
192
193 #[inline]
198 pub fn set(&mut self, row: usize, col: usize, value: T) {
199 let idx = self
200 .packed_index(row, col)
201 .expect("Element outside stored triangle");
202 self.data[idx] = value;
203 }
204
205 #[inline]
207 pub fn as_ptr(&self) -> *const T {
208 self.data.as_ptr()
209 }
210
211 #[inline]
213 pub fn as_mut_ptr(&mut self) -> *mut T {
214 self.data.as_mut_ptr()
215 }
216
217 #[inline]
219 pub fn as_slice(&self) -> &[T] {
220 self.data.as_slice()
221 }
222
223 #[inline]
225 pub fn as_slice_mut(&mut self) -> &mut [T] {
226 self.data.as_mut_slice()
227 }
228
229 pub fn to_dense(&self) -> crate::Mat<T>
231 where
232 T: bytemuck::Zeroable,
233 {
234 let mut mat = crate::Mat::zeros(self.n, self.n);
235
236 for j in 0..self.n {
237 for i in 0..self.n {
238 if let Some(idx) = self.packed_index(i, j) {
239 mat[(i, j)] = self.data[idx];
240 }
241 }
242 }
243
244 mat
245 }
246
247 pub fn from_dense(mat: &crate::MatRef<'_, T>, kind: TriangularKind) -> Self
251 where
252 T: bytemuck::Zeroable,
253 {
254 assert_eq!(mat.nrows(), mat.ncols(), "Matrix must be square");
255 let n = mat.nrows();
256 let mut packed = Self::zeros(n, kind);
257
258 for j in 0..n {
259 for i in 0..n {
260 if packed.packed_index(i, j).is_some() {
261 packed.set(i, j, mat[(i, j)]);
262 }
263 }
264 }
265
266 packed
267 }
268
269 pub fn diagonal(&self) -> Vec<T> {
271 (0..self.n)
272 .map(|i| {
273 *self
274 .get(i, i)
275 .expect("diagonal index should always be valid")
276 })
277 .collect()
278 }
279
280 pub fn set_diagonal(&mut self, diag: &[T]) {
282 assert_eq!(
283 diag.len(),
284 self.n,
285 "Diagonal length must match matrix dimension"
286 );
287 for (i, &val) in diag.iter().enumerate() {
288 self.set(i, i, val);
289 }
290 }
291
292 pub fn fill(&mut self, value: T) {
294 for elem in self.data.as_mut_slice() {
295 *elem = value;
296 }
297 }
298
299 pub fn scale(&mut self, alpha: T) {
301 for elem in self.data.as_mut_slice() {
302 *elem *= alpha;
303 }
304 }
305
306 pub fn transpose(&self) -> Self
310 where
311 T: bytemuck::Zeroable,
312 {
313 let new_kind = match self.kind {
314 TriangularKind::Upper => TriangularKind::Lower,
315 TriangularKind::Lower => TriangularKind::Upper,
316 };
317
318 let mut result = Self::zeros(self.n, new_kind);
319
320 for j in 0..self.n {
321 for i in 0..self.n {
322 if let Some(src_idx) = self.packed_index(i, j) {
323 if result.packed_index(j, i).is_some() {
325 result.set(j, i, self.data[src_idx]);
326 }
327 }
328 }
329 }
330
331 result
332 }
333}
334
335impl<T: Scalar + core::fmt::Debug> core::fmt::Debug for PackedMat<T> {
336 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
337 writeln!(f, "PackedMat {}x{} {:?} {{", self.n, self.n, self.kind)?;
338
339 for i in 0..self.n.min(8) {
340 write!(f, " [")?;
341 for j in 0..self.n.min(8) {
342 if j > 0 {
343 write!(f, ", ")?;
344 }
345 match self.get(i, j) {
346 Some(v) => write!(f, "{:8.4?}", v)?,
347 None => write!(f, " * ")?,
348 }
349 }
350 if self.n > 8 {
351 write!(f, ", ...")?;
352 }
353 writeln!(f, "]")?;
354 }
355 if self.n > 8 {
356 writeln!(f, " ...")?;
357 }
358 write!(f, "}}")
359 }
360}
361
362#[derive(Clone, Copy)]
364pub struct PackedRef<'a, T: Scalar> {
365 ptr: *const T,
367 n: usize,
369 kind: TriangularKind,
371 _marker: core::marker::PhantomData<&'a T>,
373}
374
375impl<'a, T: Scalar> PackedRef<'a, T> {
376 #[inline]
378 pub fn new(ptr: *const T, n: usize, kind: TriangularKind) -> Self {
379 PackedRef {
380 ptr,
381 n,
382 kind,
383 _marker: core::marker::PhantomData,
384 }
385 }
386
387 #[inline]
389 pub fn from_slice(data: &'a [T], n: usize, kind: TriangularKind) -> Self {
390 let expected_len = PackedMat::<T>::packed_len(n);
391 assert_eq!(
392 data.len(),
393 expected_len,
394 "Slice length must equal n*(n+1)/2"
395 );
396 PackedRef::new(data.as_ptr(), n, kind)
397 }
398
399 #[inline]
401 pub fn dim(&self) -> usize {
402 self.n
403 }
404
405 #[inline]
407 pub fn kind(&self) -> TriangularKind {
408 self.kind
409 }
410
411 #[inline]
413 pub fn packed_index(&self, row: usize, col: usize) -> Option<usize> {
414 if row >= self.n || col >= self.n {
415 return None;
416 }
417
418 match self.kind {
419 TriangularKind::Upper => {
420 if row <= col {
421 Some(col * (col + 1) / 2 + row)
422 } else {
423 None
424 }
425 }
426 TriangularKind::Lower => {
427 if row >= col {
428 let offset = self.n * col - col * (col.saturating_sub(1)) / 2;
432 Some(offset + (row - col))
433 } else {
434 None
435 }
436 }
437 }
438 }
439
440 #[inline]
442 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
443 self.packed_index(row, col)
444 .map(|idx| unsafe { &*self.ptr.add(idx) })
445 }
446
447 #[inline]
449 pub fn as_ptr(&self) -> *const T {
450 self.ptr
451 }
452}
453
454unsafe impl<'a, T: Scalar + Send> Send for PackedRef<'a, T> {}
455unsafe impl<'a, T: Scalar + Sync> Sync for PackedRef<'a, T> {}
456
457pub struct PackedMut<'a, T: Scalar> {
459 ptr: *mut T,
461 n: usize,
463 kind: TriangularKind,
465 _marker: core::marker::PhantomData<&'a mut T>,
467}
468
469impl<'a, T: Scalar> PackedMut<'a, T> {
470 #[inline]
472 pub fn new(ptr: *mut T, n: usize, kind: TriangularKind) -> Self {
473 PackedMut {
474 ptr,
475 n,
476 kind,
477 _marker: core::marker::PhantomData,
478 }
479 }
480
481 #[inline]
483 pub fn from_slice(data: &'a mut [T], n: usize, kind: TriangularKind) -> Self {
484 let expected_len = PackedMat::<T>::packed_len(n);
485 assert_eq!(
486 data.len(),
487 expected_len,
488 "Slice length must equal n*(n+1)/2"
489 );
490 PackedMut::new(data.as_mut_ptr(), n, kind)
491 }
492
493 #[inline]
495 pub fn dim(&self) -> usize {
496 self.n
497 }
498
499 #[inline]
501 pub fn kind(&self) -> TriangularKind {
502 self.kind
503 }
504
505 #[inline]
507 pub fn packed_index(&self, row: usize, col: usize) -> Option<usize> {
508 if row >= self.n || col >= self.n {
509 return None;
510 }
511
512 match self.kind {
513 TriangularKind::Upper => {
514 if row <= col {
515 Some(col * (col + 1) / 2 + row)
516 } else {
517 None
518 }
519 }
520 TriangularKind::Lower => {
521 if row >= col {
522 let offset = self.n * col - col * (col.saturating_sub(1)) / 2;
526 Some(offset + (row - col))
527 } else {
528 None
529 }
530 }
531 }
532 }
533
534 #[inline]
536 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
537 self.packed_index(row, col)
538 .map(|idx| unsafe { &*self.ptr.add(idx) })
539 }
540
541 #[inline]
543 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
544 self.packed_index(row, col)
545 .map(|idx| unsafe { &mut *self.ptr.add(idx) })
546 }
547
548 #[inline]
550 pub fn set(&mut self, row: usize, col: usize, value: T) {
551 let idx = self
552 .packed_index(row, col)
553 .expect("Element outside stored triangle");
554 unsafe {
555 *self.ptr.add(idx) = value;
556 }
557 }
558
559 #[inline]
561 pub fn as_ptr(&self) -> *const T {
562 self.ptr
563 }
564
565 #[inline]
567 pub fn as_mut_ptr(&mut self) -> *mut T {
568 self.ptr
569 }
570
571 #[inline]
573 pub fn rb(&self) -> PackedRef<'_, T> {
574 PackedRef::new(self.ptr, self.n, self.kind)
575 }
576
577 #[inline]
579 pub fn rb_mut(&mut self) -> PackedMut<'_, T> {
580 PackedMut::new(self.ptr, self.n, self.kind)
581 }
582}
583
584unsafe impl<'a, T: Scalar + Send> Send for PackedMut<'a, T> {}
585unsafe impl<'a, T: Scalar + Sync> Sync for PackedMut<'a, T> {}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
592 fn test_packed_upper_indexing() {
593 let mut p: PackedMat<f64> = PackedMat::zeros(3, TriangularKind::Upper);
599
600 assert_eq!(p.packed_index(0, 0), Some(0));
602 assert_eq!(p.packed_index(0, 1), Some(1));
603 assert_eq!(p.packed_index(1, 1), Some(2));
604 assert_eq!(p.packed_index(0, 2), Some(3));
605 assert_eq!(p.packed_index(1, 2), Some(4));
606 assert_eq!(p.packed_index(2, 2), Some(5));
607
608 assert_eq!(p.packed_index(1, 0), None);
610 assert_eq!(p.packed_index(2, 0), None);
611 assert_eq!(p.packed_index(2, 1), None);
612
613 p.set(0, 0, 1.0);
615 p.set(0, 1, 2.0);
616 p.set(1, 1, 3.0);
617 p.set(0, 2, 4.0);
618 p.set(1, 2, 5.0);
619 p.set(2, 2, 6.0);
620
621 assert_eq!(p.get(0, 0), Some(&1.0));
622 assert_eq!(p.get(0, 1), Some(&2.0));
623 assert_eq!(p.get(1, 1), Some(&3.0));
624 assert_eq!(p.get(0, 2), Some(&4.0));
625 assert_eq!(p.get(1, 2), Some(&5.0));
626 assert_eq!(p.get(2, 2), Some(&6.0));
627 }
628
629 #[test]
630 fn test_packed_lower_indexing() {
631 let mut p: PackedMat<f64> = PackedMat::zeros(3, TriangularKind::Lower);
637
638 assert_eq!(p.packed_index(0, 0), Some(0));
640 assert_eq!(p.packed_index(1, 0), Some(1));
641 assert_eq!(p.packed_index(2, 0), Some(2));
642 assert_eq!(p.packed_index(1, 1), Some(3));
643 assert_eq!(p.packed_index(2, 1), Some(4));
644 assert_eq!(p.packed_index(2, 2), Some(5));
645
646 assert_eq!(p.packed_index(0, 1), None);
648 assert_eq!(p.packed_index(0, 2), None);
649 assert_eq!(p.packed_index(1, 2), None);
650
651 p.set(0, 0, 1.0);
653 p.set(1, 0, 2.0);
654 p.set(2, 0, 3.0);
655 p.set(1, 1, 4.0);
656 p.set(2, 1, 5.0);
657 p.set(2, 2, 6.0);
658
659 assert_eq!(p.get(0, 0), Some(&1.0));
660 assert_eq!(p.get(1, 0), Some(&2.0));
661 assert_eq!(p.get(2, 0), Some(&3.0));
662 assert_eq!(p.get(1, 1), Some(&4.0));
663 assert_eq!(p.get(2, 1), Some(&5.0));
664 assert_eq!(p.get(2, 2), Some(&6.0));
665 }
666
667 #[test]
668 fn test_packed_len() {
669 assert_eq!(PackedMat::<f64>::packed_len(0), 0);
670 assert_eq!(PackedMat::<f64>::packed_len(1), 1);
671 assert_eq!(PackedMat::<f64>::packed_len(2), 3);
672 assert_eq!(PackedMat::<f64>::packed_len(3), 6);
673 assert_eq!(PackedMat::<f64>::packed_len(4), 10);
674 assert_eq!(PackedMat::<f64>::packed_len(10), 55);
675 }
676
677 #[test]
678 fn test_packed_to_dense() {
679 let mut p: PackedMat<f64> = PackedMat::zeros(3, TriangularKind::Upper);
680 p.set(0, 0, 1.0);
681 p.set(0, 1, 2.0);
682 p.set(1, 1, 3.0);
683 p.set(0, 2, 4.0);
684 p.set(1, 2, 5.0);
685 p.set(2, 2, 6.0);
686
687 let dense = p.to_dense();
688 assert_eq!(dense[(0, 0)], 1.0);
689 assert_eq!(dense[(0, 1)], 2.0);
690 assert_eq!(dense[(1, 1)], 3.0);
691 assert_eq!(dense[(0, 2)], 4.0);
692 assert_eq!(dense[(1, 2)], 5.0);
693 assert_eq!(dense[(2, 2)], 6.0);
694
695 assert_eq!(dense[(1, 0)], 0.0);
697 assert_eq!(dense[(2, 0)], 0.0);
698 assert_eq!(dense[(2, 1)], 0.0);
699 }
700
701 #[test]
702 fn test_packed_from_dense() {
703 use crate::Mat;
704
705 let dense = Mat::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
706
707 let upper = PackedMat::from_dense(&dense.as_ref(), TriangularKind::Upper);
708 assert_eq!(upper.get(0, 0), Some(&1.0));
709 assert_eq!(upper.get(0, 1), Some(&2.0));
710 assert_eq!(upper.get(0, 2), Some(&3.0));
711 assert_eq!(upper.get(1, 1), Some(&5.0));
712 assert_eq!(upper.get(1, 2), Some(&6.0));
713 assert_eq!(upper.get(2, 2), Some(&9.0));
714
715 let lower = PackedMat::from_dense(&dense.as_ref(), TriangularKind::Lower);
716 assert_eq!(lower.get(0, 0), Some(&1.0));
717 assert_eq!(lower.get(1, 0), Some(&4.0));
718 assert_eq!(lower.get(2, 0), Some(&7.0));
719 assert_eq!(lower.get(1, 1), Some(&5.0));
720 assert_eq!(lower.get(2, 1), Some(&8.0));
721 assert_eq!(lower.get(2, 2), Some(&9.0));
722 }
723
724 #[test]
725 fn test_packed_diagonal() {
726 let mut p: PackedMat<f64> = PackedMat::zeros(3, TriangularKind::Upper);
727 p.set(0, 0, 1.0);
728 p.set(0, 1, 10.0);
729 p.set(1, 1, 2.0);
730 p.set(0, 2, 20.0);
731 p.set(1, 2, 30.0);
732 p.set(2, 2, 3.0);
733
734 let diag = p.diagonal();
735 assert_eq!(diag, vec![1.0, 2.0, 3.0]);
736
737 p.set_diagonal(&[10.0, 20.0, 30.0]);
739 let diag2 = p.diagonal();
740 assert_eq!(diag2, vec![10.0, 20.0, 30.0]);
741 }
742
743 #[test]
744 fn test_packed_transpose() {
745 let mut upper: PackedMat<f64> = PackedMat::zeros(3, TriangularKind::Upper);
746 upper.set(0, 0, 1.0);
747 upper.set(0, 1, 2.0);
748 upper.set(1, 1, 3.0);
749 upper.set(0, 2, 4.0);
750 upper.set(1, 2, 5.0);
751 upper.set(2, 2, 6.0);
752
753 let lower = upper.transpose();
754 assert_eq!(lower.kind(), TriangularKind::Lower);
755
756 assert_eq!(lower.get(0, 0), Some(&1.0));
758 assert_eq!(lower.get(1, 0), Some(&2.0)); assert_eq!(lower.get(1, 1), Some(&3.0));
760 assert_eq!(lower.get(2, 0), Some(&4.0)); assert_eq!(lower.get(2, 1), Some(&5.0)); assert_eq!(lower.get(2, 2), Some(&6.0));
763 }
764
765 #[test]
766 fn test_packed_ref() {
767 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
768 let pref = PackedRef::from_slice(&data, 3, TriangularKind::Upper);
769
770 assert_eq!(pref.dim(), 3);
771 assert_eq!(pref.get(0, 0), Some(&1.0));
772 assert_eq!(pref.get(0, 1), Some(&2.0));
773 assert_eq!(pref.get(1, 1), Some(&3.0));
774 assert_eq!(pref.get(0, 2), Some(&4.0));
775 assert_eq!(pref.get(1, 2), Some(&5.0));
776 assert_eq!(pref.get(2, 2), Some(&6.0));
777 }
778
779 #[test]
780 fn test_packed_mut() {
781 let mut data = [0.0f64; 6];
782 let mut pmut = PackedMut::from_slice(&mut data, 3, TriangularKind::Lower);
783
784 pmut.set(0, 0, 1.0);
785 pmut.set(1, 0, 2.0);
786 pmut.set(2, 0, 3.0);
787 pmut.set(1, 1, 4.0);
788 pmut.set(2, 1, 5.0);
789 pmut.set(2, 2, 6.0);
790
791 assert_eq!(data, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
792 }
793
794 #[test]
795 fn test_packed_scale() {
796 let mut p: PackedMat<f64> = PackedMat::zeros(2, TriangularKind::Upper);
797 p.set(0, 0, 1.0);
798 p.set(0, 1, 2.0);
799 p.set(1, 1, 3.0);
800
801 p.scale(2.0);
802
803 assert_eq!(p.get(0, 0), Some(&2.0));
804 assert_eq!(p.get(0, 1), Some(&4.0));
805 assert_eq!(p.get(1, 1), Some(&6.0));
806 }
807}