1use std::convert::TryInto;
2use std::fmt::{Debug, Display, Formatter};
3use std::ops::ControlFlow;
4
5use itertools::Itertools;
6
7#[macro_export]
8macro_rules! shape {
9 [$($(*)? $value:expr),* $(,)?] => {
10 $crate::shape::Shape::new(vec![$($crate::shape::Size::from($value)),*])
11 };
12}
13
14#[derive(Clone, Eq, PartialEq, Hash)]
27pub struct Shape {
28 pub dims: Vec<Size>,
29}
30
31#[derive(Copy, Clone, Eq, PartialEq, Hash)]
35pub struct Size {
36 batch_exp: u32,
37 fixed_factor: usize,
38}
39
40#[derive(Debug, Clone, Eq, PartialEq, Hash)]
42pub struct ConcreteShape {
43 pub dims: Vec<usize>,
44}
45
46impl Shape {
49 pub const SCALAR: Shape = Shape { dims: vec![] };
50
51 pub fn new(dims: Vec<Size>) -> Shape {
52 Shape { dims }
53 }
54
55 pub fn single(size: Size) -> Shape {
56 Shape { dims: vec![size] }
57 }
58
59 pub fn fixed(dims: &[usize]) -> Shape {
60 let dims = dims.iter().map(|&d| Size::fixed(d)).collect_vec();
61 Shape { dims }
62 }
63
64 pub fn ones(rank: usize) -> Shape {
65 Shape::new(vec![Size::ONE; rank])
66 }
67
68 pub fn zeros(rank: usize) -> Shape {
69 Shape::new(vec![Size::ZERO; rank])
70 }
71
72 pub fn rank(&self) -> usize {
73 self.dims.len()
74 }
75
76 pub fn assert_has_axis(&self, axis: usize) {
77 assert!(axis < self.rank(), "Axis {} out of bounds for {:?}", axis, self);
78 }
79
80 pub fn as_fixed(&self) -> Option<ConcreteShape> {
81 self.dims
82 .iter()
83 .map(|d| d.try_unwrap_fixed().ok_or(()))
84 .try_collect()
85 .ok()
86 .map(ConcreteShape::new)
87 }
88
89 pub fn unwrap_fixed(&self, what: &str) -> ConcreteShape {
90 let dims = self.dims.iter().map(|d| d.unwrap_fixed(what)).collect_vec();
91 ConcreteShape { dims }
92 }
93
94 pub fn eval(&self, batch_size: usize) -> ConcreteShape {
95 let dims = self.dims.iter().map(|d| d.eval(batch_size)).collect_vec();
96 ConcreteShape { dims }
97 }
98
99 pub fn size(&self) -> Size {
100 self.dims.iter().copied().product()
101 }
102
103 pub fn unwrap_1(&self) -> Size {
104 assert_eq!(1, self.dims.len(), "Expected rank 1 shape");
105 self.dims[0]
106 }
107
108 pub fn unwrap_2(&self) -> [Size; 2] {
109 self.dims
110 .as_slice()
111 .try_into()
112 .unwrap_or_else(|_| panic!("Expected rank 2 shape, got {:?}", self))
113 }
114
115 pub fn unwrap_3(&self) -> [Size; 3] {
116 self.dims
117 .as_slice()
118 .try_into()
119 .unwrap_or_else(|_| panic!("Expected rank 3 shape, got {:?}", self))
120 }
121
122 pub fn unwrap_4(&self) -> [Size; 4] {
123 self.dims
124 .as_slice()
125 .try_into()
126 .unwrap_or_else(|_| panic!("Expected rank 4 shape, got {:?}", self))
127 }
128
129 pub fn concat(mut self, other: &Shape) -> Shape {
130 self.dims.extend_from_slice(&other.dims);
131 self
132 }
133
134 pub fn batched(&self) -> Shape {
135 shape![Size::BATCH].concat(self)
136 }
137
138 pub fn replace(&self, axis: usize, replacement: Shape) -> Shape {
140 self.replace_all(&[axis], replacement)
141 }
142
143 pub fn replace_all(&self, axes: &[usize], replacement: Shape) -> Shape {
144 assert!(axes.iter().all_unique(), "Axes must be unique, got {:?}", axes);
146
147 for &axis in axes {
148 self.assert_has_axis(axis);
149 }
150
151 let mut dims = vec![];
153 for i in 0..self.rank() {
154 if axes.contains(&i) {
155 dims.extend_from_slice(&replacement.dims);
156 } else {
157 dims.push(self[i])
158 }
159 }
160
161 Shape::new(dims)
162 }
163
164 pub fn keep(&self, axis: usize, rest: Size) -> Shape {
166 self.assert_has_axis(axis);
167
168 let mut dims = self.dims.clone();
169 for i in 0..self.rank() {
170 if i != axis {
171 dims[i] = rest;
172 }
173 }
174 Shape::new(dims)
175 }
176
177 pub fn repeat_unary(&self, axis: usize, new_size: Size) -> Shape {
178 self.assert_has_axis(axis);
179
180 assert_eq!(
181 self.dims[axis],
182 Size::ONE,
183 "Repeated axis {} must have length 1 for {:?}",
184 axis,
185 self
186 );
187
188 let mut dims = self.dims.clone();
189 dims[axis] = new_size;
190 Shape::new(dims)
191 }
192
193 pub fn insert(&self, axis: usize, size: Size) -> Shape {
194 assert!(
195 axis <= self.rank(),
196 "Axis {} out of bounds for inserting into {:?}",
197 axis,
198 self
199 );
200
201 let mut dims = self.dims.clone();
202 dims.insert(axis, size);
203 Shape::new(dims)
204 }
205
206 pub fn split(&self, index: usize) -> (Shape, Shape) {
207 assert!(
208 index <= self.rank(),
209 "Split index {} out of bounds for {:?}",
210 index,
211 self
212 );
213
214 let body = self.dims[..index].to_vec();
215 let tail = self.dims[index..].to_vec();
216
217 (Shape::new(body), Shape::new(tail))
218 }
219}
220
221impl From<usize> for Size {
222 fn from(fixed_factor: usize) -> Self {
223 Size::fixed(fixed_factor)
224 }
225}
226
227#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
228pub enum DivResult {
229 Exact(Size),
230 Remainder(usize),
231 Impossible,
232}
233
234impl Size {
235 pub const ZERO: Size = Size::new(0, 0);
236 pub const ONE: Size = Size::new(0, 1);
237 pub const BATCH: Size = Size::new(1, 1);
238
239 pub const fn new(batch_exp: u32, fixed_factor: usize) -> Size {
240 if fixed_factor == 0 {
241 Size {
242 batch_exp: 0,
243 fixed_factor: 0,
244 }
245 } else {
246 Size {
247 batch_exp,
248 fixed_factor,
249 }
250 }
251 }
252
253 pub const fn fixed(size: usize) -> Size {
254 Size {
255 batch_exp: 0,
256 fixed_factor: size,
257 }
258 }
259
260 pub const fn is_zero(&self) -> bool {
261 matches!(
262 self,
263 Size {
264 batch_exp: 0,
265 fixed_factor: 0
266 }
267 )
268 }
269
270 pub const fn components_factor_exp(self) -> (usize, u32) {
271 (self.fixed_factor, self.batch_exp)
272 }
273
274 pub fn eval(self, batch_size: usize) -> usize {
275 batch_size.pow(self.batch_exp) * self.fixed_factor
276 }
277
278 pub fn try_unwrap_fixed(self) -> Option<usize> {
279 if self.batch_exp == 0 {
280 Some(self.fixed_factor)
281 } else {
282 None
283 }
284 }
285
286 #[track_caller]
287 pub fn unwrap_fixed(self, what: &str) -> usize {
288 assert_eq!(0, self.batch_exp, "{} must be fixed, but got size {:?}", what, self);
289 self.fixed_factor
290 }
291
292 pub fn floor_div(self, rhs: Self) -> Option<Self> {
293 if self.batch_exp < rhs.batch_exp {
294 None
295 } else {
296 Some(Size::new(
297 self.batch_exp - rhs.batch_exp,
298 self.fixed_factor / rhs.fixed_factor,
299 ))
300 }
301 }
302
303 pub fn div_rem(self, rhs: impl Into<Size>) -> DivResult {
304 let rhs = rhs.into();
305 let fixed_rem = self.fixed_factor % rhs.fixed_factor;
306 if self.batch_exp < rhs.batch_exp {
307 DivResult::Impossible
308 } else if fixed_rem != 0 {
309 DivResult::Remainder(fixed_rem)
310 } else {
311 DivResult::Exact(Size::new(
312 self.batch_exp - rhs.batch_exp,
313 self.fixed_factor / rhs.fixed_factor,
314 ))
315 }
316 }
317}
318
319impl ConcreteShape {
320 pub fn new(dims: Vec<usize>) -> Self {
321 ConcreteShape { dims }
322 }
323
324 pub fn rank(&self) -> usize {
325 self.dims.len()
326 }
327
328 pub fn size(&self) -> usize {
329 self.dims.iter().product()
330 }
331
332 pub fn unwrap_2(&self) -> [usize; 2] {
333 self.dims.as_slice().try_into().expect("Expected rank 2 shape")
334 }
335
336 pub fn unwrap_3(&self) -> [usize; 3] {
337 self.dims.as_slice().try_into().expect("Expected rank 2 shape")
338 }
339
340 pub fn unwrap_4(&self) -> [usize; 4] {
341 self.dims.as_slice().try_into().expect("Expected rank 4 shape")
342 }
343}
344
345#[derive(Debug, Clone, PartialEq, Eq)]
346pub enum ShapeMismatch {
347 DifferentLength,
348 ConstantMismatch,
349 BatchConflict,
350 ImpossibleBatchValue,
351}
352
353pub fn infer_batch_size(expected: &[Shape], actual: &[ConcreteShape]) -> Result<Option<usize>, ShapeMismatch> {
361 infer_batch_size_dims(
362 expected.iter().flat_map(|s| s.dims.iter().copied()),
363 actual.iter().flat_map(|s| s.dims.iter().copied()),
364 )
365}
366
367pub fn infer_batch_size_dims(
369 expected: impl IntoIterator<Item=Size>,
370 actuals: impl IntoIterator<Item=usize>,
371) -> Result<Option<usize>, ShapeMismatch> {
372 let mut shapes = expected.into_iter();
373 let mut actuals = actuals.into_iter();
374
375 let mut batch_size = None;
376
377 loop {
378 let (expected, actual) = match (shapes.next(), actuals.next()) {
379 (Some(shape), Some(actual)) => (shape, actual),
380 (None, None) => return Ok(batch_size),
381 _ => return Err(ShapeMismatch::DifferentLength),
382 };
383
384 let (factor, exp) = expected.components_factor_exp();
385
386 if exp == 0 {
387 if actual != factor {
389 return Err(ShapeMismatch::ConstantMismatch);
390 }
391 } else {
392 if let Some(batch_size) = batch_size {
394 if actual != expected.eval(batch_size) {
396 return Err(ShapeMismatch::BatchConflict);
397 }
398 } else {
399 let batch_size_approx = (actual as f64 / factor as f64).powf(1.0 / exp as f64) as usize;
401
402 let deltas = [0, 1, -1, 2, -2];
403 let batch_size_exact = deltas
404 .iter()
405 .find_map(|&delta| {
406 let cand = batch_size_approx.checked_add_signed(delta).unwrap();
407 if factor * cand.pow(exp) == actual {
408 Some(cand)
409 } else {
410 None
411 }
412 })
413 .ok_or(ShapeMismatch::ImpossibleBatchValue)?;
414
415 batch_size = Some(batch_size_exact);
416 }
417 }
418 }
419}
420
421impl<R: Into<Size>> std::ops::Add<R> for Size {
422 type Output = Option<Size>;
423
424 fn add(self, rhs: R) -> Self::Output {
425 let rhs = rhs.into();
426 if self == Size::ZERO {
427 return Some(rhs);
428 }
429 if rhs == Size::ZERO {
430 return Some(self);
431 }
432 if self.batch_exp != rhs.batch_exp {
433 return None;
434 }
435
436 Some(Size::new(self.batch_exp, self.fixed_factor + rhs.fixed_factor))
437 }
438}
439
440impl<R: Into<Size>> std::ops::Sub<R> for Size {
441 type Output = Option<Size>;
442
443 fn sub(self, rhs: R) -> Self::Output {
444 let rhs = rhs.into();
445 if rhs == Size::ZERO {
446 return Some(self);
447 }
448
449 if self.batch_exp != rhs.batch_exp || self.fixed_factor < rhs.fixed_factor {
450 return None;
451 }
452
453 Some(Size::new(self.batch_exp, self.fixed_factor - rhs.fixed_factor))
454 }
455}
456
457impl<R: Into<Size>> std::ops::Mul<R> for Size {
458 type Output = Size;
459
460 fn mul(self, rhs: R) -> Self::Output {
461 let rhs = rhs.into();
462 Size::new(self.batch_exp + rhs.batch_exp, self.fixed_factor * rhs.fixed_factor)
463 }
464}
465
466impl<R: Into<Size>> std::ops::Div<R> for Size {
467 type Output = Option<Size>;
468
469 fn div(self, rhs: R) -> Self::Output {
470 match self.div_rem(rhs) {
471 DivResult::Exact(s) => Some(s),
472 DivResult::Remainder(_) | DivResult::Impossible => None,
473 }
474 }
475}
476
477impl<R: Into<Size>> std::ops::Rem<R> for Size {
478 type Output = Option<usize>;
479
480 fn rem(self, rhs: R) -> Self::Output {
481 match self.div_rem(rhs) {
482 DivResult::Exact(_) => Some(0),
483 DivResult::Remainder(r) => Some(r),
484 DivResult::Impossible => None,
485 }
486 }
487}
488
489impl std::iter::Sum<Size> for Option<Size> {
490 fn sum<I: Iterator<Item = Size>>(mut iter: I) -> Self {
491 let result = iter.try_fold(Size::ZERO, |a, s| match a + s {
492 Some(v) => ControlFlow::Continue(v),
493 None => ControlFlow::Break(()),
494 });
495
496 match result {
497 ControlFlow::Continue(v) => Some(v),
498 ControlFlow::Break(()) => None,
499 }
500 }
501}
502
503impl std::iter::Product for Size {
504 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
505 iter.fold(Size::fixed(1), |a, s| a * s)
506 }
507}
508
509impl std::ops::Index<usize> for Shape {
510 type Output = Size;
511
512 fn index(&self, axis: usize) -> &Self::Output {
513 self.assert_has_axis(axis);
514 &self.dims[axis]
515 }
516}
517
518impl Debug for Shape {
519 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
520 write!(f, "Shape{}", self)
521 }
522}
523
524impl Debug for Size {
525 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
526 write!(f, "Size({})", self)
527 }
528}
529
530impl Display for Shape {
531 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
532 fmt_shape_impl(f, &self.dims)
533 }
534}
535
536impl Display for Size {
537 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
538 match (self.fixed_factor, self.batch_exp) {
539 (a, 0) => write!(f, "{}", a),
540 (1, 1) => write!(f, "B"),
541 (a, 1) => write!(f, "{}B", a),
542 (1, b) => write!(f, "B^{}", b),
543 (a, b) => write!(f, "{}B^{}", a, b),
544 }
545 }
546}
547
548impl Display for ConcreteShape {
549 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
550 fmt_shape_impl(f, &self.dims)
551 }
552}
553
554fn fmt_shape_impl(f: &mut Formatter, dims: &[impl Display]) -> Result<(), std::fmt::Error> {
555 write!(f, "(")?;
556 for i in 0..dims.len() {
557 if i != 0 {
558 write!(f, " x ")?;
559 }
560
561 write!(f, "{}", dims[i])?;
562 }
563 write!(f, ")")?;
564 Ok(())
565}