1use std::fmt;
10
11use itertools::izip;
12use serde::Deserialize;
13use serde::Serialize;
14
15use crate::DimSliceIterator;
16use crate::Slice;
17use crate::SliceError;
18use crate::selection::Selection;
19
20#[derive(Debug, thiserror::Error)]
23pub enum ShapeError {
24 #[error("label slice dimension mismatch: {labels_dim} != {slice_dim}")]
25 DimSliceMismatch { labels_dim: usize, slice_dim: usize },
26
27 #[error("invalid labels `{labels:?}`")]
28 InvalidLabels { labels: Vec<String> },
29
30 #[error("empty range {range}")]
31 EmptyRange { range: Range },
32
33 #[error("out of range {range} for dimension {dim} of size {size}")]
34 OutOfRange {
35 range: Range,
36 dim: String,
37 size: usize,
38 },
39
40 #[error("selection `{expr}` exceeds dimensionality {num_dim}")]
41 SelectionTooDeep { expr: Selection, num_dim: usize },
42
43 #[error("dynamic selection `{expr}`")]
44 SelectionDynamic { expr: Selection },
45
46 #[error("{index} out of range for dimension {dim} of size {size}")]
47 IndexOutOfRange {
48 index: usize,
49 dim: String,
50 size: usize,
51 },
52
53 #[error(transparent)]
54 SliceError(#[from] SliceError),
55}
56
57#[derive(Clone, Deserialize, Serialize, PartialEq, Hash, Debug)]
59pub struct Shape {
60 labels: Vec<String>,
62 slice: Slice,
64}
65
66impl Shape {
67 pub fn new(labels: Vec<String>, slice: Slice) -> Result<Self, ShapeError> {
74 if labels.len() != slice.num_dim() {
75 return Err(ShapeError::DimSliceMismatch {
76 labels_dim: labels.len(),
77 slice_dim: slice.num_dim(),
78 });
79 }
80 Ok(Self { labels, slice })
81 }
82
83 pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
116 let dim = self.dim(label)?;
117 let range: Range = range.into();
118 if range.is_empty() {
119 return Err(ShapeError::EmptyRange { range });
120 }
121
122 let mut offset = self.slice.offset();
123 let mut sizes = self.slice.sizes().to_vec();
124 let mut strides = self.slice.strides().to_vec();
125
126 let (begin, end, stride) = range.resolve(sizes[dim]);
127 if begin >= sizes[dim] {
128 return Err(ShapeError::OutOfRange {
129 range,
130 dim: label.to_string(),
131 size: sizes[dim],
132 });
133 }
134
135 offset += begin * strides[dim];
136 sizes[dim] = (end - begin) / stride;
137 strides[dim] *= stride;
138
139 Ok(Self {
140 labels: self.labels.clone(),
141 slice: Slice::new(offset, sizes, strides).expect("cannot create invalid slice"),
142 })
143 }
144
145 pub fn select_iter(&self, dims: usize) -> Result<SelectIterator, ShapeError> {
159 let num_dims = self.slice().num_dim();
160 if dims == 0 || dims >= num_dims {
161 return Err(ShapeError::SliceError(SliceError::IndexOutOfRange {
162 index: dims,
163 total: num_dims,
164 }));
165 }
166
167 Ok(SelectIterator {
168 shape: self,
169 iter: self.slice().dim_iter(dims),
170 })
171 }
172
173 pub fn index(&self, indices: Vec<(String, usize)>) -> Result<Shape, ShapeError> {
177 let mut offset = self.slice.offset();
178 let mut names = Vec::new();
179 let mut sizes = Vec::new();
180 let mut strides = Vec::new();
181 let mut used_indices_count = 0;
182 let slice = self.slice();
183 for (dim, size, stride) in izip!(self.labels.iter(), slice.sizes(), slice.strides()) {
184 if let Some(index) = indices
185 .iter()
186 .find_map(|(name, index)| if *name == *dim { Some(index) } else { None })
187 {
188 if *index >= *size {
189 return Err(ShapeError::IndexOutOfRange {
190 index: *index,
191 dim: dim.clone(),
192 size: *size,
193 });
194 }
195 offset += index * stride;
196 used_indices_count += 1;
197 } else {
198 names.push(dim.clone());
199 sizes.push(*size);
200 strides.push(*stride);
201 }
202 }
203 if used_indices_count != indices.len() {
204 let unused_indices = indices
205 .iter()
206 .filter(|(key, _)| !self.labels.contains(key))
207 .map(|(key, _)| key.clone())
208 .collect();
209 return Err(ShapeError::InvalidLabels {
210 labels: unused_indices,
211 });
212 }
213 let slice = Slice::new(offset, sizes, strides)?;
214 Shape::new(names, slice)
215 }
216
217 pub fn labels(&self) -> &[String] {
219 &self.labels
220 }
221
222 pub fn slice(&self) -> &Slice {
224 &self.slice
225 }
226
227 pub fn coordinates(&self, rank: usize) -> Result<Vec<(String, usize)>, ShapeError> {
229 let coords = self.slice.coordinates(rank)?;
230 Ok(coords
231 .iter()
232 .zip(self.labels.iter())
233 .map(|(i, l)| (l.to_string(), *i))
234 .collect())
235 }
236
237 fn dim(&self, label: &str) -> Result<usize, ShapeError> {
238 self.labels
239 .iter()
240 .position(|l| l == label)
241 .ok_or_else(|| ShapeError::InvalidLabels {
242 labels: vec![label.to_string()],
243 })
244 }
245
246 pub fn unity() -> Shape {
248 Shape::new(vec![], Slice::new(0, vec![], vec![]).expect("unity")).expect("unity")
249 }
250}
251
252pub struct SelectIterator<'a> {
279 shape: &'a Shape,
280 iter: DimSliceIterator<'a>,
281}
282
283impl<'a> Iterator for SelectIterator<'a> {
284 type Item = Shape;
285
286 fn next(&mut self) -> Option<Self::Item> {
287 let pos = self.iter.next()?;
288 let mut shape = self.shape.clone();
289 for (dim, index) in pos.iter().enumerate() {
290 shape = shape.select(&self.shape.labels()[dim], *index).unwrap();
291 }
292 Some(shape)
293 }
294}
295
296impl fmt::Display for Shape {
297 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298 write!(f, "{{")?;
302 for dim in 0..self.labels.len() {
303 write!(f, "{}={}", self.labels[dim], self.slice.sizes()[dim])?;
304 if dim < self.labels.len() - 1 {
305 write!(f, ",")?;
306 }
307 }
308 write!(f, "}}")
309 }
310}
311
312#[macro_export]
322macro_rules! shape {
323 ( $( $label:ident = $size:expr_2021 ),* $(,)? ) => {
324 {
325 let mut labels = Vec::new();
326 let mut sizes = Vec::new();
327
328 $(
329 labels.push(stringify!($label).to_string());
330 sizes.push($size);
331 )*
332
333 $crate::shape::Shape::new(labels, $crate::Slice::new_row_major(sizes)).unwrap()
334 }
335 };
336}
337
338#[macro_export]
350macro_rules! select {
351 ($shape:ident, $label:ident = $range:expr_2021) => {
352 $shape.select(stringify!($label), $range)
353 };
354
355 ($shape:ident, $label:ident = $range:expr_2021, $($labels:ident = $ranges:expr_2021),+) => {
356 $shape.select(stringify!($label), $range).and_then(|shape| $crate::select!(shape, $($labels = $ranges),+))
357 };
358}
359
360#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
363pub struct Range(pub usize, pub Option<usize>, pub usize);
364
365impl Range {
366 pub(crate) fn resolve(&self, size: usize) -> (usize, usize, usize) {
367 match self {
368 Range(begin, Some(end), stride) => (*begin, std::cmp::min(size, *end), *stride),
369 Range(begin, None, stride) => (*begin, size, *stride),
370 }
371 }
372
373 pub(crate) fn is_empty(&self) -> bool {
374 matches!(self, Range(begin, Some(end), _) if end <= begin)
375 }
376}
377
378impl fmt::Display for Range {
379 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380 match self {
381 Range(begin, None, stride) => write!(f, "{}::{}", begin, stride),
382 Range(begin, Some(end), stride) => write!(f, "{}:{}:{}", begin, end, stride),
383 }
384 }
385}
386
387impl From<std::ops::Range<usize>> for Range {
388 fn from(r: std::ops::Range<usize>) -> Self {
389 Self(r.start, Some(r.end), 1)
390 }
391}
392
393impl From<std::ops::RangeInclusive<usize>> for Range {
394 fn from(r: std::ops::RangeInclusive<usize>) -> Self {
395 Self(*r.start(), Some(*r.end() + 1), 1)
396 }
397}
398
399impl From<std::ops::RangeFrom<usize>> for Range {
400 fn from(r: std::ops::RangeFrom<usize>) -> Self {
401 Self(r.start, None, 1)
402 }
403}
404
405impl From<usize> for Range {
406 fn from(idx: usize) -> Self {
407 Self(idx, Some(idx + 1), 1)
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use std::assert_matches::assert_matches;
414
415 use super::*;
416
417 #[test]
418 fn test_basic() {
419 let s = shape!(host = 2, gpu = 8);
420 assert_eq!(&s.labels, &["host".to_string(), "gpu".to_string()]);
421 assert_eq!(s.slice.offset(), 0);
422 assert_eq!(s.slice.sizes(), &[2, 8]);
423 assert_eq!(s.slice.strides(), &[8, 1]);
424
425 assert_eq!(s.to_string(), "{host=2,gpu=8}");
426 }
427
428 #[test]
429 fn test_select() {
430 let s = shape!(host = 2, gpu = 8);
431
432 assert_eq!(
433 s.slice().iter().collect::<Vec<_>>(),
434 &[
435 0,
436 1,
437 2,
438 3,
439 4,
440 5,
441 6,
442 7,
443 8,
444 8 + 1,
445 8 + 2,
446 8 + 3,
447 8 + 4,
448 8 + 5,
449 8 + 6,
450 8 + 7
451 ]
452 );
453
454 assert_eq!(
455 select!(s, host = 1)
456 .unwrap()
457 .slice()
458 .iter()
459 .collect::<Vec<_>>(),
460 &[8, 8 + 1, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
461 );
462
463 assert_eq!(
464 select!(s, gpu = 2..)
465 .unwrap()
466 .slice()
467 .iter()
468 .collect::<Vec<_>>(),
469 &[2, 3, 4, 5, 6, 7, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
470 );
471
472 assert_eq!(
473 select!(s, gpu = 3..5)
474 .unwrap()
475 .slice()
476 .iter()
477 .collect::<Vec<_>>(),
478 &[3, 4, 8 + 3, 8 + 4]
479 );
480
481 assert_eq!(
482 select!(s, gpu = 3..5, host = 1)
483 .unwrap()
484 .slice()
485 .iter()
486 .collect::<Vec<_>>(),
487 &[8 + 3, 8 + 4]
488 );
489 }
490
491 #[test]
492 fn test_select_iter() {
493 let s = shape!(replica = 2, host = 2, gpu = 8);
494 let selections: Vec<_> = s.select_iter(2).unwrap().collect();
495 assert_eq!(selections[0].slice().sizes(), &[1, 1, 8]);
496 assert_eq!(selections[1].slice().sizes(), &[1, 1, 8]);
497 assert_eq!(selections[2].slice().sizes(), &[1, 1, 8]);
498 assert_eq!(selections[3].slice().sizes(), &[1, 1, 8]);
499 assert_eq!(
500 selections,
501 &[
502 select!(s, replica = 0, host = 0).unwrap(),
503 select!(s, replica = 0, host = 1).unwrap(),
504 select!(s, replica = 1, host = 0).unwrap(),
505 select!(s, replica = 1, host = 1).unwrap()
506 ]
507 );
508 }
509
510 #[test]
511 fn test_coordinates() {
512 let s = shape!(host = 2, gpu = 8);
513 assert_eq!(
514 s.coordinates(0).unwrap(),
515 vec![("host".to_string(), 0), ("gpu".to_string(), 0)]
516 );
517 assert_eq!(
518 s.coordinates(1).unwrap(),
519 vec![("host".to_string(), 0), ("gpu".to_string(), 1)]
520 );
521 assert_eq!(
522 s.coordinates(8).unwrap(),
523 vec![("host".to_string(), 1), ("gpu".to_string(), 0)]
524 );
525 assert_eq!(
526 s.coordinates(9).unwrap(),
527 vec![("host".to_string(), 1), ("gpu".to_string(), 1)]
528 );
529
530 assert_matches!(
531 s.coordinates(16).unwrap_err(),
532 ShapeError::SliceError(SliceError::ValueNotInSlice { value: 16 })
533 );
534 }
535
536 #[test]
537 fn test_select_bad() {
538 let s = shape!(host = 2, gpu = 8);
539
540 assert_matches!(
541 select!(s, gpu = 1..1).unwrap_err(),
542 ShapeError::EmptyRange {
543 range: Range(1, Some(1), 1)
544 },
545 );
546
547 assert_matches!(
548 select!(s, gpu = 8).unwrap_err(),
549 ShapeError::OutOfRange {
550 range: Range(8, Some(9), 1),
551 dim,
552 size: 8,
553 } if dim == "gpu",
554 );
555 }
556
557 #[test]
558 fn test_shape_index() {
559 let n_hosts = 5;
560 let n_gpus = 7;
561
562 let s = shape!(host = n_hosts, gpu = n_gpus);
564 assert_eq!(
565 s.index(vec![("host".to_string(), 0)]).unwrap(),
566 Shape::new(
567 vec!["gpu".to_string()],
568 Slice::new(0, vec![n_gpus], vec![1]).unwrap()
569 )
570 .unwrap()
571 );
572
573 let offset = 1;
575 assert_eq!(
576 s.index(vec![("gpu".to_string(), offset)]).unwrap(),
577 Shape::new(
578 vec!["host".to_string()],
579 Slice::new(offset, vec![n_hosts], vec![n_gpus]).unwrap()
580 )
581 .unwrap()
582 );
583
584 let n_zone = 2;
586 let s = shape!(zone = n_zone, host = n_hosts, gpu = n_gpus);
587 let offset = 3;
588 assert_eq!(
589 s.index(vec![("host".to_string(), offset)]).unwrap(),
590 Shape::new(
591 vec!["zone".to_string(), "gpu".to_string()],
592 Slice::new(
593 offset * n_gpus,
594 vec![n_zone, n_gpus],
595 vec![n_hosts * n_gpus, 1]
596 )
597 .unwrap()
598 )
599 .unwrap()
600 );
601
602 assert!(
604 shape!(gpu = n_gpus)
605 .index(vec![("gpu".to_string(), n_gpus)])
606 .is_err()
607 );
608 assert!(
610 shape!(gpu = n_gpus)
611 .index(vec![("non-exist-dim".to_string(), 0)])
612 .is_err()
613 );
614 }
615}