qudit_core/array/
symsq.rs

1use std::ptr::NonNull;
2// TODO: update faer imports to crate imports
3// TODO: Make helper methods for debug and display that extracts shared functionality from Tensor
4// TODO: add basic derives for clone, PartialEq, Debug and Display
5// TODO: Use strong typing where it makes sense
6// TODO: Add helpful, useful, succinct documentation with examples.
7use super::check_bounds;
8use faer::{MatMut, MatRef, RowMut, RowRef};
9
10use crate::{
11    array::TensorMut,
12    array::TensorRef,
13    memory::{Memorable, MemoryBuffer, alloc_zeroed_memory},
14};
15
16/// Convert SymSqMatMat external indexing to internal indexing.
17///
18/// See [index_to_coords] for more information.
19///
20/// When storing the upper triangular part of a matrix (including the
21/// diagonal) into a compact vector, you essentially flatten the
22/// upper triangular part of the matrix column-wise into a one-dimensional
23/// array. Let's say you have an N*N matrix and a compact vector V of
24/// length N(N+1)/2 to store the upper triangular part of the matrix.
25/// For a matrix coordinate (i,j) in the upper triangular part
26/// where i<=j, the corresponding vector index k can be calculated
27/// using the formula:
28///
29/// ```math
30///     k = j * (j+1) / 2 + i
31/// ```
32#[inline(always)]
33fn coords_to_index(i: usize, j: usize) -> usize {
34    if i <= j {
35        j * (j + 1) / 2 + i
36    } else {
37        i * (i + 1) / 2 + j
38    }
39}
40
41#[inline(always)]
42fn calculate_flat_index<const D: usize>(indices: &[usize; D], strides: &[usize; D]) -> usize {
43    let mut flat_idx = coords_to_index(indices[0], indices[1]) * strides[1];
44    for i in 2..D {
45        flat_idx += indices[i] * strides[i];
46    }
47    flat_idx
48}
49
50// Schwarz's Theorem is satisfied for quantum tensor networks
51//
52// TODO: when const generics can appear in const expressions, this can be rewritten better
53/// A tensor with D dimensions, where the first two dimensions are equal and symmetric.
54pub struct SymSqTensor<C: Memorable, const D: usize> {
55    data: MemoryBuffer<C>,
56    dims: [usize; D],
57    strides: [usize; D],
58}
59
60impl<C: Memorable, const D: usize> SymSqTensor<C, D> {
61    /// Creates a new symmetric square tensor with the given data, dimensions, and strides.
62    pub fn new(data: MemoryBuffer<C>, dims: [usize; D], strides: [usize; D]) -> Self {
63        assert!(
64            D >= 2,
65            "Symmetric square tensors must have 2 or more dimensions."
66        );
67        assert!(
68            dims[0] == dims[1],
69            "Symmetric square tensors must be square in their two major dimensions."
70        );
71        assert!(
72            strides[0] == strides[1] * dims[1],
73            "Symmetric square tensors must be continuous across their two major dimensions."
74        );
75        assert!(
76            dims.iter().all(|&d| d != 0),
77            "Cannot have a zero-length dimension."
78        );
79        assert!(
80            strides.iter().all(|&d| d != 0),
81            "Cannot have a zero-length stride."
82        );
83
84        let mut max_element = [0; D];
85        for (i, d) in dims.iter().enumerate() {
86            max_element[i] = d - 1;
87        }
88        let max_flat_index = calculate_flat_index(&max_element, &strides);
89
90        assert!(
91            data.len() >= max_flat_index,
92            "Data buffer is not large enough."
93        );
94
95        Self {
96            data,
97            dims,
98            strides,
99        }
100    }
101
102    /// Creates a new symmetric square tensor filled with zeros.
103    pub fn zeros(dims: [usize; D]) -> Self {
104        let strides = super::calc_continuous_strides(&dims);
105        let data = alloc_zeroed_memory::<C>(strides[0] * dims[0]);
106        Self::new(data, dims, strides)
107    }
108
109    /// Returns a reference to the dimensions of the tensor.
110    pub fn dims(&self) -> &[usize; D] {
111        &self.dims
112    }
113
114    /// Returns a reference to the strides of the tensor.
115    pub fn strides(&self) -> &[usize; D] {
116        &self.strides
117    }
118
119    /// Returns the rank (number of dimensions) of the tensor.
120    pub fn rank(&self) -> usize {
121        D
122    }
123
124    /// Returns the total number of elements in the tensor.
125    pub fn num_elements(&self) -> usize {
126        self.dims.iter().product()
127    }
128
129    /// Returns a raw pointer to the tensor's data.
130    pub fn as_ptr(&self) -> *const C {
131        self.data.as_ptr()
132    }
133
134    /// Returns a mutable raw pointer to the tensor's data.
135    pub fn as_ptr_mut(&mut self) -> *mut C {
136        self.data.as_mut_ptr()
137    }
138
139    /// Returns an immutable reference to the tensor.
140    pub fn as_ref(&self) -> SymSqTensorRef<'_, C, D> {
141        unsafe { SymSqTensorRef::from_raw_parts(self.data.as_ptr(), self.dims, self.strides) }
142    }
143
144    /// Returns a mutable reference to the tensor.
145    pub fn as_mut(&mut self) -> SymSqTensorMut<'_, C, D> {
146        unsafe { SymSqTensorMut::from_raw_parts(self.data.as_mut_ptr(), self.dims, self.strides) }
147    }
148
149    /// Returns a reference to an element at the given indices.
150    ///
151    /// # Panics
152    ///
153    /// Panics if the indices are out of bounds.
154    pub fn get(&self, indices: &[usize; D]) -> &C {
155        check_bounds(indices, &self.dims);
156        // Safety: bounds are checked by `check_bounds`
157        unsafe { self.get_unchecked(indices) }
158    }
159
160    /// Returns a mutable reference to an element at the given indices.
161    ///
162    /// # Panics
163    ///
164    /// Panics if the indices are out of bounds.
165    pub fn get_mut(&mut self, indices: &[usize; D]) -> &mut C {
166        check_bounds(indices, &self.dims);
167        // Safety: bounds are checked by `check_bounds`
168        unsafe { self.get_mut_unchecked(indices) }
169    }
170
171    /// Returns an immutable reference to an element at the given indices, without performing bounds checks.
172    ///
173    /// # Safety
174    ///
175    /// Calling this method with out-of-bounds `indices` is undefined behavior.
176    #[inline(always)]
177    pub unsafe fn get_unchecked(&self, indices: &[usize; D]) -> &C {
178        unsafe { &*self.ptr_at(indices) }
179    }
180
181    /// Returns a mutable reference to an element at the given indices, without performing bounds checks.
182    ///
183    /// # Safety
184    ///
185    /// Calling this method with out-of-bounds `indices` is undefined behavior.
186    #[inline(always)]
187    pub unsafe fn get_mut_unchecked(&mut self, indices: &[usize; D]) -> &mut C {
188        unsafe { &mut *self.ptr_at_mut(indices) }
189    }
190
191    /// Returns a raw pointer to an element at the given indices, without performing bounds checks.
192    ///
193    /// # Safety
194    ///
195    /// Calling this method with out-of-bounds `indices` is undefined behavior.
196    #[inline(always)]
197    pub unsafe fn ptr_at(&self, indices: &[usize; D]) -> *const C {
198        unsafe {
199            let flat_idx = calculate_flat_index(indices, &self.strides);
200            self.as_ptr().add(flat_idx)
201        }
202    }
203
204    /// Returns a mutable raw pointer to an element at the given indices, without performing bounds checks.
205    ///
206    /// # Safety
207    ///
208    /// Calling this method with out-of-bounds `indices` is undefined behavior.
209    #[inline(always)]
210    pub unsafe fn ptr_at_mut(&mut self, indices: &[usize; D]) -> *mut C {
211        unsafe {
212            let flat_idx = calculate_flat_index(indices, &self.strides);
213            self.as_ptr_mut().add(flat_idx)
214        }
215    }
216}
217
218impl<C: Memorable, const D: usize> std::ops::Index<[usize; D]> for SymSqTensor<C, D> {
219    type Output = C;
220
221    fn index(&self, indices: [usize; D]) -> &Self::Output {
222        self.get(&indices)
223    }
224}
225
226impl<C: Memorable, const D: usize> std::ops::IndexMut<[usize; D]> for SymSqTensor<C, D> {
227    fn index_mut(&mut self, indices: [usize; D]) -> &mut Self::Output {
228        self.get_mut(&indices)
229    }
230}
231
232#[derive(Clone, Copy)]
233/// An immutable reference to a symmetric square tensor.
234pub struct SymSqTensorRef<'a, C: Memorable, const D: usize> {
235    data: NonNull<C>,
236    dims: [usize; D],
237    strides: [usize; D],
238    __marker: std::marker::PhantomData<&'a C>,
239}
240
241impl<'a, C: Memorable, const D: usize> SymSqTensorRef<'a, C, D> {
242    /// Creates a new `SymSqTensorRef` from raw parts.
243    ///
244    /// # Safety
245    ///
246    /// The caller must ensure that `data` points to a valid memory block of `C` elements,
247    /// and that `dims` and `strides` accurately describe the layout of the tensor
248    /// within that memory block. The `data` pointer must be valid for the lifetime `'a`.
249    pub unsafe fn from_raw_parts(data: *const C, dims: [usize; D], strides: [usize; D]) -> Self {
250        unsafe {
251            // SAFETY: The pointer is never used in an mutable context.
252            let mut_ptr = data as *mut C;
253
254            Self {
255                data: NonNull::new_unchecked(mut_ptr),
256                dims,
257                strides,
258                __marker: std::marker::PhantomData,
259            }
260        }
261    }
262
263    /// Returns a reference to the dimensions of the tensor.
264    pub fn dims(&self) -> &[usize; D] {
265        &self.dims
266    }
267
268    /// Returns a reference to the strides of the tensor.
269    pub fn strides(&self) -> &[usize; D] {
270        &self.strides
271    }
272
273    /// Returns the rank (number of dimensions) of the tensor.
274    pub fn rank(&self) -> usize {
275        D
276    }
277
278    /// Returns the total number of elements in the tensor.
279    pub fn num_elements(&self) -> usize {
280        self.dims.iter().product()
281    }
282
283    /// Returns a raw pointer to the tensor's data.
284    pub fn as_ptr(&self) -> *const C {
285        self.data.as_ptr()
286    }
287
288    /// Returns a reference to an element at the given indices.
289    ///
290    /// # Panics
291    ///
292    /// Panics if the indices are out of bounds.
293    pub fn get(&self, indices: &[usize; D]) -> &C {
294        check_bounds(indices, &self.dims);
295        // Safety: bounds are checked by `check_bounds`
296        unsafe { self.get_unchecked(indices) }
297    }
298
299    /// Returns an immutable reference to an element at the given indices, without performing bounds checks.
300    ///
301    /// # Safety
302    ///
303    /// Calling this method with out-of-bounds `indices` is undefined behavior.
304    #[inline(always)]
305    pub unsafe fn get_unchecked(&self, indices: &[usize; D]) -> &C {
306        unsafe { &*self.ptr_at(indices) }
307    }
308
309    /// Returns a raw pointer to an element at the given indices, without performing bounds checks.
310    ///
311    /// # Safety
312    ///
313    /// Calling this method with out-of-bounds `indices` is undefined behavior.
314    #[inline(always)]
315    pub unsafe fn ptr_at(&self, indices: &[usize; D]) -> *const C {
316        unsafe {
317            let flat_idx = calculate_flat_index(indices, &self.strides);
318            self.as_ptr().add(flat_idx)
319        }
320    }
321}
322
323impl<'a, C: Memorable, const D: usize> std::ops::Index<[usize; D]> for SymSqTensorRef<'a, C, D> {
324    type Output = C;
325
326    fn index(&self, indices: [usize; D]) -> &Self::Output {
327        self.get(&indices)
328    }
329}
330
331/// A mutable reference to a symmetric square tensor.
332pub struct SymSqTensorMut<'a, C: Memorable, const D: usize> {
333    data: NonNull<C>,
334    dims: [usize; D],
335    strides: [usize; D],
336    __marker: std::marker::PhantomData<&'a mut C>,
337}
338
339impl<'a, C: Memorable, const D: usize> SymSqTensorMut<'a, C, D> {
340    /// Creates a new `SymSqTensorMut` from raw parts.
341    ///
342    /// # Safety
343    ///
344    /// The caller must ensure that `data` points to a valid memory block of `C` elements,
345    /// and that `dims` and `strides` accurately describe the layout of the tensor
346    /// within that memory block. The `data` pointer must be valid for the lifetime `'a`
347    /// and that it is safe to mutate the data.
348    pub unsafe fn from_raw_parts(data: *mut C, dims: [usize; D], strides: [usize; D]) -> Self {
349        unsafe {
350            Self {
351                data: NonNull::new_unchecked(data),
352                dims,
353                strides,
354                __marker: std::marker::PhantomData,
355            }
356        }
357    }
358
359    /// Returns a reference to the dimensions of the tensor.
360    pub fn dims(&self) -> &[usize; D] {
361        &self.dims
362    }
363
364    /// Returns a reference to the strides of the tensor.
365    pub fn strides(&self) -> &[usize; D] {
366        &self.strides
367    }
368
369    /// Returns the rank (number of dimensions) of the tensor.
370    pub fn rank(&self) -> usize {
371        D
372    }
373
374    /// Returns the total number of elements in the tensor.
375    pub fn num_elements(&self) -> usize {
376        self.dims.iter().product()
377    }
378
379    /// Returns a mutable raw pointer to the tensor's data.
380    pub fn as_ptr(&self) -> *const C {
381        self.data.as_ptr() as *const C
382    }
383
384    /// Returns a mutable raw pointer to the tensor's data.
385    pub fn as_ptr_mut(&mut self) -> *mut C {
386        self.data.as_ptr()
387    }
388
389    /// Returns a reference to an element at the given indices.
390    ///
391    /// # Panics
392    ///
393    /// Panics if the indices are out of bounds.
394    pub fn get(&self, indices: &[usize; D]) -> &C {
395        check_bounds(indices, &self.dims);
396        // Safety: bounds are checked by `check_bounds`
397        unsafe { self.get_unchecked(indices) }
398    }
399
400    /// Returns a mutable reference to an element at the given indices.
401    ///
402    /// # Panics
403    ///
404    /// Panics if the indices are out of bounds.
405    pub fn get_mut(&mut self, indices: &[usize; D]) -> &mut C {
406        check_bounds(indices, &self.dims);
407        // Safety: bounds are checked by `check_bounds`
408        unsafe { self.get_mut_unchecked(indices) }
409    }
410
411    /// Returns an immutable reference to an element at the given indices, without performing bounds checks.
412    ///
413    /// # Safety
414    ///
415    /// Calling this method with out-of-bounds `indices` is undefined behavior.
416    #[inline(always)]
417    pub unsafe fn get_unchecked(&self, indices: &[usize; D]) -> &C {
418        unsafe { &*self.ptr_at(indices) }
419    }
420
421    /// Returns a mutable reference to an element at the given indices, without performing bounds checks.
422    ///
423    /// # Safety
424    ///
425    /// Calling this method with out-of-bounds `indices` is undefined behavior.
426    #[inline(always)]
427    pub unsafe fn get_mut_unchecked(&mut self, indices: &[usize; D]) -> &mut C {
428        unsafe { &mut *self.ptr_at_mut(indices) }
429    }
430
431    /// Returns a raw pointer to an element at the given indices, without performing bounds checks.
432    ///
433    /// # Safety
434    ///
435    /// Calling this method with out-of-bounds `indices` is undefined behavior.
436    #[inline(always)]
437    pub unsafe fn ptr_at(&self, indices: &[usize; D]) -> *const C {
438        unsafe {
439            let flat_idx = calculate_flat_index(indices, &self.strides);
440            self.as_ptr().add(flat_idx)
441        }
442    }
443
444    /// Returns a mutable raw pointer to an element at the given indices, without performing bounds checks.
445    ///
446    /// # Safety
447    ///
448    /// Calling this method with out-of-bounds `indices` is undefined behavior.
449    #[inline(always)]
450    pub unsafe fn ptr_at_mut(&mut self, indices: &[usize; D]) -> *mut C {
451        unsafe {
452            let flat_idx = calculate_flat_index(indices, &self.strides);
453            self.as_ptr_mut().add(flat_idx)
454        }
455    }
456}
457
458impl<'a, C: Memorable, const D: usize> std::ops::Index<[usize; D]> for SymSqTensorMut<'a, C, D> {
459    type Output = C;
460
461    fn index(&self, indices: [usize; D]) -> &Self::Output {
462        self.get(&indices)
463    }
464}
465
466impl<'a, C: Memorable, const D: usize> std::ops::IndexMut<[usize; D]> for SymSqTensorMut<'a, C, D> {
467    fn index_mut(&mut self, indices: [usize; D]) -> &mut Self::Output {
468        self.get_mut(&indices)
469    }
470}
471
472// TODO add some documentation plus a todo tag on relevant rust issues (const generic expressions)
473impl<C: Memorable> SymSqTensor<C, 5> {
474    /// Returns an immutable reference to the 3D subtensor at the given matrix indices.
475    pub fn subtensor_ref(&self, m1: usize, m2: usize) -> TensorRef<'_, C, 3> {
476        check_bounds(&[m1, m2, 0, 0, 0], &self.dims);
477        // Safety: bounds have been checked.
478        unsafe { self.subtensor_ref_unchecked(m1, m2) }
479    }
480
481    /// Returns a mutable reference to the 3D subtensor at the given matrix indices.
482    pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> TensorMut<'_, C, 3> {
483        check_bounds(&[m1, m2, 0, 0, 0], &self.dims);
484        // Safety: bounds have been checked.
485        unsafe { self.subtensor_mut_unchecked(m1, m2) }
486    }
487
488    /// Returns an immutable reference to the 3D subtensor at the given matrix indices without bounds checking.
489    ///
490    /// # Safety
491    ///
492    /// Caller should ensure that m1 and m2 are in bounds.
493    pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> TensorRef<'_, C, 3> {
494        unsafe {
495            TensorRef::from_raw_parts(
496                self.ptr_at(&[m1, m2, 0, 0, 0]),
497                [self.dims[2], self.dims[3], self.dims[4]],
498                [self.strides[2], self.strides[3], self.strides[4]],
499            )
500        }
501    }
502
503    /// Returns a mutable reference to the 3D subtensor at the given matrix indices without bounds checking.
504    ///
505    /// # Safety
506    ///
507    /// Caller should ensure that m1 and m2 are in bounds.
508    pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> TensorMut<'_, C, 3> {
509        unsafe {
510            TensorMut::from_raw_parts(
511                self.ptr_at_mut(&[m1, m2, 0, 0, 0]),
512                [self.dims[2], self.dims[3], self.dims[4]],
513                [self.strides[2], self.strides[3], self.strides[4]],
514            )
515        }
516    }
517}
518
519impl<C: Memorable> SymSqTensor<C, 4> {
520    /// Returns an immutable matrix reference to the subtensor at the given indices.
521    pub fn subtensor_ref(&self, m1: usize, m2: usize) -> MatRef<'_, C> {
522        check_bounds(&[m1, m2, 0, 0], &self.dims);
523        // Safety: bounds have been checked.
524        unsafe { self.subtensor_ref_unchecked(m1, m2) }
525    }
526
527    /// Returns a mutable matrix reference to the subtensor at the given indices.
528    pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> MatMut<'_, C> {
529        check_bounds(&[m1, m2, 0, 0], &self.dims);
530        // Safety: bounds have been checked.
531        unsafe { self.subtensor_mut_unchecked(m1, m2) }
532    }
533
534    /// Returns an immutable matrix reference to the subtensor at the given indices without bounds checking.
535    ///
536    /// # Safety
537    ///
538    /// Caller should ensure that m1 and m2 are in bounds.
539    pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> MatRef<'_, C> {
540        unsafe {
541            MatRef::from_raw_parts(
542                self.ptr_at(&[m1, m2, 0, 0]),
543                self.dims[2],
544                self.dims[3],
545                self.strides[2] as isize,
546                self.strides[3] as isize,
547            )
548        }
549    }
550
551    /// Returns a mutable matrix reference to the subtensor at the given indices without bounds checking.
552    ///
553    /// # Safety
554    ///
555    /// Caller should ensure that m1 and m2 are in bounds.
556    pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> MatMut<'_, C> {
557        unsafe {
558            MatMut::from_raw_parts_mut(
559                self.ptr_at_mut(&[m1, m2, 0, 0]),
560                self.dims[2],
561                self.dims[3],
562                self.strides[2] as isize,
563                self.strides[3] as isize,
564            )
565        }
566    }
567}
568
569impl<C: Memorable> SymSqTensor<C, 3> {
570    /// Returns an immutable row reference to the subtensor at the given indices.
571    pub fn subtensor_ref(&self, m1: usize, m2: usize) -> RowRef<'_, C> {
572        check_bounds(&[m1, m2, 0], &self.dims);
573        // Safety: bounds have been checked.
574        unsafe { self.subtensor_ref_unchecked(m1, m2) }
575    }
576
577    /// Returns a mutable row reference to the subtensor at the given indices.
578    pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> RowMut<'_, C> {
579        check_bounds(&[m1, m2, 0], &self.dims);
580        // Safety: bounds have been checked.
581        unsafe { self.subtensor_mut_unchecked(m1, m2) }
582    }
583
584    /// Returns an immutable row reference to the subtensor at the given indices without bounds checking.
585    ///
586    /// # Safety
587    ///
588    /// Caller should ensure that m1 and m2 are in bounds.
589    pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> RowRef<'_, C> {
590        unsafe {
591            RowRef::from_raw_parts(
592                self.ptr_at(&[m1, m2, 0]),
593                self.dims[2],
594                self.strides[2] as isize,
595            )
596        }
597    }
598
599    /// Returns a mutable row reference to the subtensor at the given indices without bounds checking.
600    ///
601    /// # Safety
602    ///
603    /// Caller should ensure that m1 and m2 are in bounds.
604    pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> RowMut<'_, C> {
605        unsafe {
606            RowMut::from_raw_parts_mut(
607                self.ptr_at_mut(&[m1, m2, 0]),
608                self.dims[2],
609                self.strides[2] as isize,
610            )
611        }
612    }
613}
614
615impl<C: Memorable> SymSqTensor<C, 2> {
616    /// Returns an immutable reference to the element at the given indices.
617    pub fn subtensor_ref(&self, m1: usize, m2: usize) -> &C {
618        self.get(&[m1, m2])
619    }
620
621    /// Returns a mutable reference to the element at the given indices.
622    pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> &mut C {
623        self.get_mut(&[m1, m2])
624    }
625
626    /// Returns an immutable reference to the element at the given indices without bounds checking.
627    ///
628    /// # Safety
629    ///
630    /// Caller should ensure that m1 and m2 are in bounds.
631    pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> &C {
632        unsafe { self.get_unchecked(&[m1, m2]) }
633    }
634
635    /// Returns a mutable reference to the element at the given indices without bounds checking.
636    ///
637    /// # Safety
638    ///
639    /// Caller should ensure that m1 and m2 are in bounds.
640    pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> &mut C {
641        unsafe { self.get_mut_unchecked(&[m1, m2]) }
642    }
643}
644
645impl<'a, C: Memorable> SymSqTensorRef<'a, C, 5> {
646    /// Returns an immutable reference to the 3D subtensor at the given matrix indices.
647    pub fn subtensor_ref(&self, m1: usize, m2: usize) -> TensorRef<'a, C, 3> {
648        check_bounds(&[m1, m2, 0, 0, 0], &self.dims);
649        // Safety: bounds have been checked.
650        unsafe { self.subtensor_ref_unchecked(m1, m2) }
651    }
652
653    /// Returns an immutable reference to the 3D subtensor at the given matrix indices without bounds checking.
654    ///
655    /// # Safety
656    ///
657    /// Caller should ensure that m1 and m2 are in bounds.
658    pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> TensorRef<'a, C, 3> {
659        unsafe {
660            TensorRef::from_raw_parts(
661                self.ptr_at(&[m1, m2, 0, 0, 0]),
662                [self.dims[2], self.dims[3], self.dims[4]],
663                [self.strides[2], self.strides[3], self.strides[4]],
664            )
665        }
666    }
667}
668
669impl<'a, C: Memorable> SymSqTensorRef<'a, C, 4> {
670    /// Returns an immutable matrix reference to the subtensor at the given indices.
671    pub fn subtensor_ref(&self, m1: usize, m2: usize) -> MatRef<'a, C> {
672        check_bounds(&[m1, m2, 0, 0], &self.dims);
673        // Safety: bounds have been checked.
674        unsafe { self.subtensor_ref_unchecked(m1, m2) }
675    }
676
677    /// Returns an immutable matrix reference to the subtensor at the given indices without bounds checking.
678    ///
679    /// # Safety
680    ///
681    /// Caller should ensure that m1 and m2 are in bounds.
682    pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> MatRef<'a, C> {
683        unsafe {
684            MatRef::from_raw_parts(
685                self.ptr_at(&[m1, m2, 0, 0]),
686                self.dims[2],
687                self.dims[3],
688                self.strides[2] as isize,
689                self.strides[3] as isize,
690            )
691        }
692    }
693}
694
695impl<'a, C: Memorable> SymSqTensorRef<'a, C, 3> {
696    /// Returns an immutable row reference to the subtensor at the given indices.
697    pub fn subtensor_ref(&self, m1: usize, m2: usize) -> RowRef<'a, C> {
698        check_bounds(&[m1, m2, 0], &self.dims);
699        // Safety: bounds have been checked.
700        unsafe { self.subtensor_ref_unchecked(m1, m2) }
701    }
702
703    /// Returns an immutable row reference to the subtensor at the given indices without bounds checking.
704    ///
705    /// # Safety
706    ///
707    /// Caller should ensure that m1 and m2 are in bounds.
708    pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> RowRef<'a, C> {
709        unsafe {
710            RowRef::from_raw_parts(
711                self.ptr_at(&[m1, m2, 0]),
712                self.dims[2],
713                self.strides[2] as isize,
714            )
715        }
716    }
717}
718
719impl<'a, C: Memorable> SymSqTensorRef<'a, C, 2> {
720    /// Returns an immutable reference to the element at the given indices.
721    pub fn subtensor_ref(&self, m1: usize, m2: usize) -> &C {
722        self.get(&[m1, m2])
723    }
724
725    /// Returns an immutable reference to the element at the given indices without bounds checking.
726    ///
727    /// # Safety
728    ///
729    /// Caller should ensure that m1 and m2 are in bounds.
730    pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> &C {
731        unsafe { self.get_unchecked(&[m1, m2]) }
732    }
733}
734
735impl<'a, C: Memorable> SymSqTensorMut<'a, C, 5> {
736    /// Returns an immutable reference to the 3D subtensor at the given matrix indices.
737    pub fn subtensor_ref(&self, m1: usize, m2: usize) -> TensorRef<'a, C, 3> {
738        check_bounds(&[m1, m2, 0, 0, 0], &self.dims);
739        // Safety: bounds have been checked.
740        unsafe { self.subtensor_ref_unchecked(m1, m2) }
741    }
742
743    /// Returns a mutable reference to the 3D subtensor at the given matrix indices.
744    pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> TensorMut<'a, C, 3> {
745        check_bounds(&[m1, m2, 0, 0, 0], &self.dims);
746        // Safety: bounds have been checked.
747        unsafe { self.subtensor_mut_unchecked(m1, m2) }
748    }
749
750    /// Returns an immutable reference to the 3D subtensor at the given matrix indices without bounds checking.
751    ///
752    /// # Safety
753    ///
754    /// Caller should ensure that m1 and m2 are in bounds.
755    pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> TensorRef<'a, C, 3> {
756        unsafe {
757            TensorRef::from_raw_parts(
758                self.ptr_at(&[m1, m2, 0, 0, 0]),
759                [self.dims[2], self.dims[3], self.dims[4]],
760                [self.strides[2], self.strides[3], self.strides[4]],
761            )
762        }
763    }
764
765    /// Returns a mutable reference to the 3D subtensor at the given matrix indices without bounds checking.
766    ///
767    /// # Safety
768    ///
769    /// Caller should ensure that m1 and m2 are in bounds.
770    pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> TensorMut<'a, C, 3> {
771        unsafe {
772            TensorMut::from_raw_parts(
773                self.ptr_at_mut(&[m1, m2, 0, 0, 0]),
774                [self.dims[2], self.dims[3], self.dims[4]],
775                [self.strides[2], self.strides[3], self.strides[4]],
776            )
777        }
778    }
779}
780
781impl<'a, C: Memorable> SymSqTensorMut<'a, C, 4> {
782    /// Returns an immutable matrix reference to the subtensor at the given indices.
783    pub fn subtensor_ref(&self, m1: usize, m2: usize) -> MatRef<'a, C> {
784        check_bounds(&[m1, m2, 0, 0], &self.dims);
785        // Safety: bounds have been checked.
786        unsafe { self.subtensor_ref_unchecked(m1, m2) }
787    }
788
789    /// Returns a mutable matrix reference to the subtensor at the given indices.
790    pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> MatMut<'a, C> {
791        check_bounds(&[m1, m2, 0, 0], &self.dims);
792        // Safety: bounds have been checked.
793        unsafe { self.subtensor_mut_unchecked(m1, m2) }
794    }
795
796    /// Returns an immutable matrix reference to the subtensor at the given indices without bounds checking.
797    ///
798    /// # Safety
799    ///
800    /// Caller should ensure that m1 and m2 are in bounds.
801    pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> MatRef<'a, C> {
802        unsafe {
803            MatRef::from_raw_parts(
804                self.ptr_at(&[m1, m2, 0, 0]),
805                self.dims[2],
806                self.dims[3],
807                self.strides[2] as isize,
808                self.strides[3] as isize,
809            )
810        }
811    }
812
813    /// Returns a mutable matrix reference to the subtensor at the given indices without bounds checking.
814    ///
815    /// # Safety
816    ///
817    /// Caller should ensure that m1 and m2 are in bounds.
818    pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> MatMut<'a, C> {
819        unsafe {
820            MatMut::from_raw_parts_mut(
821                self.ptr_at_mut(&[m1, m2, 0, 0]),
822                self.dims[2],
823                self.dims[3],
824                self.strides[2] as isize,
825                self.strides[3] as isize,
826            )
827        }
828    }
829}
830
831impl<'a, C: Memorable> SymSqTensorMut<'a, C, 3> {
832    /// Returns an immutable row reference to the subtensor at the given indices.
833    pub fn subtensor_ref(&self, m1: usize, m2: usize) -> RowRef<'a, C> {
834        check_bounds(&[m1, m2, 0], &self.dims);
835        // Safety: bounds have been checked.
836        unsafe { self.subtensor_ref_unchecked(m1, m2) }
837    }
838
839    /// Returns a mutable row reference to the subtensor at the given indices.
840    pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> RowMut<'a, C> {
841        check_bounds(&[m1, m2, 0], &self.dims);
842        // Safety: bounds have been checked.
843        unsafe { self.subtensor_mut_unchecked(m1, m2) }
844    }
845
846    /// Returns an immutable row reference to the subtensor at the given indices without bounds checking.
847    ///
848    /// # Safety
849    ///
850    /// Caller should ensure that m1 and m2 are in bounds.
851    pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> RowRef<'a, C> {
852        unsafe {
853            RowRef::from_raw_parts(
854                self.ptr_at(&[m1, m2, 0]),
855                self.dims[2],
856                self.strides[2] as isize,
857            )
858        }
859    }
860
861    /// Returns a mutable row reference to the subtensor at the given indices without bounds checking.
862    ///
863    /// # Safety
864    ///
865    /// Caller should ensure that m1 and m2 are in bounds.
866    pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> RowMut<'a, C> {
867        unsafe {
868            RowMut::from_raw_parts_mut(
869                self.ptr_at_mut(&[m1, m2, 0]),
870                self.dims[2],
871                self.strides[2] as isize,
872            )
873        }
874    }
875}
876
877impl<'a, C: Memorable> SymSqTensorMut<'a, C, 2> {
878    /// Returns an immutable reference to the element at the given indices.
879    pub fn subtensor_ref(&self, m1: usize, m2: usize) -> &C {
880        self.get(&[m1, m2])
881    }
882
883    /// Returns a mutable reference to the element at the given indices.
884    pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> &mut C {
885        self.get_mut(&[m1, m2])
886    }
887
888    /// Returns an immutable reference to the element at the given indices without bounds checking.
889    ///
890    /// # Safety
891    ///
892    /// Caller should ensure that m1 and m2 are in bounds.
893    pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> &C {
894        unsafe { self.get_unchecked(&[m1, m2]) }
895    }
896
897    /// Returns a mutable reference to the element at the given indices without bounds checking.
898    ///
899    /// # Safety
900    ///
901    /// Caller should ensure that m1 and m2 are in bounds.
902    pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> &mut C {
903        unsafe { self.get_mut_unchecked(&[m1, m2]) }
904    }
905}