qudit_core/array/tensor.rs
1//! Implements the tensor struct and associated methods for the Openqudit library.
2
3use faer::{MatMut, MatRef, RowMut, RowRef};
4use std::fmt::{self, Debug, Display, Formatter};
5use std::ptr::NonNull;
6
7use super::check_bounds;
8use crate::memory::{Memorable, MemoryBuffer, alloc_zeroed_memory};
9
10/// Helper for flat index calculation from multi-dimensional indices.
11#[inline(always)]
12fn calculate_flat_index<const D: usize>(indices: &[usize; D], strides: &[usize; D]) -> usize {
13 let mut flat_idx = 0;
14 for i in 0..D {
15 flat_idx += indices[i] * strides[i];
16 }
17 flat_idx
18}
19
20/// A tensor struct that holds data in an aligned memory buffer.
21pub struct Tensor<C: Memorable, const D: usize> {
22 /// The data buffer containing the tensor elements.
23 data: MemoryBuffer<C>,
24 /// The dimensions of the tensor (size of each axis).
25 dims: [usize; D],
26 /// The strides for each dimension.
27 strides: [usize; D],
28}
29
30impl<C: Memorable, const D: usize> Tensor<C, D> {
31 /// Creates a new tensor from a memory buffer with specified dimensions and strides.
32 ///
33 /// # Arguments
34 ///
35 /// * `data` - The memory buffer containing the tensor data
36 /// * `dims` - Array specifying the size of each dimension
37 /// * `strides` - Array specifying the stride for each dimension
38 ///
39 /// # Panics
40 ///
41 /// * If any dimension or stride is zero
42 /// * If the data buffer is not large enough for the specified dimensions and strides
43 pub fn new(data: MemoryBuffer<C>, dims: [usize; D], strides: [usize; D]) -> Self {
44 assert!(
45 dims.iter().all(|&d| d != 0),
46 "Cannot have a zero-length dimension."
47 );
48 assert!(
49 strides.iter().all(|&d| d != 0),
50 "Cannot have a zero-length stride."
51 );
52
53 let mut max_element = [0; D];
54 for (i, d) in dims.iter().enumerate() {
55 max_element[i] = d - 1;
56 }
57 let max_flat_index = calculate_flat_index(&max_element, &strides);
58
59 assert!(
60 data.len() >= max_flat_index,
61 "Data buffer is not large enough."
62 );
63
64 Self {
65 data,
66 dims,
67 strides,
68 }
69 }
70
71 /// Creates a new tensor with all elements initialized to zero,
72 /// with specified shape.
73 ///
74 /// # Arguments
75 ///
76 /// * `dims` - A slice of `usize` containing the size of each dimension.
77 ///
78 /// # Returns
79 ///
80 /// * An new tensor with specified shape, filled with zeros.
81 ///
82 /// # Panics
83 ///
84 /// * If the length of `dims` is not equal to the number of
85 /// dimensions of the tensor.
86 ///
87 /// # Examples
88 /// ```
89 /// # use qudit_core::array::Tensor;
90 ///
91 /// let test_tensor = Tensor::<f64, 2>::zeros([3, 4]);
92 ///
93 /// for i in 0..3 {
94 /// for j in 0..4 {
95 /// assert_eq!(test_tensor.get(&[i, j]), &0.0);
96 /// }
97 /// }
98 /// ```
99 pub fn zeros(dims: [usize; D]) -> Self {
100 let strides = super::calc_continuous_strides(&dims);
101 let data = alloc_zeroed_memory::<C>(strides[0] * dims[0]);
102 Self::new(data, dims, strides)
103 }
104
105 /// Returns a reference to the dimensions of the tensor.
106 pub fn dims(&self) -> &[usize; D] {
107 &self.dims
108 }
109
110 /// Returns a reference to the strides of the tensor.
111 pub fn strides(&self) -> &[usize; D] {
112 &self.strides
113 }
114
115 /// Returns the rank (number of dimensions) of the tensor.
116 pub fn rank(&self) -> usize {
117 D
118 }
119
120 /// Returns the total number of elements in the tensor.
121 pub fn num_elements(&self) -> usize {
122 self.dims.iter().product()
123 }
124
125 /// Returns a raw pointer to the tensor's data.
126 pub fn as_ptr(&self) -> *const C {
127 self.data.as_ptr()
128 }
129
130 /// Returns a mutable raw pointer to the tensor's data.
131 pub fn as_ptr_mut(&mut self) -> *mut C {
132 self.data.as_mut_ptr()
133 }
134
135 /// Returns an immutable reference to the tensor.
136 pub fn as_ref(&self) -> TensorRef<'_, C, D> {
137 unsafe { TensorRef::from_raw_parts(self.data.as_ptr(), self.dims, self.strides) }
138 }
139
140 /// Returns a mutable reference to the tensor.
141 pub fn as_mut(&mut self) -> TensorMut<'_, C, D> {
142 unsafe { TensorMut::from_raw_parts(self.data.as_mut_ptr(), self.dims, self.strides) }
143 }
144
145 /// Returns a reference to an element at the given indices.
146 ///
147 /// # Panics
148 ///
149 /// Panics if the indices are out of bounds.
150 pub fn get(&self, indices: &[usize; D]) -> &C {
151 check_bounds(indices, &self.dims);
152 // Safety: bounds are checked by `check_bounds`
153 unsafe { self.get_unchecked(indices) }
154 }
155
156 /// Returns a mutable reference to an element at the given indices.
157 ///
158 /// # Panics
159 ///
160 /// Panics if the indices are out of bounds.
161 pub fn get_mut(&mut self, indices: &[usize; D]) -> &mut C {
162 check_bounds(indices, &self.dims);
163 // Safety: bounds are checked by `check_bounds`
164 unsafe { self.get_mut_unchecked(indices) }
165 }
166
167 /// Returns an immutable reference to an element at the given indices, without performing bounds checks.
168 ///
169 /// # Safety
170 ///
171 /// Calling this method with out-of-bounds `indices` is undefined behavior.
172 #[inline(always)]
173 pub unsafe fn get_unchecked(&self, indices: &[usize; D]) -> &C {
174 unsafe { &*self.ptr_at(indices) }
175 }
176
177 /// Returns a mutable reference to an element at the given indices, without performing bounds checks.
178 ///
179 /// # Safety
180 ///
181 /// Calling this method with out-of-bounds `indices` is undefined behavior.
182 #[inline(always)]
183 pub unsafe fn get_mut_unchecked(&mut self, indices: &[usize; D]) -> &mut C {
184 unsafe { &mut *self.ptr_at_mut(indices) }
185 }
186
187 /// Returns a raw pointer to an element at the given indices, without performing bounds checks.
188 ///
189 /// # Safety
190 ///
191 /// Calling this method with out-of-bounds `indices` is undefined behavior.
192 #[inline(always)]
193 pub unsafe fn ptr_at(&self, indices: &[usize; D]) -> *const C {
194 unsafe {
195 let flat_idx = calculate_flat_index(indices, &self.strides);
196 self.as_ptr().add(flat_idx)
197 }
198 }
199
200 /// Returns a mutable raw pointer to an element at the given indices, without performing bounds checks.
201 ///
202 /// # Safety
203 ///
204 /// Calling this method with out-of-bounds `indices` is undefined behavior.
205 #[inline(always)]
206 pub unsafe fn ptr_at_mut(&mut self, indices: &[usize; D]) -> *mut C {
207 unsafe {
208 let flat_idx = calculate_flat_index(indices, &self.strides);
209 self.as_ptr_mut().add(flat_idx)
210 }
211 }
212
213 /// Creates a new `Tensor` from a flat `Vec` and its dimensions.
214 ///
215 /// This is a convenience constructor that automatically converts the `Vec`
216 /// into a `MemoryBuffer` and then calls the `new` constructor.
217 ///
218 /// # Panics
219 /// Panics if the total number of elements implied by `dimensions`
220 /// (product of all dimension sizes) does not match the length of the `data_vec`.
221 ///
222 /// # Examples
223 /// ```
224 /// # use qudit_core::array::Tensor;
225 /// let tensor_from_slice = Tensor::from_slice(&vec![10, 20, 30, 40], [2, 2]);
226 /// assert_eq!(tensor_from_slice.dims(), &[2, 2]);
227 /// assert_eq!(tensor_from_slice.strides(), &[2, 1]);
228 /// ```
229 pub fn from_slice(slice: &[C], dims: [usize; D]) -> Self {
230 let strides = super::calc_continuous_strides(&dims);
231 Self::from_slice_with_strides(slice, dims, strides)
232 }
233
234 /// Creates a new `Tensor` from a slice of data, explicit dimensions, and strides.
235 ///
236 /// This constructor allows for creating tensors with custom stride patterns,
237 /// which can be useful for representing views or sub-tensors of larger data
238 /// structures without copying the underlying data.
239 ///
240 /// # Panics
241 /// Panics if:
242 /// - The `dimensions` and `strides` arrays do not have the same number of elements as `D`.
243 /// - The total number of elements implied by `dimensions` and `strides` (i.e., the
244 /// maximum flat index + 1) exceeds the length of the `slice`.
245 /// - Any stride is zero unless its corresponding dimension is also zero.
246 ///
247 /// # Arguments
248 /// * `slice` - The underlying data slice.
249 /// * `dimensions` - An array of `usize` defining the size of each dimension.
250 /// * `strides` - An array of `usize` defining the stride for each dimension.
251 ///
252 /// # Examples
253 /// ```
254 /// # use qudit_core::array::Tensor;
255 /// // Create a 2x3 tensor from a slice with custom strides
256 /// let data = vec![1, 2, 3, 4, 5, 6];
257 /// let tensor = Tensor::from_slice_with_strides(
258 /// &data,
259 /// [2, 3], // 2 rows, 3 columns
260 /// [3, 1], // Stride for rows is 3 elements, for columns is 1 element
261 /// );
262 /// assert_eq!(tensor.dims(), &[2, 3]);
263 /// assert_eq!(tensor.strides(), &[3, 1]);
264 /// assert_eq!(tensor.get(&[0, 0]), &1);
265 /// assert_eq!(tensor.get(&[0, 1]), &2);
266 /// assert_eq!(tensor.get(&[1, 0]), &4);
267 ///
268 /// // Creating a column vector view from a larger matrix's data
269 /// let matrix_data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; // A 3x3 matrix's data
270 /// // View the second column (elements 2, 5, 8) as a 3x1 tensor
271 /// let column_view = Tensor::from_slice_with_strides(
272 /// &matrix_data,
273 /// [3, 1], // 3 rows, 1 column
274 /// [3, 1], // Stride to next row is 3, stride to next column is 1 (but only 1 column)
275 /// );
276 /// assert_eq!(column_view.dims(), &[3, 1]);
277 /// assert_eq!(column_view.strides(), &[3, 1]);
278 /// // Note: This example is slightly misleading as the slice itself doesn't change for the column.
279 /// // A more accurate example for strides would involve a sub-view that skips elements.
280 /// // This specific case would be more typical for `TensorRef`.
281 /// ```
282 pub fn from_slice_with_strides(slice: &[C], dims: [usize; D], strides: [usize; D]) -> Self {
283 let data = MemoryBuffer::from_slice(64, slice);
284 Self::new(data, dims, strides)
285 }
286
287 /// Creates a new tensor with all elements initialized to zero,
288 /// with specified shape and strides.
289 ///
290 /// # Arguments
291 ///
292 /// * `dims` - A slice of `usize` containing the size of each dimension
293 /// * `strides` - A slice of `usize` containing the stride for each dimension.
294 ///
295 /// # Returns
296 ///
297 /// * A new tensor with specified shape and strides, filled with zeros.
298 ///
299 /// # Panics
300 ///
301 /// * If the length of `dims` or `strides` is not equal to the number of
302 /// dimensions of the tensor.
303 /// * If the size of any dimension is zero but the corresponding stride is non-zero.
304 /// * If the size of any dimension is non-zero but the corresponding stride is zero.
305 ///
306 /// # Examples
307 /// ```
308 /// # use qudit_core::array::Tensor;
309 ///
310 /// let test_tensor = Tensor::<f64, 2>::zeros_with_strides(&[3, 4], &[4, 1]);
311 ///
312 /// for i in 0..3 {
313 /// for j in 0..4 {
314 /// assert_eq!(test_tensor.get(&[i, j]), &0.0);
315 /// }
316 /// }
317 /// ```
318 pub fn zeros_with_strides(dims: &[usize; D], strides: &[usize; D]) -> Self {
319 let data = alloc_zeroed_memory::<C>(strides[0] * dims[0]);
320 Self::new(data, *dims, *strides)
321 }
322}
323
324impl<C: Memorable, const D: usize> std::ops::Index<[usize; D]> for Tensor<C, D> {
325 type Output = C;
326
327 fn index(&self, indices: [usize; D]) -> &Self::Output {
328 self.get(&indices)
329 }
330}
331
332impl<C: Memorable, const D: usize> std::ops::IndexMut<[usize; D]> for Tensor<C, D> {
333 fn index_mut(&mut self, indices: [usize; D]) -> &mut Self::Output {
334 self.get_mut(&indices)
335 }
336}
337
338// Helper struct for recursively formatting the tensor data
339// to display it as a multi-dimensional array.
340struct TensorDataDebugHelper<'a, C: Display> {
341 data_ptr: *const C,
342 dimensions: &'a [usize],
343 strides: &'a [usize],
344 current_dim_idx: usize,
345 current_flat_offset: usize,
346}
347
348impl<'a, C: Display> Debug for TensorDataDebugHelper<'a, C> {
349 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
350 let indent = "\t".repeat(self.current_dim_idx);
351 // Base case: If we've reached the deepest dimension level,
352 // it means we are at an individual element. Print its value directly.
353 if self.current_dim_idx == self.dimensions.len() {
354 // SAFETY: The `current_flat_offset` is calculated based on the tensor's
355 // dimensions and strides. It is assumed to be within the bounds of the
356 // allocated data, as guaranteed by the `Tensor` and `TensorRef` structures.
357 unsafe { write!(f, "{}", &*self.data_ptr.add(self.current_flat_offset)) }
358 } else {
359 // Recursive case: We are at an intermediate dimension.
360 // Print this dimension as a list of sub-tensors/elements.
361 let dim_size = self.dimensions[self.current_dim_idx];
362 let dim_stride = self.strides[self.current_dim_idx];
363
364 // let mut list_formatter = f.debug_list();
365 if self.current_dim_idx == self.dimensions.len() - 1 {
366 write!(f, "{}[", indent)?;
367 } else {
368 writeln!(f, "{}[", indent)?;
369 }
370 for i in 0..dim_size {
371 let next_offset = self.current_flat_offset + i * dim_stride;
372
373 write!(
374 f,
375 "{:?}",
376 TensorDataDebugHelper {
377 data_ptr: self.data_ptr,
378 dimensions: self.dimensions,
379 strides: self.strides,
380 current_dim_idx: self.current_dim_idx + 1,
381 current_flat_offset: next_offset,
382 }
383 )?;
384
385 if self.current_dim_idx == self.dimensions.len() - 1 && i != dim_size - 1 {
386 write!(f, ", ")?;
387 }
388 }
389 if self.current_dim_idx == self.dimensions.len() - 1 {
390 writeln!(f, "],",)
391 } else {
392 writeln!(f, "{}],", indent)
393 }
394 // write!(f, "\n")
395 // list_formatter.finish()
396 }
397 }
398}
399
400impl<C: Display + Debug + Memorable, const D: usize> Debug for Tensor<C, D> {
401 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
402 f.debug_struct("Tensor")
403 .field("dimensions", &self.dims)
404 .field("strides", &self.strides)
405 .field(
406 "data",
407 &TensorDataDebugHelper {
408 data_ptr: self.data.as_ptr(), // Pointer to the start of the data buffer
409 dimensions: &self.dims,
410 strides: &self.strides,
411 current_dim_idx: 0, // Start formatting from the first dimension (index 0)
412 current_flat_offset: 0, // Start from offset 0 in the flat data buffer
413 },
414 )
415 .finish()
416 }
417}
418
419impl<'a, C: Display + Debug + Memorable, const D: usize> Debug for TensorRef<'a, C, D> {
420 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
421 f.debug_struct("TensorRef")
422 .field("dimensions", &self.dims)
423 .field("strides", &self.strides)
424 .field(
425 "data",
426 &TensorDataDebugHelper {
427 data_ptr: self.data.as_ptr(), // `self.data` is already a `*const C` for `TensorRef`
428 dimensions: &self.dims,
429 strides: &self.strides,
430 current_dim_idx: 0,
431 current_flat_offset: 0,
432 },
433 )
434 .finish()
435 }
436}
437
438// TODO: add iterators for subtensors
439
440/// An immutable view into tensor data.
441///
442/// This struct provides read-only access to tensor data without owning the underlying memory.
443/// It holds a reference to the data along with dimension and stride information.
444#[derive(Clone, Copy)]
445pub struct TensorRef<'a, C: Memorable, const D: usize> {
446 /// Non-null pointer to the tensor data.
447 data: NonNull<C>,
448 /// The dimensions of the tensor (size of each axis).
449 dims: [usize; D],
450 /// The strides for each dimension.
451 strides: [usize; D],
452 /// Phantom data to enforce lifetime constraints.
453 __marker: std::marker::PhantomData<&'a C>,
454}
455
456impl<'a, C: Memorable, const D: usize> TensorRef<'a, C, D> {
457 /// Creates a `TensorRef` from pointers to tensor data, dimensions, and strides.
458 ///
459 /// # Safety
460 ///
461 /// The caller must ensure:
462 /// - The pointers are valid and non-null.
463 /// - For each unit, the entire memory region addressed by the tensor is
464 /// within a single allocation.
465 /// - The memory is accessible by the pointer.
466 /// - No mutable aliasing occurs. No mutable references to the tensor data
467 /// exist when the `MatVecRef` is alive.
468 pub unsafe fn from_raw_parts(data: *const C, dims: [usize; D], strides: [usize; D]) -> Self {
469 // SAFETY: The pointer is never used in an mutable context.
470 let ptr = unsafe { NonNull::new_unchecked(data as *mut C) };
471
472 Self {
473 data: ptr,
474 dims,
475 strides,
476 __marker: std::marker::PhantomData,
477 }
478 }
479
480 /// Returns a reference to the dimensions of the tensor.
481 pub fn dims(&self) -> &[usize; D] {
482 &self.dims
483 }
484
485 /// Returns a reference to the strides of the tensor.
486 pub fn strides(&self) -> &[usize; D] {
487 &self.strides
488 }
489
490 /// Returns the rank (number of dimensions) of the tensor.
491 pub fn rank(&self) -> usize {
492 D
493 }
494
495 /// Returns the total number of elements in the tensor.
496 pub fn num_elements(&self) -> usize {
497 self.dims.iter().product()
498 }
499
500 /// Returns a raw pointer to the tensor's data.
501 pub fn as_ptr(&self) -> *const C {
502 self.data.as_ptr()
503 }
504
505 /// Returns a reference to an element at the given indices.
506 ///
507 /// # Panics
508 ///
509 /// Panics if the indices are out of bounds.
510 pub fn get(&self, indices: &[usize; D]) -> &C {
511 check_bounds(indices, &self.dims);
512 // Safety: bounds are checked by `check_bounds`
513 unsafe { self.get_unchecked(indices) }
514 }
515
516 /// Returns an immutable reference to an element at the given indices, without performing bounds checks.
517 ///
518 /// # Safety
519 ///
520 /// Calling this method with out-of-bounds `indices` is undefined behavior.
521 #[inline(always)]
522 pub unsafe fn get_unchecked(&self, indices: &[usize; D]) -> &C {
523 unsafe { &*self.ptr_at(indices) }
524 }
525
526 /// Returns a raw pointer to an element at the given indices, without performing bounds checks.
527 ///
528 /// # Safety
529 ///
530 /// Calling this method with out-of-bounds `indices` is undefined behavior.
531 #[inline(always)]
532 pub unsafe fn ptr_at(&self, indices: &[usize; D]) -> *const C {
533 unsafe {
534 let flat_idx = calculate_flat_index(indices, &self.strides);
535 self.as_ptr().add(flat_idx)
536 }
537 }
538
539 /// Creates an owned `Tensor` by copying the data from this `TensorRef`.
540 pub fn to_owned(self) -> Tensor<C, D> {
541 let mut max_element = [0; D];
542 for (i, d) in self.dims.iter().enumerate() {
543 max_element[i] = d - 1;
544 }
545 let max_flat_index = calculate_flat_index(&max_element, &self.strides);
546 // Safety: Memory is nonnull and shared throughout max_flat_index.
547 // Slice is copied from and dropped immediately.
548 unsafe {
549 let slice = std::slice::from_raw_parts(self.data.as_ptr(), max_flat_index + 1);
550 Tensor::from_slice_with_strides(slice, self.dims, self.strides)
551 }
552 }
553}
554
555impl<'a, C: Memorable, const D: usize> std::ops::Index<[usize; D]> for TensorRef<'a, C, D> {
556 type Output = C;
557
558 fn index(&self, indices: [usize; D]) -> &Self::Output {
559 self.get(&indices)
560 }
561}
562
563/// A mutable view into tensor data.
564///
565/// This struct provides read-write access to tensor data without owning the underlying memory.
566/// It holds a mutable reference to the data along with dimension and stride information.
567pub struct TensorMut<'a, C: Memorable, const D: usize> {
568 /// Non-null pointer to the tensor data.
569 data: NonNull<C>,
570 /// The dimensions of the tensor (size of each axis).
571 dims: [usize; D],
572 /// The strides for each dimension.
573 strides: [usize; D],
574 /// Phantom data to enforce lifetime constraints.
575 __marker: std::marker::PhantomData<&'a mut C>,
576}
577
578impl<'a, C: Memorable, const D: usize> TensorMut<'a, C, D> {
579 /// Creates a new `SymSqTensorMut` from raw parts.
580 ///
581 /// # Safety
582 ///
583 /// The caller must ensure that `data` points to a valid memory block of `C` elements,
584 /// and that `dims` and `strides` accurately describe the layout of the tensor
585 /// within that memory block. The `data` pointer must be valid for the lifetime `'a`
586 /// and that it is safe to mutate the data.
587 pub unsafe fn from_raw_parts(data: *mut C, dims: [usize; D], strides: [usize; D]) -> Self {
588 unsafe {
589 Self {
590 data: NonNull::new_unchecked(data),
591 dims,
592 strides,
593 __marker: std::marker::PhantomData,
594 }
595 }
596 }
597
598 /// Returns a reference to the dimensions of the tensor.
599 pub fn dims(&self) -> &[usize; D] {
600 &self.dims
601 }
602
603 /// Returns a reference to the strides of the tensor.
604 pub fn strides(&self) -> &[usize; D] {
605 &self.strides
606 }
607
608 /// Returns the rank (number of dimensions) of the tensor.
609 pub fn rank(&self) -> usize {
610 D
611 }
612
613 /// Returns the total number of elements in the tensor.
614 pub fn num_elements(&self) -> usize {
615 self.dims.iter().product()
616 }
617
618 /// Returns a mutable raw pointer to the tensor's data.
619 pub fn as_ptr(&self) -> *const C {
620 self.data.as_ptr() as *const C
621 }
622
623 /// Returns a mutable raw pointer to the tensor's data.
624 pub fn as_ptr_mut(&mut self) -> *mut C {
625 self.data.as_ptr()
626 }
627
628 /// Returns a reference to an element at the given indices.
629 ///
630 /// # Panics
631 ///
632 /// Panics if the indices are out of bounds.
633 pub fn get(&self, indices: &[usize; D]) -> &C {
634 check_bounds(indices, &self.dims);
635 // Safety: bounds are checked by `check_bounds`
636 unsafe { self.get_unchecked(indices) }
637 }
638
639 /// Returns a mutable reference to an element at the given indices.
640 ///
641 /// # Panics
642 ///
643 /// Panics if the indices are out of bounds.
644 pub fn get_mut(&mut self, indices: &[usize; D]) -> &mut C {
645 check_bounds(indices, &self.dims);
646 // Safety: bounds are checked by `check_bounds`
647 unsafe { self.get_mut_unchecked(indices) }
648 }
649
650 /// Returns an immutable reference to an element at the given indices, without performing bounds checks.
651 ///
652 /// # Safety
653 ///
654 /// Calling this method with out-of-bounds `indices` is undefined behavior.
655 #[inline(always)]
656 pub unsafe fn get_unchecked(&self, indices: &[usize; D]) -> &C {
657 unsafe { &*self.ptr_at(indices) }
658 }
659
660 /// Returns a mutable reference to an element at the given indices, without performing bounds checks.
661 ///
662 /// # Safety
663 ///
664 /// Calling this method with out-of-bounds `indices` is undefined behavior.
665 #[inline(always)]
666 pub unsafe fn get_mut_unchecked(&mut self, indices: &[usize; D]) -> &mut C {
667 unsafe { &mut *self.ptr_at_mut(indices) }
668 }
669
670 /// Returns a raw pointer to an element at the given indices, without performing bounds checks.
671 ///
672 /// # Safety
673 ///
674 /// Calling this method with out-of-bounds `indices` is undefined behavior.
675 #[inline(always)]
676 pub unsafe fn ptr_at(&self, indices: &[usize; D]) -> *const C {
677 unsafe {
678 let flat_idx = calculate_flat_index(indices, &self.strides);
679 self.as_ptr().add(flat_idx)
680 }
681 }
682
683 /// Returns a mutable raw pointer to an element at the given indices, without performing bounds checks.
684 ///
685 /// # Safety
686 ///
687 /// Calling this method with out-of-bounds `indices` is undefined behavior.
688 #[inline(always)]
689 pub unsafe fn ptr_at_mut(&mut self, indices: &[usize; D]) -> *mut C {
690 unsafe {
691 let flat_idx = calculate_flat_index(indices, &self.strides);
692 self.as_ptr_mut().add(flat_idx)
693 }
694 }
695}
696
697impl<'a, C: Memorable, const D: usize> std::ops::Index<[usize; D]> for TensorMut<'a, C, D> {
698 type Output = C;
699
700 fn index(&self, indices: [usize; D]) -> &Self::Output {
701 self.get(&indices)
702 }
703}
704
705impl<'a, C: Memorable, const D: usize> std::ops::IndexMut<[usize; D]> for TensorMut<'a, C, D> {
706 fn index_mut(&mut self, indices: [usize; D]) -> &mut Self::Output {
707 self.get_mut(&indices)
708 }
709}
710
711// TODO add some documentation plus a todo tag on relevant rust issues
712impl<C: Memorable> Tensor<C, 4> {
713 /// Returns an immutable view of a 3D subtensor at the given index.
714 pub fn subtensor_ref(&self, m: usize) -> TensorRef<'_, C, 3> {
715 check_bounds(&[m, 0, 0, 0], &self.dims);
716 // Safety: bounds have been checked.
717 unsafe { self.subtensor_ref_unchecked(m) }
718 }
719
720 /// Returns a mutable view of a 3D subtensor at the given index.
721 pub fn subtensor_mut(&mut self, m: usize) -> TensorMut<'_, C, 3> {
722 check_bounds(&[m, 0, 0, 0], &self.dims);
723 // Safety: bounds have been checked.
724 unsafe { self.subtensor_mut_unchecked(m) }
725 }
726
727 #[inline(always)]
728 /// Returns an immutable view of a 3D subtensor at the given index without bounds checking.
729 ///
730 /// # Safety
731 ///
732 /// Caller must ensure that m is within bounds.
733 pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> TensorRef<'_, C, 3> {
734 unsafe {
735 TensorRef::from_raw_parts(
736 self.ptr_at(&[m, 0, 0, 0]),
737 [self.dims[1], self.dims[2], self.dims[3]],
738 [self.strides[1], self.strides[2], self.strides[3]],
739 )
740 }
741 }
742
743 #[inline(always)]
744 /// Returns a mutable view of a 3D subtensor at the given index without bounds checking.
745 ///
746 /// # Safety
747 ///
748 /// Caller must ensure that m is within bounds.
749 pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> TensorMut<'_, C, 3> {
750 unsafe {
751 TensorMut::from_raw_parts(
752 self.ptr_at_mut(&[m, 0, 0, 0]),
753 [self.dims[1], self.dims[2], self.dims[3]],
754 [self.strides[1], self.strides[2], self.strides[3]],
755 )
756 }
757 }
758}
759
760impl<C: Memorable> Tensor<C, 3> {
761 /// Returns an immutable matrix view of a 2D subtensor at the given index.
762 pub fn subtensor_ref(&self, m: usize) -> MatRef<'_, C> {
763 check_bounds(&[m, 0, 0], &self.dims);
764 // Safety: bounds have been checked.
765 unsafe { self.subtensor_ref_unchecked(m) }
766 }
767
768 /// Returns a mutable matrix view of a 2D subtensor at the given index.
769 pub fn subtensor_mut(&mut self, m: usize) -> MatMut<'_, C> {
770 check_bounds(&[m, 0, 0], &self.dims);
771 // Safety: bounds have been checked.
772 unsafe { self.subtensor_mut_unchecked(m) }
773 }
774
775 #[inline(always)]
776 /// Returns an immutable matrix view of a 2D subtensor at the given index without bounds checking.
777 ///
778 /// # Safety
779 ///
780 /// Caller must ensure that m is within bounds.
781 pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> MatRef<'_, C> {
782 unsafe {
783 MatRef::from_raw_parts(
784 self.ptr_at(&[m, 0, 0]),
785 self.dims[1],
786 self.dims[2],
787 self.strides[1] as isize,
788 self.strides[2] as isize,
789 )
790 }
791 }
792
793 #[inline(always)]
794 /// Returns a mutable matrix view of a 2D subtensor at the given index without bounds checking.
795 ///
796 /// # Safety
797 ///
798 /// Caller must ensure that m is within bounds.
799 pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> MatMut<'_, C> {
800 unsafe {
801 MatMut::from_raw_parts_mut(
802 self.ptr_at_mut(&[m, 0, 0]),
803 self.dims[1],
804 self.dims[2],
805 self.strides[1] as isize,
806 self.strides[2] as isize,
807 )
808 }
809 }
810}
811
812impl<C: Memorable> Tensor<C, 2> {
813 /// Returns an immutable row view of a 1D subtensor at the given index.
814 pub fn subtensor_ref(&self, m: usize) -> RowRef<'_, C> {
815 check_bounds(&[m, 0], &self.dims);
816 // Safety: bounds have been checked.
817 unsafe { self.subtensor_ref_unchecked(m) }
818 }
819
820 /// Returns a mutable row view of a 1D subtensor at the given index.
821 pub fn subtensor_mut(&mut self, m: usize) -> RowMut<'_, C> {
822 check_bounds(&[m, 0], &self.dims);
823 // Safety: bounds have been checked.
824 unsafe { self.subtensor_mut_unchecked(m) }
825 }
826
827 #[inline(always)]
828 /// Returns an immutable row view of a 1D subtensor at the given index without bounds checking.
829 ///
830 /// # Safety
831 ///
832 /// Caller must ensure that m is within bounds.
833 pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> RowRef<'_, C> {
834 unsafe {
835 RowRef::from_raw_parts(self.ptr_at(&[m, 0]), self.dims[1], self.strides[1] as isize)
836 }
837 }
838
839 #[inline(always)]
840 /// Returns a mutable row view of a 1D subtensor at the given index without bounds checking.
841 ///
842 /// # Safety
843 ///
844 /// Caller must ensure that m is within bounds.
845 pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> RowMut<'_, C> {
846 unsafe {
847 RowMut::from_raw_parts_mut(
848 self.ptr_at_mut(&[m, 0]),
849 self.dims[1],
850 self.strides[1] as isize,
851 )
852 }
853 }
854}
855
856impl<C: Memorable> Tensor<C, 1> {
857 /// Returns an immutable reference to an element at the given index.
858 pub fn subtensor_ref(&self, m: usize) -> &C {
859 self.get(&[m])
860 }
861
862 /// Returns a mutable reference to an element at the given index.
863 pub fn subtensor_mut(&mut self, m: usize) -> &mut C {
864 self.get_mut(&[m])
865 }
866
867 #[inline(always)]
868 /// Returns an immutable reference to an element at the given index without bounds checking.
869 ///
870 /// # Safety
871 ///
872 /// Caller must ensure that m is within bounds.
873 pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> &C {
874 unsafe { self.get_unchecked(&[m]) }
875 }
876
877 #[inline(always)]
878 /// Returns a mutable reference to an element at the given index without bounds checking.
879 ///
880 /// # Safety
881 ///
882 /// Caller must ensure that m is within bounds.
883 pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> &mut C {
884 unsafe { self.get_mut_unchecked(&[m]) }
885 }
886}
887
888impl<'a, C: Memorable> TensorRef<'a, C, 4> {
889 /// Returns an immutable view of a 3D subtensor at the given index.
890 pub fn subtensor_ref(&self, m: usize) -> TensorRef<'a, C, 3> {
891 check_bounds(&[m, 0, 0, 0], &self.dims);
892 // Safety: bounds have been checked.
893 unsafe { self.subtensor_ref_unchecked(m) }
894 }
895
896 #[inline(always)]
897 /// Returns an immutable view of a 3D subtensor at the given index without bounds checking.
898 ///
899 /// # Safety
900 ///
901 /// Caller must ensure that m is within bounds.
902 pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> TensorRef<'a, C, 3> {
903 unsafe {
904 TensorRef::from_raw_parts(
905 self.ptr_at(&[m, 0, 0, 0]),
906 [self.dims[1], self.dims[2], self.dims[3]],
907 [self.strides[1], self.strides[2], self.strides[3]],
908 )
909 }
910 }
911}
912
913impl<'a, C: Memorable> TensorRef<'a, C, 3> {
914 /// Returns an immutable matrix view of a 2D subtensor at the given index.
915 pub fn subtensor_ref(&self, m: usize) -> MatRef<'a, C> {
916 check_bounds(&[m, 0, 0], &self.dims);
917 // Safety: bounds have been checked.
918 unsafe { self.subtensor_ref_unchecked(m) }
919 }
920
921 #[inline(always)]
922 /// Returns an immutable matrix view of a 2D subtensor at the given index without bounds checking.
923 ///
924 /// # Safety
925 ///
926 /// Caller must ensure that m is within bounds.
927 pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> MatRef<'a, C> {
928 unsafe {
929 MatRef::from_raw_parts(
930 self.ptr_at(&[m, 0, 0]),
931 self.dims[1],
932 self.dims[2],
933 self.strides[1] as isize,
934 self.strides[2] as isize,
935 )
936 }
937 }
938}
939
940impl<'a, C: Memorable> TensorRef<'a, C, 2> {
941 /// Returns an immutable row view of a 1D subtensor at the given index.
942 pub fn subtensor_ref(&self, m: usize) -> RowRef<'a, C> {
943 check_bounds(&[m, 0], &self.dims);
944 // Safety: bounds have been checked.
945 unsafe { self.subtensor_ref_unchecked(m) }
946 }
947
948 #[inline(always)]
949 /// Returns an immutable row view of a 1D subtensor at the given index without bounds checking.
950 ///
951 /// # Safety
952 ///
953 /// Caller must ensure that m is within bounds.
954 pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> RowRef<'a, C> {
955 unsafe {
956 RowRef::from_raw_parts(self.ptr_at(&[m, 0]), self.dims[1], self.strides[1] as isize)
957 }
958 }
959}
960
961impl<'a, C: Memorable> TensorRef<'a, C, 1> {
962 /// Returns an immutable reference to an element at the given index.
963 pub fn subtensor_ref(&self, m: usize) -> &C {
964 self.get(&[m])
965 }
966
967 #[inline(always)]
968 /// Returns an immutable reference to an element at the given index without bounds checking.
969 ///
970 /// # Safety
971 ///
972 /// Caller must ensure that m is within bounds.
973 pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> &C {
974 unsafe { self.get_unchecked(&[m]) }
975 }
976}
977
978impl<'a, C: Memorable> TensorMut<'a, C, 4> {
979 /// Returns an immutable view of a 3D subtensor at the given index.
980 pub fn subtensor_ref(&self, m: usize) -> TensorRef<'a, C, 3> {
981 check_bounds(&[m, 0, 0, 0], &self.dims);
982 // Safety: bounds have been checked.
983 unsafe { self.subtensor_ref_unchecked(m) }
984 }
985
986 /// Returns a mutable view of a 3D subtensor at the given index.
987 pub fn subtensor_mut(&mut self, m: usize) -> TensorMut<'a, C, 3> {
988 check_bounds(&[m, 0, 0, 0], &self.dims);
989 // Safety: bounds have been checked.
990 unsafe { self.subtensor_mut_unchecked(m) }
991 }
992
993 #[inline(always)]
994 /// Returns an immutable view of a 3D subtensor at the given index without bounds checking.
995 ///
996 /// # Safety
997 ///
998 /// Caller must ensure that m is within bounds.
999 pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> TensorRef<'a, C, 3> {
1000 unsafe {
1001 TensorRef::from_raw_parts(
1002 self.ptr_at(&[m, 0, 0, 0]),
1003 [self.dims[1], self.dims[2], self.dims[3]],
1004 [self.strides[1], self.strides[2], self.strides[3]],
1005 )
1006 }
1007 }
1008
1009 #[inline(always)]
1010 /// Returns a mutable view of a 3D subtensor at the given index without bounds checking.
1011 ///
1012 /// # Safety
1013 ///
1014 /// Caller must ensure that m is within bounds.
1015 pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> TensorMut<'a, C, 3> {
1016 unsafe {
1017 TensorMut::from_raw_parts(
1018 self.ptr_at_mut(&[m, 0, 0, 0]),
1019 [self.dims[1], self.dims[2], self.dims[3]],
1020 [self.strides[1], self.strides[2], self.strides[3]],
1021 )
1022 }
1023 }
1024}
1025
1026impl<'a, C: Memorable> TensorMut<'a, C, 3> {
1027 /// Returns an immutable matrix view of a 2D subtensor at the given index.
1028 pub fn subtensor_ref(&self, m: usize) -> MatRef<'a, C> {
1029 check_bounds(&[m, 0, 0], &self.dims);
1030 // Safety: bounds have been checked.
1031 unsafe { self.subtensor_ref_unchecked(m) }
1032 }
1033
1034 /// Returns a mutable matrix view of a 2D subtensor at the given index.
1035 pub fn subtensor_mut(&mut self, m: usize) -> MatMut<'a, C> {
1036 check_bounds(&[m, 0, 0], &self.dims);
1037 // Safety: bounds have been checked.
1038 unsafe { self.subtensor_mut_unchecked(m) }
1039 }
1040
1041 #[inline(always)]
1042 /// Returns an immutable matrix view of a 2D subtensor at the given index without bounds checking.
1043 ///
1044 /// # Safety
1045 ///
1046 /// Caller must ensure that m is within bounds.
1047 pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> MatRef<'a, C> {
1048 unsafe {
1049 MatRef::from_raw_parts(
1050 self.ptr_at(&[m, 0, 0]),
1051 self.dims[1],
1052 self.dims[2],
1053 self.strides[1] as isize,
1054 self.strides[2] as isize,
1055 )
1056 }
1057 }
1058
1059 #[inline(always)]
1060 /// Returns a mutable matrix view of a 2D subtensor at the given index without bounds checking.
1061 ///
1062 /// # Safety
1063 ///
1064 /// Caller must ensure that m is within bounds.
1065 pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> MatMut<'a, C> {
1066 unsafe {
1067 MatMut::from_raw_parts_mut(
1068 self.ptr_at_mut(&[m, 0, 0]),
1069 self.dims[1],
1070 self.dims[2],
1071 self.strides[1] as isize,
1072 self.strides[2] as isize,
1073 )
1074 }
1075 }
1076}
1077
1078impl<'a, C: Memorable> TensorMut<'a, C, 2> {
1079 /// Returns an immutable row view of a 1D subtensor at the given index.
1080 pub fn subtensor_ref(&self, m: usize) -> RowRef<'a, C> {
1081 check_bounds(&[m, 0], &self.dims);
1082 // Safety: bounds have been checked.
1083 unsafe { self.subtensor_ref_unchecked(m) }
1084 }
1085
1086 /// Returns a mutable row view of a 1D subtensor at the given index.
1087 pub fn subtensor_mut(&mut self, m: usize) -> RowMut<'a, C> {
1088 check_bounds(&[m, 0], &self.dims);
1089 // Safety: bounds have been checked.
1090 unsafe { self.subtensor_mut_unchecked(m) }
1091 }
1092
1093 #[inline(always)]
1094 /// Returns an immutable row view of a 1D subtensor at the given index without bounds checking.
1095 ///
1096 /// # Safety
1097 ///
1098 /// Caller must ensure that m is within bounds.
1099 pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> RowRef<'a, C> {
1100 unsafe {
1101 RowRef::from_raw_parts(self.ptr_at(&[m, 0]), self.dims[1], self.strides[1] as isize)
1102 }
1103 }
1104
1105 #[inline(always)]
1106 /// Returns a mutable row view of a 1D subtensor at the given index without bounds checking.
1107 ///
1108 /// # Safety
1109 ///
1110 /// Caller must ensure that m is within bounds.
1111 pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> RowMut<'a, C> {
1112 unsafe {
1113 RowMut::from_raw_parts_mut(
1114 self.ptr_at_mut(&[m, 0]),
1115 self.dims[1],
1116 self.strides[1] as isize,
1117 )
1118 }
1119 }
1120}
1121
1122impl<'a, C: Memorable> TensorMut<'a, C, 1> {
1123 /// Returns an immutable reference to an element at the given index.
1124 pub fn subtensor_ref(&self, m: usize) -> &C {
1125 self.get(&[m])
1126 }
1127
1128 /// Returns a mutable reference to an element at the given index.
1129 pub fn subtensor_mut(&mut self, m: usize) -> &mut C {
1130 self.get_mut(&[m])
1131 }
1132
1133 #[inline(always)]
1134 /// Returns an immutable reference to an element at the given index without bounds checking.
1135 ///
1136 /// # Safety
1137 ///
1138 /// Caller must ensure that m is within bounds.
1139 pub unsafe fn subtensor_ref_unchecked(&self, m: usize) -> &C {
1140 unsafe { self.get_unchecked(&[m]) }
1141 }
1142
1143 #[inline(always)]
1144 /// Returns a mutable reference to an element at the given index without bounds checking.
1145 ///
1146 /// # Safety
1147 ///
1148 /// Caller must ensure that m is within bounds.
1149 pub unsafe fn subtensor_mut_unchecked(&mut self, m: usize) -> &mut C {
1150 unsafe { self.get_mut_unchecked(&[m]) }
1151 }
1152}