1#![allow(clippy::redundant_closure_call)]
3use crate::{Error, Result};
4
5#[derive(Clone, PartialEq, Eq)]
6pub struct Shape(Vec<usize>);
7
8pub const SCALAR: Shape = Shape(vec![]);
9
10impl std::fmt::Debug for Shape {
11 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
12 write!(f, "{:?}", &self.dims())
13 }
14}
15
16impl<const C: usize> From<&[usize; C]> for Shape {
17 fn from(dims: &[usize; C]) -> Self {
18 Self(dims.to_vec())
19 }
20}
21
22impl From<&[usize]> for Shape {
23 fn from(dims: &[usize]) -> Self {
24 Self(dims.to_vec())
25 }
26}
27
28impl From<&Shape> for Shape {
29 fn from(shape: &Shape) -> Self {
30 Self(shape.0.to_vec())
31 }
32}
33
34impl From<()> for Shape {
35 fn from(_: ()) -> Self {
36 Self(vec![])
37 }
38}
39
40impl From<usize> for Shape {
41 fn from(d1: usize) -> Self {
42 Self(vec![d1])
43 }
44}
45
46impl From<(usize,)> for Shape {
47 fn from(d1: (usize,)) -> Self {
48 Self(vec![d1.0])
49 }
50}
51
52impl From<(usize, usize)> for Shape {
53 fn from(d12: (usize, usize)) -> Self {
54 Self(vec![d12.0, d12.1])
55 }
56}
57
58impl From<(usize, usize, usize)> for Shape {
59 fn from(d123: (usize, usize, usize)) -> Self {
60 Self(vec![d123.0, d123.1, d123.2])
61 }
62}
63
64impl From<(usize, usize, usize, usize)> for Shape {
65 fn from(d1234: (usize, usize, usize, usize)) -> Self {
66 Self(vec![d1234.0, d1234.1, d1234.2, d1234.3])
67 }
68}
69
70impl From<(usize, usize, usize, usize, usize)> for Shape {
71 fn from(d12345: (usize, usize, usize, usize, usize)) -> Self {
72 Self(vec![d12345.0, d12345.1, d12345.2, d12345.3, d12345.4])
73 }
74}
75
76impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
77 fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
78 Self(vec![
79 d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
80 ])
81 }
82}
83
84impl From<Vec<usize>> for Shape {
85 fn from(dims: Vec<usize>) -> Self {
86 Self(dims)
87 }
88}
89
90macro_rules! extract_dims {
91 ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
92 pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
93 if dims.len() != $cnt {
94 Err(Error::UnexpectedNumberOfDims {
95 expected: $cnt,
96 got: dims.len(),
97 shape: Shape::from(dims),
98 }
99 .bt())
100 } else {
101 Ok($dims(dims))
102 }
103 }
104
105 impl Shape {
106 pub fn $fn_name(&self) -> Result<$out_type> {
107 $fn_name(self.0.as_slice())
108 }
109 }
110
111 impl crate::Tensor {
112 pub fn $fn_name(&self) -> Result<$out_type> {
113 self.shape().$fn_name()
114 }
115 }
116
117 impl std::convert::TryInto<$out_type> for Shape {
118 type Error = crate::Error;
119 fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
120 self.$fn_name()
121 }
122 }
123 };
124}
125
126impl Shape {
127 pub fn from_dims(dims: &[usize]) -> Self {
128 Self(dims.to_vec())
129 }
130
131 pub fn rank(&self) -> usize {
133 self.0.len()
134 }
135
136 pub fn into_dims(self) -> Vec<usize> {
137 self.0
138 }
139
140 pub fn dims(&self) -> &[usize] {
142 &self.0
143 }
144
145 pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
147 let dim = dim.to_index(self, "dim")?;
148 Ok(self.dims()[dim])
149 }
150
151 pub fn elem_count(&self) -> usize {
153 self.0.iter().product()
154 }
155
156 pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
159 let mut stride: Vec<_> = self
160 .0
161 .iter()
162 .rev()
163 .scan(1, |prod, u| {
164 let prod_pre_mult = *prod;
165 *prod *= u;
166 Some(prod_pre_mult)
167 })
168 .collect();
169 stride.reverse();
170 stride
171 }
172
173 pub fn is_contiguous(&self, stride: &[usize]) -> bool {
175 if self.0.len() != stride.len() {
176 return false;
177 }
178 let mut acc = 1;
179 for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
180 if dim > 1 && stride != acc {
181 return false;
182 }
183 acc *= dim;
184 }
185 true
186 }
187
188 pub fn is_fortran_contiguous(&self, stride: &[usize]) -> bool {
190 if self.0.len() != stride.len() {
191 return false;
192 }
193 let mut acc = 1;
194 for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
195 if dim > 1 && stride != acc {
196 return false;
197 }
198 acc *= dim;
199 }
200 true
201 }
202
203 pub fn extend(mut self, additional_dims: &[usize]) -> Self {
206 self.0.extend(additional_dims);
207 self
208 }
209
210 pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
213 let lhs = self;
214 let lhs_dims = lhs.dims();
215 let rhs_dims = rhs.dims();
216 let lhs_ndims = lhs_dims.len();
217 let rhs_ndims = rhs_dims.len();
218 let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
219 let mut bcast_dims = vec![0; bcast_ndims];
220 for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
221 let rev_idx = bcast_ndims - idx;
222 let l_value = if lhs_ndims < rev_idx {
223 1
224 } else {
225 lhs_dims[lhs_ndims - rev_idx]
226 };
227 let r_value = if rhs_ndims < rev_idx {
228 1
229 } else {
230 rhs_dims[rhs_ndims - rev_idx]
231 };
232 *bcast_value = if l_value == r_value {
233 l_value
234 } else if l_value == 1 {
235 r_value
236 } else if r_value == 1 {
237 l_value
238 } else {
239 Err(Error::ShapeMismatchBinaryOp {
240 lhs: lhs.clone(),
241 rhs: rhs.clone(),
242 op,
243 }
244 .bt())?
245 }
246 }
247 Ok(Shape::from(bcast_dims))
248 }
249
250 pub(crate) fn broadcast_shape_matmul(&self, rhs: &Self) -> Result<(Shape, Shape)> {
251 let lhs = self;
252 let lhs_dims = lhs.dims();
253 let rhs_dims = rhs.dims();
254 if lhs_dims.len() < 2 || rhs_dims.len() < 2 {
255 crate::bail!("only 2d matrixes are supported {lhs:?} {rhs:?}")
256 }
257 let (m, lhs_k) = (lhs_dims[lhs_dims.len() - 2], lhs_dims[lhs_dims.len() - 1]);
258 let (rhs_k, n) = (rhs_dims[rhs_dims.len() - 2], rhs_dims[rhs_dims.len() - 1]);
259 if lhs_k != rhs_k {
260 crate::bail!("different inner dimensions in broadcast matmul {lhs:?} {rhs:?}")
261 }
262
263 let lhs_b = Self::from(&lhs_dims[..lhs_dims.len() - 2]);
264 let rhs_b = Self::from(&rhs_dims[..rhs_dims.len() - 2]);
265 let bcast = lhs_b.broadcast_shape_binary_op(&rhs_b, "broadcast_matmul")?;
266 let bcast_dims = bcast.dims();
267
268 let bcast_lhs = [bcast_dims, &[m, lhs_k]].concat();
269 let bcast_rhs = [bcast_dims, &[rhs_k, n]].concat();
270 Ok((Shape::from(bcast_lhs), Shape::from(bcast_rhs)))
271 }
272}
273
274pub trait Dim {
275 fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
276 fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>;
277}
278
279impl Dim for usize {
280 fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
281 let dim = *self;
282 if dim >= shape.dims().len() {
283 Err(Error::DimOutOfRange {
284 shape: shape.clone(),
285 dim: dim as i32,
286 op,
287 }
288 .bt())?
289 } else {
290 Ok(dim)
291 }
292 }
293
294 fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
295 let dim = *self;
296 if dim > shape.dims().len() {
297 Err(Error::DimOutOfRange {
298 shape: shape.clone(),
299 dim: dim as i32,
300 op,
301 }
302 .bt())?
303 } else {
304 Ok(dim)
305 }
306 }
307}
308
309#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
310pub enum D {
311 Minus1,
312 Minus2,
313 Minus(usize),
314}
315
316impl D {
317 fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error {
318 let dim = match self {
319 Self::Minus1 => -1,
320 Self::Minus2 => -2,
321 Self::Minus(u) => -(*u as i32),
322 };
323 Error::DimOutOfRange {
324 shape: shape.clone(),
325 dim,
326 op,
327 }
328 .bt()
329 }
330}
331
332impl Dim for D {
333 fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
334 let rank = shape.rank();
335 match self {
336 Self::Minus1 if rank >= 1 => Ok(rank - 1),
337 Self::Minus2 if rank >= 2 => Ok(rank - 2),
338 Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
339 _ => Err(self.out_of_range(shape, op)),
340 }
341 }
342
343 fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
344 let rank = shape.rank();
345 match self {
346 Self::Minus1 => Ok(rank),
347 Self::Minus2 if rank >= 1 => Ok(rank - 1),
348 Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
349 _ => Err(self.out_of_range(shape, op)),
350 }
351 }
352}
353
354pub trait Dims: Sized {
355 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>>;
356
357 fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
358 let dims = self.to_indexes_internal(shape, op)?;
359 for (i, &dim) in dims.iter().enumerate() {
360 if dims[..i].contains(&dim) {
361 Err(Error::DuplicateDimIndex {
362 shape: shape.clone(),
363 dims: dims.clone(),
364 op,
365 }
366 .bt())?
367 }
368 if dim >= shape.rank() {
369 Err(Error::DimOutOfRange {
370 shape: shape.clone(),
371 dim: dim as i32,
372 op,
373 }
374 .bt())?
375 }
376 }
377 Ok(dims)
378 }
379}
380
381impl Dims for Vec<usize> {
382 fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
383 Ok(self)
384 }
385}
386
387impl<const N: usize> Dims for [usize; N] {
388 fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
389 Ok(self.to_vec())
390 }
391}
392
393impl Dims for &[usize] {
394 fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
395 Ok(self.to_vec())
396 }
397}
398
399impl Dims for () {
400 fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
401 Ok(vec![])
402 }
403}
404
405impl<D: Dim + Sized> Dims for D {
406 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
407 let dim = self.to_index(shape, op)?;
408 Ok(vec![dim])
409 }
410}
411
412impl<D: Dim> Dims for (D,) {
413 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
414 let dim = self.0.to_index(shape, op)?;
415 Ok(vec![dim])
416 }
417}
418
419impl<D1: Dim, D2: Dim> Dims for (D1, D2) {
420 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
421 let d0 = self.0.to_index(shape, op)?;
422 let d1 = self.1.to_index(shape, op)?;
423 Ok(vec![d0, d1])
424 }
425}
426
427impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
428 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
429 let d0 = self.0.to_index(shape, op)?;
430 let d1 = self.1.to_index(shape, op)?;
431 let d2 = self.2.to_index(shape, op)?;
432 Ok(vec![d0, d1, d2])
433 }
434}
435
436impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
437 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
438 let d0 = self.0.to_index(shape, op)?;
439 let d1 = self.1.to_index(shape, op)?;
440 let d2 = self.2.to_index(shape, op)?;
441 let d3 = self.3.to_index(shape, op)?;
442 Ok(vec![d0, d1, d2, d3])
443 }
444}
445
446impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
447 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
448 let d0 = self.0.to_index(shape, op)?;
449 let d1 = self.1.to_index(shape, op)?;
450 let d2 = self.2.to_index(shape, op)?;
451 let d3 = self.3.to_index(shape, op)?;
452 let d4 = self.4.to_index(shape, op)?;
453 Ok(vec![d0, d1, d2, d3, d4])
454 }
455}
456
457impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
458 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
459 let d0 = self.0.to_index(shape, op)?;
460 let d1 = self.1.to_index(shape, op)?;
461 let d2 = self.2.to_index(shape, op)?;
462 let d3 = self.3.to_index(shape, op)?;
463 let d4 = self.4.to_index(shape, op)?;
464 let d5 = self.5.to_index(shape, op)?;
465 Ok(vec![d0, d1, d2, d3, d4, d5])
466 }
467}
468
469extract_dims!(dims0, 0, |_: &[usize]| (), ());
470extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
471extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
472extract_dims!(
473 dims3,
474 3,
475 |d: &[usize]| (d[0], d[1], d[2]),
476 (usize, usize, usize)
477);
478extract_dims!(
479 dims4,
480 4,
481 |d: &[usize]| (d[0], d[1], d[2], d[3]),
482 (usize, usize, usize, usize)
483);
484extract_dims!(
485 dims5,
486 5,
487 |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
488 (usize, usize, usize, usize, usize)
489);
490
491pub trait ShapeWithOneHole {
492 fn into_shape(self, el_count: usize) -> Result<Shape>;
493}
494
495impl<S: Into<Shape>> ShapeWithOneHole for S {
496 fn into_shape(self, _el_count: usize) -> Result<Shape> {
497 Ok(self.into())
498 }
499}
500
501impl ShapeWithOneHole for ((),) {
502 fn into_shape(self, el_count: usize) -> Result<Shape> {
503 Ok(el_count.into())
504 }
505}
506
507fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
508 if prod_d == 0 {
509 crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
510 }
511 if el_count % prod_d != 0 {
512 crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
513 }
514 Ok(el_count / prod_d)
515}
516
517impl ShapeWithOneHole for ((), usize) {
518 fn into_shape(self, el_count: usize) -> Result<Shape> {
519 let ((), d1) = self;
520 Ok((hole_size(el_count, d1, &self)?, d1).into())
521 }
522}
523
524impl ShapeWithOneHole for (usize, ()) {
525 fn into_shape(self, el_count: usize) -> Result<Shape> {
526 let (d1, ()) = self;
527 Ok((d1, hole_size(el_count, d1, &self)?).into())
528 }
529}
530
531impl ShapeWithOneHole for ((), usize, usize) {
532 fn into_shape(self, el_count: usize) -> Result<Shape> {
533 let ((), d1, d2) = self;
534 Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
535 }
536}
537
538impl ShapeWithOneHole for (usize, (), usize) {
539 fn into_shape(self, el_count: usize) -> Result<Shape> {
540 let (d1, (), d2) = self;
541 Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
542 }
543}
544
545impl ShapeWithOneHole for (usize, usize, ()) {
546 fn into_shape(self, el_count: usize) -> Result<Shape> {
547 let (d1, d2, ()) = self;
548 Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
549 }
550}
551
552impl ShapeWithOneHole for ((), usize, usize, usize) {
553 fn into_shape(self, el_count: usize) -> Result<Shape> {
554 let ((), d1, d2, d3) = self;
555 let d = hole_size(el_count, d1 * d2 * d3, &self)?;
556 Ok((d, d1, d2, d3).into())
557 }
558}
559
560impl ShapeWithOneHole for (usize, (), usize, usize) {
561 fn into_shape(self, el_count: usize) -> Result<Shape> {
562 let (d1, (), d2, d3) = self;
563 let d = hole_size(el_count, d1 * d2 * d3, &self)?;
564 Ok((d1, d, d2, d3).into())
565 }
566}
567
568impl ShapeWithOneHole for (usize, usize, (), usize) {
569 fn into_shape(self, el_count: usize) -> Result<Shape> {
570 let (d1, d2, (), d3) = self;
571 let d = hole_size(el_count, d1 * d2 * d3, &self)?;
572 Ok((d1, d2, d, d3).into())
573 }
574}
575
576impl ShapeWithOneHole for (usize, usize, usize, ()) {
577 fn into_shape(self, el_count: usize) -> Result<Shape> {
578 let (d1, d2, d3, ()) = self;
579 let d = hole_size(el_count, d1 * d2 * d3, &self)?;
580 Ok((d1, d2, d3, d).into())
581 }
582}
583
584impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
585 fn into_shape(self, el_count: usize) -> Result<Shape> {
586 let ((), d1, d2, d3, d4) = self;
587 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
588 Ok((d, d1, d2, d3, d4).into())
589 }
590}
591
592impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
593 fn into_shape(self, el_count: usize) -> Result<Shape> {
594 let (d1, (), d2, d3, d4) = self;
595 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
596 Ok((d1, d, d2, d3, d4).into())
597 }
598}
599
600impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
601 fn into_shape(self, el_count: usize) -> Result<Shape> {
602 let (d1, d2, (), d3, d4) = self;
603 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
604 Ok((d1, d2, d, d3, d4).into())
605 }
606}
607
608impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
609 fn into_shape(self, el_count: usize) -> Result<Shape> {
610 let (d1, d2, d3, (), d4) = self;
611 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
612 Ok((d1, d2, d3, d, d4).into())
613 }
614}
615
616impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
617 fn into_shape(self, el_count: usize) -> Result<Shape> {
618 let (d1, d2, d3, d4, ()) = self;
619 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
620 Ok((d1, d2, d3, d4, d).into())
621 }
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627
628 #[test]
629 fn stride() {
630 let shape = Shape::from(());
631 assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
632 let shape = Shape::from(42);
633 assert_eq!(shape.stride_contiguous(), [1]);
634 let shape = Shape::from((42, 1337));
635 assert_eq!(shape.stride_contiguous(), [1337, 1]);
636 let shape = Shape::from((299, 792, 458));
637 assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
638 }
639}