1use crate::DType;
23use smallvec::SmallVec;
24
25#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum Dim {
29 Static(usize),
31 Dynamic(u32),
34}
35
36impl Dim {
37 pub fn unwrap_static(self) -> usize {
38 match self {
39 Self::Static(n) => n,
40 Self::Dynamic(s) => panic!("expected static dim, got dynamic symbol {s}"),
41 }
42 }
43
44 pub fn is_static(self) -> bool {
45 matches!(self, Self::Static(_))
46 }
47}
48
49impl From<usize> for Dim {
50 fn from(n: usize) -> Self {
51 Self::Static(n)
52 }
53}
54
55impl std::fmt::Display for Dim {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 Self::Static(n) => write!(f, "{n}"),
59 Self::Dynamic(s) => write!(f, "?{s}"),
60 }
61 }
62}
63
64#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69pub struct Shape {
70 dims: SmallVec<[Dim; 4]>,
71 dtype: DType,
72}
73
74impl Shape {
75 pub fn new(dims: &[usize], dtype: DType) -> Self {
77 Self {
78 dims: dims.iter().map(|&d| Dim::Static(d)).collect(),
79 dtype,
80 }
81 }
82
83 pub fn from_dims(dims: &[Dim], dtype: DType) -> Self {
85 Self {
86 dims: dims.into(),
87 dtype,
88 }
89 }
90
91 pub fn scalar(dtype: DType) -> Self {
93 Self {
94 dims: SmallVec::new(),
95 dtype,
96 }
97 }
98
99 pub fn rank(&self) -> usize {
100 self.dims.len()
101 }
102 pub fn dtype(&self) -> DType {
103 self.dtype
104 }
105 pub fn dims(&self) -> &[Dim] {
106 &self.dims
107 }
108 pub fn dim(&self, i: usize) -> Dim {
109 self.dims[i]
110 }
111
112 pub fn dynamic_symbols(&self) -> Vec<u32> {
115 let mut syms: Vec<u32> = self
116 .dims
117 .iter()
118 .filter_map(|d| match d {
119 Dim::Dynamic(s) => Some(*s),
120 _ => None,
121 })
122 .collect();
123 syms.sort();
124 syms.dedup();
125 syms
126 }
127
128 pub fn bind(&self, bindings: &DimBinding) -> Self {
133 let dims = self
134 .dims
135 .iter()
136 .map(|d| match d {
137 Dim::Dynamic(s) => match bindings.get(*s) {
138 Some(n) => Dim::Static(n),
139 None => *d,
140 },
141 _ => *d,
142 })
143 .collect();
144 Self {
145 dims,
146 dtype: self.dtype,
147 }
148 }
149
150 pub fn num_elements(&self) -> Option<usize> {
152 let mut total = 1usize;
153 for d in &self.dims {
154 match d {
155 Dim::Static(n) => total = total.checked_mul(*n)?,
156 Dim::Dynamic(_) => return None,
157 }
158 }
159 Some(total)
160 }
161
162 pub fn size_bytes(&self) -> Option<usize> {
164 self.num_elements().map(|n| n * self.dtype.size_bytes())
165 }
166
167 pub fn is_static(&self) -> bool {
169 self.dims.iter().all(|d| d.is_static())
170 }
171
172 pub fn with_dim(mut self, axis: usize, dim: Dim) -> Self {
174 self.dims[axis] = dim;
175 self
176 }
177
178 pub fn with_dtype(mut self, dtype: DType) -> Self {
180 self.dtype = dtype;
181 self
182 }
183
184 pub fn broadcast_with(&self, other: &Shape) -> Result<Shape, String> {
186 broadcast(self, other)
187 }
188}
189
190pub fn broadcast(a: &Shape, b: &Shape) -> Result<Shape, String> {
194 let max_rank = a.rank().max(b.rank());
195 let mut dims = SmallVec::new();
196 for i in 0..max_rank {
197 let ad = if i < max_rank - a.rank() {
198 Dim::Static(1)
199 } else {
200 a.dims[i - (max_rank - a.rank())]
201 };
202 let bd = if i < max_rank - b.rank() {
203 Dim::Static(1)
204 } else {
205 b.dims[i - (max_rank - b.rank())]
206 };
207 let d = broadcast_dim(ad, bd)?;
208 dims.push(d);
209 }
210 Ok(Shape {
211 dims,
212 dtype: a.dtype,
213 })
214}
215
216fn broadcast_dim(a: Dim, b: Dim) -> Result<Dim, String> {
217 match (a, b) {
218 (Dim::Static(1), d) | (d, Dim::Static(1)) => Ok(d),
219 (Dim::Static(x), Dim::Static(y)) if x == y => Ok(Dim::Static(x)),
220 (Dim::Static(x), Dim::Static(y)) => Err(format!("cannot broadcast {x} with {y}")),
221 (Dim::Dynamic(s), Dim::Dynamic(t)) if s == t => Ok(Dim::Dynamic(s)),
222 (Dim::Dynamic(_), _) | (_, Dim::Dynamic(_)) => Ok(a), }
224}
225
226pub fn matmul_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape, String> {
228 if lhs.rank() < 2 || rhs.rank() < 2 {
229 return Err(format!(
230 "matmul requires rank >= 2, got {} and {}",
231 lhs.rank(),
232 rhs.rank()
233 ));
234 }
235 let m = lhs.dims[lhs.rank() - 2];
236 let k1 = lhs.dims[lhs.rank() - 1];
237 let k2 = rhs.dims[rhs.rank() - 2];
238 let n = rhs.dims[rhs.rank() - 1];
239
240 match (k1, k2) {
242 (Dim::Static(a), Dim::Static(b)) if a != b => {
243 return Err(format!("matmul K mismatch: {a} vs {b}"));
244 }
245 (Dim::Dynamic(s), Dim::Dynamic(t)) if s != t => {
246 return Err(format!("matmul K mismatch: ?{s} vs ?{t}"));
247 }
248 _ => {}
249 }
250
251 let lhs_batch = &lhs.dims[..lhs.rank() - 2];
253 let rhs_batch = &rhs.dims[..rhs.rank() - 2];
254 let batch_a = Shape::from_dims(lhs_batch, lhs.dtype);
255 let batch_b = Shape::from_dims(rhs_batch, rhs.dtype);
256 let batch = if lhs_batch.is_empty() && rhs_batch.is_empty() {
257 SmallVec::new()
258 } else if lhs_batch.is_empty() {
259 rhs_batch.into()
260 } else if rhs_batch.is_empty() {
261 lhs_batch.into()
262 } else {
263 broadcast(&batch_a, &batch_b)?.dims.clone()
264 };
265
266 let mut dims = batch;
267 dims.push(m);
268 dims.push(n);
269 Ok(Shape {
270 dims,
271 dtype: lhs.dtype,
272 })
273}
274
275pub fn binary_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape, String> {
277 broadcast(lhs, rhs)
278}
279
280pub fn unary_shape(input: &Shape) -> Shape {
282 input.clone()
283}
284
285pub fn cast_shape(input: &Shape, to: DType) -> Shape {
287 input.clone().with_dtype(to)
288}
289
290pub fn compare_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape, String> {
292 Ok(broadcast(lhs, rhs)?.with_dtype(DType::Bool))
293}
294
295pub fn reduce_shape(input: &Shape, axes: &[usize], keep_dim: bool) -> Result<Shape, String> {
297 let mut dims = SmallVec::new();
298 for (i, &d) in input.dims.iter().enumerate() {
299 if axes.contains(&i) {
300 if keep_dim {
301 dims.push(Dim::Static(1));
302 }
303 } else {
304 dims.push(d);
305 }
306 }
307 Ok(Shape {
308 dims,
309 dtype: input.dtype,
310 })
311}
312
313pub fn softmax_shape(input: &Shape) -> Shape {
315 input.clone()
316}
317
318pub fn transpose_shape(input: &Shape, perm: &[usize]) -> Result<Shape, String> {
320 if perm.len() != input.rank() {
321 return Err(format!("perm len {} != rank {}", perm.len(), input.rank()));
322 }
323 let dims: SmallVec<[Dim; 4]> = perm.iter().map(|&i| input.dims[i]).collect();
324 Ok(Shape {
325 dims,
326 dtype: input.dtype,
327 })
328}
329
330pub fn narrow_shape(input: &Shape, axis: usize, len: usize) -> Result<Shape, String> {
332 if axis >= input.rank() {
333 return Err(format!("axis {axis} >= rank {}", input.rank()));
334 }
335 Ok(input.clone().with_dim(axis, Dim::Static(len)))
336}
337
338pub fn concat_shape(inputs: &[&Shape], axis: usize) -> Result<Shape, String> {
340 if inputs.is_empty() {
341 return Err("concat: no inputs".into());
342 }
343 let base = inputs[0];
344 let mut static_sum = 0usize;
345 let mut dyn_sym: Option<u32> = None;
346 for s in inputs {
347 if s.rank() != base.rank() {
348 return Err(format!(
349 "concat: rank mismatch {} vs {}",
350 s.rank(),
351 base.rank()
352 ));
353 }
354 match s.dims[axis] {
355 Dim::Static(n) => static_sum += n,
356 Dim::Dynamic(sym) => {
357 if let Some(prev) = dyn_sym {
358 if prev != sym {
359 return Err(format!(
360 "concat: mismatched dynamic symbols {prev} vs {sym} on axis {axis}"
361 ));
362 }
363 }
364 dyn_sym = Some(sym);
365 }
366 }
367 }
368 let out_dim = match dyn_sym {
369 None => Dim::Static(static_sum),
370 Some(sym) if static_sum == 0 => Dim::Dynamic(sym),
371 Some(sym) => {
372 let _ = static_sum;
375 Dim::Dynamic(sym)
376 }
377 };
378 Ok(base.clone().with_dim(axis, out_dim))
379}
380
381pub fn gather_shape(table: &Shape, indices: &Shape, axis: usize) -> Result<Shape, String> {
383 if axis >= table.rank() {
384 return Err(format!("gather: axis {axis} >= rank {}", table.rank()));
385 }
386 let mut dims: SmallVec<[Dim; 4]> = indices.dims.clone();
387 for i in (axis + 1)..table.rank() {
388 dims.push(table.dims[i]);
389 }
390 Ok(Shape {
391 dims,
392 dtype: table.dtype,
393 })
394}
395
396pub fn reshape_shape(input: &Shape, new_shape: &[i64]) -> Result<Shape, String> {
398 let neg_count = new_shape.iter().filter(|&&d| d == -1).count();
399 if neg_count > 1 {
400 return Err("reshape: at most one -1".into());
401 }
402
403 if input.is_static() {
404 let total = input
405 .num_elements()
406 .ok_or_else(|| "reshape: input has dynamic dims".to_string())?;
407 let known_product: i64 = new_shape.iter().filter(|&&d| d != -1).product();
408 let mut dims = SmallVec::new();
409 for &d in new_shape {
410 if d == -1 {
411 let inferred = total as i64 / known_product;
412 dims.push(Dim::Static(inferred as usize));
413 } else if d < 0 {
414 return Err(format!("reshape: invalid dim {d}"));
415 } else {
416 dims.push(Dim::Static(d as usize));
417 }
418 }
419 return Ok(Shape {
420 dims,
421 dtype: input.dtype,
422 });
423 }
424
425 let dyn_syms = input.dynamic_symbols();
428 let neg_idx = new_shape.iter().position(|&d| d == -1);
429 let mut out_dims: SmallVec<[Dim; 4]> = SmallVec::new();
430 for (i, &d) in new_shape.iter().enumerate() {
431 if Some(i) == neg_idx {
432 continue;
433 }
434 if d < 0 {
435 return Err(format!("reshape: invalid dim {d}"));
436 }
437 out_dims.push(Dim::Static(d as usize));
438 }
439 if let Some(ni) = neg_idx {
440 let inferred = if dyn_syms.len() == 1 {
441 Dim::Dynamic(dyn_syms[0])
442 } else if dyn_syms.is_empty() {
443 return Err("reshape: cannot infer -1 on static input".into());
444 } else {
445 Dim::Dynamic(crate::dynamic::sym::ROWS)
446 };
447 out_dims.insert(ni, inferred);
448 }
449 Ok(Shape {
450 dims: out_dims,
451 dtype: input.dtype,
452 })
453}
454
455pub fn leading_flatten_fused_shape(input: &Shape) -> Option<Shape> {
457 if input.rank() < 2 {
458 return None;
459 }
460 let Dim::Static(h) = input.dim(input.rank() - 1) else {
461 return None;
462 };
463 let leading = &input.dims()[..input.rank() - 1];
464 let lead_dim = if leading.iter().all(|d| d.is_static()) {
465 Dim::Static(leading.iter().map(|d| d.unwrap_static()).product::<usize>())
466 } else {
467 let mut syms: Vec<u32> = leading
468 .iter()
469 .filter_map(|d| match d {
470 Dim::Dynamic(s) => Some(*s),
471 _ => None,
472 })
473 .collect();
474 syms.sort();
475 syms.dedup();
476 match syms.len() {
477 0 => Dim::Static(leading.iter().map(|d| d.unwrap_static()).product::<usize>()),
478 1 => Dim::Dynamic(syms[0]),
479 _ => Dim::Dynamic(crate::dynamic::sym::ROWS),
480 }
481 };
482 Some(Shape::from_dims(&[lead_dim, Dim::Static(h)], input.dtype()))
483}
484
485pub fn leading_flatten_shape(input: &Shape, new_shape: &[i64]) -> Option<Shape> {
487 if new_shape.len() != 2 {
488 return None;
489 }
490 let flat = leading_flatten_fused_shape(input)?;
491 let Dim::Static(h) = input.dim(input.rank() - 1) else {
492 return None;
493 };
494 if new_shape[1] as usize != h {
495 return None;
496 }
497 match flat.dim(0) {
498 Dim::Static(lead) if new_shape[0] as usize == lead => Some(flat),
499 Dim::Dynamic(_) if new_shape[0] == -1 => Some(flat),
500 _ => None,
501 }
502}
503
504pub fn attention_shape(q: &Shape) -> Shape {
506 q.clone()
507}
508
509impl std::fmt::Display for Shape {
510 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
511 write!(f, "[")?;
512 for (i, d) in self.dims.iter().enumerate() {
513 if i > 0 {
514 write!(f, ", ")?;
515 }
516 write!(f, "{d}")?;
517 }
518 write!(f, "] {}", self.dtype)
519 }
520}
521
522pub fn conv2d_spatial_output(
524 in_size: usize,
525 kernel: usize,
526 stride: usize,
527 padding: usize,
528 dilation: usize,
529) -> usize {
530 let dil_k = dilation.saturating_mul(kernel.saturating_sub(1));
531 (in_size + 2 * padding)
532 .saturating_sub(dil_k)
533 .saturating_sub(1)
534 / stride
535 + 1
536}
537
538pub fn conv_transpose2d_spatial_output(
540 in_size: usize,
541 kernel: usize,
542 stride: usize,
543 padding: usize,
544 dilation: usize,
545 output_padding: usize,
546) -> usize {
547 let dil_k = dilation.saturating_mul(kernel.saturating_sub(1));
548 (in_size - 1) * stride + output_padding + dil_k - 2 * padding + 1
549}
550
551pub fn conv2d_output_shape(
553 input: &Shape,
554 weight: &Shape,
555 kernel_size: [usize; 2],
556 stride: [usize; 2],
557 padding: [usize; 2],
558 dilation: [usize; 2],
559 groups: usize,
560) -> Result<Shape, String> {
561 if input.rank() != 4 || weight.rank() != 4 {
562 return Err("conv2d requires NCHW input and 4-D weight".into());
563 }
564 let n = input.dim(0).unwrap_static();
565 let c_in = input.dim(1).unwrap_static();
566 let h = input.dim(2).unwrap_static();
567 let w = input.dim(3).unwrap_static();
568 let c_out = weight.dim(0).unwrap_static();
569 let w_cin = weight.dim(1).unwrap_static();
570 if w_cin * groups != c_in {
571 return Err(format!(
572 "conv2d weight C_in/g={w_cin} * groups={groups} != input C={c_in}"
573 ));
574 }
575 let h_out = conv2d_spatial_output(h, kernel_size[0], stride[0], padding[0], dilation[0]);
576 let w_out = conv2d_spatial_output(w, kernel_size[1], stride[1], padding[1], dilation[1]);
577 Ok(Shape::new(&[n, c_out, h_out, w_out], input.dtype()))
578}
579
580pub fn conv_transpose2d_output_shape(
582 input: &Shape,
583 weight: &Shape,
584 kernel_size: [usize; 2],
585 stride: [usize; 2],
586 padding: [usize; 2],
587 dilation: [usize; 2],
588 output_padding: [usize; 2],
589 groups: usize,
590) -> Result<Shape, String> {
591 if input.rank() != 4 || weight.rank() != 4 {
592 return Err("conv_transpose2d requires NCHW input and 4-D weight".into());
593 }
594 let n = input.dim(0).unwrap_static();
595 let c_in = input.dim(1).unwrap_static();
596 let h = input.dim(2).unwrap_static();
597 let w = input.dim(3).unwrap_static();
598 let w_cin = weight.dim(0).unwrap_static();
599 let c_out_per_g = weight.dim(1).unwrap_static();
600 if w_cin != c_in {
601 return Err(format!(
602 "conv_transpose2d weight C_in={w_cin} != input C={c_in}"
603 ));
604 }
605 let h_out = conv_transpose2d_spatial_output(
606 h,
607 kernel_size[0],
608 stride[0],
609 padding[0],
610 dilation[0],
611 output_padding[0],
612 );
613 let w_out = conv_transpose2d_spatial_output(
614 w,
615 kernel_size[1],
616 stride[1],
617 padding[1],
618 dilation[1],
619 output_padding[1],
620 );
621 Ok(Shape::new(
622 &[n, c_out_per_g * groups, h_out, w_out],
623 input.dtype(),
624 ))
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630
631 #[test]
632 fn static_shape() {
633 let s = Shape::new(&[4, 15, 384], DType::F32);
634 assert_eq!(s.rank(), 3);
635 assert_eq!(s.num_elements(), Some(4 * 15 * 384));
636 assert_eq!(s.size_bytes(), Some(4 * 15 * 384 * 4));
637 assert!(s.is_static());
638 assert_eq!(format!("{s}"), "[4, 15, 384] f32");
639 }
640
641 #[test]
644 fn broadcast_same() {
645 let a = Shape::new(&[4, 15, 384], DType::F32);
646 let r = broadcast(&a, &a).unwrap();
647 assert_eq!(r.dims(), a.dims());
648 }
649
650 #[test]
651 fn broadcast_bias() {
652 let a = Shape::new(&[4, 15, 384], DType::F32);
653 let b = Shape::new(&[384], DType::F32);
654 let r = broadcast(&a, &b).unwrap();
655 assert_eq!(r, Shape::new(&[4, 15, 384], DType::F32));
656 }
657
658 #[test]
659 fn broadcast_scalar() {
660 let a = Shape::new(&[4, 15, 384], DType::F32);
661 let b = Shape::scalar(DType::F32);
662 let r = broadcast(&a, &b).unwrap();
663 assert_eq!(r, a);
664 }
665
666 #[test]
667 fn broadcast_mismatch() {
668 let a = Shape::new(&[4, 15, 384], DType::F32);
669 let b = Shape::new(&[4, 15, 256], DType::F32);
670 assert!(broadcast(&a, &b).is_err());
671 }
672
673 #[test]
674 fn matmul_basic() {
675 let a = Shape::new(&[4, 15, 384], DType::F32);
676 let b = Shape::new(&[384, 1536], DType::F32);
677 let r = matmul_shape(&a, &b).unwrap();
678 assert_eq!(r, Shape::new(&[4, 15, 1536], DType::F32));
679 }
680
681 #[test]
682 fn matmul_batched() {
683 let a = Shape::new(&[4, 15, 384], DType::F32);
684 let b = Shape::new(&[4, 384, 1536], DType::F32);
685 let r = matmul_shape(&a, &b).unwrap();
686 assert_eq!(r, Shape::new(&[4, 15, 1536], DType::F32));
687 }
688
689 #[test]
690 fn matmul_k_mismatch() {
691 let a = Shape::new(&[4, 15, 384], DType::F32);
692 let b = Shape::new(&[512, 1536], DType::F32);
693 assert!(matmul_shape(&a, &b).is_err());
694 }
695
696 #[test]
697 fn reduce_keepdim() {
698 let a = Shape::new(&[4, 15, 384], DType::F32);
699 let r = reduce_shape(&a, &[2], true).unwrap();
700 assert_eq!(r, Shape::new(&[4, 15, 1], DType::F32));
701 }
702
703 #[test]
704 fn reduce_no_keepdim() {
705 let a = Shape::new(&[4, 15, 384], DType::F32);
706 let r = reduce_shape(&a, &[2], false).unwrap();
707 assert_eq!(r, Shape::new(&[4, 15], DType::F32));
708 }
709
710 #[test]
711 fn concat_basic() {
712 let a = Shape::new(&[4, 15, 384], DType::F32);
713 let b = Shape::new(&[4, 15, 384], DType::F32);
714 let r = concat_shape(&[&a, &b], 2).unwrap();
715 assert_eq!(r, Shape::new(&[4, 15, 768], DType::F32));
716 }
717
718 #[test]
719 fn gather_embedding() {
720 let table = Shape::new(&[30522, 384], DType::F32);
721 let indices = Shape::new(&[4, 15], DType::I64);
722 let r = gather_shape(&table, &indices, 0).unwrap();
723 assert_eq!(
724 r,
725 Shape::from_dims(
726 &[Dim::Static(4), Dim::Static(15), Dim::Static(384)],
727 DType::F32
728 )
729 );
730 }
731
732 #[test]
733 fn reshape_with_neg1() {
734 let a = Shape::new(&[4, 15, 384], DType::F32);
735 let r = reshape_shape(&a, &[60, -1]).unwrap();
736 assert_eq!(r, Shape::new(&[60, 384], DType::F32));
737 }
738
739 #[test]
740 fn transpose_basic() {
741 let a = Shape::new(&[4, 15, 384], DType::F32);
742 let r = transpose_shape(&a, &[0, 2, 1]).unwrap();
743 assert_eq!(r, Shape::new(&[4, 384, 15], DType::F32));
744 }
745
746 #[test]
747 fn narrow_basic() {
748 let a = Shape::new(&[4, 15, 1152], DType::F32);
749 let r = narrow_shape(&a, 2, 384).unwrap();
750 assert_eq!(r, Shape::new(&[4, 15, 384], DType::F32));
751 }
752
753 #[test]
754 fn compare_bool_output() {
755 let a = Shape::new(&[4, 15], DType::F32);
756 let b = Shape::new(&[4, 15], DType::F32);
757 let r = compare_shape(&a, &b).unwrap();
758 assert_eq!(r.dtype(), DType::Bool);
759 assert_eq!(r.rank(), 2);
760 }
761
762 #[test]
765 fn dynamic_shape() {
766 let s = Shape::from_dims(
767 &[Dim::Dynamic(0), Dim::Dynamic(1), Dim::Static(384)],
768 DType::F32,
769 );
770 assert_eq!(s.rank(), 3);
771 assert_eq!(s.num_elements(), None);
772 assert!(!s.is_static());
773 assert_eq!(format!("{s}"), "[?0, ?1, 384] f32");
774 }
775
776 #[test]
777 fn dynamic_symbols_lists_distinct_dims() {
778 let s = Shape::from_dims(
779 &[
780 Dim::Dynamic(1),
781 Dim::Static(384),
782 Dim::Dynamic(0),
783 Dim::Dynamic(1),
784 ],
785 DType::F32,
786 );
787 assert_eq!(s.dynamic_symbols(), vec![0, 1]);
788 }
789
790 #[test]
791 fn bind_specializes_known_symbols() {
792 let s = Shape::from_dims(
793 &[Dim::Dynamic(0), Dim::Dynamic(1), Dim::Static(384)],
794 DType::F32,
795 );
796 let mut b = DimBinding::new();
797 b.set(0, 8);
798 b.set(1, 64);
799 let s2 = s.bind(&b);
800 assert!(s2.is_static());
801 assert_eq!(s2.num_elements(), Some(8 * 64 * 384));
802 }
803
804 #[test]
805 fn bind_leaves_unknown_symbols_alone() {
806 let s = Shape::from_dims(&[Dim::Dynamic(0), Dim::Dynamic(99)], DType::F32);
807 let mut b = DimBinding::new();
808 b.set(0, 4);
809 let s2 = s.bind(&b);
810 assert!(!s2.is_static()); assert_eq!(s2.dynamic_symbols(), vec![99]);
812 }
813}
814
815#[derive(Debug, Clone, Default)]
818pub struct DimBinding {
819 map: std::collections::HashMap<u32, usize>,
820}
821
822impl DimBinding {
823 pub fn new() -> Self {
824 Self::default()
825 }
826 pub fn set(&mut self, symbol: u32, size: usize) -> Option<usize> {
827 self.map.insert(symbol, size)
828 }
829 pub fn get(&self, symbol: u32) -> Option<usize> {
830 self.map.get(&symbol).copied()
831 }
832 pub fn is_empty(&self) -> bool {
833 self.map.is_empty()
834 }
835 pub fn len(&self) -> usize {
836 self.map.len()
837 }
838 pub fn iter(&self) -> impl Iterator<Item = (u32, usize)> + '_ {
839 self.map.iter().map(|(&s, &n)| (s, n))
840 }
841}