qudit_core/array/symsq.rs
1use std::ptr::NonNull;
2// TODO: update faer imports to crate imports
3// TODO: Make helper methods for debug and display that extracts shared functionality from Tensor
4// TODO: add basic derives for clone, PartialEq, Debug and Display
5// TODO: Use strong typing where it makes sense
6// TODO: Add helpful, useful, succinct documentation with examples.
7use super::check_bounds;
8use faer::{MatMut, MatRef, RowMut, RowRef};
9
10use crate::{
11 array::TensorMut,
12 array::TensorRef,
13 memory::{Memorable, MemoryBuffer, alloc_zeroed_memory},
14};
15
16/// Convert SymSqMatMat external indexing to internal indexing.
17///
18/// See [index_to_coords] for more information.
19///
20/// When storing the upper triangular part of a matrix (including the
21/// diagonal) into a compact vector, you essentially flatten the
22/// upper triangular part of the matrix column-wise into a one-dimensional
23/// array. Let's say you have an N*N matrix and a compact vector V of
24/// length N(N+1)/2 to store the upper triangular part of the matrix.
25/// For a matrix coordinate (i,j) in the upper triangular part
26/// where i<=j, the corresponding vector index k can be calculated
27/// using the formula:
28///
29/// ```math
30/// k = j * (j+1) / 2 + i
31/// ```
32#[inline(always)]
33fn coords_to_index(i: usize, j: usize) -> usize {
34 if i <= j {
35 j * (j + 1) / 2 + i
36 } else {
37 i * (i + 1) / 2 + j
38 }
39}
40
41#[inline(always)]
42fn calculate_flat_index<const D: usize>(indices: &[usize; D], strides: &[usize; D]) -> usize {
43 let mut flat_idx = coords_to_index(indices[0], indices[1]) * strides[1];
44 for i in 2..D {
45 flat_idx += indices[i] * strides[i];
46 }
47 flat_idx
48}
49
50// Schwarz's Theorem is satisfied for quantum tensor networks
51//
52// TODO: when const generics can appear in const expressions, this can be rewritten better
53/// A tensor with D dimensions, where the first two dimensions are equal and symmetric.
54pub struct SymSqTensor<C: Memorable, const D: usize> {
55 data: MemoryBuffer<C>,
56 dims: [usize; D],
57 strides: [usize; D],
58}
59
60impl<C: Memorable, const D: usize> SymSqTensor<C, D> {
61 /// Creates a new symmetric square tensor with the given data, dimensions, and strides.
62 pub fn new(data: MemoryBuffer<C>, dims: [usize; D], strides: [usize; D]) -> Self {
63 assert!(
64 D >= 2,
65 "Symmetric square tensors must have 2 or more dimensions."
66 );
67 assert!(
68 dims[0] == dims[1],
69 "Symmetric square tensors must be square in their two major dimensions."
70 );
71 assert!(
72 strides[0] == strides[1] * dims[1],
73 "Symmetric square tensors must be continuous across their two major dimensions."
74 );
75 assert!(
76 dims.iter().all(|&d| d != 0),
77 "Cannot have a zero-length dimension."
78 );
79 assert!(
80 strides.iter().all(|&d| d != 0),
81 "Cannot have a zero-length stride."
82 );
83
84 let mut max_element = [0; D];
85 for (i, d) in dims.iter().enumerate() {
86 max_element[i] = d - 1;
87 }
88 let max_flat_index = calculate_flat_index(&max_element, &strides);
89
90 assert!(
91 data.len() >= max_flat_index,
92 "Data buffer is not large enough."
93 );
94
95 Self {
96 data,
97 dims,
98 strides,
99 }
100 }
101
102 /// Creates a new symmetric square tensor filled with zeros.
103 pub fn zeros(dims: [usize; D]) -> Self {
104 let strides = super::calc_continuous_strides(&dims);
105 let data = alloc_zeroed_memory::<C>(strides[0] * dims[0]);
106 Self::new(data, dims, strides)
107 }
108
109 /// Returns a reference to the dimensions of the tensor.
110 pub fn dims(&self) -> &[usize; D] {
111 &self.dims
112 }
113
114 /// Returns a reference to the strides of the tensor.
115 pub fn strides(&self) -> &[usize; D] {
116 &self.strides
117 }
118
119 /// Returns the rank (number of dimensions) of the tensor.
120 pub fn rank(&self) -> usize {
121 D
122 }
123
124 /// Returns the total number of elements in the tensor.
125 pub fn num_elements(&self) -> usize {
126 self.dims.iter().product()
127 }
128
129 /// Returns a raw pointer to the tensor's data.
130 pub fn as_ptr(&self) -> *const C {
131 self.data.as_ptr()
132 }
133
134 /// Returns a mutable raw pointer to the tensor's data.
135 pub fn as_ptr_mut(&mut self) -> *mut C {
136 self.data.as_mut_ptr()
137 }
138
139 /// Returns an immutable reference to the tensor.
140 pub fn as_ref(&self) -> SymSqTensorRef<'_, C, D> {
141 unsafe { SymSqTensorRef::from_raw_parts(self.data.as_ptr(), self.dims, self.strides) }
142 }
143
144 /// Returns a mutable reference to the tensor.
145 pub fn as_mut(&mut self) -> SymSqTensorMut<'_, C, D> {
146 unsafe { SymSqTensorMut::from_raw_parts(self.data.as_mut_ptr(), self.dims, self.strides) }
147 }
148
149 /// Returns a reference to an element at the given indices.
150 ///
151 /// # Panics
152 ///
153 /// Panics if the indices are out of bounds.
154 pub fn get(&self, indices: &[usize; D]) -> &C {
155 check_bounds(indices, &self.dims);
156 // Safety: bounds are checked by `check_bounds`
157 unsafe { self.get_unchecked(indices) }
158 }
159
160 /// Returns a mutable reference to an element at the given indices.
161 ///
162 /// # Panics
163 ///
164 /// Panics if the indices are out of bounds.
165 pub fn get_mut(&mut self, indices: &[usize; D]) -> &mut C {
166 check_bounds(indices, &self.dims);
167 // Safety: bounds are checked by `check_bounds`
168 unsafe { self.get_mut_unchecked(indices) }
169 }
170
171 /// Returns an immutable reference to an element at the given indices, without performing bounds checks.
172 ///
173 /// # Safety
174 ///
175 /// Calling this method with out-of-bounds `indices` is undefined behavior.
176 #[inline(always)]
177 pub unsafe fn get_unchecked(&self, indices: &[usize; D]) -> &C {
178 unsafe { &*self.ptr_at(indices) }
179 }
180
181 /// Returns a mutable reference to an element at the given indices, without performing bounds checks.
182 ///
183 /// # Safety
184 ///
185 /// Calling this method with out-of-bounds `indices` is undefined behavior.
186 #[inline(always)]
187 pub unsafe fn get_mut_unchecked(&mut self, indices: &[usize; D]) -> &mut C {
188 unsafe { &mut *self.ptr_at_mut(indices) }
189 }
190
191 /// Returns a raw pointer to an element at the given indices, without performing bounds checks.
192 ///
193 /// # Safety
194 ///
195 /// Calling this method with out-of-bounds `indices` is undefined behavior.
196 #[inline(always)]
197 pub unsafe fn ptr_at(&self, indices: &[usize; D]) -> *const C {
198 unsafe {
199 let flat_idx = calculate_flat_index(indices, &self.strides);
200 self.as_ptr().add(flat_idx)
201 }
202 }
203
204 /// Returns a mutable raw pointer to an element at the given indices, without performing bounds checks.
205 ///
206 /// # Safety
207 ///
208 /// Calling this method with out-of-bounds `indices` is undefined behavior.
209 #[inline(always)]
210 pub unsafe fn ptr_at_mut(&mut self, indices: &[usize; D]) -> *mut C {
211 unsafe {
212 let flat_idx = calculate_flat_index(indices, &self.strides);
213 self.as_ptr_mut().add(flat_idx)
214 }
215 }
216}
217
218impl<C: Memorable, const D: usize> std::ops::Index<[usize; D]> for SymSqTensor<C, D> {
219 type Output = C;
220
221 fn index(&self, indices: [usize; D]) -> &Self::Output {
222 self.get(&indices)
223 }
224}
225
226impl<C: Memorable, const D: usize> std::ops::IndexMut<[usize; D]> for SymSqTensor<C, D> {
227 fn index_mut(&mut self, indices: [usize; D]) -> &mut Self::Output {
228 self.get_mut(&indices)
229 }
230}
231
232#[derive(Clone, Copy)]
233/// An immutable reference to a symmetric square tensor.
234pub struct SymSqTensorRef<'a, C: Memorable, const D: usize> {
235 data: NonNull<C>,
236 dims: [usize; D],
237 strides: [usize; D],
238 __marker: std::marker::PhantomData<&'a C>,
239}
240
241impl<'a, C: Memorable, const D: usize> SymSqTensorRef<'a, C, D> {
242 /// Creates a new `SymSqTensorRef` from raw parts.
243 ///
244 /// # Safety
245 ///
246 /// The caller must ensure that `data` points to a valid memory block of `C` elements,
247 /// and that `dims` and `strides` accurately describe the layout of the tensor
248 /// within that memory block. The `data` pointer must be valid for the lifetime `'a`.
249 pub unsafe fn from_raw_parts(data: *const C, dims: [usize; D], strides: [usize; D]) -> Self {
250 unsafe {
251 // SAFETY: The pointer is never used in an mutable context.
252 let mut_ptr = data as *mut C;
253
254 Self {
255 data: NonNull::new_unchecked(mut_ptr),
256 dims,
257 strides,
258 __marker: std::marker::PhantomData,
259 }
260 }
261 }
262
263 /// Returns a reference to the dimensions of the tensor.
264 pub fn dims(&self) -> &[usize; D] {
265 &self.dims
266 }
267
268 /// Returns a reference to the strides of the tensor.
269 pub fn strides(&self) -> &[usize; D] {
270 &self.strides
271 }
272
273 /// Returns the rank (number of dimensions) of the tensor.
274 pub fn rank(&self) -> usize {
275 D
276 }
277
278 /// Returns the total number of elements in the tensor.
279 pub fn num_elements(&self) -> usize {
280 self.dims.iter().product()
281 }
282
283 /// Returns a raw pointer to the tensor's data.
284 pub fn as_ptr(&self) -> *const C {
285 self.data.as_ptr()
286 }
287
288 /// Returns a reference to an element at the given indices.
289 ///
290 /// # Panics
291 ///
292 /// Panics if the indices are out of bounds.
293 pub fn get(&self, indices: &[usize; D]) -> &C {
294 check_bounds(indices, &self.dims);
295 // Safety: bounds are checked by `check_bounds`
296 unsafe { self.get_unchecked(indices) }
297 }
298
299 /// Returns an immutable reference to an element at the given indices, without performing bounds checks.
300 ///
301 /// # Safety
302 ///
303 /// Calling this method with out-of-bounds `indices` is undefined behavior.
304 #[inline(always)]
305 pub unsafe fn get_unchecked(&self, indices: &[usize; D]) -> &C {
306 unsafe { &*self.ptr_at(indices) }
307 }
308
309 /// Returns a raw pointer to an element at the given indices, without performing bounds checks.
310 ///
311 /// # Safety
312 ///
313 /// Calling this method with out-of-bounds `indices` is undefined behavior.
314 #[inline(always)]
315 pub unsafe fn ptr_at(&self, indices: &[usize; D]) -> *const C {
316 unsafe {
317 let flat_idx = calculate_flat_index(indices, &self.strides);
318 self.as_ptr().add(flat_idx)
319 }
320 }
321}
322
323impl<'a, C: Memorable, const D: usize> std::ops::Index<[usize; D]> for SymSqTensorRef<'a, C, D> {
324 type Output = C;
325
326 fn index(&self, indices: [usize; D]) -> &Self::Output {
327 self.get(&indices)
328 }
329}
330
331/// A mutable reference to a symmetric square tensor.
332pub struct SymSqTensorMut<'a, C: Memorable, const D: usize> {
333 data: NonNull<C>,
334 dims: [usize; D],
335 strides: [usize; D],
336 __marker: std::marker::PhantomData<&'a mut C>,
337}
338
339impl<'a, C: Memorable, const D: usize> SymSqTensorMut<'a, C, D> {
340 /// Creates a new `SymSqTensorMut` from raw parts.
341 ///
342 /// # Safety
343 ///
344 /// The caller must ensure that `data` points to a valid memory block of `C` elements,
345 /// and that `dims` and `strides` accurately describe the layout of the tensor
346 /// within that memory block. The `data` pointer must be valid for the lifetime `'a`
347 /// and that it is safe to mutate the data.
348 pub unsafe fn from_raw_parts(data: *mut C, dims: [usize; D], strides: [usize; D]) -> Self {
349 unsafe {
350 Self {
351 data: NonNull::new_unchecked(data),
352 dims,
353 strides,
354 __marker: std::marker::PhantomData,
355 }
356 }
357 }
358
359 /// Returns a reference to the dimensions of the tensor.
360 pub fn dims(&self) -> &[usize; D] {
361 &self.dims
362 }
363
364 /// Returns a reference to the strides of the tensor.
365 pub fn strides(&self) -> &[usize; D] {
366 &self.strides
367 }
368
369 /// Returns the rank (number of dimensions) of the tensor.
370 pub fn rank(&self) -> usize {
371 D
372 }
373
374 /// Returns the total number of elements in the tensor.
375 pub fn num_elements(&self) -> usize {
376 self.dims.iter().product()
377 }
378
379 /// Returns a mutable raw pointer to the tensor's data.
380 pub fn as_ptr(&self) -> *const C {
381 self.data.as_ptr() as *const C
382 }
383
384 /// Returns a mutable raw pointer to the tensor's data.
385 pub fn as_ptr_mut(&mut self) -> *mut C {
386 self.data.as_ptr()
387 }
388
389 /// Returns a reference to an element at the given indices.
390 ///
391 /// # Panics
392 ///
393 /// Panics if the indices are out of bounds.
394 pub fn get(&self, indices: &[usize; D]) -> &C {
395 check_bounds(indices, &self.dims);
396 // Safety: bounds are checked by `check_bounds`
397 unsafe { self.get_unchecked(indices) }
398 }
399
400 /// Returns a mutable reference to an element at the given indices.
401 ///
402 /// # Panics
403 ///
404 /// Panics if the indices are out of bounds.
405 pub fn get_mut(&mut self, indices: &[usize; D]) -> &mut C {
406 check_bounds(indices, &self.dims);
407 // Safety: bounds are checked by `check_bounds`
408 unsafe { self.get_mut_unchecked(indices) }
409 }
410
411 /// Returns an immutable reference to an element at the given indices, without performing bounds checks.
412 ///
413 /// # Safety
414 ///
415 /// Calling this method with out-of-bounds `indices` is undefined behavior.
416 #[inline(always)]
417 pub unsafe fn get_unchecked(&self, indices: &[usize; D]) -> &C {
418 unsafe { &*self.ptr_at(indices) }
419 }
420
421 /// Returns a mutable reference to an element at the given indices, without performing bounds checks.
422 ///
423 /// # Safety
424 ///
425 /// Calling this method with out-of-bounds `indices` is undefined behavior.
426 #[inline(always)]
427 pub unsafe fn get_mut_unchecked(&mut self, indices: &[usize; D]) -> &mut C {
428 unsafe { &mut *self.ptr_at_mut(indices) }
429 }
430
431 /// Returns a raw pointer to an element at the given indices, without performing bounds checks.
432 ///
433 /// # Safety
434 ///
435 /// Calling this method with out-of-bounds `indices` is undefined behavior.
436 #[inline(always)]
437 pub unsafe fn ptr_at(&self, indices: &[usize; D]) -> *const C {
438 unsafe {
439 let flat_idx = calculate_flat_index(indices, &self.strides);
440 self.as_ptr().add(flat_idx)
441 }
442 }
443
444 /// Returns a mutable raw pointer to an element at the given indices, without performing bounds checks.
445 ///
446 /// # Safety
447 ///
448 /// Calling this method with out-of-bounds `indices` is undefined behavior.
449 #[inline(always)]
450 pub unsafe fn ptr_at_mut(&mut self, indices: &[usize; D]) -> *mut C {
451 unsafe {
452 let flat_idx = calculate_flat_index(indices, &self.strides);
453 self.as_ptr_mut().add(flat_idx)
454 }
455 }
456}
457
458impl<'a, C: Memorable, const D: usize> std::ops::Index<[usize; D]> for SymSqTensorMut<'a, C, D> {
459 type Output = C;
460
461 fn index(&self, indices: [usize; D]) -> &Self::Output {
462 self.get(&indices)
463 }
464}
465
466impl<'a, C: Memorable, const D: usize> std::ops::IndexMut<[usize; D]> for SymSqTensorMut<'a, C, D> {
467 fn index_mut(&mut self, indices: [usize; D]) -> &mut Self::Output {
468 self.get_mut(&indices)
469 }
470}
471
472// TODO add some documentation plus a todo tag on relevant rust issues (const generic expressions)
473impl<C: Memorable> SymSqTensor<C, 5> {
474 /// Returns an immutable reference to the 3D subtensor at the given matrix indices.
475 pub fn subtensor_ref(&self, m1: usize, m2: usize) -> TensorRef<'_, C, 3> {
476 check_bounds(&[m1, m2, 0, 0, 0], &self.dims);
477 // Safety: bounds have been checked.
478 unsafe { self.subtensor_ref_unchecked(m1, m2) }
479 }
480
481 /// Returns a mutable reference to the 3D subtensor at the given matrix indices.
482 pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> TensorMut<'_, C, 3> {
483 check_bounds(&[m1, m2, 0, 0, 0], &self.dims);
484 // Safety: bounds have been checked.
485 unsafe { self.subtensor_mut_unchecked(m1, m2) }
486 }
487
488 /// Returns an immutable reference to the 3D subtensor at the given matrix indices without bounds checking.
489 ///
490 /// # Safety
491 ///
492 /// Caller should ensure that m1 and m2 are in bounds.
493 pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> TensorRef<'_, C, 3> {
494 unsafe {
495 TensorRef::from_raw_parts(
496 self.ptr_at(&[m1, m2, 0, 0, 0]),
497 [self.dims[2], self.dims[3], self.dims[4]],
498 [self.strides[2], self.strides[3], self.strides[4]],
499 )
500 }
501 }
502
503 /// Returns a mutable reference to the 3D subtensor at the given matrix indices without bounds checking.
504 ///
505 /// # Safety
506 ///
507 /// Caller should ensure that m1 and m2 are in bounds.
508 pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> TensorMut<'_, C, 3> {
509 unsafe {
510 TensorMut::from_raw_parts(
511 self.ptr_at_mut(&[m1, m2, 0, 0, 0]),
512 [self.dims[2], self.dims[3], self.dims[4]],
513 [self.strides[2], self.strides[3], self.strides[4]],
514 )
515 }
516 }
517}
518
519impl<C: Memorable> SymSqTensor<C, 4> {
520 /// Returns an immutable matrix reference to the subtensor at the given indices.
521 pub fn subtensor_ref(&self, m1: usize, m2: usize) -> MatRef<'_, C> {
522 check_bounds(&[m1, m2, 0, 0], &self.dims);
523 // Safety: bounds have been checked.
524 unsafe { self.subtensor_ref_unchecked(m1, m2) }
525 }
526
527 /// Returns a mutable matrix reference to the subtensor at the given indices.
528 pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> MatMut<'_, C> {
529 check_bounds(&[m1, m2, 0, 0], &self.dims);
530 // Safety: bounds have been checked.
531 unsafe { self.subtensor_mut_unchecked(m1, m2) }
532 }
533
534 /// Returns an immutable matrix reference to the subtensor at the given indices without bounds checking.
535 ///
536 /// # Safety
537 ///
538 /// Caller should ensure that m1 and m2 are in bounds.
539 pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> MatRef<'_, C> {
540 unsafe {
541 MatRef::from_raw_parts(
542 self.ptr_at(&[m1, m2, 0, 0]),
543 self.dims[2],
544 self.dims[3],
545 self.strides[2] as isize,
546 self.strides[3] as isize,
547 )
548 }
549 }
550
551 /// Returns a mutable matrix reference to the subtensor at the given indices without bounds checking.
552 ///
553 /// # Safety
554 ///
555 /// Caller should ensure that m1 and m2 are in bounds.
556 pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> MatMut<'_, C> {
557 unsafe {
558 MatMut::from_raw_parts_mut(
559 self.ptr_at_mut(&[m1, m2, 0, 0]),
560 self.dims[2],
561 self.dims[3],
562 self.strides[2] as isize,
563 self.strides[3] as isize,
564 )
565 }
566 }
567}
568
569impl<C: Memorable> SymSqTensor<C, 3> {
570 /// Returns an immutable row reference to the subtensor at the given indices.
571 pub fn subtensor_ref(&self, m1: usize, m2: usize) -> RowRef<'_, C> {
572 check_bounds(&[m1, m2, 0], &self.dims);
573 // Safety: bounds have been checked.
574 unsafe { self.subtensor_ref_unchecked(m1, m2) }
575 }
576
577 /// Returns a mutable row reference to the subtensor at the given indices.
578 pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> RowMut<'_, C> {
579 check_bounds(&[m1, m2, 0], &self.dims);
580 // Safety: bounds have been checked.
581 unsafe { self.subtensor_mut_unchecked(m1, m2) }
582 }
583
584 /// Returns an immutable row reference to the subtensor at the given indices without bounds checking.
585 ///
586 /// # Safety
587 ///
588 /// Caller should ensure that m1 and m2 are in bounds.
589 pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> RowRef<'_, C> {
590 unsafe {
591 RowRef::from_raw_parts(
592 self.ptr_at(&[m1, m2, 0]),
593 self.dims[2],
594 self.strides[2] as isize,
595 )
596 }
597 }
598
599 /// Returns a mutable row reference to the subtensor at the given indices without bounds checking.
600 ///
601 /// # Safety
602 ///
603 /// Caller should ensure that m1 and m2 are in bounds.
604 pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> RowMut<'_, C> {
605 unsafe {
606 RowMut::from_raw_parts_mut(
607 self.ptr_at_mut(&[m1, m2, 0]),
608 self.dims[2],
609 self.strides[2] as isize,
610 )
611 }
612 }
613}
614
615impl<C: Memorable> SymSqTensor<C, 2> {
616 /// Returns an immutable reference to the element at the given indices.
617 pub fn subtensor_ref(&self, m1: usize, m2: usize) -> &C {
618 self.get(&[m1, m2])
619 }
620
621 /// Returns a mutable reference to the element at the given indices.
622 pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> &mut C {
623 self.get_mut(&[m1, m2])
624 }
625
626 /// Returns an immutable reference to the element at the given indices without bounds checking.
627 ///
628 /// # Safety
629 ///
630 /// Caller should ensure that m1 and m2 are in bounds.
631 pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> &C {
632 unsafe { self.get_unchecked(&[m1, m2]) }
633 }
634
635 /// Returns a mutable reference to the element at the given indices without bounds checking.
636 ///
637 /// # Safety
638 ///
639 /// Caller should ensure that m1 and m2 are in bounds.
640 pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> &mut C {
641 unsafe { self.get_mut_unchecked(&[m1, m2]) }
642 }
643}
644
645impl<'a, C: Memorable> SymSqTensorRef<'a, C, 5> {
646 /// Returns an immutable reference to the 3D subtensor at the given matrix indices.
647 pub fn subtensor_ref(&self, m1: usize, m2: usize) -> TensorRef<'a, C, 3> {
648 check_bounds(&[m1, m2, 0, 0, 0], &self.dims);
649 // Safety: bounds have been checked.
650 unsafe { self.subtensor_ref_unchecked(m1, m2) }
651 }
652
653 /// Returns an immutable reference to the 3D subtensor at the given matrix indices without bounds checking.
654 ///
655 /// # Safety
656 ///
657 /// Caller should ensure that m1 and m2 are in bounds.
658 pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> TensorRef<'a, C, 3> {
659 unsafe {
660 TensorRef::from_raw_parts(
661 self.ptr_at(&[m1, m2, 0, 0, 0]),
662 [self.dims[2], self.dims[3], self.dims[4]],
663 [self.strides[2], self.strides[3], self.strides[4]],
664 )
665 }
666 }
667}
668
669impl<'a, C: Memorable> SymSqTensorRef<'a, C, 4> {
670 /// Returns an immutable matrix reference to the subtensor at the given indices.
671 pub fn subtensor_ref(&self, m1: usize, m2: usize) -> MatRef<'a, C> {
672 check_bounds(&[m1, m2, 0, 0], &self.dims);
673 // Safety: bounds have been checked.
674 unsafe { self.subtensor_ref_unchecked(m1, m2) }
675 }
676
677 /// Returns an immutable matrix reference to the subtensor at the given indices without bounds checking.
678 ///
679 /// # Safety
680 ///
681 /// Caller should ensure that m1 and m2 are in bounds.
682 pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> MatRef<'a, C> {
683 unsafe {
684 MatRef::from_raw_parts(
685 self.ptr_at(&[m1, m2, 0, 0]),
686 self.dims[2],
687 self.dims[3],
688 self.strides[2] as isize,
689 self.strides[3] as isize,
690 )
691 }
692 }
693}
694
695impl<'a, C: Memorable> SymSqTensorRef<'a, C, 3> {
696 /// Returns an immutable row reference to the subtensor at the given indices.
697 pub fn subtensor_ref(&self, m1: usize, m2: usize) -> RowRef<'a, C> {
698 check_bounds(&[m1, m2, 0], &self.dims);
699 // Safety: bounds have been checked.
700 unsafe { self.subtensor_ref_unchecked(m1, m2) }
701 }
702
703 /// Returns an immutable row reference to the subtensor at the given indices without bounds checking.
704 ///
705 /// # Safety
706 ///
707 /// Caller should ensure that m1 and m2 are in bounds.
708 pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> RowRef<'a, C> {
709 unsafe {
710 RowRef::from_raw_parts(
711 self.ptr_at(&[m1, m2, 0]),
712 self.dims[2],
713 self.strides[2] as isize,
714 )
715 }
716 }
717}
718
719impl<'a, C: Memorable> SymSqTensorRef<'a, C, 2> {
720 /// Returns an immutable reference to the element at the given indices.
721 pub fn subtensor_ref(&self, m1: usize, m2: usize) -> &C {
722 self.get(&[m1, m2])
723 }
724
725 /// Returns an immutable reference to the element at the given indices without bounds checking.
726 ///
727 /// # Safety
728 ///
729 /// Caller should ensure that m1 and m2 are in bounds.
730 pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> &C {
731 unsafe { self.get_unchecked(&[m1, m2]) }
732 }
733}
734
735impl<'a, C: Memorable> SymSqTensorMut<'a, C, 5> {
736 /// Returns an immutable reference to the 3D subtensor at the given matrix indices.
737 pub fn subtensor_ref(&self, m1: usize, m2: usize) -> TensorRef<'a, C, 3> {
738 check_bounds(&[m1, m2, 0, 0, 0], &self.dims);
739 // Safety: bounds have been checked.
740 unsafe { self.subtensor_ref_unchecked(m1, m2) }
741 }
742
743 /// Returns a mutable reference to the 3D subtensor at the given matrix indices.
744 pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> TensorMut<'a, C, 3> {
745 check_bounds(&[m1, m2, 0, 0, 0], &self.dims);
746 // Safety: bounds have been checked.
747 unsafe { self.subtensor_mut_unchecked(m1, m2) }
748 }
749
750 /// Returns an immutable reference to the 3D subtensor at the given matrix indices without bounds checking.
751 ///
752 /// # Safety
753 ///
754 /// Caller should ensure that m1 and m2 are in bounds.
755 pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> TensorRef<'a, C, 3> {
756 unsafe {
757 TensorRef::from_raw_parts(
758 self.ptr_at(&[m1, m2, 0, 0, 0]),
759 [self.dims[2], self.dims[3], self.dims[4]],
760 [self.strides[2], self.strides[3], self.strides[4]],
761 )
762 }
763 }
764
765 /// Returns a mutable reference to the 3D subtensor at the given matrix indices without bounds checking.
766 ///
767 /// # Safety
768 ///
769 /// Caller should ensure that m1 and m2 are in bounds.
770 pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> TensorMut<'a, C, 3> {
771 unsafe {
772 TensorMut::from_raw_parts(
773 self.ptr_at_mut(&[m1, m2, 0, 0, 0]),
774 [self.dims[2], self.dims[3], self.dims[4]],
775 [self.strides[2], self.strides[3], self.strides[4]],
776 )
777 }
778 }
779}
780
781impl<'a, C: Memorable> SymSqTensorMut<'a, C, 4> {
782 /// Returns an immutable matrix reference to the subtensor at the given indices.
783 pub fn subtensor_ref(&self, m1: usize, m2: usize) -> MatRef<'a, C> {
784 check_bounds(&[m1, m2, 0, 0], &self.dims);
785 // Safety: bounds have been checked.
786 unsafe { self.subtensor_ref_unchecked(m1, m2) }
787 }
788
789 /// Returns a mutable matrix reference to the subtensor at the given indices.
790 pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> MatMut<'a, C> {
791 check_bounds(&[m1, m2, 0, 0], &self.dims);
792 // Safety: bounds have been checked.
793 unsafe { self.subtensor_mut_unchecked(m1, m2) }
794 }
795
796 /// Returns an immutable matrix reference to the subtensor at the given indices without bounds checking.
797 ///
798 /// # Safety
799 ///
800 /// Caller should ensure that m1 and m2 are in bounds.
801 pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> MatRef<'a, C> {
802 unsafe {
803 MatRef::from_raw_parts(
804 self.ptr_at(&[m1, m2, 0, 0]),
805 self.dims[2],
806 self.dims[3],
807 self.strides[2] as isize,
808 self.strides[3] as isize,
809 )
810 }
811 }
812
813 /// Returns a mutable matrix reference to the subtensor at the given indices without bounds checking.
814 ///
815 /// # Safety
816 ///
817 /// Caller should ensure that m1 and m2 are in bounds.
818 pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> MatMut<'a, C> {
819 unsafe {
820 MatMut::from_raw_parts_mut(
821 self.ptr_at_mut(&[m1, m2, 0, 0]),
822 self.dims[2],
823 self.dims[3],
824 self.strides[2] as isize,
825 self.strides[3] as isize,
826 )
827 }
828 }
829}
830
831impl<'a, C: Memorable> SymSqTensorMut<'a, C, 3> {
832 /// Returns an immutable row reference to the subtensor at the given indices.
833 pub fn subtensor_ref(&self, m1: usize, m2: usize) -> RowRef<'a, C> {
834 check_bounds(&[m1, m2, 0], &self.dims);
835 // Safety: bounds have been checked.
836 unsafe { self.subtensor_ref_unchecked(m1, m2) }
837 }
838
839 /// Returns a mutable row reference to the subtensor at the given indices.
840 pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> RowMut<'a, C> {
841 check_bounds(&[m1, m2, 0], &self.dims);
842 // Safety: bounds have been checked.
843 unsafe { self.subtensor_mut_unchecked(m1, m2) }
844 }
845
846 /// Returns an immutable row reference to the subtensor at the given indices without bounds checking.
847 ///
848 /// # Safety
849 ///
850 /// Caller should ensure that m1 and m2 are in bounds.
851 pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> RowRef<'a, C> {
852 unsafe {
853 RowRef::from_raw_parts(
854 self.ptr_at(&[m1, m2, 0]),
855 self.dims[2],
856 self.strides[2] as isize,
857 )
858 }
859 }
860
861 /// Returns a mutable row reference to the subtensor at the given indices without bounds checking.
862 ///
863 /// # Safety
864 ///
865 /// Caller should ensure that m1 and m2 are in bounds.
866 pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> RowMut<'a, C> {
867 unsafe {
868 RowMut::from_raw_parts_mut(
869 self.ptr_at_mut(&[m1, m2, 0]),
870 self.dims[2],
871 self.strides[2] as isize,
872 )
873 }
874 }
875}
876
877impl<'a, C: Memorable> SymSqTensorMut<'a, C, 2> {
878 /// Returns an immutable reference to the element at the given indices.
879 pub fn subtensor_ref(&self, m1: usize, m2: usize) -> &C {
880 self.get(&[m1, m2])
881 }
882
883 /// Returns a mutable reference to the element at the given indices.
884 pub fn subtensor_mut(&mut self, m1: usize, m2: usize) -> &mut C {
885 self.get_mut(&[m1, m2])
886 }
887
888 /// Returns an immutable reference to the element at the given indices without bounds checking.
889 ///
890 /// # Safety
891 ///
892 /// Caller should ensure that m1 and m2 are in bounds.
893 pub unsafe fn subtensor_ref_unchecked(&self, m1: usize, m2: usize) -> &C {
894 unsafe { self.get_unchecked(&[m1, m2]) }
895 }
896
897 /// Returns a mutable reference to the element at the given indices without bounds checking.
898 ///
899 /// # Safety
900 ///
901 /// Caller should ensure that m1 and m2 are in bounds.
902 pub unsafe fn subtensor_mut_unchecked(&mut self, m1: usize, m2: usize) -> &mut C {
903 unsafe { self.get_mut_unchecked(&[m1, m2]) }
904 }
905}