qudit_core/array/
tensor.rs

1//! Implements the tensor struct and associated methods for the Openqudit library.
2
3use faer::{MatMut, MatRef, RowMut, RowRef};
4use std::fmt::{self, Debug, Display, Formatter};
5use std::ptr::NonNull;
6
7use super::check_bounds;
8use crate::memory::{Memorable, MemoryBuffer, alloc_zeroed_memory};
9
10/// Helper for flat index calculation from multi-dimensional indices.
11#[inline(always)]
12fn calculate_flat_index<const D: usize>(indices: &[usize; D], strides: &[usize; D]) -> usize {
13    let mut flat_idx = 0;
14    for i in 0..D {
15        flat_idx += indices[i] * strides[i];
16    }
17    flat_idx
18}
19
20/// A tensor struct that holds data in an aligned memory buffer.
21pub struct Tensor<C: Memorable, const D: usize> {
22    /// The data buffer containing the tensor elements.
23    data: MemoryBuffer<C>,
24    /// The dimensions of the tensor (size of each axis).
25    dims: [usize; D],
26    /// The strides for each dimension.
27    strides: [usize; D],
28}
29
30impl<C: Memorable, const D: usize> Tensor<C, D> {
31    /// Creates a new tensor from a memory buffer with specified dimensions and strides.
32    ///
33    /// # Arguments
34    ///
35    /// * `data` - The memory buffer containing the tensor data
36    /// * `dims` - Array specifying the size of each dimension
37    /// * `strides` - Array specifying the stride for each dimension
38    ///
39    /// # Panics
40    ///
41    /// * If any dimension or stride is zero
42    /// * If the data buffer is not large enough for the specified dimensions and strides
43    pub fn new(data: MemoryBuffer<C>, dims: [usize; D], strides: [usize; D]) -> Self {
44        assert!(
45            dims.iter().all(|&d| d != 0),
46            "Cannot have a zero-length dimension."
47        );
48        assert!(
49            strides.iter().all(|&d| d != 0),
50            "Cannot have a zero-length stride."
51        );
52
53        let mut max_element = [0; D];
54        for (i, d) in dims.iter().enumerate() {
55            max_element[i] = d - 1;
56        }
57        let max_flat_index = calculate_flat_index(&max_element, &strides);
58
59        assert!(
60            data.len() >= max_flat_index,
61            "Data buffer is not large enough."
62        );
63
64        Self {
65            data,
66            dims,
67            strides,
68        }
69    }
70
71    /// Creates a new tensor with all elements initialized to zero,
72    /// with specified shape.
73    ///
74    /// # Arguments
75    ///
76    /// * `dims` - A slice of `usize` containing the size of each dimension.
77    ///
78    /// # Returns
79    ///
80    /// * An new tensor with specified shape, filled with zeros.
81    ///
82    /// # Panics
83    ///
84    /// * If the length of `dims` is not equal to the number of
85    ///   dimensions of the tensor.
86    ///
87    /// # Examples
88    /// ```
89    /// # use qudit_core::array::Tensor;
90    ///
91    /// let test_tensor = Tensor::<f64, 2>::zeros([3, 4]);
92    ///
93    /// for i in 0..3 {
94    ///     for j in 0..4 {
95    ///         assert_eq!(test_tensor.get(&[i, j]), &0.0);
96    ///     }
97    /// }
98    /// ```
99    pub fn zeros(dims: [usize; D]) -> Self {
100        let strides = super::calc_continuous_strides(&dims);
101        let data = alloc_zeroed_memory::<C>(strides[0] * dims[0]);
102        Self::new(data, dims, strides)
103    }
104
105    /// Returns a reference to the dimensions of the tensor.
106    pub fn dims(&self) -> &[usize; D] {
107        &self.dims
108    }
109
110    /// Returns a reference to the strides of the tensor.
111    pub fn strides(&self) -> &[usize; D] {
112        &self.strides
113    }
114
115    /// Returns the rank (number of dimensions) of the tensor.
116    pub fn rank(&self) -> usize {
117        D
118    }
119
120    /// Returns the total number of elements in the tensor.
121    pub fn num_elements(&self) -> usize {
122        self.dims.iter().product()
123    }
124
125    /// Returns a raw pointer to the tensor's data.
126    pub fn as_ptr(&self) -> *const C {
127        self.data.as_ptr()
128    }
129
130    /// Returns a mutable raw pointer to the tensor's data.
131    pub fn as_ptr_mut(&mut self) -> *mut C {
132        self.data.as_mut_ptr()
133    }
134
135    /// Returns an immutable reference to the tensor.
136    pub fn as_ref(&self) -> TensorRef<'_, C, D> {
137        unsafe { TensorRef::from_raw_parts(self.data.as_ptr(), self.dims, self.strides) }
138    }
139
140    /// Returns a mutable reference to the tensor.
141    pub fn as_mut(&mut self) -> TensorMut<'_, C, D> {
142        unsafe { TensorMut::from_raw_parts(self.data.as_mut_ptr(), self.dims, self.strides) }
143    }
144
145    /// Returns a reference to an element at the given indices.
146    ///
147    /// # Panics
148    ///
149    /// Panics if the indices are out of bounds.
150    pub fn get(&self, indices: &[usize; D]) -> &C {
151        check_bounds(indices, &self.dims);
152        // Safety: bounds are checked by `check_bounds`
153        unsafe { self.get_unchecked(indices) }
154    }
155
156    /// Returns a mutable reference to an element at the given indices.
157    ///
158    /// # Panics
159    ///
160    /// Panics if the indices are out of bounds.
161    pub fn get_mut(&mut self, indices: &[usize; D]) -> &mut C {
162        check_bounds(indices, &self.dims);
163        // Safety: bounds are checked by `check_bounds`
164        unsafe { self.get_mut_unchecked(indices) }
165    }
166
167    /// Returns an immutable reference to an element at the given indices, without performing bounds checks.
168    ///
169    /// # Safety
170    ///
171    /// Calling this method with out-of-bounds `indices` is undefined behavior.
172    #[inline(always)]
173    pub unsafe fn get_unchecked(&self, indices: &[usize; D]) -> &C {
174        unsafe { &*self.ptr_at(indices) }
175    }
176
177    /// Returns a mutable reference to an element at the given indices, without performing bounds checks.
178    ///
179    /// # Safety
180    ///
181    /// Calling this method with out-of-bounds `indices` is undefined behavior.
182    #[inline(always)]
183    pub unsafe fn get_mut_unchecked(&mut self, indices: &[usize; D]) -> &mut C {
184        unsafe { &mut *self.ptr_at_mut(indices) }
185    }
186
187    /// Returns a raw pointer to an element at the given indices, without performing bounds checks.
188    ///
189    /// # Safety
190    ///
191    /// Calling this method with out-of-bounds `indices` is undefined behavior.
192    #[inline(always)]
193    pub unsafe fn ptr_at(&self, indices: &[usize; D]) -> *const C {
194        unsafe {
195            let flat_idx = calculate_flat_index(indices, &self.strides);
196            self.as_ptr().add(flat_idx)
197        }
198    }
199
200    /// Returns a mutable raw pointer to an element at the given indices, without performing bounds checks.
201    ///
202    /// # Safety
203    ///
204    /// Calling this method with out-of-bounds `indices` is undefined behavior.
205    #[inline(always)]
206    pub unsafe fn ptr_at_mut(&mut self, indices: &[usize; D]) -> *mut C {
207        unsafe {
208            let flat_idx = calculate_flat_index(indices, &self.strides);
209            self.as_ptr_mut().add(flat_idx)
210        }
211    }
212
213    /// Creates a new `Tensor` from a flat `Vec` and its dimensions.
214    ///
215    /// This is a convenience constructor that automatically converts the `Vec`
216    /// into a `MemoryBuffer` and then calls the `new` constructor.
217    ///
218    /// # Panics
219    /// Panics if the total number of elements implied by `dimensions`
220    /// (product of all dimension sizes) does not match the length of the `data_vec`.
221    ///
222    /// # Examples
223    /// ```
224    /// # use qudit_core::array::Tensor;
225    /// let tensor_from_slice = Tensor::from_slice(&vec![10, 20, 30, 40], [2, 2]);
226    /// assert_eq!(tensor_from_slice.dims(), &[2, 2]);
227    /// assert_eq!(tensor_from_slice.strides(), &[2, 1]);
228    /// ```
229    pub fn from_slice(slice: &[C], dims: [usize; D]) -> Self {
230        let strides = super::calc_continuous_strides(&dims);
231        Self::from_slice_with_strides(slice, dims, strides)
232    }
233
234    /// Creates a new `Tensor` from a slice of data, explicit dimensions, and strides.
235    ///
236    /// This constructor allows for creating tensors with custom stride patterns,
237    /// which can be useful for representing views or sub-tensors of larger data
238    /// structures without copying the underlying data.
239    ///
240    /// # Panics
241    /// Panics if:
242    /// - The `dimensions` and `strides` arrays do not have the same number of elements as `D`.
243    /// - The total number of elements implied by `dimensions` and `strides` (i.e., the
244    ///   maximum flat index + 1) exceeds the length of the `slice`.
245    /// - Any stride is zero unless its corresponding dimension is also zero.
246    ///
247    /// # Arguments
248    /// * `slice` - The underlying data slice.
249    /// * `dimensions` - An array of `usize` defining the size of each dimension.
250    /// * `strides` - An array of `usize` defining the stride for each dimension.
251    ///
252    /// # Examples
253    /// ```
254    /// # use qudit_core::array::Tensor;
255    /// // Create a 2x3 tensor from a slice with custom strides
256    /// let data = vec![1, 2, 3, 4, 5, 6];
257    /// let tensor = Tensor::from_slice_with_strides(
258    ///     &data,
259    ///     [2, 3], // 2 rows, 3 columns
260    ///     [3, 1], // Stride for rows is 3 elements, for columns is 1 element
261    /// );
262    /// assert_eq!(tensor.dims(), &[2, 3]);
263    /// assert_eq!(tensor.strides(), &[3, 1]);
264    /// assert_eq!(tensor.get(&[0, 0]), &1);
265    /// assert_eq!(tensor.get(&[0, 1]), &2);
266    /// assert_eq!(tensor.get(&[1, 0]), &4);
267    ///
268    /// // Creating a column vector view from a larger matrix's data
269    /// let matrix_data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // A 3x3 matrix's data
270    /// // View the second column (elements 2, 5, 8) as a 3x1 tensor
271    /// let column_view = Tensor::from_slice_with_strides(
272    ///     &matrix_data,
273    ///     [3, 1], // 3 rows, 1 column
274    ///     [3, 1], // Stride to next row is 3, stride to next column is 1 (but only 1 column)
275    /// );
276    /// assert_eq!(column_view.dims(), &[3, 1]);
277    /// assert_eq!(column_view.strides(), &[3, 1]);
278    /// // Note: This example is slightly misleading as the slice itself doesn't change for the column.
279    /// // A more accurate example for strides would involve a sub-view that skips elements.
280    /// // This specific case would be more typical for `TensorRef`.
281    /// ```
282    pub fn from_slice_with_strides(slice: &[C], dims: [usize; D], strides: [usize; D]) -> Self {
283        let data = MemoryBuffer::from_slice(64, slice);
284        Self::new(data, dims, strides)
285    }
286
287    /// Creates a new tensor with all elements initialized to zero,
288    /// with specified shape and strides.
289    ///
290    /// # Arguments
291    ///
292    /// * `dims` - A slice of `usize` containing the size of each dimension
293    /// * `strides` - A slice of `usize` containing the stride for each dimension.
294    ///
295    /// # Returns
296    ///
297    /// * A new tensor with specified shape and strides, filled with zeros.
298    ///
299    /// # Panics
300    ///
301    /// * If the length of `dims` or `strides` is not equal to the number of
302    ///   dimensions of the tensor.
303    /// * If the size of any dimension is zero but the corresponding stride is non-zero.
304    /// * If the size of any dimension is non-zero but the corresponding stride is zero.
305    ///
306    /// # Examples
307    /// ```
308    /// # use qudit_core::array::Tensor;
309    ///
310    /// let test_tensor = Tensor::<f64, 2>::zeros_with_strides(&[3, 4], &[4, 1]);
311    ///
312    /// for i in 0..3 {
313    ///     for j in 0..4 {
314    ///         assert_eq!(test_tensor.get(&[i, j]), &0.0);
315    ///     }
316    /// }
317    /// ```
318    pub fn zeros_with_strides(dims: &[usize; D], strides: &[usize; D]) -> Self {
319        let data = alloc_zeroed_memory::<C>(strides[0] * dims[0]);
320        Self::new(data, *dims, *strides)
321    }
322}
323
324impl<C: Memorable, const D: usize> std::ops::Index<[usize; D]> for Tensor<C, D> {
325    type Output = C;
326
327    fn index(&self, indices: [usize; D]) -> &Self::Output {
328        self.get(&indices)
329    }
330}
331
332impl<C: Memorable, const D: usize> std::ops::IndexMut<[usize; D]> for Tensor<C, D> {
333    fn index_mut(&mut self, indices: [usize; D]) -> &mut Self::Output {
334        self.get_mut(&indices)
335    }
336}
337
338// Helper struct for recursively formatting the tensor data
339// to display it as a multi-dimensional array.
340struct TensorDataDebugHelper<'a, C: Display> {
341    data_ptr: *const C,
342    dimensions: &'a [usize],
343    strides: &'a [usize],
344    current_dim_idx: usize,
345    current_flat_offset: usize,
346}
347
348impl<'a, C: Display> Debug for TensorDataDebugHelper<'a, C> {
349    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
350        let indent = "\t".repeat(self.current_dim_idx);
351        // Base case: If we've reached the deepest dimension level,
352        // it means we are at an individual element. Print its value directly.
353        if self.current_dim_idx == self.dimensions.len() {
354            // SAFETY: The `current_flat_offset` is calculated based on the tensor's
355            // dimensions and strides. It is assumed to be within the bounds of the
356            // allocated data, as guaranteed by the `Tensor` and `TensorRef` structures.
357            unsafe { write!(f, "{}", &*self.data_ptr.add(self.current_flat_offset)) }
358        } else {
359            // Recursive case: We are at an intermediate dimension.
360            // Print this dimension as a list of sub-tensors/elements.
361            let dim_size = self.dimensions[self.current_dim_idx];
362            let dim_stride = self.strides[self.current_dim_idx];
363
364            // let mut list_formatter = f.debug_list();
365            if self.current_dim_idx == self.dimensions.len() - 1 {
366                write!(f, "{}[", indent)?;
367            } else {
368                writeln!(f, "{}[", indent)?;
369            }
370            for i in 0..dim_size {
371                let next_offset = self.current_flat_offset + i * dim_stride;
372
373                write!(
374                    f,
375                    "{:?}",
376                    TensorDataDebugHelper {
377                        data_ptr: self.data_ptr,
378                        dimensions: self.dimensions,
379                        strides: self.strides,
380                        current_dim_idx: self.current_dim_idx + 1,
381                        current_flat_offset: next_offset,
382                    }
383                )?;
384
385                if self.current_dim_idx == self.dimensions.len() - 1 && i != dim_size - 1 {
386                    write!(f, ", ")?;
387                }
388            }
389            if self.current_dim_idx == self.dimensions.len() - 1 {
390                writeln!(f, "],",)
391            } else {
392                writeln!(f, "{}],", indent)
393            }
394            // write!(f, "\n")
395            // list_formatter.finish()
396        }
397    }
398}
399
400impl<C: Display + Debug + Memorable, const D: usize> Debug for Tensor<C, D> {
401    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
402        f.debug_struct("Tensor")
403            .field("dimensions", &self.dims)
404            .field("strides", &self.strides)
405            .field(
406                "data",
407                &TensorDataDebugHelper {
408                    data_ptr: self.data.as_ptr(), // Pointer to the start of the data buffer
409                    dimensions: &self.dims,
410                    strides: &self.strides,
411                    current_dim_idx: 0, // Start formatting from the first dimension (index 0)
412                    current_flat_offset: 0, // Start from offset 0 in the flat data buffer
413                },
414            )
415            .finish()
416    }
417}
418
419impl<'a, C: Display + Debug + Memorable, const D: usize> Debug for TensorRef<'a, C, D> {
420    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
421        f.debug_struct("TensorRef")
422            .field("dimensions", &self.dims)
423            .field("strides", &self.strides)
424            .field(
425                "data",
426                &TensorDataDebugHelper {
427                    data_ptr: self.data.as_ptr(), // `self.data` is already a `*const C` for `TensorRef`
428                    dimensions: &self.dims,
429                    strides: &self.strides,
430                    current_dim_idx: 0,
431                    current_flat_offset: 0,
432                },
433            )
434            .finish()
435    }
436}
437
438// TODO: add iterators for subtensors
439
440/// An immutable view into tensor data.
441///
442/// This struct provides read-only access to tensor data without owning the underlying memory.
443/// It holds a reference to the data along with dimension and stride information.
444#[derive(Clone, Copy)]
445pub struct TensorRef<'a, C: Memorable, const D: usize> {
446    /// Non-null pointer to the tensor data.
447    data: NonNull<C>,
448    /// The dimensions of the tensor (size of each axis).
449    dims: [usize; D],
450    /// The strides for each dimension.
451    strides: [usize; D],
452    /// Phantom data to enforce lifetime constraints.
453    __marker: std::marker::PhantomData<&'a C>,
454}
455
456impl<'a, C: Memorable, const D: usize> TensorRef<'a, C, D> {
457    /// Creates a `TensorRef` from pointers to tensor data, dimensions, and strides.
458    ///
459    /// # Safety
460    ///
461    /// The caller must ensure:
462    /// - The pointers are valid and non-null.
463    /// - For each unit, the entire memory region addressed by the tensor is
464    ///   within a single allocation.
465    /// - The memory is accessible by the pointer.
466    /// - No mutable aliasing occurs. No mutable references to the tensor data
467    ///   exist when the `MatVecRef` is alive.
468    pub unsafe fn from_raw_parts(data: *const C, dims: [usize; D], strides: [usize; D]) -> Self {
469        // SAFETY: The pointer is never used in an mutable context.
470        let ptr = unsafe { NonNull::new_unchecked(data as *mut C) };
471
472        Self {
473            data: ptr,
474            dims,
475            strides,
476            __marker: std::marker::PhantomData,
477        }
478    }
479
480    /// Returns a reference to the dimensions of the tensor.
481    pub fn dims(&self) -> &[usize; D] {
482        &self.dims
483    }
484
485    /// Returns a reference to the strides of the tensor.
486    pub fn strides(&self) -> &[usize; D] {
487        &self.strides
488    }
489
490    /// Returns the rank (number of dimensions) of the tensor.
491    pub fn rank(&self) -> usize {
492        D
493    }
494
495    /// Returns the total number of elements in the tensor.
496    pub fn num_elements(&self) -> usize {
497        self.dims.iter().product()
498    }
499
500    /// Returns a raw pointer to the tensor's data.
501    pub fn as_ptr(&self) -> *const C {
502        self.data.as_ptr()
503    }
504
505    /// Returns a reference to an element at the given indices.
506    ///
507    /// # Panics
508    ///
509    /// Panics if the indices are out of bounds.
510    pub fn get(&self, indices: &[usize; D]) -> &C {
511        check_bounds(indices, &self.dims);
512        // Safety: bounds are checked by `check_bounds`
513        unsafe { self.get_unchecked(indices) }
514    }
515
516    /// Returns an immutable reference to an element at the given indices, without performing bounds checks.
517    ///
518    /// # Safety
519    ///
520    /// Calling this method with out-of-bounds `indices` is undefined behavior.
521    #[inline(always)]
522    pub unsafe fn get_unchecked(&self, indices: &[usize; D]) -> &C {
523        unsafe { &*self.ptr_at(indices) }
524    }
525
526    /// Returns a raw pointer to an element at the given indices, without performing bounds checks.
527    ///
528    /// # Safety
529    ///
530    /// Calling this method with out-of-bounds `indices` is undefined behavior.
531    #[inline(always)]
532    pub unsafe fn ptr_at(&self, indices: &[usize; D]) -> *const C {
533        unsafe {
534            let flat_idx = calculate_flat_index(indices, &self.strides);
535            self.as_ptr().add(flat_idx)
536        }
537    }
538
539    /// Creates an owned `Tensor` by copying the data from this `TensorRef`.
540    pub fn to_owned(self) -> Tensor<C, D> {
541        let mut max_element = [0; D];
542        for (i, d) in self.dims.iter().enumerate() {
543            max_element[i] = d - 1;
544        }
545        let max_flat_index = calculate_flat_index(&max_element, &self.strides);
546        // Safety: Memory is nonnull and shared throughout max_flat_index.
547        // Slice is copied from and dropped immediately.
548        unsafe {
549            let slice = std::slice::from_raw_parts(self.data.as_ptr(), max_flat_index + 1);
550            Tensor::from_slice_with_strides(slice, self.dims, self.strides)
551        }
552    }
553}
554
555impl<'a, C: Memorable, const D: usize> std::ops::Index<[usize; D]> for TensorRef<'a, C, D> {
556    type Output = C;
557
558    fn index(&self, indices: [usize; D]) -> &Self::Output {
559        self.get(&indices)
560    }
561}
562
563/// A mutable view into tensor data.
564///
565/// This struct provides read-write access to tensor data without owning the underlying memory.
566/// It holds a mutable reference to the data along with dimension and stride information.
567pub struct TensorMut<'a, C: Memorable, const D: usize> {
568    /// Non-null pointer to the tensor data.
569    data: NonNull<C>,
570    /// The dimensions of the tensor (size of each axis).
571    dims: [usize; D],
572    /// The strides for each dimension.
573    strides: [usize; D],
574    /// Phantom data to enforce lifetime constraints.
575    __marker: std::marker::PhantomData<&'a mut C>,
576}
577
578impl<'a, C: Memorable, const D: usize> TensorMut<'a, C, D> {
579    /// Creates a new `SymSqTensorMut` from raw parts.
580    ///
581    /// # Safety
582    ///
583    /// The caller must ensure that `data` points to a valid memory block of `C` elements,
584    /// and that `dims` and `strides` accurately describe the layout of the tensor
585    /// within that memory block. The `data` pointer must be valid for the lifetime `'a`
586    /// and that it is safe to mutate the data.
587    pub unsafe fn from_raw_parts(data: *mut C, dims: [usize; D], strides: [usize; D]) -> Self {
588        unsafe {
589            Self {
590                data: NonNull::new_unchecked(data),
591                dims,
592                strides,
593                __marker: std::marker::PhantomData,
594            }
595        }
596    }
597
598    /// Returns a reference to the dimensions of the tensor.
599    pub fn dims(&self) -> &[usize; D] {
600        &self.dims
601    }
602
603    /// Returns a reference to the strides of the tensor.
604    pub fn strides(&self) -> &[usize; D] {
605        &self.strides
606    }
607
608    /// Returns the rank (number of dimensions) of the tensor.
609    pub fn rank(&self) -> usize {
610        D
611    }
612
613    /// Returns the total number of elements in the tensor.
614    pub fn num_elements(&self) -> usize {
615        self.dims.iter().product()
616    }
617
618    /// Returns a mutable raw pointer to the tensor's data.
619    pub fn as_ptr(&self) -> *const C {
620        self.data.as_ptr() as *const C
621    }
622
623    /// Returns a mutable raw pointer to the tensor's data.
624    pub fn as_ptr_mut(&mut self) -> *mut C {
625        self.data.as_ptr()
626    }
627
628    /// Returns a reference to an element at the given indices.
629    ///
630    /// # Panics
631    ///
632    /// Panics if the indices are out of bounds.
633    pub fn get(&self, indices: &[usize; D]) -> &C {
634        check_bounds(indices, &self.dims);
635        // Safety: bounds are checked by `check_bounds`
636        unsafe { self.get_unchecked(indices) }
637    }
638
639    /// Returns a mutable reference to an element at the given indices.
640    ///
641    /// # Panics
642    ///
643    /// Panics if the indices are out of bounds.
644    pub fn get_mut(&mut self, indices: &[usize; D]) -> &mut C {
645        check_bounds(indices, &self.dims);
646        // Safety: bounds are checked by `check_bounds`
647        unsafe { self.get_mut_unchecked(indices) }
648    }
649
650    /// Returns an immutable reference to an element at the given indices, without performing bounds checks.
651    ///
652    /// # Safety
653    ///
654    /// Calling this method with out-of-bounds `indices` is undefined behavior.
655    #[inline(always)]
656    pub unsafe fn get_unchecked(&self, indices: &[usize; D]) -> &C {
657        unsafe { &*self.ptr_at(indices) }
658    }
659
660    /// Returns a mutable reference to an element at the given indices, without performing bounds checks.
661    ///
662    /// # Safety
663    ///
664    /// Calling this method with out-of-bounds `indices` is undefined behavior.
665    #[inline(always)]
666    pub unsafe fn get_mut_unchecked(&mut self, indices: &[usize; D]) -> &mut C {
667        unsafe { &mut *self.ptr_at_mut(indices) }
668    }
669
670    /// Returns a raw pointer to an element at the given indices, without performing bounds checks.
671    ///
672    /// # Safety
673    ///
674    /// Calling this method with out-of-bounds `indices` is undefined behavior.
675    #[inline(always)]
676    pub unsafe fn ptr_at(&self, indices: &[usize; D]) -> *const C {
677        unsafe {
678            let flat_idx = calculate_flat_index(indices, &self.strides);
679            self.as_ptr().add(flat_idx)
680        }
681    }
682
683    /// Returns a mutable raw pointer to an element at the given indices, without performing bounds checks.
684    ///
685    /// # Safety
686    ///
687    /// Calling this method with out-of-bounds `indices` is undefined behavior.
688    #[inline(always)]
689    pub unsafe fn ptr_at_mut(&mut self, indices: &[usize; D]) -> *mut C {
690        unsafe {
691            let flat_idx = calculate_flat_index(indices, &self.strides);
692            self.as_ptr_mut().add(flat_idx)
693        }
694    }
695}
696
697impl<'a, C: Memorable, const D: usize> std::ops::Index<[usize; D]> for TensorMut<'a, C, D> {
698    type Output = C;
699
700    fn index(&self, indices: [usize; D]) -> &Self::Output {
701        self.get(&indices)
702    }
703}
704
705impl<'a, C: Memorable, const D: usize> std::ops::IndexMut<[usize; D]> for TensorMut<'a, C, D> {
706    fn index_mut(&mut self, indices: [usize; D]) -> &mut Self::Output {
707        self.get_mut(&indices)
708    }
709}
710
711// TODO add some documentation plus a todo tag on relevant rust issues
712impl<C: Memorable> Tensor<C, 4> {
713    /// Returns an immutable view of a 3D subtensor at the given index.
714    pub fn subtensor_ref(&self, m: usize) -> TensorRef<'_, C, 3> {
715        check_bounds(&[m, 0, 0, 0], &self.dims);
716        // Safety: bounds have been checked.
717        unsafe { self.subtensor_ref_unchecked(m) }
718    }
719
720    /// Returns a mutable view of a 3D subtensor at the given index.
721    pub fn subtensor_mut(&mut self, m: usize) -> TensorMut<'_, C, 3> {
722        check_bounds(&[m, 0, 0, 0], &self.dims);
723        // Safety: bounds have been checked.
724        unsafe { self.subtensor_mut_unchecked(m) }
725    }
726
727    #[inline(always)]
728    /// Returns an immutable view of a 3D subtensor at the given index without bounds checking.
729    ///
730    /// # Safety
731    ///
732    /// Caller must ensure that m is within bounds.
733    pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> TensorRef<'_, C, 3> {
734        unsafe {
735            TensorRef::from_raw_parts(
736                self.ptr_at(&[m, 0, 0, 0]),
737                [self.dims[1], self.dims[2], self.dims[3]],
738                [self.strides[1], self.strides[2], self.strides[3]],
739            )
740        }
741    }
742
743    #[inline(always)]
744    /// Returns a mutable view of a 3D subtensor at the given index without bounds checking.
745    ///
746    /// # Safety
747    ///
748    /// Caller must ensure that m is within bounds.
749    pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> TensorMut<'_, C, 3> {
750        unsafe {
751            TensorMut::from_raw_parts(
752                self.ptr_at_mut(&[m, 0, 0, 0]),
753                [self.dims[1], self.dims[2], self.dims[3]],
754                [self.strides[1], self.strides[2], self.strides[3]],
755            )
756        }
757    }
758}
759
760impl<C: Memorable> Tensor<C, 3> {
761    /// Returns an immutable matrix view of a 2D subtensor at the given index.
762    pub fn subtensor_ref(&self, m: usize) -> MatRef<'_, C> {
763        check_bounds(&[m, 0, 0], &self.dims);
764        // Safety: bounds have been checked.
765        unsafe { self.subtensor_ref_unchecked(m) }
766    }
767
768    /// Returns a mutable matrix view of a 2D subtensor at the given index.
769    pub fn subtensor_mut(&mut self, m: usize) -> MatMut<'_, C> {
770        check_bounds(&[m, 0, 0], &self.dims);
771        // Safety: bounds have been checked.
772        unsafe { self.subtensor_mut_unchecked(m) }
773    }
774
775    #[inline(always)]
776    /// Returns an immutable matrix view of a 2D subtensor at the given index without bounds checking.
777    ///
778    /// # Safety
779    ///
780    /// Caller must ensure that m is within bounds.
781    pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> MatRef<'_, C> {
782        unsafe {
783            MatRef::from_raw_parts(
784                self.ptr_at(&[m, 0, 0]),
785                self.dims[1],
786                self.dims[2],
787                self.strides[1] as isize,
788                self.strides[2] as isize,
789            )
790        }
791    }
792
793    #[inline(always)]
794    /// Returns a mutable matrix view of a 2D subtensor at the given index without bounds checking.
795    ///
796    /// # Safety
797    ///
798    /// Caller must ensure that m is within bounds.
799    pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> MatMut<'_, C> {
800        unsafe {
801            MatMut::from_raw_parts_mut(
802                self.ptr_at_mut(&[m, 0, 0]),
803                self.dims[1],
804                self.dims[2],
805                self.strides[1] as isize,
806                self.strides[2] as isize,
807            )
808        }
809    }
810}
811
812impl<C: Memorable> Tensor<C, 2> {
813    /// Returns an immutable row view of a 1D subtensor at the given index.
814    pub fn subtensor_ref(&self, m: usize) -> RowRef<'_, C> {
815        check_bounds(&[m, 0], &self.dims);
816        // Safety: bounds have been checked.
817        unsafe { self.subtensor_ref_unchecked(m) }
818    }
819
820    /// Returns a mutable row view of a 1D subtensor at the given index.
821    pub fn subtensor_mut(&mut self, m: usize) -> RowMut<'_, C> {
822        check_bounds(&[m, 0], &self.dims);
823        // Safety: bounds have been checked.
824        unsafe { self.subtensor_mut_unchecked(m) }
825    }
826
827    #[inline(always)]
828    /// Returns an immutable row view of a 1D subtensor at the given index without bounds checking.
829    ///
830    /// # Safety
831    ///
832    /// Caller must ensure that m is within bounds.
833    pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> RowRef<'_, C> {
834        unsafe {
835            RowRef::from_raw_parts(self.ptr_at(&[m, 0]), self.dims[1], self.strides[1] as isize)
836        }
837    }
838
839    #[inline(always)]
840    /// Returns a mutable row view of a 1D subtensor at the given index without bounds checking.
841    ///
842    /// # Safety
843    ///
844    /// Caller must ensure that m is within bounds.
845    pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> RowMut<'_, C> {
846        unsafe {
847            RowMut::from_raw_parts_mut(
848                self.ptr_at_mut(&[m, 0]),
849                self.dims[1],
850                self.strides[1] as isize,
851            )
852        }
853    }
854}
855
856impl<C: Memorable> Tensor<C, 1> {
857    /// Returns an immutable reference to an element at the given index.
858    pub fn subtensor_ref(&self, m: usize) -> &C {
859        self.get(&[m])
860    }
861
862    /// Returns a mutable reference to an element at the given index.
863    pub fn subtensor_mut(&mut self, m: usize) -> &mut C {
864        self.get_mut(&[m])
865    }
866
867    #[inline(always)]
868    /// Returns an immutable reference to an element at the given index without bounds checking.
869    ///
870    /// # Safety
871    ///
872    /// Caller must ensure that m is within bounds.
873    pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> &C {
874        unsafe { self.get_unchecked(&[m]) }
875    }
876
877    #[inline(always)]
878    /// Returns a mutable reference to an element at the given index without bounds checking.
879    ///
880    /// # Safety
881    ///
882    /// Caller must ensure that m is within bounds.
883    pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> &mut C {
884        unsafe { self.get_mut_unchecked(&[m]) }
885    }
886}
887
888impl<'a, C: Memorable> TensorRef<'a, C, 4> {
889    /// Returns an immutable view of a 3D subtensor at the given index.
890    pub fn subtensor_ref(&self, m: usize) -> TensorRef<'a, C, 3> {
891        check_bounds(&[m, 0, 0, 0], &self.dims);
892        // Safety: bounds have been checked.
893        unsafe { self.subtensor_ref_unchecked(m) }
894    }
895
896    #[inline(always)]
897    /// Returns an immutable view of a 3D subtensor at the given index without bounds checking.
898    ///
899    /// # Safety
900    ///
901    /// Caller must ensure that m is within bounds.
902    pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> TensorRef<'a, C, 3> {
903        unsafe {
904            TensorRef::from_raw_parts(
905                self.ptr_at(&[m, 0, 0, 0]),
906                [self.dims[1], self.dims[2], self.dims[3]],
907                [self.strides[1], self.strides[2], self.strides[3]],
908            )
909        }
910    }
911}
912
913impl<'a, C: Memorable> TensorRef<'a, C, 3> {
914    /// Returns an immutable matrix view of a 2D subtensor at the given index.
915    pub fn subtensor_ref(&self, m: usize) -> MatRef<'a, C> {
916        check_bounds(&[m, 0, 0], &self.dims);
917        // Safety: bounds have been checked.
918        unsafe { self.subtensor_ref_unchecked(m) }
919    }
920
921    #[inline(always)]
922    /// Returns an immutable matrix view of a 2D subtensor at the given index without bounds checking.
923    ///
924    /// # Safety
925    ///
926    /// Caller must ensure that m is within bounds.
927    pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> MatRef<'a, C> {
928        unsafe {
929            MatRef::from_raw_parts(
930                self.ptr_at(&[m, 0, 0]),
931                self.dims[1],
932                self.dims[2],
933                self.strides[1] as isize,
934                self.strides[2] as isize,
935            )
936        }
937    }
938}
939
940impl<'a, C: Memorable> TensorRef<'a, C, 2> {
941    /// Returns an immutable row view of a 1D subtensor at the given index.
942    pub fn subtensor_ref(&self, m: usize) -> RowRef<'a, C> {
943        check_bounds(&[m, 0], &self.dims);
944        // Safety: bounds have been checked.
945        unsafe { self.subtensor_ref_unchecked(m) }
946    }
947
948    #[inline(always)]
949    /// Returns an immutable row view of a 1D subtensor at the given index without bounds checking.
950    ///
951    /// # Safety
952    ///
953    /// Caller must ensure that m is within bounds.
954    pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> RowRef<'a, C> {
955        unsafe {
956            RowRef::from_raw_parts(self.ptr_at(&[m, 0]), self.dims[1], self.strides[1] as isize)
957        }
958    }
959}
960
961impl<'a, C: Memorable> TensorRef<'a, C, 1> {
962    /// Returns an immutable reference to an element at the given index.
963    pub fn subtensor_ref(&self, m: usize) -> &C {
964        self.get(&[m])
965    }
966
967    #[inline(always)]
968    /// Returns an immutable reference to an element at the given index without bounds checking.
969    ///
970    /// # Safety
971    ///
972    /// Caller must ensure that m is within bounds.
973    pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> &C {
974        unsafe { self.get_unchecked(&[m]) }
975    }
976}
977
978impl<'a, C: Memorable> TensorMut<'a, C, 4> {
979    /// Returns an immutable view of a 3D subtensor at the given index.
980    pub fn subtensor_ref(&self, m: usize) -> TensorRef<'a, C, 3> {
981        check_bounds(&[m, 0, 0, 0], &self.dims);
982        // Safety: bounds have been checked.
983        unsafe { self.subtensor_ref_unchecked(m) }
984    }
985
986    /// Returns a mutable view of a 3D subtensor at the given index.
987    pub fn subtensor_mut(&mut self, m: usize) -> TensorMut<'a, C, 3> {
988        check_bounds(&[m, 0, 0, 0], &self.dims);
989        // Safety: bounds have been checked.
990        unsafe { self.subtensor_mut_unchecked(m) }
991    }
992
993    #[inline(always)]
994    /// Returns an immutable view of a 3D subtensor at the given index without bounds checking.
995    ///
996    /// # Safety
997    ///
998    /// Caller must ensure that m is within bounds.
999    pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> TensorRef<'a, C, 3> {
1000        unsafe {
1001            TensorRef::from_raw_parts(
1002                self.ptr_at(&[m, 0, 0, 0]),
1003                [self.dims[1], self.dims[2], self.dims[3]],
1004                [self.strides[1], self.strides[2], self.strides[3]],
1005            )
1006        }
1007    }
1008
1009    #[inline(always)]
1010    /// Returns a mutable view of a 3D subtensor at the given index without bounds checking.
1011    ///
1012    /// # Safety
1013    ///
1014    /// Caller must ensure that m is within bounds.
1015    pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> TensorMut<'a, C, 3> {
1016        unsafe {
1017            TensorMut::from_raw_parts(
1018                self.ptr_at_mut(&[m, 0, 0, 0]),
1019                [self.dims[1], self.dims[2], self.dims[3]],
1020                [self.strides[1], self.strides[2], self.strides[3]],
1021            )
1022        }
1023    }
1024}
1025
1026impl<'a, C: Memorable> TensorMut<'a, C, 3> {
1027    /// Returns an immutable matrix view of a 2D subtensor at the given index.
1028    pub fn subtensor_ref(&self, m: usize) -> MatRef<'a, C> {
1029        check_bounds(&[m, 0, 0], &self.dims);
1030        // Safety: bounds have been checked.
1031        unsafe { self.subtensor_ref_unchecked(m) }
1032    }
1033
1034    /// Returns a mutable matrix view of a 2D subtensor at the given index.
1035    pub fn subtensor_mut(&mut self, m: usize) -> MatMut<'a, C> {
1036        check_bounds(&[m, 0, 0], &self.dims);
1037        // Safety: bounds have been checked.
1038        unsafe { self.subtensor_mut_unchecked(m) }
1039    }
1040
1041    #[inline(always)]
1042    /// Returns an immutable matrix view of a 2D subtensor at the given index without bounds checking.
1043    ///
1044    /// # Safety
1045    ///
1046    /// Caller must ensure that m is within bounds.
1047    pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> MatRef<'a, C> {
1048        unsafe {
1049            MatRef::from_raw_parts(
1050                self.ptr_at(&[m, 0, 0]),
1051                self.dims[1],
1052                self.dims[2],
1053                self.strides[1] as isize,
1054                self.strides[2] as isize,
1055            )
1056        }
1057    }
1058
1059    #[inline(always)]
1060    /// Returns a mutable matrix view of a 2D subtensor at the given index without bounds checking.
1061    ///
1062    /// # Safety
1063    ///
1064    /// Caller must ensure that m is within bounds.
1065    pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> MatMut<'a, C> {
1066        unsafe {
1067            MatMut::from_raw_parts_mut(
1068                self.ptr_at_mut(&[m, 0, 0]),
1069                self.dims[1],
1070                self.dims[2],
1071                self.strides[1] as isize,
1072                self.strides[2] as isize,
1073            )
1074        }
1075    }
1076}
1077
1078impl<'a, C: Memorable> TensorMut<'a, C, 2> {
1079    /// Returns an immutable row view of a 1D subtensor at the given index.
1080    pub fn subtensor_ref(&self, m: usize) -> RowRef<'a, C> {
1081        check_bounds(&[m, 0], &self.dims);
1082        // Safety: bounds have been checked.
1083        unsafe { self.subtensor_ref_unchecked(m) }
1084    }
1085
1086    /// Returns a mutable row view of a 1D subtensor at the given index.
1087    pub fn subtensor_mut(&mut self, m: usize) -> RowMut<'a, C> {
1088        check_bounds(&[m, 0], &self.dims);
1089        // Safety: bounds have been checked.
1090        unsafe { self.subtensor_mut_unchecked(m) }
1091    }
1092
1093    #[inline(always)]
1094    /// Returns an immutable row view of a 1D subtensor at the given index without bounds checking.
1095    ///
1096    /// # Safety
1097    ///
1098    /// Caller must ensure that m is within bounds.
1099    pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> RowRef<'a, C> {
1100        unsafe {
1101            RowRef::from_raw_parts(self.ptr_at(&[m, 0]), self.dims[1], self.strides[1] as isize)
1102        }
1103    }
1104
1105    #[inline(always)]
1106    /// Returns a mutable row view of a 1D subtensor at the given index without bounds checking.
1107    ///
1108    /// # Safety
1109    ///
1110    /// Caller must ensure that m is within bounds.
1111    pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> RowMut<'a, C> {
1112        unsafe {
1113            RowMut::from_raw_parts_mut(
1114                self.ptr_at_mut(&[m, 0]),
1115                self.dims[1],
1116                self.strides[1] as isize,
1117            )
1118        }
1119    }
1120}
1121
1122impl<'a, C: Memorable> TensorMut<'a, C, 1> {
1123    /// Returns an immutable reference to an element at the given index.
1124    pub fn subtensor_ref(&self, m: usize) -> &C {
1125        self.get(&[m])
1126    }
1127
1128    /// Returns a mutable reference to an element at the given index.
1129    pub fn subtensor_mut(&mut self, m: usize) -> &mut C {
1130        self.get_mut(&[m])
1131    }
1132
1133    #[inline(always)]
1134    /// Returns an immutable reference to an element at the given index without bounds checking.
1135    ///
1136    /// # Safety
1137    ///
1138    /// Caller must ensure that m is within bounds.
1139    pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> &C {
1140        unsafe { self.get_unchecked(&[m]) }
1141    }
1142
1143    #[inline(always)]
1144    /// Returns a mutable reference to an element at the given index without bounds checking.
1145    ///
1146    /// # Safety
1147    ///
1148    /// Caller must ensure that m is within bounds.
1149    pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> &mut C {
1150        unsafe { self.get_mut_unchecked(&[m]) }
1151    }
1152}