ndslice/
shape.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt;
10
11use itertools::izip;
12use serde::Deserialize;
13use serde::Serialize;
14
15use crate::DimSliceIterator;
16use crate::Slice;
17use crate::SliceError;
18use crate::selection::Selection;
19
20// We always retain dimensions here even if they are selected out.
21
22#[derive(Debug, thiserror::Error)]
23pub enum ShapeError {
24    #[error("label slice dimension mismatch: {labels_dim} != {slice_dim}")]
25    DimSliceMismatch { labels_dim: usize, slice_dim: usize },
26
27    #[error("invalid labels `{labels:?}`")]
28    InvalidLabels { labels: Vec<String> },
29
30    #[error("empty range {range}")]
31    EmptyRange { range: Range },
32
33    #[error("out of range {range} for dimension {dim} of size {size}")]
34    OutOfRange {
35        range: Range,
36        dim: String,
37        size: usize,
38    },
39
40    #[error("selection `{expr}` exceeds dimensionality {num_dim}")]
41    SelectionTooDeep { expr: Selection, num_dim: usize },
42
43    #[error("dynamic selection `{expr}`")]
44    SelectionDynamic { expr: Selection },
45
46    #[error("{index} out of range for dimension {dim} of size {size}")]
47    IndexOutOfRange {
48        index: usize,
49        dim: String,
50        size: usize,
51    },
52
53    #[error(transparent)]
54    SliceError(#[from] SliceError),
55}
56
57/// A shape is a [`Slice`] with labeled dimensions and a selection API.
58#[derive(Clone, Deserialize, Serialize, PartialEq, Hash, Debug)]
59pub struct Shape {
60    /// The labels for each dimension in slice.
61    labels: Vec<String>,
62    /// The slice itself, which describes the topology of the shape.
63    slice: Slice,
64}
65
66impl Shape {
67    /// Creates a new shape with the provided labels, which describe the
68    /// provided Slice.
69    ///
70    /// Shapes can also be constructed by way of the [`shape`] macro, which
71    /// creates a by-construction correct slice in row-major order given a set of
72    /// sized dimensions.
73    pub fn new(labels: Vec<String>, slice: Slice) -> Result<Self, ShapeError> {
74        if labels.len() != slice.num_dim() {
75            return Err(ShapeError::DimSliceMismatch {
76                labels_dim: labels.len(),
77                slice_dim: slice.num_dim(),
78            });
79        }
80        Ok(Self { labels, slice })
81    }
82
83    /// Restrict this shape along a named dimension using a [`Range`]. The
84    /// provided range must be nonempty.
85    //
86    /// A shape defines a "strided view" where a strided view is a
87    /// triple (`offset, `sizes`, `strides`). Each coordinate maps to
88    /// a flat memory index using the formula:
89    /// ``` text
90    ///     index = offset + ∑ i_k * strides[k]
91    /// ```
92    /// where `i_k` is the coordinate in dimension `k`.
93    ///
94    /// The `select(dim, range)` operation restricts the view to a
95    /// subrange along a single dimension. It refines the shape by
96    /// updating the `offset`, `sizes[dim]`, and `strides[dim]` to
97    /// describe a logically reindexed subregion:
98    ///
99    /// ```text
100    ///     offset       += begin x strides[dim]
101    ///     sizes[dim]    = (end - begin) / step
102    ///     strides[dim] *= step
103    /// ```
104    ///
105    /// This transformation preserves the strided layout and avoids
106    /// copying data. After `select`, the view behaves as if indexing
107    /// starts at zero in the selected dimension, with a new length
108    /// and stride. From the user's perspective, nothing changes —
109    /// indexing remains zero-based, and the resulting shape can be
110    /// used like any other. The transformation is internal: the
111    /// view's offset and stride absorb the selection logic.
112    ///
113    /// `select` is composable — it can be applied repeatedly, even on
114    /// the same dimension, to refine the view incrementally.
115    pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
116        let dim = self.dim(label)?;
117        let range: Range = range.into();
118        if range.is_empty() {
119            return Err(ShapeError::EmptyRange { range });
120        }
121
122        let mut offset = self.slice.offset();
123        let mut sizes = self.slice.sizes().to_vec();
124        let mut strides = self.slice.strides().to_vec();
125
126        let (begin, end, stride) = range.resolve(sizes[dim]);
127        if begin >= sizes[dim] {
128            return Err(ShapeError::OutOfRange {
129                range,
130                dim: label.to_string(),
131                size: sizes[dim],
132            });
133        }
134
135        offset += begin * strides[dim];
136        sizes[dim] = (end - begin) / stride;
137        strides[dim] *= stride;
138
139        Ok(Self {
140            labels: self.labels.clone(),
141            slice: Slice::new(offset, sizes, strides).expect("cannot create invalid slice"),
142        })
143    }
144
145    /// Produces an iterator over subshapes by fixing the first `dims`
146    /// dimensions.
147    ///
148    /// For a shape of rank `n`, this yields `∏ sizes[0..dims]`
149    /// subshapes, each with the first `dims` dimensions restricted to
150    /// size 1. The remaining dimensions are left unconstrained.
151    ///
152    /// This is useful for structured traversal of slices within a
153    /// multidimensional shape. See [`SelectIterator`] for details and
154    /// examples.
155    ///
156    /// # Errors
157    /// Returns an error if `dims == 0` or `dims >= self.rank()`.
158    pub fn select_iter(&self, dims: usize) -> Result<SelectIterator, ShapeError> {
159        let num_dims = self.slice().num_dim();
160        if dims == 0 || dims >= num_dims {
161            return Err(ShapeError::SliceError(SliceError::IndexOutOfRange {
162                index: dims,
163                total: num_dims,
164            }));
165        }
166
167        Ok(SelectIterator {
168            shape: self,
169            iter: self.slice().dim_iter(dims),
170        })
171    }
172
173    /// Sub-set this shape by select a particular row of the given indices
174    /// The resulting shape will no longer have dimensions for the given indices
175    /// Example shape.index(vec![("gpu", 3), ("host", 0)])
176    pub fn index(&self, indices: Vec<(String, usize)>) -> Result<Shape, ShapeError> {
177        let mut offset = self.slice.offset();
178        let mut names = Vec::new();
179        let mut sizes = Vec::new();
180        let mut strides = Vec::new();
181        let mut used_indices_count = 0;
182        let slice = self.slice();
183        for (dim, size, stride) in izip!(self.labels.iter(), slice.sizes(), slice.strides()) {
184            if let Some(index) = indices
185                .iter()
186                .find_map(|(name, index)| if *name == *dim { Some(index) } else { None })
187            {
188                if *index >= *size {
189                    return Err(ShapeError::IndexOutOfRange {
190                        index: *index,
191                        dim: dim.clone(),
192                        size: *size,
193                    });
194                }
195                offset += index * stride;
196                used_indices_count += 1;
197            } else {
198                names.push(dim.clone());
199                sizes.push(*size);
200                strides.push(*stride);
201            }
202        }
203        if used_indices_count != indices.len() {
204            let unused_indices = indices
205                .iter()
206                .filter(|(key, _)| !self.labels.contains(key))
207                .map(|(key, _)| key.clone())
208                .collect();
209            return Err(ShapeError::InvalidLabels {
210                labels: unused_indices,
211            });
212        }
213        let slice = Slice::new(offset, sizes, strides)?;
214        Shape::new(names, slice)
215    }
216
217    /// The per-dimension labels of this shape.
218    pub fn labels(&self) -> &[String] {
219        &self.labels
220    }
221
222    /// The slice describing the shape.
223    pub fn slice(&self) -> &Slice {
224        &self.slice
225    }
226
227    /// Return a set of labeled coordinates for the given rank.
228    pub fn coordinates(&self, rank: usize) -> Result<Vec<(String, usize)>, ShapeError> {
229        let coords = self.slice.coordinates(rank)?;
230        Ok(coords
231            .iter()
232            .zip(self.labels.iter())
233            .map(|(i, l)| (l.to_string(), *i))
234            .collect())
235    }
236
237    fn dim(&self, label: &str) -> Result<usize, ShapeError> {
238        self.labels
239            .iter()
240            .position(|l| l == label)
241            .ok_or_else(|| ShapeError::InvalidLabels {
242                labels: vec![label.to_string()],
243            })
244    }
245
246    /// Return the 0-dimensional single element shape
247    pub fn unity() -> Shape {
248        Shape::new(vec![], Slice::new(0, vec![], vec![]).expect("unity")).expect("unity")
249    }
250}
251
252/// Iterator over subshapes obtained by fixing a prefix of dimensions.
253///
254/// This iterator is produced by [`Shape::select_iter(dims)`], and
255/// yields one `Shape` per coordinate prefix in the first `dims`
256/// dimensions.
257///
258/// For a shape of `n` dimensions, each yielded shape has:
259/// - The first `dims` dimensions restricted to size 1 (i.e., fixed
260///   via `select`)
261/// - The remaining `n - dims` dimensions left unconstrained
262///
263/// This allows structured iteration over "slices" of the original
264/// shape: for example with `n` = 3, `select_iter(1)` walks through 2D
265/// planes, while `select_iter(2)` yields 1D subshapes.
266///
267/// # Example
268/// ```ignore
269/// let s = shape!(zone = 2, host = 2, gpu = 8);
270/// let views: Vec<_> = s.select_iter(2).unwrap().collect();
271/// assert_eq!(views.len(), 4);
272/// assert_eq!(views[0].slice().sizes(), &[1, 1, 8]);
273/// ```
274/// The above example can be interpreted as: for each `(zone, host)`
275/// pair, `select_iter(2)` yields a `Shape` describing the associated
276/// row of GPUs — a view into the `[1, 1, 8]` subregion of the full
277/// `[2, 2, 8]` shape.
278pub struct SelectIterator<'a> {
279    shape: &'a Shape,
280    iter: DimSliceIterator<'a>,
281}
282
283impl<'a> Iterator for SelectIterator<'a> {
284    type Item = Shape;
285
286    fn next(&mut self) -> Option<Self::Item> {
287        let pos = self.iter.next()?;
288        let mut shape = self.shape.clone();
289        for (dim, index) in pos.iter().enumerate() {
290            shape = shape.select(&self.shape.labels()[dim], *index).unwrap();
291        }
292        Some(shape)
293    }
294}
295
296impl fmt::Display for Shape {
297    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298        // Just display the sizes of each dimension, for now.
299        // Once we have a selection algebra, we can provide a
300        // better Display implementation.
301        write!(f, "{{")?;
302        for dim in 0..self.labels.len() {
303            write!(f, "{}={}", self.labels[dim], self.slice.sizes()[dim])?;
304            if dim < self.labels.len() - 1 {
305                write!(f, ",")?;
306            }
307        }
308        write!(f, "}}")
309    }
310}
311
312/// Construct a new shape with the given set of dimension-size pairs in row-major
313/// order.
314///
315/// ```
316/// let s = ndslice::shape!(host = 2, gpu = 8);
317/// assert_eq!(s.labels(), &["host".to_string(), "gpu".to_string()]);
318/// assert_eq!(s.slice().sizes(), &[2, 8]);
319/// assert_eq!(s.slice().strides(), &[8, 1]);
320/// ```
321#[macro_export]
322macro_rules! shape {
323    ( $( $label:ident = $size:expr_2021 ),* $(,)? ) => {
324        {
325            let mut labels = Vec::new();
326            let mut sizes = Vec::new();
327
328            $(
329                labels.push(stringify!($label).to_string());
330                sizes.push($size);
331            )*
332
333            $crate::shape::Shape::new(labels, $crate::Slice::new_row_major(sizes)).unwrap()
334        }
335    };
336}
337
338/// Perform a sub-selection on the provided [`Shape`] object.
339///
340/// This macro chains `.select()` calls to apply multiple labeled
341/// dimension restrictions in a fluent way.
342///
343/// ```
344/// let s = ndslice::shape!(host = 2, gpu = 8);
345/// let s = ndslice::select!(s, host = 1, gpu = 4..).unwrap();
346/// assert_eq!(s.labels(), &["host".to_string(), "gpu".to_string()]);
347/// assert_eq!(s.slice().sizes(), &[1, 4]);
348/// ```
349#[macro_export]
350macro_rules! select {
351    ($shape:ident, $label:ident = $range:expr_2021) => {
352        $shape.select(stringify!($label), $range)
353    };
354
355    ($shape:ident, $label:ident = $range:expr_2021, $($labels:ident = $ranges:expr_2021),+) => {
356        $shape.select(stringify!($label), $range).and_then(|shape| $crate::select!(shape, $($labels = $ranges),+))
357    };
358}
359
360/// A range of indices, with a stride. Ranges are convertible from
361/// native Rust ranges.
362#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
363pub struct Range(pub usize, pub Option<usize>, pub usize);
364
365impl Range {
366    pub(crate) fn resolve(&self, size: usize) -> (usize, usize, usize) {
367        match self {
368            Range(begin, Some(end), stride) => (*begin, std::cmp::min(size, *end), *stride),
369            Range(begin, None, stride) => (*begin, size, *stride),
370        }
371    }
372
373    pub(crate) fn is_empty(&self) -> bool {
374        matches!(self, Range(begin, Some(end), _) if end <= begin)
375    }
376}
377
378impl fmt::Display for Range {
379    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380        match self {
381            Range(begin, None, stride) => write!(f, "{}::{}", begin, stride),
382            Range(begin, Some(end), stride) => write!(f, "{}:{}:{}", begin, end, stride),
383        }
384    }
385}
386
387impl From<std::ops::Range<usize>> for Range {
388    fn from(r: std::ops::Range<usize>) -> Self {
389        Self(r.start, Some(r.end), 1)
390    }
391}
392
393impl From<std::ops::RangeInclusive<usize>> for Range {
394    fn from(r: std::ops::RangeInclusive<usize>) -> Self {
395        Self(*r.start(), Some(*r.end() + 1), 1)
396    }
397}
398
399impl From<std::ops::RangeFrom<usize>> for Range {
400    fn from(r: std::ops::RangeFrom<usize>) -> Self {
401        Self(r.start, None, 1)
402    }
403}
404
405impl From<usize> for Range {
406    fn from(idx: usize) -> Self {
407        Self(idx, Some(idx + 1), 1)
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use std::assert_matches::assert_matches;
414
415    use super::*;
416
417    #[test]
418    fn test_basic() {
419        let s = shape!(host = 2, gpu = 8);
420        assert_eq!(&s.labels, &["host".to_string(), "gpu".to_string()]);
421        assert_eq!(s.slice.offset(), 0);
422        assert_eq!(s.slice.sizes(), &[2, 8]);
423        assert_eq!(s.slice.strides(), &[8, 1]);
424
425        assert_eq!(s.to_string(), "{host=2,gpu=8}");
426    }
427
428    #[test]
429    fn test_select() {
430        let s = shape!(host = 2, gpu = 8);
431
432        assert_eq!(
433            s.slice().iter().collect::<Vec<_>>(),
434            &[
435                0,
436                1,
437                2,
438                3,
439                4,
440                5,
441                6,
442                7,
443                8,
444                8 + 1,
445                8 + 2,
446                8 + 3,
447                8 + 4,
448                8 + 5,
449                8 + 6,
450                8 + 7
451            ]
452        );
453
454        assert_eq!(
455            select!(s, host = 1)
456                .unwrap()
457                .slice()
458                .iter()
459                .collect::<Vec<_>>(),
460            &[8, 8 + 1, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
461        );
462
463        assert_eq!(
464            select!(s, gpu = 2..)
465                .unwrap()
466                .slice()
467                .iter()
468                .collect::<Vec<_>>(),
469            &[2, 3, 4, 5, 6, 7, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
470        );
471
472        assert_eq!(
473            select!(s, gpu = 3..5)
474                .unwrap()
475                .slice()
476                .iter()
477                .collect::<Vec<_>>(),
478            &[3, 4, 8 + 3, 8 + 4]
479        );
480
481        assert_eq!(
482            select!(s, gpu = 3..5, host = 1)
483                .unwrap()
484                .slice()
485                .iter()
486                .collect::<Vec<_>>(),
487            &[8 + 3, 8 + 4]
488        );
489    }
490
491    #[test]
492    fn test_select_iter() {
493        let s = shape!(replica = 2, host = 2, gpu = 8);
494        let selections: Vec<_> = s.select_iter(2).unwrap().collect();
495        assert_eq!(selections[0].slice().sizes(), &[1, 1, 8]);
496        assert_eq!(selections[1].slice().sizes(), &[1, 1, 8]);
497        assert_eq!(selections[2].slice().sizes(), &[1, 1, 8]);
498        assert_eq!(selections[3].slice().sizes(), &[1, 1, 8]);
499        assert_eq!(
500            selections,
501            &[
502                select!(s, replica = 0, host = 0).unwrap(),
503                select!(s, replica = 0, host = 1).unwrap(),
504                select!(s, replica = 1, host = 0).unwrap(),
505                select!(s, replica = 1, host = 1).unwrap()
506            ]
507        );
508    }
509
510    #[test]
511    fn test_coordinates() {
512        let s = shape!(host = 2, gpu = 8);
513        assert_eq!(
514            s.coordinates(0).unwrap(),
515            vec![("host".to_string(), 0), ("gpu".to_string(), 0)]
516        );
517        assert_eq!(
518            s.coordinates(1).unwrap(),
519            vec![("host".to_string(), 0), ("gpu".to_string(), 1)]
520        );
521        assert_eq!(
522            s.coordinates(8).unwrap(),
523            vec![("host".to_string(), 1), ("gpu".to_string(), 0)]
524        );
525        assert_eq!(
526            s.coordinates(9).unwrap(),
527            vec![("host".to_string(), 1), ("gpu".to_string(), 1)]
528        );
529
530        assert_matches!(
531            s.coordinates(16).unwrap_err(),
532            ShapeError::SliceError(SliceError::ValueNotInSlice { value: 16 })
533        );
534    }
535
536    #[test]
537    fn test_select_bad() {
538        let s = shape!(host = 2, gpu = 8);
539
540        assert_matches!(
541            select!(s, gpu = 1..1).unwrap_err(),
542            ShapeError::EmptyRange {
543                range: Range(1, Some(1), 1)
544            },
545        );
546
547        assert_matches!(
548            select!(s, gpu = 8).unwrap_err(),
549            ShapeError::OutOfRange {
550                range: Range(8, Some(9), 1),
551                dim,
552                size: 8,
553            } if dim == "gpu",
554        );
555    }
556
557    #[test]
558    fn test_shape_index() {
559        let n_hosts = 5;
560        let n_gpus = 7;
561
562        // Index first dim
563        let s = shape!(host = n_hosts, gpu = n_gpus);
564        assert_eq!(
565            s.index(vec![("host".to_string(), 0)]).unwrap(),
566            Shape::new(
567                vec!["gpu".to_string()],
568                Slice::new(0, vec![n_gpus], vec![1]).unwrap()
569            )
570            .unwrap()
571        );
572
573        // Index last dims
574        let offset = 1;
575        assert_eq!(
576            s.index(vec![("gpu".to_string(), offset)]).unwrap(),
577            Shape::new(
578                vec!["host".to_string()],
579                Slice::new(offset, vec![n_hosts], vec![n_gpus]).unwrap()
580            )
581            .unwrap()
582        );
583
584        // Index middle dim
585        let n_zone = 2;
586        let s = shape!(zone = n_zone, host = n_hosts, gpu = n_gpus);
587        let offset = 3;
588        assert_eq!(
589            s.index(vec![("host".to_string(), offset)]).unwrap(),
590            Shape::new(
591                vec!["zone".to_string(), "gpu".to_string()],
592                Slice::new(
593                    offset * n_gpus,
594                    vec![n_zone, n_gpus],
595                    vec![n_hosts * n_gpus, 1]
596                )
597                .unwrap()
598            )
599            .unwrap()
600        );
601
602        // Out of range
603        assert!(
604            shape!(gpu = n_gpus)
605                .index(vec![("gpu".to_string(), n_gpus)])
606                .is_err()
607        );
608        // Invalid dim
609        assert!(
610            shape!(gpu = n_gpus)
611                .index(vec![("non-exist-dim".to_string(), 0)])
612                .is_err()
613        );
614    }
615}