1use crate::DType;
23use smallvec::SmallVec;
24
25#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum Dim {
29 Static(usize),
31 Dynamic(u32),
34}
35
36impl Dim {
37 pub fn unwrap_static(self) -> usize {
38 match self {
39 Self::Static(n) => n,
40 Self::Dynamic(s) => panic!("expected static dim, got dynamic symbol {s}"),
41 }
42 }
43
44 pub fn is_static(self) -> bool {
45 matches!(self, Self::Static(_))
46 }
47}
48
49impl From<usize> for Dim {
50 fn from(n: usize) -> Self {
51 Self::Static(n)
52 }
53}
54
55impl std::fmt::Display for Dim {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 Self::Static(n) => write!(f, "{n}"),
59 Self::Dynamic(s) => write!(f, "?{s}"),
60 }
61 }
62}
63
64#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69pub struct Shape {
70 dims: SmallVec<[Dim; 4]>,
71 dtype: DType,
72}
73
74impl Shape {
75 pub fn new(dims: &[usize], dtype: DType) -> Self {
77 Self {
78 dims: dims.iter().map(|&d| Dim::Static(d)).collect(),
79 dtype,
80 }
81 }
82
83 pub fn from_dims(dims: &[Dim], dtype: DType) -> Self {
85 Self {
86 dims: dims.into(),
87 dtype,
88 }
89 }
90
91 pub fn scalar(dtype: DType) -> Self {
93 Self {
94 dims: SmallVec::new(),
95 dtype,
96 }
97 }
98
99 pub fn rank(&self) -> usize {
100 self.dims.len()
101 }
102 pub fn dtype(&self) -> DType {
103 self.dtype
104 }
105 pub fn dims(&self) -> &[Dim] {
106 &self.dims
107 }
108 pub fn dim(&self, i: usize) -> Dim {
109 self.dims.get(i).copied().unwrap_or_else(|| {
110 let dims: Vec<_> = self.dims.iter().map(|d| d.unwrap_static()).collect();
111 panic!(
112 "Shape::dim({i}) out of bounds for rank {} dims={dims:?}",
113 self.rank()
114 );
115 })
116 }
117
118 pub fn dynamic_symbols(&self) -> Vec<u32> {
121 let mut syms: Vec<u32> = self
122 .dims
123 .iter()
124 .filter_map(|d| match d {
125 Dim::Dynamic(s) => Some(*s),
126 _ => None,
127 })
128 .collect();
129 syms.sort();
130 syms.dedup();
131 syms
132 }
133
134 pub fn bind(&self, bindings: &DimBinding) -> Self {
139 let dims = self
140 .dims
141 .iter()
142 .map(|d| match d {
143 Dim::Dynamic(s) => match bindings.get(*s) {
144 Some(n) => Dim::Static(n),
145 None => *d,
146 },
147 _ => *d,
148 })
149 .collect();
150 Self {
151 dims,
152 dtype: self.dtype,
153 }
154 }
155
156 pub fn num_elements(&self) -> Option<usize> {
158 let mut total = 1usize;
159 for d in &self.dims {
160 match d {
161 Dim::Static(n) => total = total.checked_mul(*n)?,
162 Dim::Dynamic(_) => return None,
163 }
164 }
165 Some(total)
166 }
167
168 pub fn size_bytes(&self) -> Option<usize> {
170 self.num_elements().map(|n| n * self.dtype.size_bytes())
171 }
172
173 pub fn is_static(&self) -> bool {
175 self.dims.iter().all(|d| d.is_static())
176 }
177
178 pub fn with_dim(mut self, axis: usize, dim: Dim) -> Self {
180 self.dims[axis] = dim;
181 self
182 }
183
184 pub fn with_dtype(mut self, dtype: DType) -> Self {
186 self.dtype = dtype;
187 self
188 }
189
190 pub fn broadcast_with(&self, other: &Shape) -> Result<Shape, String> {
192 broadcast(self, other)
193 }
194}
195
196pub fn broadcast(a: &Shape, b: &Shape) -> Result<Shape, String> {
200 let max_rank = a.rank().max(b.rank());
201 let mut dims = SmallVec::new();
202 for i in 0..max_rank {
203 let ad = if i < max_rank - a.rank() {
204 Dim::Static(1)
205 } else {
206 a.dims[i - (max_rank - a.rank())]
207 };
208 let bd = if i < max_rank - b.rank() {
209 Dim::Static(1)
210 } else {
211 b.dims[i - (max_rank - b.rank())]
212 };
213 let d = broadcast_dim(ad, bd)?;
214 dims.push(d);
215 }
216 Ok(Shape {
217 dims,
218 dtype: a.dtype,
219 })
220}
221
222fn broadcast_dim(a: Dim, b: Dim) -> Result<Dim, String> {
223 match (a, b) {
224 (Dim::Static(1), d) | (d, Dim::Static(1)) => Ok(d),
225 (Dim::Static(x), Dim::Static(y)) if x == y => Ok(Dim::Static(x)),
226 (Dim::Static(x), Dim::Static(y)) => Err(format!("cannot broadcast {x} with {y}")),
227 (Dim::Dynamic(s), Dim::Dynamic(t)) if s == t => Ok(Dim::Dynamic(s)),
228 (Dim::Dynamic(_), _) | (_, Dim::Dynamic(_)) => Ok(a), }
230}
231
232pub fn matmul_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape, String> {
234 if lhs.rank() < 2 || rhs.rank() < 2 {
235 return Err(format!(
236 "matmul requires rank >= 2, got {} and {}",
237 lhs.rank(),
238 rhs.rank()
239 ));
240 }
241 let m = lhs.dims[lhs.rank() - 2];
242 let k1 = lhs.dims[lhs.rank() - 1];
243 let k2 = rhs.dims[rhs.rank() - 2];
244 let n = rhs.dims[rhs.rank() - 1];
245
246 match (k1, k2) {
248 (Dim::Static(a), Dim::Static(b)) if a != b => {
249 return Err(format!("matmul K mismatch: {a} vs {b}"));
250 }
251 (Dim::Dynamic(s), Dim::Dynamic(t)) if s != t => {
252 return Err(format!("matmul K mismatch: ?{s} vs ?{t}"));
253 }
254 _ => {}
255 }
256
257 let lhs_batch = &lhs.dims[..lhs.rank() - 2];
259 let rhs_batch = &rhs.dims[..rhs.rank() - 2];
260 let batch_a = Shape::from_dims(lhs_batch, lhs.dtype);
261 let batch_b = Shape::from_dims(rhs_batch, rhs.dtype);
262 let batch = if lhs_batch.is_empty() && rhs_batch.is_empty() {
263 SmallVec::new()
264 } else if lhs_batch.is_empty() {
265 rhs_batch.into()
266 } else if rhs_batch.is_empty() {
267 lhs_batch.into()
268 } else {
269 broadcast(&batch_a, &batch_b)?.dims.clone()
270 };
271
272 let mut dims = batch;
273 dims.push(m);
274 dims.push(n);
275 Ok(Shape {
276 dims,
277 dtype: lhs.dtype,
278 })
279}
280
281pub fn expand_shape(input: &Shape, target: &[i64]) -> Result<Shape, String> {
283 if target.iter().any(|&d| d < 0) {
284 return Err("expand target has negative dim".into());
285 }
286 let target_s = Shape::new(
287 &target.iter().map(|&d| d as usize).collect::<Vec<_>>(),
288 input.dtype(),
289 );
290 broadcast(input, &target_s)
291}
292
293pub fn binary_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape, String> {
295 broadcast(lhs, rhs)
296}
297
298pub fn unary_shape(input: &Shape) -> Shape {
300 input.clone()
301}
302
303pub fn cast_shape(input: &Shape, to: DType) -> Shape {
305 input.clone().with_dtype(to)
306}
307
308pub fn compare_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape, String> {
310 Ok(broadcast(lhs, rhs)?.with_dtype(DType::Bool))
311}
312
313pub fn reduce_shape(input: &Shape, axes: &[usize], keep_dim: bool) -> Result<Shape, String> {
315 let mut dims = SmallVec::new();
316 for (i, &d) in input.dims.iter().enumerate() {
317 if axes.contains(&i) {
318 if keep_dim {
319 dims.push(Dim::Static(1));
320 }
321 } else {
322 dims.push(d);
323 }
324 }
325 Ok(Shape {
326 dims,
327 dtype: input.dtype,
328 })
329}
330
331pub fn softmax_shape(input: &Shape) -> Shape {
333 input.clone()
334}
335
336pub fn transpose_shape(input: &Shape, perm: &[usize]) -> Result<Shape, String> {
338 if perm.len() != input.rank() {
339 return Err(format!("perm len {} != rank {}", perm.len(), input.rank()));
340 }
341 let dims: SmallVec<[Dim; 4]> = perm.iter().map(|&i| input.dims[i]).collect();
342 Ok(Shape {
343 dims,
344 dtype: input.dtype,
345 })
346}
347
348pub fn narrow_shape(input: &Shape, axis: usize, len: usize) -> Result<Shape, String> {
350 if axis >= input.rank() {
351 return Err(format!("axis {axis} >= rank {}", input.rank()));
352 }
353 Ok(input.clone().with_dim(axis, Dim::Static(len)))
354}
355
356pub fn concat_shape(inputs: &[&Shape], axis: usize) -> Result<Shape, String> {
358 if inputs.is_empty() {
359 return Err("concat: no inputs".into());
360 }
361 let base = inputs[0];
362 let mut static_sum = 0usize;
363 let mut dyn_sym: Option<u32> = None;
364 for s in inputs {
365 if s.rank() == 0 {
366 return Err("concat: input has rank 0".into());
367 }
368 if s.rank() != base.rank() {
369 return Err(format!(
370 "concat: rank mismatch {} vs {}",
371 s.rank(),
372 base.rank()
373 ));
374 }
375 let ax = axis.min(s.rank().saturating_sub(1));
376 match s.dims[ax] {
377 Dim::Static(n) => static_sum += n,
378 Dim::Dynamic(sym) => {
379 if let Some(prev) = dyn_sym {
380 if prev != sym {
381 return Err(format!(
382 "concat: mismatched dynamic symbols {prev} vs {sym} on axis {axis}"
383 ));
384 }
385 }
386 dyn_sym = Some(sym);
387 }
388 }
389 }
390 let out_dim = match dyn_sym {
391 None => Dim::Static(static_sum),
392 Some(sym) if static_sum == 0 => Dim::Dynamic(sym),
393 Some(sym) => {
394 let _ = static_sum;
397 Dim::Dynamic(sym)
398 }
399 };
400 let out_axis = axis.min(base.rank().saturating_sub(1));
401 Ok(base.clone().with_dim(out_axis, out_dim))
402}
403
404pub fn gather_shape(table: &Shape, indices: &Shape, axis: usize) -> Result<Shape, String> {
406 if axis >= table.rank() {
407 return Err(format!("gather: axis {axis} >= rank {}", table.rank()));
408 }
409 let mut dims: SmallVec<[Dim; 4]> = indices.dims.clone();
410 for i in (axis + 1)..table.rank() {
411 dims.push(table.dims[i]);
412 }
413 Ok(Shape {
414 dims,
415 dtype: table.dtype,
416 })
417}
418
419pub fn reshape_shape(input: &Shape, new_shape: &[i64]) -> Result<Shape, String> {
421 let neg_count = new_shape.iter().filter(|&&d| d == -1).count();
422 if neg_count > 1 {
423 return Err("reshape: at most one -1".into());
424 }
425
426 if input.is_static() {
427 let total = input
428 .num_elements()
429 .ok_or_else(|| "reshape: input has dynamic dims".to_string())?;
430 let known_product: i64 = new_shape.iter().filter(|&&d| d != -1).product();
431 let mut dims = SmallVec::new();
432 for &d in new_shape {
433 if d == -1 {
434 let inferred = total as i64 / known_product;
435 dims.push(Dim::Static(inferred as usize));
436 } else if d < 0 {
437 return Err(format!("reshape: invalid dim {d}"));
438 } else {
439 dims.push(Dim::Static(d as usize));
440 }
441 }
442 return Ok(Shape {
443 dims,
444 dtype: input.dtype,
445 });
446 }
447
448 let dyn_syms = input.dynamic_symbols();
451 let neg_idx = new_shape.iter().position(|&d| d == -1);
452 let mut out_dims: SmallVec<[Dim; 4]> = SmallVec::new();
453 for (i, &d) in new_shape.iter().enumerate() {
454 if Some(i) == neg_idx {
455 continue;
456 }
457 if d < 0 {
458 return Err(format!("reshape: invalid dim {d}"));
459 }
460 out_dims.push(Dim::Static(d as usize));
461 }
462 if let Some(ni) = neg_idx {
463 let inferred = if dyn_syms.len() == 1 {
464 Dim::Dynamic(dyn_syms[0])
465 } else if dyn_syms.is_empty() {
466 return Err("reshape: cannot infer -1 on static input".into());
467 } else {
468 Dim::Dynamic(crate::dynamic::sym::ROWS)
469 };
470 out_dims.insert(ni, inferred);
471 }
472 Ok(Shape {
473 dims: out_dims,
474 dtype: input.dtype,
475 })
476}
477
478pub fn leading_flatten_fused_shape(input: &Shape) -> Option<Shape> {
480 if input.rank() < 2 {
481 return None;
482 }
483 let Dim::Static(h) = input.dim(input.rank() - 1) else {
484 return None;
485 };
486 let leading = &input.dims()[..input.rank() - 1];
487 let lead_dim = if leading.iter().all(|d| d.is_static()) {
488 Dim::Static(leading.iter().map(|d| d.unwrap_static()).product::<usize>())
489 } else {
490 let mut syms: Vec<u32> = leading
491 .iter()
492 .filter_map(|d| match d {
493 Dim::Dynamic(s) => Some(*s),
494 _ => None,
495 })
496 .collect();
497 syms.sort();
498 syms.dedup();
499 match syms.len() {
500 0 => Dim::Static(leading.iter().map(|d| d.unwrap_static()).product::<usize>()),
501 1 => Dim::Dynamic(syms[0]),
502 _ => Dim::Dynamic(crate::dynamic::sym::ROWS),
503 }
504 };
505 Some(Shape::from_dims(&[lead_dim, Dim::Static(h)], input.dtype()))
506}
507
508pub fn leading_flatten_shape(input: &Shape, new_shape: &[i64]) -> Option<Shape> {
510 if new_shape.len() != 2 {
511 return None;
512 }
513 let flat = leading_flatten_fused_shape(input)?;
514 let Dim::Static(h) = input.dim(input.rank() - 1) else {
515 return None;
516 };
517 if new_shape[1] as usize != h {
518 return None;
519 }
520 match flat.dim(0) {
521 Dim::Static(lead) if new_shape[0] as usize == lead => Some(flat),
522 Dim::Dynamic(_) if new_shape[0] == -1 => Some(flat),
523 _ => None,
524 }
525}
526
527pub fn attention_shape(q: &Shape) -> Shape {
529 q.clone()
530}
531
532impl std::fmt::Display for Shape {
533 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
534 write!(f, "[")?;
535 for (i, d) in self.dims.iter().enumerate() {
536 if i > 0 {
537 write!(f, ", ")?;
538 }
539 write!(f, "{d}")?;
540 }
541 write!(f, "] {}", self.dtype)
542 }
543}
544
545pub fn conv2d_spatial_output(
547 in_size: usize,
548 kernel: usize,
549 stride: usize,
550 padding: usize,
551 dilation: usize,
552) -> usize {
553 let dil_k = dilation.saturating_mul(kernel.saturating_sub(1));
554 (in_size + 2 * padding)
555 .saturating_sub(dil_k)
556 .saturating_sub(1)
557 / stride
558 + 1
559}
560
561pub fn conv_transpose2d_spatial_output(
563 in_size: usize,
564 kernel: usize,
565 stride: usize,
566 padding: usize,
567 dilation: usize,
568 output_padding: usize,
569) -> usize {
570 let dil_k = dilation.saturating_mul(kernel.saturating_sub(1));
571 (in_size - 1) * stride + output_padding + dil_k - 2 * padding + 1
572}
573
574pub fn conv2d_output_shape(
576 input: &Shape,
577 weight: &Shape,
578 kernel_size: [usize; 2],
579 stride: [usize; 2],
580 padding: [usize; 2],
581 dilation: [usize; 2],
582 groups: usize,
583) -> Result<Shape, String> {
584 if input.rank() != 4 || weight.rank() != 4 {
585 return Err("conv2d requires NCHW input and 4-D weight".into());
586 }
587 let n = input.dim(0);
588 let c_in = input.dim(1).unwrap_static();
589 let h = input.dim(2).unwrap_static();
590 let w = input.dim(3).unwrap_static();
591 let c_out = weight.dim(0).unwrap_static();
592 let w_cin = weight.dim(1).unwrap_static();
593 if w_cin * groups != c_in {
594 return Err(format!(
595 "conv2d weight C_in/g={w_cin} * groups={groups} != input C={c_in}"
596 ));
597 }
598 let h_out = conv2d_spatial_output(h, kernel_size[0], stride[0], padding[0], dilation[0]);
599 let w_out = conv2d_spatial_output(w, kernel_size[1], stride[1], padding[1], dilation[1]);
600 Ok(Shape::from_dims(
601 &[
602 n,
603 Dim::Static(c_out),
604 Dim::Static(h_out),
605 Dim::Static(w_out),
606 ],
607 input.dtype(),
608 ))
609}
610
611pub fn im2col_output_shape(
614 input: &Shape,
615 kernel_size: [usize; 2],
616 stride: [usize; 2],
617 padding: [usize; 2],
618 dilation: [usize; 2],
619) -> Result<Shape, String> {
620 if input.rank() != 4 {
621 return Err("im2col requires NCHW input".into());
622 }
623 let c_in = input.dim(1).unwrap_static();
624 let h = input.dim(2).unwrap_static();
625 let w = input.dim(3).unwrap_static();
626 let kh = kernel_size[0];
627 let kw = kernel_size[1];
628 let h_out = conv2d_spatial_output(h, kh, stride[0], padding[0], dilation[0]);
629 let w_out = conv2d_spatial_output(w, kw, stride[1], padding[1], dilation[1]);
630 let k = c_in * kh * kw;
631 let spatial = h_out * w_out;
632 let m = match input.dim(0) {
633 Dim::Static(n) => Dim::Static(n * spatial),
634 Dim::Dynamic(crate::dynamic::sym::BATCH) | Dim::Dynamic(crate::dynamic::sym::ROWS) => {
635 Dim::Dynamic(crate::dynamic::sym::ROWS)
636 }
637 Dim::Dynamic(_) => Dim::Dynamic(crate::dynamic::sym::ROWS),
638 };
639 Ok(Shape::from_dims(&[m, Dim::Static(k)], input.dtype()))
640}
641
642pub fn conv_transpose2d_output_shape(
644 input: &Shape,
645 weight: &Shape,
646 kernel_size: [usize; 2],
647 stride: [usize; 2],
648 padding: [usize; 2],
649 dilation: [usize; 2],
650 output_padding: [usize; 2],
651 groups: usize,
652) -> Result<Shape, String> {
653 if input.rank() != 4 || weight.rank() != 4 {
654 return Err("conv_transpose2d requires NCHW input and 4-D weight".into());
655 }
656 let n = input.dim(0).unwrap_static();
657 let c_in = input.dim(1).unwrap_static();
658 let h = input.dim(2).unwrap_static();
659 let w = input.dim(3).unwrap_static();
660 let w_cin = weight.dim(0).unwrap_static();
661 let c_out_per_g = weight.dim(1).unwrap_static();
662 if w_cin != c_in {
663 return Err(format!(
664 "conv_transpose2d weight C_in={w_cin} != input C={c_in}"
665 ));
666 }
667 let h_out = conv_transpose2d_spatial_output(
668 h,
669 kernel_size[0],
670 stride[0],
671 padding[0],
672 dilation[0],
673 output_padding[0],
674 );
675 let w_out = conv_transpose2d_spatial_output(
676 w,
677 kernel_size[1],
678 stride[1],
679 padding[1],
680 dilation[1],
681 output_padding[1],
682 );
683 Ok(Shape::new(
684 &[n, c_out_per_g * groups, h_out, w_out],
685 input.dtype(),
686 ))
687}
688
689#[cfg(test)]
690mod tests {
691 use super::*;
692
693 #[test]
694 fn static_shape() {
695 let s = Shape::new(&[4, 15, 384], DType::F32);
696 assert_eq!(s.rank(), 3);
697 assert_eq!(s.num_elements(), Some(4 * 15 * 384));
698 assert_eq!(s.size_bytes(), Some(4 * 15 * 384 * 4));
699 assert!(s.is_static());
700 assert_eq!(format!("{s}"), "[4, 15, 384] f32");
701 }
702
703 #[test]
706 fn broadcast_same() {
707 let a = Shape::new(&[4, 15, 384], DType::F32);
708 let r = broadcast(&a, &a).unwrap();
709 assert_eq!(r.dims(), a.dims());
710 }
711
712 #[test]
713 fn broadcast_bias() {
714 let a = Shape::new(&[4, 15, 384], DType::F32);
715 let b = Shape::new(&[384], DType::F32);
716 let r = broadcast(&a, &b).unwrap();
717 assert_eq!(r, Shape::new(&[4, 15, 384], DType::F32));
718 }
719
720 #[test]
721 fn broadcast_scalar() {
722 let a = Shape::new(&[4, 15, 384], DType::F32);
723 let b = Shape::scalar(DType::F32);
724 let r = broadcast(&a, &b).unwrap();
725 assert_eq!(r, a);
726 }
727
728 #[test]
729 fn broadcast_mismatch() {
730 let a = Shape::new(&[4, 15, 384], DType::F32);
731 let b = Shape::new(&[4, 15, 256], DType::F32);
732 assert!(broadcast(&a, &b).is_err());
733 }
734
735 #[test]
736 fn matmul_basic() {
737 let a = Shape::new(&[4, 15, 384], DType::F32);
738 let b = Shape::new(&[384, 1536], DType::F32);
739 let r = matmul_shape(&a, &b).unwrap();
740 assert_eq!(r, Shape::new(&[4, 15, 1536], DType::F32));
741 }
742
743 #[test]
744 fn matmul_batched() {
745 let a = Shape::new(&[4, 15, 384], DType::F32);
746 let b = Shape::new(&[4, 384, 1536], DType::F32);
747 let r = matmul_shape(&a, &b).unwrap();
748 assert_eq!(r, Shape::new(&[4, 15, 1536], DType::F32));
749 }
750
751 #[test]
752 fn matmul_k_mismatch() {
753 let a = Shape::new(&[4, 15, 384], DType::F32);
754 let b = Shape::new(&[512, 1536], DType::F32);
755 assert!(matmul_shape(&a, &b).is_err());
756 }
757
758 #[test]
759 fn reduce_keepdim() {
760 let a = Shape::new(&[4, 15, 384], DType::F32);
761 let r = reduce_shape(&a, &[2], true).unwrap();
762 assert_eq!(r, Shape::new(&[4, 15, 1], DType::F32));
763 }
764
765 #[test]
766 fn reduce_no_keepdim() {
767 let a = Shape::new(&[4, 15, 384], DType::F32);
768 let r = reduce_shape(&a, &[2], false).unwrap();
769 assert_eq!(r, Shape::new(&[4, 15], DType::F32));
770 }
771
772 #[test]
773 fn concat_basic() {
774 let a = Shape::new(&[4, 15, 384], DType::F32);
775 let b = Shape::new(&[4, 15, 384], DType::F32);
776 let r = concat_shape(&[&a, &b], 2).unwrap();
777 assert_eq!(r, Shape::new(&[4, 15, 768], DType::F32));
778 }
779
780 #[test]
781 fn gather_embedding() {
782 let table = Shape::new(&[30522, 384], DType::F32);
783 let indices = Shape::new(&[4, 15], DType::I64);
784 let r = gather_shape(&table, &indices, 0).unwrap();
785 assert_eq!(
786 r,
787 Shape::from_dims(
788 &[Dim::Static(4), Dim::Static(15), Dim::Static(384)],
789 DType::F32
790 )
791 );
792 }
793
794 #[test]
795 fn reshape_with_neg1() {
796 let a = Shape::new(&[4, 15, 384], DType::F32);
797 let r = reshape_shape(&a, &[60, -1]).unwrap();
798 assert_eq!(r, Shape::new(&[60, 384], DType::F32));
799 }
800
801 #[test]
802 fn transpose_basic() {
803 let a = Shape::new(&[4, 15, 384], DType::F32);
804 let r = transpose_shape(&a, &[0, 2, 1]).unwrap();
805 assert_eq!(r, Shape::new(&[4, 384, 15], DType::F32));
806 }
807
808 #[test]
809 fn narrow_basic() {
810 let a = Shape::new(&[4, 15, 1152], DType::F32);
811 let r = narrow_shape(&a, 2, 384).unwrap();
812 assert_eq!(r, Shape::new(&[4, 15, 384], DType::F32));
813 }
814
815 #[test]
816 fn compare_bool_output() {
817 let a = Shape::new(&[4, 15], DType::F32);
818 let b = Shape::new(&[4, 15], DType::F32);
819 let r = compare_shape(&a, &b).unwrap();
820 assert_eq!(r.dtype(), DType::Bool);
821 assert_eq!(r.rank(), 2);
822 }
823
824 #[test]
827 fn dynamic_shape() {
828 let s = Shape::from_dims(
829 &[Dim::Dynamic(0), Dim::Dynamic(1), Dim::Static(384)],
830 DType::F32,
831 );
832 assert_eq!(s.rank(), 3);
833 assert_eq!(s.num_elements(), None);
834 assert!(!s.is_static());
835 assert_eq!(format!("{s}"), "[?0, ?1, 384] f32");
836 }
837
838 #[test]
839 fn dynamic_symbols_lists_distinct_dims() {
840 let s = Shape::from_dims(
841 &[
842 Dim::Dynamic(1),
843 Dim::Static(384),
844 Dim::Dynamic(0),
845 Dim::Dynamic(1),
846 ],
847 DType::F32,
848 );
849 assert_eq!(s.dynamic_symbols(), vec![0, 1]);
850 }
851
852 #[test]
853 fn bind_specializes_known_symbols() {
854 let s = Shape::from_dims(
855 &[Dim::Dynamic(0), Dim::Dynamic(1), Dim::Static(384)],
856 DType::F32,
857 );
858 let mut b = DimBinding::new();
859 b.set(0, 8);
860 b.set(1, 64);
861 let s2 = s.bind(&b);
862 assert!(s2.is_static());
863 assert_eq!(s2.num_elements(), Some(8 * 64 * 384));
864 }
865
866 #[test]
867 fn bind_leaves_unknown_symbols_alone() {
868 let s = Shape::from_dims(&[Dim::Dynamic(0), Dim::Dynamic(99)], DType::F32);
869 let mut b = DimBinding::new();
870 b.set(0, 4);
871 let s2 = s.bind(&b);
872 assert!(!s2.is_static()); assert_eq!(s2.dynamic_symbols(), vec![99]);
874 }
875}
876
877#[derive(Debug, Clone, Default)]
880pub struct DimBinding {
881 map: std::collections::HashMap<u32, usize>,
882}
883
884impl DimBinding {
885 pub fn new() -> Self {
886 Self::default()
887 }
888 pub fn set(&mut self, symbol: u32, size: usize) -> Option<usize> {
889 self.map.insert(symbol, size)
890 }
891 pub fn get(&self, symbol: u32) -> Option<usize> {
892 self.map.get(&symbol).copied()
893 }
894 pub fn is_empty(&self) -> bool {
895 self.map.is_empty()
896 }
897 pub fn len(&self) -> usize {
898 self.map.len()
899 }
900 pub fn iter(&self) -> impl Iterator<Item = (u32, usize)> + '_ {
901 self.map.iter().map(|(&s, &n)| (s, n))
902 }
903}