1use oxiblas_core::memory::AlignedVec;
22use oxiblas_core::scalar::Scalar;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum TriangularKind {
27 Upper,
29 Lower,
31}
32
33#[derive(Clone)]
59pub struct PackedMat<T: Scalar> {
60 data: AlignedVec<T>,
62 n: usize,
64 kind: TriangularKind,
66}
67
68impl<T: Scalar> PackedMat<T> {
69 pub fn zeros(n: usize, kind: TriangularKind) -> Self
71 where
72 T: bytemuck::Zeroable,
73 {
74 let len = Self::packed_len(n);
75 PackedMat {
76 data: AlignedVec::zeros(len),
77 n,
78 kind,
79 }
80 }
81
82 pub fn filled(n: usize, kind: TriangularKind, value: T) -> Self {
84 let len = Self::packed_len(n);
85 PackedMat {
86 data: AlignedVec::filled(len, value),
87 n,
88 kind,
89 }
90 }
91
92 pub fn from_slice(n: usize, kind: TriangularKind, data: &[T]) -> Self {
97 let len = Self::packed_len(n);
98 assert_eq!(
99 data.len(),
100 len,
101 "Slice length must equal n*(n+1)/2 = {}",
102 len
103 );
104
105 PackedMat {
106 data: AlignedVec::from_slice(data),
107 n,
108 kind,
109 }
110 }
111
112 #[inline]
114 pub const fn packed_len(n: usize) -> usize {
115 n * (n + 1) / 2
116 }
117
118 #[inline]
120 pub fn dim(&self) -> usize {
121 self.n
122 }
123
124 #[inline]
126 pub fn kind(&self) -> TriangularKind {
127 self.kind
128 }
129
130 #[inline]
132 pub fn len(&self) -> usize {
133 self.data.len()
134 }
135
136 #[inline]
138 pub fn is_empty(&self) -> bool {
139 self.n == 0
140 }
141
142 #[inline]
146 pub fn packed_index(&self, row: usize, col: usize) -> Option<usize> {
147 if row >= self.n || col >= self.n {
148 return None;
149 }
150
151 match self.kind {
152 TriangularKind::Upper => {
153 if row <= col {
154 Some(col * (col + 1) / 2 + row)
156 } else {
157 None
158 }
159 }
160 TriangularKind::Lower => {
161 if row >= col {
162 let offset = self.n * col - col * (col.saturating_sub(1)) / 2;
166 Some(offset + (row - col))
167 } else {
168 None
169 }
170 }
171 }
172 }
173
174 #[inline]
178 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
179 self.packed_index(row, col).map(|idx| &self.data[idx])
180 }
181
182 #[inline]
186 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
187 self.packed_index(row, col).map(|idx| &mut self.data[idx])
188 }
189
190 #[inline]
195 pub fn set(&mut self, row: usize, col: usize, value: T) {
196 let idx = self
197 .packed_index(row, col)
198 .expect("Element outside stored triangle");
199 self.data[idx] = value;
200 }
201
202 #[inline]
204 pub fn as_ptr(&self) -> *const T {
205 self.data.as_ptr()
206 }
207
208 #[inline]
210 pub fn as_mut_ptr(&mut self) -> *mut T {
211 self.data.as_mut_ptr()
212 }
213
214 #[inline]
216 pub fn as_slice(&self) -> &[T] {
217 self.data.as_slice()
218 }
219
220 #[inline]
222 pub fn as_slice_mut(&mut self) -> &mut [T] {
223 self.data.as_mut_slice()
224 }
225
226 pub fn to_dense(&self) -> crate::Mat<T>
228 where
229 T: bytemuck::Zeroable,
230 {
231 let mut mat = crate::Mat::zeros(self.n, self.n);
232
233 for j in 0..self.n {
234 for i in 0..self.n {
235 if let Some(idx) = self.packed_index(i, j) {
236 mat[(i, j)] = self.data[idx];
237 }
238 }
239 }
240
241 mat
242 }
243
244 pub fn from_dense(mat: &crate::MatRef<'_, T>, kind: TriangularKind) -> Self
248 where
249 T: bytemuck::Zeroable,
250 {
251 assert_eq!(mat.nrows(), mat.ncols(), "Matrix must be square");
252 let n = mat.nrows();
253 let mut packed = Self::zeros(n, kind);
254
255 for j in 0..n {
256 for i in 0..n {
257 if packed.packed_index(i, j).is_some() {
258 packed.set(i, j, mat[(i, j)]);
259 }
260 }
261 }
262
263 packed
264 }
265
266 pub fn diagonal(&self) -> Vec<T> {
268 (0..self.n).map(|i| *self.get(i, i).unwrap()).collect()
269 }
270
271 pub fn set_diagonal(&mut self, diag: &[T]) {
273 assert_eq!(
274 diag.len(),
275 self.n,
276 "Diagonal length must match matrix dimension"
277 );
278 for (i, &val) in diag.iter().enumerate() {
279 self.set(i, i, val);
280 }
281 }
282
283 pub fn fill(&mut self, value: T) {
285 for elem in self.data.as_mut_slice() {
286 *elem = value;
287 }
288 }
289
290 pub fn scale(&mut self, alpha: T) {
292 for elem in self.data.as_mut_slice() {
293 *elem *= alpha;
294 }
295 }
296
297 pub fn transpose(&self) -> Self
301 where
302 T: bytemuck::Zeroable,
303 {
304 let new_kind = match self.kind {
305 TriangularKind::Upper => TriangularKind::Lower,
306 TriangularKind::Lower => TriangularKind::Upper,
307 };
308
309 let mut result = Self::zeros(self.n, new_kind);
310
311 for j in 0..self.n {
312 for i in 0..self.n {
313 if let Some(src_idx) = self.packed_index(i, j) {
314 if result.packed_index(j, i).is_some() {
316 result.set(j, i, self.data[src_idx]);
317 }
318 }
319 }
320 }
321
322 result
323 }
324}
325
326impl<T: Scalar + core::fmt::Debug> core::fmt::Debug for PackedMat<T> {
327 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
328 writeln!(f, "PackedMat {}x{} {:?} {{", self.n, self.n, self.kind)?;
329
330 for i in 0..self.n.min(8) {
331 write!(f, " [")?;
332 for j in 0..self.n.min(8) {
333 if j > 0 {
334 write!(f, ", ")?;
335 }
336 match self.get(i, j) {
337 Some(v) => write!(f, "{:8.4?}", v)?,
338 None => write!(f, " * ")?,
339 }
340 }
341 if self.n > 8 {
342 write!(f, ", ...")?;
343 }
344 writeln!(f, "]")?;
345 }
346 if self.n > 8 {
347 writeln!(f, " ...")?;
348 }
349 write!(f, "}}")
350 }
351}
352
353#[derive(Clone, Copy)]
355pub struct PackedRef<'a, T: Scalar> {
356 ptr: *const T,
358 n: usize,
360 kind: TriangularKind,
362 _marker: core::marker::PhantomData<&'a T>,
364}
365
366impl<'a, T: Scalar> PackedRef<'a, T> {
367 #[inline]
369 pub fn new(ptr: *const T, n: usize, kind: TriangularKind) -> Self {
370 PackedRef {
371 ptr,
372 n,
373 kind,
374 _marker: core::marker::PhantomData,
375 }
376 }
377
378 #[inline]
380 pub fn from_slice(data: &'a [T], n: usize, kind: TriangularKind) -> Self {
381 let expected_len = PackedMat::<T>::packed_len(n);
382 assert_eq!(
383 data.len(),
384 expected_len,
385 "Slice length must equal n*(n+1)/2"
386 );
387 PackedRef::new(data.as_ptr(), n, kind)
388 }
389
390 #[inline]
392 pub fn dim(&self) -> usize {
393 self.n
394 }
395
396 #[inline]
398 pub fn kind(&self) -> TriangularKind {
399 self.kind
400 }
401
402 #[inline]
404 pub fn packed_index(&self, row: usize, col: usize) -> Option<usize> {
405 if row >= self.n || col >= self.n {
406 return None;
407 }
408
409 match self.kind {
410 TriangularKind::Upper => {
411 if row <= col {
412 Some(col * (col + 1) / 2 + row)
413 } else {
414 None
415 }
416 }
417 TriangularKind::Lower => {
418 if row >= col {
419 let offset = self.n * col - col * (col.saturating_sub(1)) / 2;
423 Some(offset + (row - col))
424 } else {
425 None
426 }
427 }
428 }
429 }
430
431 #[inline]
433 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
434 self.packed_index(row, col)
435 .map(|idx| unsafe { &*self.ptr.add(idx) })
436 }
437
438 #[inline]
440 pub fn as_ptr(&self) -> *const T {
441 self.ptr
442 }
443}
444
445unsafe impl<'a, T: Scalar + Send> Send for PackedRef<'a, T> {}
446unsafe impl<'a, T: Scalar + Sync> Sync for PackedRef<'a, T> {}
447
448pub struct PackedMut<'a, T: Scalar> {
450 ptr: *mut T,
452 n: usize,
454 kind: TriangularKind,
456 _marker: core::marker::PhantomData<&'a mut T>,
458}
459
460impl<'a, T: Scalar> PackedMut<'a, T> {
461 #[inline]
463 pub fn new(ptr: *mut T, n: usize, kind: TriangularKind) -> Self {
464 PackedMut {
465 ptr,
466 n,
467 kind,
468 _marker: core::marker::PhantomData,
469 }
470 }
471
472 #[inline]
474 pub fn from_slice(data: &'a mut [T], n: usize, kind: TriangularKind) -> Self {
475 let expected_len = PackedMat::<T>::packed_len(n);
476 assert_eq!(
477 data.len(),
478 expected_len,
479 "Slice length must equal n*(n+1)/2"
480 );
481 PackedMut::new(data.as_mut_ptr(), n, kind)
482 }
483
484 #[inline]
486 pub fn dim(&self) -> usize {
487 self.n
488 }
489
490 #[inline]
492 pub fn kind(&self) -> TriangularKind {
493 self.kind
494 }
495
496 #[inline]
498 pub fn packed_index(&self, row: usize, col: usize) -> Option<usize> {
499 if row >= self.n || col >= self.n {
500 return None;
501 }
502
503 match self.kind {
504 TriangularKind::Upper => {
505 if row <= col {
506 Some(col * (col + 1) / 2 + row)
507 } else {
508 None
509 }
510 }
511 TriangularKind::Lower => {
512 if row >= col {
513 let offset = self.n * col - col * (col.saturating_sub(1)) / 2;
517 Some(offset + (row - col))
518 } else {
519 None
520 }
521 }
522 }
523 }
524
525 #[inline]
527 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
528 self.packed_index(row, col)
529 .map(|idx| unsafe { &*self.ptr.add(idx) })
530 }
531
532 #[inline]
534 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
535 self.packed_index(row, col)
536 .map(|idx| unsafe { &mut *self.ptr.add(idx) })
537 }
538
539 #[inline]
541 pub fn set(&mut self, row: usize, col: usize, value: T) {
542 let idx = self
543 .packed_index(row, col)
544 .expect("Element outside stored triangle");
545 unsafe {
546 *self.ptr.add(idx) = value;
547 }
548 }
549
550 #[inline]
552 pub fn as_ptr(&self) -> *const T {
553 self.ptr
554 }
555
556 #[inline]
558 pub fn as_mut_ptr(&mut self) -> *mut T {
559 self.ptr
560 }
561
562 #[inline]
564 pub fn rb(&self) -> PackedRef<'_, T> {
565 PackedRef::new(self.ptr, self.n, self.kind)
566 }
567
568 #[inline]
570 pub fn rb_mut(&mut self) -> PackedMut<'_, T> {
571 PackedMut::new(self.ptr, self.n, self.kind)
572 }
573}
574
575unsafe impl<'a, T: Scalar + Send> Send for PackedMut<'a, T> {}
576unsafe impl<'a, T: Scalar + Sync> Sync for PackedMut<'a, T> {}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581
582 #[test]
583 fn test_packed_upper_indexing() {
584 let mut p: PackedMat<f64> = PackedMat::zeros(3, TriangularKind::Upper);
590
591 assert_eq!(p.packed_index(0, 0), Some(0));
593 assert_eq!(p.packed_index(0, 1), Some(1));
594 assert_eq!(p.packed_index(1, 1), Some(2));
595 assert_eq!(p.packed_index(0, 2), Some(3));
596 assert_eq!(p.packed_index(1, 2), Some(4));
597 assert_eq!(p.packed_index(2, 2), Some(5));
598
599 assert_eq!(p.packed_index(1, 0), None);
601 assert_eq!(p.packed_index(2, 0), None);
602 assert_eq!(p.packed_index(2, 1), None);
603
604 p.set(0, 0, 1.0);
606 p.set(0, 1, 2.0);
607 p.set(1, 1, 3.0);
608 p.set(0, 2, 4.0);
609 p.set(1, 2, 5.0);
610 p.set(2, 2, 6.0);
611
612 assert_eq!(p.get(0, 0), Some(&1.0));
613 assert_eq!(p.get(0, 1), Some(&2.0));
614 assert_eq!(p.get(1, 1), Some(&3.0));
615 assert_eq!(p.get(0, 2), Some(&4.0));
616 assert_eq!(p.get(1, 2), Some(&5.0));
617 assert_eq!(p.get(2, 2), Some(&6.0));
618 }
619
620 #[test]
621 fn test_packed_lower_indexing() {
622 let mut p: PackedMat<f64> = PackedMat::zeros(3, TriangularKind::Lower);
628
629 assert_eq!(p.packed_index(0, 0), Some(0));
631 assert_eq!(p.packed_index(1, 0), Some(1));
632 assert_eq!(p.packed_index(2, 0), Some(2));
633 assert_eq!(p.packed_index(1, 1), Some(3));
634 assert_eq!(p.packed_index(2, 1), Some(4));
635 assert_eq!(p.packed_index(2, 2), Some(5));
636
637 assert_eq!(p.packed_index(0, 1), None);
639 assert_eq!(p.packed_index(0, 2), None);
640 assert_eq!(p.packed_index(1, 2), None);
641
642 p.set(0, 0, 1.0);
644 p.set(1, 0, 2.0);
645 p.set(2, 0, 3.0);
646 p.set(1, 1, 4.0);
647 p.set(2, 1, 5.0);
648 p.set(2, 2, 6.0);
649
650 assert_eq!(p.get(0, 0), Some(&1.0));
651 assert_eq!(p.get(1, 0), Some(&2.0));
652 assert_eq!(p.get(2, 0), Some(&3.0));
653 assert_eq!(p.get(1, 1), Some(&4.0));
654 assert_eq!(p.get(2, 1), Some(&5.0));
655 assert_eq!(p.get(2, 2), Some(&6.0));
656 }
657
658 #[test]
659 fn test_packed_len() {
660 assert_eq!(PackedMat::<f64>::packed_len(0), 0);
661 assert_eq!(PackedMat::<f64>::packed_len(1), 1);
662 assert_eq!(PackedMat::<f64>::packed_len(2), 3);
663 assert_eq!(PackedMat::<f64>::packed_len(3), 6);
664 assert_eq!(PackedMat::<f64>::packed_len(4), 10);
665 assert_eq!(PackedMat::<f64>::packed_len(10), 55);
666 }
667
668 #[test]
669 fn test_packed_to_dense() {
670 let mut p: PackedMat<f64> = PackedMat::zeros(3, TriangularKind::Upper);
671 p.set(0, 0, 1.0);
672 p.set(0, 1, 2.0);
673 p.set(1, 1, 3.0);
674 p.set(0, 2, 4.0);
675 p.set(1, 2, 5.0);
676 p.set(2, 2, 6.0);
677
678 let dense = p.to_dense();
679 assert_eq!(dense[(0, 0)], 1.0);
680 assert_eq!(dense[(0, 1)], 2.0);
681 assert_eq!(dense[(1, 1)], 3.0);
682 assert_eq!(dense[(0, 2)], 4.0);
683 assert_eq!(dense[(1, 2)], 5.0);
684 assert_eq!(dense[(2, 2)], 6.0);
685
686 assert_eq!(dense[(1, 0)], 0.0);
688 assert_eq!(dense[(2, 0)], 0.0);
689 assert_eq!(dense[(2, 1)], 0.0);
690 }
691
692 #[test]
693 fn test_packed_from_dense() {
694 use crate::Mat;
695
696 let dense = Mat::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
697
698 let upper = PackedMat::from_dense(&dense.as_ref(), TriangularKind::Upper);
699 assert_eq!(upper.get(0, 0), Some(&1.0));
700 assert_eq!(upper.get(0, 1), Some(&2.0));
701 assert_eq!(upper.get(0, 2), Some(&3.0));
702 assert_eq!(upper.get(1, 1), Some(&5.0));
703 assert_eq!(upper.get(1, 2), Some(&6.0));
704 assert_eq!(upper.get(2, 2), Some(&9.0));
705
706 let lower = PackedMat::from_dense(&dense.as_ref(), TriangularKind::Lower);
707 assert_eq!(lower.get(0, 0), Some(&1.0));
708 assert_eq!(lower.get(1, 0), Some(&4.0));
709 assert_eq!(lower.get(2, 0), Some(&7.0));
710 assert_eq!(lower.get(1, 1), Some(&5.0));
711 assert_eq!(lower.get(2, 1), Some(&8.0));
712 assert_eq!(lower.get(2, 2), Some(&9.0));
713 }
714
715 #[test]
716 fn test_packed_diagonal() {
717 let mut p: PackedMat<f64> = PackedMat::zeros(3, TriangularKind::Upper);
718 p.set(0, 0, 1.0);
719 p.set(0, 1, 10.0);
720 p.set(1, 1, 2.0);
721 p.set(0, 2, 20.0);
722 p.set(1, 2, 30.0);
723 p.set(2, 2, 3.0);
724
725 let diag = p.diagonal();
726 assert_eq!(diag, vec![1.0, 2.0, 3.0]);
727
728 p.set_diagonal(&[10.0, 20.0, 30.0]);
730 let diag2 = p.diagonal();
731 assert_eq!(diag2, vec![10.0, 20.0, 30.0]);
732 }
733
734 #[test]
735 fn test_packed_transpose() {
736 let mut upper: PackedMat<f64> = PackedMat::zeros(3, TriangularKind::Upper);
737 upper.set(0, 0, 1.0);
738 upper.set(0, 1, 2.0);
739 upper.set(1, 1, 3.0);
740 upper.set(0, 2, 4.0);
741 upper.set(1, 2, 5.0);
742 upper.set(2, 2, 6.0);
743
744 let lower = upper.transpose();
745 assert_eq!(lower.kind(), TriangularKind::Lower);
746
747 assert_eq!(lower.get(0, 0), Some(&1.0));
749 assert_eq!(lower.get(1, 0), Some(&2.0)); assert_eq!(lower.get(1, 1), Some(&3.0));
751 assert_eq!(lower.get(2, 0), Some(&4.0)); assert_eq!(lower.get(2, 1), Some(&5.0)); assert_eq!(lower.get(2, 2), Some(&6.0));
754 }
755
756 #[test]
757 fn test_packed_ref() {
758 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
759 let pref = PackedRef::from_slice(&data, 3, TriangularKind::Upper);
760
761 assert_eq!(pref.dim(), 3);
762 assert_eq!(pref.get(0, 0), Some(&1.0));
763 assert_eq!(pref.get(0, 1), Some(&2.0));
764 assert_eq!(pref.get(1, 1), Some(&3.0));
765 assert_eq!(pref.get(0, 2), Some(&4.0));
766 assert_eq!(pref.get(1, 2), Some(&5.0));
767 assert_eq!(pref.get(2, 2), Some(&6.0));
768 }
769
770 #[test]
771 fn test_packed_mut() {
772 let mut data = [0.0f64; 6];
773 let mut pmut = PackedMut::from_slice(&mut data, 3, TriangularKind::Lower);
774
775 pmut.set(0, 0, 1.0);
776 pmut.set(1, 0, 2.0);
777 pmut.set(2, 0, 3.0);
778 pmut.set(1, 1, 4.0);
779 pmut.set(2, 1, 5.0);
780 pmut.set(2, 2, 6.0);
781
782 assert_eq!(data, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
783 }
784
785 #[test]
786 fn test_packed_scale() {
787 let mut p: PackedMat<f64> = PackedMat::zeros(2, TriangularKind::Upper);
788 p.set(0, 0, 1.0);
789 p.set(0, 1, 2.0);
790 p.set(1, 1, 3.0);
791
792 p.scale(2.0);
793
794 assert_eq!(p.get(0, 0), Some(&2.0));
795 assert_eq!(p.get(0, 1), Some(&4.0));
796 assert_eq!(p.get(1, 1), Some(&6.0));
797 }
798}