nd_slice/
nd_slice_mut.rs

1//! `NdSliceMut` wraps `&mut [T]` to represent a mutable n-dimensional array
2
3use std::ops::{Index, IndexMut};
4use std::slice;
5use crate::addressing::{Order, address};
6use crate::errors::ShapeError;
7
8#[derive(Debug)]
9/// `NdSliceMut` wraps `&mut [T]` to represent a mutable n-dimensional array
10///
11/// ```
12/// # use nd_slice::{NdSliceMut, Order};
13/// let mut arr = [7, 2, 3, 4, 5, 8];
14/// let mut n = NdSliceMut::new(&mut arr, [2, 3], Order::RowMajor).unwrap();
15/// n[[0, 0]] = 1;
16/// n[[1, 2]] = 6;
17/// assert_eq!(n[[0, 0]], 1);
18/// assert_eq!(n[[1, 2]], 6);
19///
20/// let mut arr = [9, 2, 3, 4, 5, 6, 7, 10];
21/// let mut n = NdSliceMut::new(&mut arr, [2, 2, 2], Order::RowMajor).unwrap();
22/// n[[0, 0, 0]] = 1;
23/// n[[1, 1, 1]] = 8;
24/// assert_eq!(n[[0, 0, 0]], 1);
25/// assert_eq!(n[[1, 1, 1]], 8);
26/// ```
27///
28/// If the slice doesn't have enough elements to represent an array of that shape, it will
29/// return an `Err(ShapeError)`.
30///
31/// ```should_panic
32/// # use nd_slice::{NdSliceMut, Order};
33/// let n = NdSliceMut::new(&mut [1, 2, 3, 4, 5, 6], [2, 2], Order::RowMajor).unwrap(); // more elements
34/// let n = NdSliceMut::new(&mut [1, 2, 3, 4, 5, 6], [2, 4], Order::RowMajor).unwrap(); // less elements
35/// ```
36pub struct NdSliceMut<'s, T, const N: usize> {
37    slice: &'s mut [T],
38    shape: [usize; N],
39    order: Order,
40}
41
42type ConstructionResult<'s, T, const N: usize> = Result<NdSliceMut<'s, T, N>, ShapeError<'s, T, N>>;
43
44impl<'s, T, const N: usize> NdSliceMut<'s, T, N> {
45    /// Creates a new `NdSliceMut` with the specified ordering from a given slice and the expected shape
46    pub fn new(slice: &'s mut [T], shape: [usize; N], order: Order) -> ConstructionResult<'s, T, N> {
47        if slice.len() == shape.iter().fold(1, |acc, &x| acc * x) {
48            Ok(Self { slice, shape, order })
49        } else {
50            Err(ShapeError::new(slice, shape))
51        }
52    }
53
54    /// Creates a new `NdSliceMut` with the specified ordering from a raw pointer, it's length and the expected shape
55    pub unsafe fn from_ptr(ptr: *mut T, len: usize, shape: [usize; N], order: Order) -> ConstructionResult<'s, T, N> {
56        NdSliceMut::new(slice::from_raw_parts_mut(ptr, len), shape, order)
57    }
58
59    /// Creates a new `NdSliceMut` with row-major ordering from a given slice and the expected shape
60    pub fn new_row_ordered(slice: &'s mut [T], shape: [usize; N]) -> ConstructionResult<'s, T, N> {
61        NdSliceMut::new(slice, shape, Order::RowMajor)
62    }
63
64    /// Creates a new `NdSliceMut` with row-major ordering from a raw pointer, it's length and the expected shape
65    pub unsafe fn row_ordered_from_ptr(ptr: *mut T, len: usize, shape: [usize; N]) -> ConstructionResult<'s, T, N> {
66        NdSliceMut::from_ptr(ptr, len, shape, Order::RowMajor)
67    }
68
69    /// Creates a new `NdSliceMut` with column-major ordering from a given slice and the expected shape
70    pub fn col_ordered(slice: &'s mut [T], shape: [usize; N]) -> ConstructionResult<'s, T, N> {
71        NdSliceMut::new(slice, shape, Order::ColumnMajor)
72    }
73
74    /// Creates a new `NdSliceMut` with column-major ordering from a raw pointer, it's length and the expected shape.
75    pub unsafe fn col_ordered_from_ptr(ptr: *mut T, len: usize, shape: [usize; N]) -> ConstructionResult<'s, T, N> {
76        NdSliceMut::from_ptr(ptr, len, shape, Order::ColumnMajor)
77    }
78}
79
80impl<T, const N: usize> Index<[usize; N]> for NdSliceMut<'_, T, N> {
81    type Output = T;
82
83    fn index(&self, index: [usize; N]) -> &Self::Output {
84        &self.slice[address(&self.order, &self.shape, &index)]
85    }
86}
87
88impl<T, const N: usize> IndexMut<[usize; N]> for NdSliceMut<'_, T, N> {
89    fn index_mut(&mut self, index: [usize; N]) -> &mut Self::Output {
90        &mut self.slice[address(&self.order, &self.shape, &index)]
91    }
92}