1#![allow(clippy::redundant_closure_call)]
3use crate::{Error, Result};
4
5#[derive(Clone, PartialEq, Eq)]
6pub struct Shape(Vec<usize>);
7
8pub const SCALAR: Shape = Shape(vec![]);
9
10impl std::fmt::Debug for Shape {
11 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
12 write!(f, "{:?}", &self.dims())
13 }
14}
15
16impl<const C: usize> From<&[usize; C]> for Shape {
17 fn from(dims: &[usize; C]) -> Self {
18 Self(dims.to_vec())
19 }
20}
21
22impl From<&[usize]> for Shape {
23 fn from(dims: &[usize]) -> Self {
24 Self(dims.to_vec())
25 }
26}
27
28impl From<&Shape> for Shape {
29 fn from(shape: &Shape) -> Self {
30 Self(shape.0.to_vec())
31 }
32}
33
34impl From<()> for Shape {
35 fn from(_: ()) -> Self {
36 Self(vec![])
37 }
38}
39
40impl From<usize> for Shape {
41 fn from(d1: usize) -> Self {
42 Self(vec![d1])
43 }
44}
45
46macro_rules! impl_from_tuple {
47 ($tuple:ty, $($index:tt),+) => {
48 impl From<$tuple> for Shape {
49 fn from(d: $tuple) -> Self {
50 Self(vec![$(d.$index,)+])
51 }
52 }
53 }
54}
55
56impl_from_tuple!((usize,), 0);
57impl_from_tuple!((usize, usize), 0, 1);
58impl_from_tuple!((usize, usize, usize), 0, 1, 2);
59impl_from_tuple!((usize, usize, usize, usize), 0, 1, 2, 3);
60impl_from_tuple!((usize, usize, usize, usize, usize), 0, 1, 2, 3, 4);
61impl_from_tuple!((usize, usize, usize, usize, usize, usize), 0, 1, 2, 3, 4, 5);
62
63impl From<Vec<usize>> for Shape {
64 fn from(dims: Vec<usize>) -> Self {
65 Self(dims)
66 }
67}
68
69macro_rules! extract_dims {
70 ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
71 pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
72 if dims.len() != $cnt {
73 Err(Error::UnexpectedNumberOfDims {
74 expected: $cnt,
75 got: dims.len(),
76 shape: Shape::from(dims),
77 }
78 .bt())
79 } else {
80 Ok($dims(dims))
81 }
82 }
83
84 impl Shape {
85 pub fn $fn_name(&self) -> Result<$out_type> {
86 $fn_name(self.0.as_slice())
87 }
88 }
89
90 impl crate::Tensor {
91 pub fn $fn_name(&self) -> Result<$out_type> {
92 self.shape().$fn_name()
93 }
94 }
95
96 impl std::convert::TryInto<$out_type> for Shape {
97 type Error = crate::Error;
98 fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
99 self.$fn_name()
100 }
101 }
102 };
103}
104
105impl Shape {
106 pub fn from_dims(dims: &[usize]) -> Self {
107 Self(dims.to_vec())
108 }
109
110 pub fn rank(&self) -> usize {
112 self.0.len()
113 }
114
115 pub fn into_dims(self) -> Vec<usize> {
116 self.0
117 }
118
119 pub fn dims(&self) -> &[usize] {
121 &self.0
122 }
123
124 pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
126 let dim = dim.to_index(self, "dim")?;
127 Ok(self.dims()[dim])
128 }
129
130 pub fn elem_count(&self) -> usize {
132 self.0.iter().product()
133 }
134
135 pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
138 let mut stride: Vec<_> = self
139 .0
140 .iter()
141 .rev()
142 .scan(1, |prod, u| {
143 let prod_pre_mult = *prod;
144 *prod *= u;
145 Some(prod_pre_mult)
146 })
147 .collect();
148 stride.reverse();
149 stride
150 }
151
152 pub fn is_contiguous(&self, stride: &[usize]) -> bool {
154 if self.0.len() != stride.len() {
155 return false;
156 }
157 let mut acc = 1;
158 for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
159 if dim > 1 && stride != acc {
160 return false;
161 }
162 acc *= dim;
163 }
164 true
165 }
166
167 pub fn is_fortran_contiguous(&self, stride: &[usize]) -> bool {
169 if self.0.len() != stride.len() {
170 return false;
171 }
172 let mut acc = 1;
173 for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
174 if dim > 1 && stride != acc {
175 return false;
176 }
177 acc *= dim;
178 }
179 true
180 }
181
182 pub fn extend(mut self, additional_dims: &[usize]) -> Self {
185 self.0.extend(additional_dims);
186 self
187 }
188
189 pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
192 let lhs = self;
193 let lhs_dims = lhs.dims();
194 let rhs_dims = rhs.dims();
195 let lhs_ndims = lhs_dims.len();
196 let rhs_ndims = rhs_dims.len();
197 let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
198 let mut bcast_dims = vec![0; bcast_ndims];
199 for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
200 let rev_idx = bcast_ndims - idx;
201 let l_value = if lhs_ndims < rev_idx {
202 1
203 } else {
204 lhs_dims[lhs_ndims - rev_idx]
205 };
206 let r_value = if rhs_ndims < rev_idx {
207 1
208 } else {
209 rhs_dims[rhs_ndims - rev_idx]
210 };
211 *bcast_value = if l_value == r_value {
212 l_value
213 } else if l_value == 1 {
214 r_value
215 } else if r_value == 1 {
216 l_value
217 } else {
218 Err(Error::ShapeMismatchBinaryOp {
219 lhs: lhs.clone(),
220 rhs: rhs.clone(),
221 op,
222 }
223 .bt())?
224 }
225 }
226 Ok(Shape::from(bcast_dims))
227 }
228
229 pub(crate) fn broadcast_shape_matmul(&self, rhs: &Self) -> Result<(Shape, Shape)> {
230 let lhs = self;
231 let lhs_dims = lhs.dims();
232 let rhs_dims = rhs.dims();
233 if lhs_dims.len() < 2 || rhs_dims.len() < 2 {
234 crate::bail!("only 2d matrixes are supported {lhs:?} {rhs:?}")
235 }
236 let (m, lhs_k) = (lhs_dims[lhs_dims.len() - 2], lhs_dims[lhs_dims.len() - 1]);
237 let (rhs_k, n) = (rhs_dims[rhs_dims.len() - 2], rhs_dims[rhs_dims.len() - 1]);
238 if lhs_k != rhs_k {
239 crate::bail!("different inner dimensions in broadcast matmul {lhs:?} {rhs:?}")
240 }
241
242 let lhs_b = Self::from(&lhs_dims[..lhs_dims.len() - 2]);
243 let rhs_b = Self::from(&rhs_dims[..rhs_dims.len() - 2]);
244 let bcast = lhs_b.broadcast_shape_binary_op(&rhs_b, "broadcast_matmul")?;
245 let bcast_dims = bcast.dims();
246
247 let bcast_lhs = [bcast_dims, &[m, lhs_k]].concat();
248 let bcast_rhs = [bcast_dims, &[rhs_k, n]].concat();
249 Ok((Shape::from(bcast_lhs), Shape::from(bcast_rhs)))
250 }
251}
252
253pub trait Dim {
254 fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
255 fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>;
256}
257
258impl Dim for usize {
259 fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
260 let dim = *self;
261 if dim >= shape.dims().len() {
262 Err(Error::DimOutOfRange {
263 shape: shape.clone(),
264 dim: dim as i32,
265 op,
266 }
267 .bt())?
268 } else {
269 Ok(dim)
270 }
271 }
272
273 fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
274 let dim = *self;
275 if dim > shape.dims().len() {
276 Err(Error::DimOutOfRange {
277 shape: shape.clone(),
278 dim: dim as i32,
279 op,
280 }
281 .bt())?
282 } else {
283 Ok(dim)
284 }
285 }
286}
287
288#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
289pub enum D {
290 Minus1,
291 Minus2,
292 Minus(usize),
293}
294
295impl D {
296 fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error {
297 let dim = match self {
298 Self::Minus1 => -1,
299 Self::Minus2 => -2,
300 Self::Minus(u) => -(*u as i32),
301 };
302 Error::DimOutOfRange {
303 shape: shape.clone(),
304 dim,
305 op,
306 }
307 .bt()
308 }
309}
310
311impl Dim for D {
312 fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
313 let rank = shape.rank();
314 match self {
315 Self::Minus1 if rank >= 1 => Ok(rank - 1),
316 Self::Minus2 if rank >= 2 => Ok(rank - 2),
317 Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
318 _ => Err(self.out_of_range(shape, op)),
319 }
320 }
321
322 fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
323 let rank = shape.rank();
324 match self {
325 Self::Minus1 => Ok(rank),
326 Self::Minus2 if rank >= 1 => Ok(rank - 1),
327 Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
328 _ => Err(self.out_of_range(shape, op)),
329 }
330 }
331}
332
333pub trait Dims: Sized {
334 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>>;
335
336 fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
337 let dims = self.to_indexes_internal(shape, op)?;
338 for (i, &dim) in dims.iter().enumerate() {
339 if dims[..i].contains(&dim) {
340 Err(Error::DuplicateDimIndex {
341 shape: shape.clone(),
342 dims: dims.clone(),
343 op,
344 }
345 .bt())?
346 }
347 if dim >= shape.rank() {
348 Err(Error::DimOutOfRange {
349 shape: shape.clone(),
350 dim: dim as i32,
351 op,
352 }
353 .bt())?
354 }
355 }
356 Ok(dims)
357 }
358}
359
360impl Dims for Vec<usize> {
361 fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
362 Ok(self)
363 }
364}
365
366impl<const N: usize> Dims for [usize; N] {
367 fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
368 Ok(self.to_vec())
369 }
370}
371
372impl Dims for &[usize] {
373 fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
374 Ok(self.to_vec())
375 }
376}
377
378impl Dims for () {
379 fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
380 Ok(vec![])
381 }
382}
383
384impl<D: Dim + Sized> Dims for D {
385 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
386 let dim = self.to_index(shape, op)?;
387 Ok(vec![dim])
388 }
389}
390
391impl<D: Dim> Dims for (D,) {
392 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
393 let dim = self.0.to_index(shape, op)?;
394 Ok(vec![dim])
395 }
396}
397
398impl<D1: Dim, D2: Dim> Dims for (D1, D2) {
399 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
400 let d0 = self.0.to_index(shape, op)?;
401 let d1 = self.1.to_index(shape, op)?;
402 Ok(vec![d0, d1])
403 }
404}
405
406impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
407 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
408 let d0 = self.0.to_index(shape, op)?;
409 let d1 = self.1.to_index(shape, op)?;
410 let d2 = self.2.to_index(shape, op)?;
411 Ok(vec![d0, d1, d2])
412 }
413}
414
415impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
416 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
417 let d0 = self.0.to_index(shape, op)?;
418 let d1 = self.1.to_index(shape, op)?;
419 let d2 = self.2.to_index(shape, op)?;
420 let d3 = self.3.to_index(shape, op)?;
421 Ok(vec![d0, d1, d2, d3])
422 }
423}
424
425impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
426 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
427 let d0 = self.0.to_index(shape, op)?;
428 let d1 = self.1.to_index(shape, op)?;
429 let d2 = self.2.to_index(shape, op)?;
430 let d3 = self.3.to_index(shape, op)?;
431 let d4 = self.4.to_index(shape, op)?;
432 Ok(vec![d0, d1, d2, d3, d4])
433 }
434}
435
436impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
437 fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
438 let d0 = self.0.to_index(shape, op)?;
439 let d1 = self.1.to_index(shape, op)?;
440 let d2 = self.2.to_index(shape, op)?;
441 let d3 = self.3.to_index(shape, op)?;
442 let d4 = self.4.to_index(shape, op)?;
443 let d5 = self.5.to_index(shape, op)?;
444 Ok(vec![d0, d1, d2, d3, d4, d5])
445 }
446}
447
448extract_dims!(dims0, 0, |_: &[usize]| (), ());
449extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
450extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
451extract_dims!(
452 dims3,
453 3,
454 |d: &[usize]| (d[0], d[1], d[2]),
455 (usize, usize, usize)
456);
457extract_dims!(
458 dims4,
459 4,
460 |d: &[usize]| (d[0], d[1], d[2], d[3]),
461 (usize, usize, usize, usize)
462);
463extract_dims!(
464 dims5,
465 5,
466 |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
467 (usize, usize, usize, usize, usize)
468);
469
470pub trait ShapeWithOneHole {
471 fn into_shape(self, el_count: usize) -> Result<Shape>;
472}
473
474impl<S: Into<Shape>> ShapeWithOneHole for S {
475 fn into_shape(self, _el_count: usize) -> Result<Shape> {
476 Ok(self.into())
477 }
478}
479
480impl ShapeWithOneHole for ((),) {
481 fn into_shape(self, el_count: usize) -> Result<Shape> {
482 Ok(el_count.into())
483 }
484}
485
486fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
487 if prod_d == 0 {
488 crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
489 }
490 if el_count % prod_d != 0 {
491 crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
492 }
493 Ok(el_count / prod_d)
494}
495
496impl ShapeWithOneHole for ((), usize) {
497 fn into_shape(self, el_count: usize) -> Result<Shape> {
498 let ((), d1) = self;
499 Ok((hole_size(el_count, d1, &self)?, d1).into())
500 }
501}
502
503impl ShapeWithOneHole for (usize, ()) {
504 fn into_shape(self, el_count: usize) -> Result<Shape> {
505 let (d1, ()) = self;
506 Ok((d1, hole_size(el_count, d1, &self)?).into())
507 }
508}
509
510impl ShapeWithOneHole for ((), usize, usize) {
511 fn into_shape(self, el_count: usize) -> Result<Shape> {
512 let ((), d1, d2) = self;
513 Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
514 }
515}
516
517impl ShapeWithOneHole for (usize, (), usize) {
518 fn into_shape(self, el_count: usize) -> Result<Shape> {
519 let (d1, (), d2) = self;
520 Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
521 }
522}
523
524impl ShapeWithOneHole for (usize, usize, ()) {
525 fn into_shape(self, el_count: usize) -> Result<Shape> {
526 let (d1, d2, ()) = self;
527 Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
528 }
529}
530
531impl ShapeWithOneHole for ((), usize, usize, usize) {
532 fn into_shape(self, el_count: usize) -> Result<Shape> {
533 let ((), d1, d2, d3) = self;
534 let d = hole_size(el_count, d1 * d2 * d3, &self)?;
535 Ok((d, d1, d2, d3).into())
536 }
537}
538
539impl ShapeWithOneHole for (usize, (), usize, usize) {
540 fn into_shape(self, el_count: usize) -> Result<Shape> {
541 let (d1, (), d2, d3) = self;
542 let d = hole_size(el_count, d1 * d2 * d3, &self)?;
543 Ok((d1, d, d2, d3).into())
544 }
545}
546
547impl ShapeWithOneHole for (usize, usize, (), usize) {
548 fn into_shape(self, el_count: usize) -> Result<Shape> {
549 let (d1, d2, (), d3) = self;
550 let d = hole_size(el_count, d1 * d2 * d3, &self)?;
551 Ok((d1, d2, d, d3).into())
552 }
553}
554
555impl ShapeWithOneHole for (usize, usize, usize, ()) {
556 fn into_shape(self, el_count: usize) -> Result<Shape> {
557 let (d1, d2, d3, ()) = self;
558 let d = hole_size(el_count, d1 * d2 * d3, &self)?;
559 Ok((d1, d2, d3, d).into())
560 }
561}
562
563impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
564 fn into_shape(self, el_count: usize) -> Result<Shape> {
565 let ((), d1, d2, d3, d4) = self;
566 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
567 Ok((d, d1, d2, d3, d4).into())
568 }
569}
570
571impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
572 fn into_shape(self, el_count: usize) -> Result<Shape> {
573 let (d1, (), d2, d3, d4) = self;
574 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
575 Ok((d1, d, d2, d3, d4).into())
576 }
577}
578
579impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
580 fn into_shape(self, el_count: usize) -> Result<Shape> {
581 let (d1, d2, (), d3, d4) = self;
582 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
583 Ok((d1, d2, d, d3, d4).into())
584 }
585}
586
587impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
588 fn into_shape(self, el_count: usize) -> Result<Shape> {
589 let (d1, d2, d3, (), d4) = self;
590 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
591 Ok((d1, d2, d3, d, d4).into())
592 }
593}
594
595impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
596 fn into_shape(self, el_count: usize) -> Result<Shape> {
597 let (d1, d2, d3, d4, ()) = self;
598 let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
599 Ok((d1, d2, d3, d4, d).into())
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606
607 #[test]
608 fn stride() {
609 let shape = Shape::from(());
610 assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
611 let shape = Shape::from(42);
612 assert_eq!(shape.stride_contiguous(), [1]);
613 let shape = Shape::from((42, 1337));
614 assert_eq!(shape.stride_contiguous(), [1337, 1]);
615 let shape = Shape::from((299, 792, 458));
616 assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
617 }
618
619 #[test]
620 fn test_from_tuple() {
621 let shape = Shape::from((2,));
622 assert_eq!(shape.dims(), &[2]);
623 let shape = Shape::from((2, 3));
624 assert_eq!(shape.dims(), &[2, 3]);
625 let shape = Shape::from((2, 3, 4));
626 assert_eq!(shape.dims(), &[2, 3, 4]);
627 let shape = Shape::from((2, 3, 4, 5));
628 assert_eq!(shape.dims(), &[2, 3, 4, 5]);
629 let shape = Shape::from((2, 3, 4, 5, 6));
630 assert_eq!(shape.dims(), &[2, 3, 4, 5, 6]);
631 let shape = Shape::from((2, 3, 4, 5, 6, 7));
632 assert_eq!(shape.dims(), &[2, 3, 4, 5, 6, 7]);
633 }
634}