uiua/
shape.rs

1use std::{
2    borrow::{Borrow, Cow},
3    fmt,
4    hash::Hash,
5    ops::{Deref, DerefMut, Index, RangeBounds},
6};
7
8use serde::*;
9use smallvec::SmallVec;
10
11/// Uiua's array shape type
12#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default, Serialize, Deserialize)]
13#[serde(transparent)]
14pub struct Shape {
15    dims: SmallVec<[usize; INLINE_DIMS]>,
16}
17const INLINE_DIMS: usize = 2;
18
19impl Shape {
20    /// A shape with no dimensions
21    pub const SCALAR: Self = Shape {
22        dims: SmallVec::new_const(),
23    };
24    /// An empty list shape
25    pub const EMPTY_LIST: Self = Shape {
26        dims: unsafe { SmallVec::from_const_with_len_unchecked([0; INLINE_DIMS], 1) },
27    };
28    /// Create a new scalar shape with the given capacity
29    pub fn with_capacity(capacity: usize) -> Self {
30        Shape {
31            dims: SmallVec::with_capacity(capacity),
32        }
33    }
34    /// Remove dimensions in the given range
35    pub fn drain(&mut self, range: impl RangeBounds<usize>) {
36        self.dims.drain(range);
37    }
38    /// Add a leading dimension
39    pub fn prepend(&mut self, dim: usize) {
40        self.dims.insert(0, dim);
41    }
42    /// Add a trailing dimension
43    pub fn push(&mut self, dim: usize) {
44        self.dims.push(dim);
45    }
46    /// Remove the last dimension
47    pub fn pop(&mut self) -> Option<usize> {
48        self.dims.pop()
49    }
50    /// Insert a dimension at the given index
51    pub fn insert(&mut self, index: usize, dim: usize) {
52        self.dims.insert(index, dim);
53    }
54    /// Get a mutable reference to the first dimension, setting it if empty
55    pub fn row_count_mut(&mut self) -> &mut usize {
56        if self.is_empty() {
57            self.push(1);
58        }
59        &mut self.dims[0]
60    }
61    /// Remove the dimension at the given index
62    pub fn remove(&mut self, index: usize) -> usize {
63        self.dims.remove(index)
64    }
65    /// Get the row count
66    #[inline(always)]
67    pub fn row_count(&self) -> usize {
68        self.dims.first().copied().unwrap_or(1)
69    }
70    /// Get the row length
71    pub fn row_len(&self) -> usize {
72        self.dims.iter().skip(1).product()
73    }
74    /// Get the row shape
75    pub fn row(&self) -> Shape {
76        let mut shape = self.clone();
77        shape.make_row();
78        shape
79    }
80    /// Get the row shape slice
81    pub fn row_slice(&self) -> &[usize] {
82        &self.dims[self.len().min(1)..]
83    }
84    /// Construct a subshape
85    pub fn subshape<R>(&self, range: R) -> Shape
86    where
87        [usize]: Index<R>,
88        Self: for<'a> From<&'a <[usize] as Index<R>>::Output>,
89    {
90        Shape::from(&self.dims.as_slice()[range])
91    }
92    /// Get the number of elements
93    pub fn elements(&self) -> usize {
94        self.iter().product()
95    }
96    /// Make the shape its row shape
97    pub fn make_row(&mut self) {
98        if !self.is_empty() {
99            self.dims.remove(0);
100        }
101    }
102    /// Make the shape 1-dimensional
103    pub fn deshape(&mut self) {
104        if self.len() != 1 {
105            *self = self.elements().into();
106        }
107    }
108    /// Add a 1-length dimension to the front of the array's shape
109    pub fn fix(&mut self) {
110        self.fix_depth(0);
111    }
112    pub(crate) fn fix_depth(&mut self, depth: usize) -> usize {
113        let depth = depth.min(self.len());
114        self.insert(depth, 1);
115        depth
116    }
117    /// Remove a 1-length dimension from the front of the array's shape
118    pub fn unfix(&mut self) -> Result<(), Cow<'static, str>> {
119        match self.unfix_inner() {
120            Some(1) => Ok(()),
121            Some(d) => Err(Cow::Owned(format!("Cannot unfix array with length {d}"))),
122            None if self.contains(&0) => Err("Cannot unfix empty array".into()),
123            None if self.is_empty() => Err("Cannot unfix scalar".into()),
124            None => Err(Cow::Owned(format!(
125                "Cannot unfix array with shape {self:?}"
126            ))),
127        }
128    }
129    /// Collapse the top two dimensions of the array's shape
130    pub fn undo_fix(&mut self) {
131        self.unfix_inner();
132    }
133    /// Unfix the shape
134    ///
135    /// Returns the first dimension
136    fn unfix_inner(&mut self) -> Option<usize> {
137        match &mut **self {
138            [1, ..] => Some(self.remove(0)),
139            [a, b, ..] => {
140                let new_first_dim = *a * *b;
141                *b = new_first_dim;
142                Some(self.remove(0))
143            }
144            _ => None,
145        }
146    }
147    /// Extend the shape with the given dimensions
148    pub fn extend_from_slice(&mut self, dims: &[usize]) {
149        self.dims.extend_from_slice(dims);
150    }
151    /// Split the shape at the given index
152    pub fn split_off(&mut self, at: usize) -> Self {
153        let (_, b) = self.dims.split_at(at);
154        let second = Shape::from(b);
155        self.dims.truncate(at);
156        second
157    }
158    /// Get a mutable reference to the dimensions
159    pub fn dims_mut(&mut self) -> &mut [usize] {
160        &mut self.dims
161    }
162    /// Truncate the shape
163    #[track_caller]
164    pub fn truncate(&mut self, len: usize) {
165        self.dims.truncate(len);
166    }
167    pub(crate) fn flat_to_dims(&self, flat: usize, index: &mut Vec<usize>) {
168        index.clear();
169        let mut flat = flat;
170        for &dim in self.dims.iter().rev() {
171            index.push(flat % dim);
172            flat /= dim;
173        }
174        index.reverse();
175    }
176    pub(crate) fn dims_to_flat(
177        &self,
178        index: impl IntoIterator<Item = impl Borrow<usize>>,
179    ) -> Option<usize> {
180        let mut flat = 0;
181        for (&dim, i) in self.dims.iter().zip(index) {
182            let i = *i.borrow();
183            if i >= dim {
184                return None;
185            }
186            flat = flat * dim + i;
187        }
188        Some(flat)
189    }
190    pub(crate) fn i_dims_to_flat(
191        &self,
192        index: impl IntoIterator<Item = impl Borrow<isize>>,
193    ) -> Option<usize> {
194        let mut flat = 0;
195        for (&dim, i) in self.dims.iter().zip(index) {
196            let i = *i.borrow();
197            if i < 0 || i >= dim as isize {
198                return None;
199            }
200            flat = flat * dim + i as usize;
201        }
202        Some(flat)
203    }
204}
205
206impl fmt::Debug for Shape {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        write!(f, "[")?;
209        for (i, dim) in self.dims.iter().enumerate() {
210            if i > 0 {
211                write!(f, " × ")?;
212            }
213            write!(f, "{dim}")?;
214        }
215        write!(f, "]")
216    }
217}
218
219impl fmt::Display for Shape {
220    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221        write!(f, "{self:?}")
222    }
223}
224
225impl From<usize> for Shape {
226    fn from(dim: usize) -> Self {
227        Self::from([dim])
228    }
229}
230
231impl From<&[usize]> for Shape {
232    fn from(dims: &[usize]) -> Self {
233        Self {
234            dims: dims.iter().copied().collect(),
235        }
236    }
237}
238
239impl From<Vec<usize>> for Shape {
240    fn from(dims: Vec<usize>) -> Self {
241        Self {
242            dims: SmallVec::from_vec(dims),
243        }
244    }
245}
246
247impl<const N: usize> From<[usize; N]> for Shape {
248    fn from(dims: [usize; N]) -> Self {
249        dims.as_slice().into()
250    }
251}
252
253impl Deref for Shape {
254    type Target = [usize];
255    fn deref(&self) -> &Self::Target {
256        &self.dims
257    }
258}
259
260impl DerefMut for Shape {
261    fn deref_mut(&mut self) -> &mut Self::Target {
262        &mut self.dims
263    }
264}
265
266impl IntoIterator for Shape {
267    type Item = usize;
268    type IntoIter = <SmallVec<[usize; INLINE_DIMS]> as IntoIterator>::IntoIter;
269    fn into_iter(self) -> Self::IntoIter {
270        self.dims.into_iter()
271    }
272}
273
274impl<'a> IntoIterator for &'a Shape {
275    type Item = &'a usize;
276    type IntoIter = <&'a [usize] as IntoIterator>::IntoIter;
277    fn into_iter(self) -> Self::IntoIter {
278        self.dims.iter()
279    }
280}
281
282impl FromIterator<usize> for Shape {
283    fn from_iter<I: IntoIterator<Item = usize>>(iter: I) -> Self {
284        Self {
285            dims: iter.into_iter().collect(),
286        }
287    }
288}
289
290impl Extend<usize> for Shape {
291    fn extend<I: IntoIterator<Item = usize>>(&mut self, iter: I) {
292        self.dims.extend(iter);
293    }
294}
295
296impl PartialEq<usize> for Shape {
297    fn eq(&self, other: &usize) -> bool {
298        self == [*other]
299    }
300}
301
302impl PartialEq<usize> for &Shape {
303    fn eq(&self, other: &usize) -> bool {
304        *self == [*other]
305    }
306}
307
308impl<const N: usize> PartialEq<[usize; N]> for Shape {
309    fn eq(&self, other: &[usize; N]) -> bool {
310        self == other.as_slice()
311    }
312}
313
314impl<const N: usize> PartialEq<[usize; N]> for &Shape {
315    fn eq(&self, other: &[usize; N]) -> bool {
316        *self == other.as_slice()
317    }
318}
319
320impl PartialEq<[usize]> for Shape {
321    fn eq(&self, other: &[usize]) -> bool {
322        self.dims.as_slice() == other
323    }
324}
325
326impl PartialEq<[usize]> for &Shape {
327    fn eq(&self, other: &[usize]) -> bool {
328        *self == other
329    }
330}
331
332impl PartialEq<&[usize]> for Shape {
333    fn eq(&self, other: &&[usize]) -> bool {
334        self.dims.as_slice() == *other
335    }
336}
337
338impl PartialEq<Shape> for &[usize] {
339    fn eq(&self, other: &Shape) -> bool {
340        other == self
341    }
342}
343
344impl PartialEq<Shape> for [usize] {
345    fn eq(&self, other: &Shape) -> bool {
346        other == self
347    }
348}