Skip to main content

rlx_ir/
shape.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tensor shapes with static and dynamic dimensions.
17//!
18//! Shapes are first-class in RLX IR — every node's output shape is known
19//! (or symbolically bounded) at graph construction time. This enables
20//! buffer size computation for memory planning.
21
22use crate::DType;
23use smallvec::SmallVec;
24
25/// A single dimension — either a concrete size or a symbolic dynamic dim.
26#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum Dim {
29    /// Known at graph construction time.
30    Static(usize),
31    /// Unknown until runtime. Identified by a symbol index so that
32    /// `Dim::Dynamic(0)` in two shapes means "same unknown size".
33    Dynamic(u32),
34}
35
36impl Dim {
37    pub fn unwrap_static(self) -> usize {
38        match self {
39            Self::Static(n) => n,
40            Self::Dynamic(s) => panic!("expected static dim, got dynamic symbol {s}"),
41        }
42    }
43
44    pub fn is_static(self) -> bool {
45        matches!(self, Self::Static(_))
46    }
47}
48
49impl From<usize> for Dim {
50    fn from(n: usize) -> Self {
51        Self::Static(n)
52    }
53}
54
55impl std::fmt::Display for Dim {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            Self::Static(n) => write!(f, "{n}"),
59            Self::Dynamic(s) => write!(f, "?{s}"),
60        }
61    }
62}
63
64/// Tensor shape: ordered list of dimensions + element type.
65///
66/// SmallVec<[Dim; 4]> avoids heap allocation for up to 4D tensors (the common case).
67#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69pub struct Shape {
70    dims: SmallVec<[Dim; 4]>,
71    dtype: DType,
72}
73
74impl Shape {
75    /// Create a shape from static dimensions.
76    pub fn new(dims: &[usize], dtype: DType) -> Self {
77        Self {
78            dims: dims.iter().map(|&d| Dim::Static(d)).collect(),
79            dtype,
80        }
81    }
82
83    /// Create a shape with mixed static/dynamic dimensions.
84    pub fn from_dims(dims: &[Dim], dtype: DType) -> Self {
85        Self {
86            dims: dims.into(),
87            dtype,
88        }
89    }
90
91    /// Scalar (0-dimensional).
92    pub fn scalar(dtype: DType) -> Self {
93        Self {
94            dims: SmallVec::new(),
95            dtype,
96        }
97    }
98
99    pub fn rank(&self) -> usize {
100        self.dims.len()
101    }
102    pub fn dtype(&self) -> DType {
103        self.dtype
104    }
105    pub fn dims(&self) -> &[Dim] {
106        &self.dims
107    }
108    pub fn dim(&self, i: usize) -> Dim {
109        self.dims[i]
110    }
111
112    /// Set of dynamic dim symbols this shape references. Useful for
113    /// "what bindings does this graph need?" queries on inputs.
114    pub fn dynamic_symbols(&self) -> Vec<u32> {
115        let mut syms: Vec<u32> = self
116            .dims
117            .iter()
118            .filter_map(|d| match d {
119                Dim::Dynamic(s) => Some(*s),
120                _ => None,
121            })
122            .collect();
123        syms.sort();
124        syms.dedup();
125        syms
126    }
127
128    /// Specialize the shape against a binding (`symbol → static
129    /// size`). Unknown symbols stay [`Dim::Dynamic`]. Plan #54: the
130    /// step that takes a "compile once, run at any seq length" graph
131    /// and produces the runtime-specific concrete shape.
132    pub fn bind(&self, bindings: &DimBinding) -> Self {
133        let dims = self
134            .dims
135            .iter()
136            .map(|d| match d {
137                Dim::Dynamic(s) => match bindings.get(*s) {
138                    Some(n) => Dim::Static(n),
139                    None => *d,
140                },
141                _ => *d,
142            })
143            .collect();
144        Self {
145            dims,
146            dtype: self.dtype,
147        }
148    }
149
150    /// Total number of elements (only if all dims are static).
151    pub fn num_elements(&self) -> Option<usize> {
152        let mut total = 1usize;
153        for d in &self.dims {
154            match d {
155                Dim::Static(n) => total = total.checked_mul(*n)?,
156                Dim::Dynamic(_) => return None,
157            }
158        }
159        Some(total)
160    }
161
162    /// Total size in bytes (only if all dims are static).
163    pub fn size_bytes(&self) -> Option<usize> {
164        self.num_elements().map(|n| n * self.dtype.size_bytes())
165    }
166
167    /// True if all dimensions are statically known.
168    pub fn is_static(&self) -> bool {
169        self.dims.iter().all(|d| d.is_static())
170    }
171
172    /// Replace a dimension.
173    pub fn with_dim(mut self, axis: usize, dim: Dim) -> Self {
174        self.dims[axis] = dim;
175        self
176    }
177
178    /// Change dtype (for cast operations).
179    pub fn with_dtype(mut self, dtype: DType) -> Self {
180        self.dtype = dtype;
181        self
182    }
183
184    /// Numpy-style broadcast with another shape (fusion / lowering).
185    pub fn broadcast_with(&self, other: &Shape) -> Result<Shape, String> {
186        broadcast(self, other)
187    }
188}
189
190// ── Shape inference functions ────────────────────────────────────────────
191
192/// Numpy-style broadcast of two shapes. Returns the broadcast result.
193pub fn broadcast(a: &Shape, b: &Shape) -> Result<Shape, String> {
194    let max_rank = a.rank().max(b.rank());
195    let mut dims = SmallVec::new();
196    for i in 0..max_rank {
197        let ad = if i < max_rank - a.rank() {
198            Dim::Static(1)
199        } else {
200            a.dims[i - (max_rank - a.rank())]
201        };
202        let bd = if i < max_rank - b.rank() {
203            Dim::Static(1)
204        } else {
205            b.dims[i - (max_rank - b.rank())]
206        };
207        let d = broadcast_dim(ad, bd)?;
208        dims.push(d);
209    }
210    Ok(Shape {
211        dims,
212        dtype: a.dtype,
213    })
214}
215
216fn broadcast_dim(a: Dim, b: Dim) -> Result<Dim, String> {
217    match (a, b) {
218        (Dim::Static(1), d) | (d, Dim::Static(1)) => Ok(d),
219        (Dim::Static(x), Dim::Static(y)) if x == y => Ok(Dim::Static(x)),
220        (Dim::Static(x), Dim::Static(y)) => Err(format!("cannot broadcast {x} with {y}")),
221        (Dim::Dynamic(s), Dim::Dynamic(t)) if s == t => Ok(Dim::Dynamic(s)),
222        (Dim::Dynamic(_), _) | (_, Dim::Dynamic(_)) => Ok(a), // keep first dynamic
223    }
224}
225
226/// MatMul output shape: `[..,M,K] × [..,K,N] → [..,M,N]`.
227pub fn matmul_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape, String> {
228    if lhs.rank() < 2 || rhs.rank() < 2 {
229        return Err(format!(
230            "matmul requires rank >= 2, got {} and {}",
231            lhs.rank(),
232            rhs.rank()
233        ));
234    }
235    let m = lhs.dims[lhs.rank() - 2];
236    let k1 = lhs.dims[lhs.rank() - 1];
237    let k2 = rhs.dims[rhs.rank() - 2];
238    let n = rhs.dims[rhs.rank() - 1];
239
240    // Verify K dimensions match
241    match (k1, k2) {
242        (Dim::Static(a), Dim::Static(b)) if a != b => {
243            return Err(format!("matmul K mismatch: {a} vs {b}"));
244        }
245        (Dim::Dynamic(s), Dim::Dynamic(t)) if s != t => {
246            return Err(format!("matmul K mismatch: ?{s} vs ?{t}"));
247        }
248        _ => {}
249    }
250
251    // Broadcast batch dimensions
252    let lhs_batch = &lhs.dims[..lhs.rank() - 2];
253    let rhs_batch = &rhs.dims[..rhs.rank() - 2];
254    let batch_a = Shape::from_dims(lhs_batch, lhs.dtype);
255    let batch_b = Shape::from_dims(rhs_batch, rhs.dtype);
256    let batch = if lhs_batch.is_empty() && rhs_batch.is_empty() {
257        SmallVec::new()
258    } else if lhs_batch.is_empty() {
259        rhs_batch.into()
260    } else if rhs_batch.is_empty() {
261        lhs_batch.into()
262    } else {
263        broadcast(&batch_a, &batch_b)?.dims.clone()
264    };
265
266    let mut dims = batch;
267    dims.push(m);
268    dims.push(n);
269    Ok(Shape {
270        dims,
271        dtype: lhs.dtype,
272    })
273}
274
275/// Binary element-wise shape (broadcast).
276pub fn binary_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape, String> {
277    broadcast(lhs, rhs)
278}
279
280/// Unary op: output = input shape.
281pub fn unary_shape(input: &Shape) -> Shape {
282    input.clone()
283}
284
285/// Cast: change dtype, keep shape.
286pub fn cast_shape(input: &Shape, to: DType) -> Shape {
287    input.clone().with_dtype(to)
288}
289
290/// Compare: broadcast + Bool dtype.
291pub fn compare_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape, String> {
292    Ok(broadcast(lhs, rhs)?.with_dtype(DType::Bool))
293}
294
295/// Reduce along axes.
296pub fn reduce_shape(input: &Shape, axes: &[usize], keep_dim: bool) -> Result<Shape, String> {
297    let mut dims = SmallVec::new();
298    for (i, &d) in input.dims.iter().enumerate() {
299        if axes.contains(&i) {
300            if keep_dim {
301                dims.push(Dim::Static(1));
302            }
303        } else {
304            dims.push(d);
305        }
306    }
307    Ok(Shape {
308        dims,
309        dtype: input.dtype,
310    })
311}
312
313/// Softmax: preserves shape.
314pub fn softmax_shape(input: &Shape) -> Shape {
315    input.clone()
316}
317
318/// Transpose: permute dims.
319pub fn transpose_shape(input: &Shape, perm: &[usize]) -> Result<Shape, String> {
320    if perm.len() != input.rank() {
321        return Err(format!("perm len {} != rank {}", perm.len(), input.rank()));
322    }
323    let dims: SmallVec<[Dim; 4]> = perm.iter().map(|&i| input.dims[i]).collect();
324    Ok(Shape {
325        dims,
326        dtype: input.dtype,
327    })
328}
329
330/// Narrow: slice along one axis.
331pub fn narrow_shape(input: &Shape, axis: usize, len: usize) -> Result<Shape, String> {
332    if axis >= input.rank() {
333        return Err(format!("axis {axis} >= rank {}", input.rank()));
334    }
335    Ok(input.clone().with_dim(axis, Dim::Static(len)))
336}
337
338/// Concat along axis.
339pub fn concat_shape(inputs: &[&Shape], axis: usize) -> Result<Shape, String> {
340    if inputs.is_empty() {
341        return Err("concat: no inputs".into());
342    }
343    let base = inputs[0];
344    let mut static_sum = 0usize;
345    let mut dyn_sym: Option<u32> = None;
346    for s in inputs {
347        if s.rank() != base.rank() {
348            return Err(format!(
349                "concat: rank mismatch {} vs {}",
350                s.rank(),
351                base.rank()
352            ));
353        }
354        match s.dims[axis] {
355            Dim::Static(n) => static_sum += n,
356            Dim::Dynamic(sym) => {
357                if let Some(prev) = dyn_sym {
358                    if prev != sym {
359                        return Err(format!(
360                            "concat: mismatched dynamic symbols {prev} vs {sym} on axis {axis}"
361                        ));
362                    }
363                }
364                dyn_sym = Some(sym);
365            }
366        }
367    }
368    let out_dim = match dyn_sym {
369        None => Dim::Static(static_sum),
370        Some(sym) if static_sum == 0 => Dim::Dynamic(sym),
371        Some(sym) => {
372            // Mixed static + dynamic (e.g. conv_state || qkv). After `bind_graph`,
373            // `sync_concat_shapes` recomputes from concrete input shapes.
374            let _ = static_sum;
375            Dim::Dynamic(sym)
376        }
377    };
378    Ok(base.clone().with_dim(axis, out_dim))
379}
380
381/// Gather (embedding lookup): table\[V,D\] + indices\[B,S\] → \[B,S,D\].
382pub fn gather_shape(table: &Shape, indices: &Shape, axis: usize) -> Result<Shape, String> {
383    if axis >= table.rank() {
384        return Err(format!("gather: axis {axis} >= rank {}", table.rank()));
385    }
386    let mut dims: SmallVec<[Dim; 4]> = indices.dims.clone();
387    for i in (axis + 1)..table.rank() {
388        dims.push(table.dims[i]);
389    }
390    Ok(Shape {
391        dims,
392        dtype: table.dtype,
393    })
394}
395
396/// Reshape with -1 wildcard support.
397pub fn reshape_shape(input: &Shape, new_shape: &[i64]) -> Result<Shape, String> {
398    let neg_count = new_shape.iter().filter(|&&d| d == -1).count();
399    if neg_count > 1 {
400        return Err("reshape: at most one -1".into());
401    }
402
403    if input.is_static() {
404        let total = input
405            .num_elements()
406            .ok_or_else(|| "reshape: input has dynamic dims".to_string())?;
407        let known_product: i64 = new_shape.iter().filter(|&&d| d != -1).product();
408        let mut dims = SmallVec::new();
409        for &d in new_shape {
410            if d == -1 {
411                let inferred = total as i64 / known_product;
412                dims.push(Dim::Static(inferred as usize));
413            } else if d < 0 {
414                return Err(format!("reshape: invalid dim {d}"));
415            } else {
416                dims.push(Dim::Static(d as usize));
417            }
418        }
419        return Ok(Shape {
420            dims,
421            dtype: input.dtype,
422        });
423    }
424
425    // Symbolic input: map `-1` to the sole dynamic symbol when unambiguous
426    // (qwen35 prefill with batch=1 and `sym::SEQ`), otherwise keep dynamic.
427    let dyn_syms = input.dynamic_symbols();
428    let neg_idx = new_shape.iter().position(|&d| d == -1);
429    let mut out_dims: SmallVec<[Dim; 4]> = SmallVec::new();
430    for (i, &d) in new_shape.iter().enumerate() {
431        if Some(i) == neg_idx {
432            continue;
433        }
434        if d < 0 {
435            return Err(format!("reshape: invalid dim {d}"));
436        }
437        out_dims.push(Dim::Static(d as usize));
438    }
439    if let Some(ni) = neg_idx {
440        let inferred = if dyn_syms.len() == 1 {
441            Dim::Dynamic(dyn_syms[0])
442        } else if dyn_syms.is_empty() {
443            return Err("reshape: cannot infer -1 on static input".into());
444        } else {
445            Dim::Dynamic(crate::dynamic::sym::ROWS)
446        };
447        out_dims.insert(ni, inferred);
448    }
449    Ok(Shape {
450        dims: out_dims,
451        dtype: input.dtype,
452    })
453}
454
455/// Flatten leading axes to `[∏leading, H]` — used by `FuseRmsNormReshape` and shape verify.
456pub fn leading_flatten_fused_shape(input: &Shape) -> Option<Shape> {
457    if input.rank() < 2 {
458        return None;
459    }
460    let Dim::Static(h) = input.dim(input.rank() - 1) else {
461        return None;
462    };
463    let leading = &input.dims()[..input.rank() - 1];
464    let lead_dim = if leading.iter().all(|d| d.is_static()) {
465        Dim::Static(leading.iter().map(|d| d.unwrap_static()).product::<usize>())
466    } else {
467        let mut syms: Vec<u32> = leading
468            .iter()
469            .filter_map(|d| match d {
470                Dim::Dynamic(s) => Some(*s),
471                _ => None,
472            })
473            .collect();
474        syms.sort();
475        syms.dedup();
476        match syms.len() {
477            0 => Dim::Static(leading.iter().map(|d| d.unwrap_static()).product::<usize>()),
478            1 => Dim::Dynamic(syms[0]),
479            _ => Dim::Dynamic(crate::dynamic::sym::ROWS),
480        }
481    };
482    Some(Shape::from_dims(&[lead_dim, Dim::Static(h)], input.dtype()))
483}
484
485/// Match `Reshape { new_shape }` after RmsNorm when fusing to a single op.
486pub fn leading_flatten_shape(input: &Shape, new_shape: &[i64]) -> Option<Shape> {
487    if new_shape.len() != 2 {
488        return None;
489    }
490    let flat = leading_flatten_fused_shape(input)?;
491    let Dim::Static(h) = input.dim(input.rank() - 1) else {
492        return None;
493    };
494    if new_shape[1] as usize != h {
495        return None;
496    }
497    match flat.dim(0) {
498        Dim::Static(lead) if new_shape[0] as usize == lead => Some(flat),
499        Dim::Dynamic(_) if new_shape[0] == -1 => Some(flat),
500        _ => None,
501    }
502}
503
504/// Attention: output shape = Q shape.
505pub fn attention_shape(q: &Shape) -> Shape {
506    q.clone()
507}
508
509impl std::fmt::Display for Shape {
510    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
511        write!(f, "[")?;
512        for (i, d) in self.dims.iter().enumerate() {
513            if i > 0 {
514                write!(f, ", ")?;
515            }
516            write!(f, "{d}")?;
517        }
518        write!(f, "] {}", self.dtype)
519    }
520}
521
522/// Spatial output size for NCHW `Op::Conv` / `conv2d`.
523pub fn conv2d_spatial_output(
524    in_size: usize,
525    kernel: usize,
526    stride: usize,
527    padding: usize,
528    dilation: usize,
529) -> usize {
530    let dil_k = dilation.saturating_mul(kernel.saturating_sub(1));
531    (in_size + 2 * padding)
532        .saturating_sub(dil_k)
533        .saturating_sub(1)
534        / stride
535        + 1
536}
537
538/// Spatial output size for NCHW `Op::ConvTranspose2d`.
539pub fn conv_transpose2d_spatial_output(
540    in_size: usize,
541    kernel: usize,
542    stride: usize,
543    padding: usize,
544    dilation: usize,
545    output_padding: usize,
546) -> usize {
547    let dil_k = dilation.saturating_mul(kernel.saturating_sub(1));
548    (in_size - 1) * stride + output_padding + dil_k - 2 * padding + 1
549}
550
551/// Output shape for `conv2d` given NCHW `input` and weight `[C_out, C_in/g, kH, kW]`.
552pub fn conv2d_output_shape(
553    input: &Shape,
554    weight: &Shape,
555    kernel_size: [usize; 2],
556    stride: [usize; 2],
557    padding: [usize; 2],
558    dilation: [usize; 2],
559    groups: usize,
560) -> Result<Shape, String> {
561    if input.rank() != 4 || weight.rank() != 4 {
562        return Err("conv2d requires NCHW input and 4-D weight".into());
563    }
564    let n = input.dim(0).unwrap_static();
565    let c_in = input.dim(1).unwrap_static();
566    let h = input.dim(2).unwrap_static();
567    let w = input.dim(3).unwrap_static();
568    let c_out = weight.dim(0).unwrap_static();
569    let w_cin = weight.dim(1).unwrap_static();
570    if w_cin * groups != c_in {
571        return Err(format!(
572            "conv2d weight C_in/g={w_cin} * groups={groups} != input C={c_in}"
573        ));
574    }
575    let h_out = conv2d_spatial_output(h, kernel_size[0], stride[0], padding[0], dilation[0]);
576    let w_out = conv2d_spatial_output(w, kernel_size[1], stride[1], padding[1], dilation[1]);
577    Ok(Shape::new(&[n, c_out, h_out, w_out], input.dtype()))
578}
579
580/// Output shape for `conv_transpose2d` (weight `[C_in, C_out/g, kH, kW]`).
581pub fn conv_transpose2d_output_shape(
582    input: &Shape,
583    weight: &Shape,
584    kernel_size: [usize; 2],
585    stride: [usize; 2],
586    padding: [usize; 2],
587    dilation: [usize; 2],
588    output_padding: [usize; 2],
589    groups: usize,
590) -> Result<Shape, String> {
591    if input.rank() != 4 || weight.rank() != 4 {
592        return Err("conv_transpose2d requires NCHW input and 4-D weight".into());
593    }
594    let n = input.dim(0).unwrap_static();
595    let c_in = input.dim(1).unwrap_static();
596    let h = input.dim(2).unwrap_static();
597    let w = input.dim(3).unwrap_static();
598    let w_cin = weight.dim(0).unwrap_static();
599    let c_out_per_g = weight.dim(1).unwrap_static();
600    if w_cin != c_in {
601        return Err(format!(
602            "conv_transpose2d weight C_in={w_cin} != input C={c_in}"
603        ));
604    }
605    let h_out = conv_transpose2d_spatial_output(
606        h,
607        kernel_size[0],
608        stride[0],
609        padding[0],
610        dilation[0],
611        output_padding[0],
612    );
613    let w_out = conv_transpose2d_spatial_output(
614        w,
615        kernel_size[1],
616        stride[1],
617        padding[1],
618        dilation[1],
619        output_padding[1],
620    );
621    Ok(Shape::new(
622        &[n, c_out_per_g * groups, h_out, w_out],
623        input.dtype(),
624    ))
625}
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630
631    #[test]
632    fn static_shape() {
633        let s = Shape::new(&[4, 15, 384], DType::F32);
634        assert_eq!(s.rank(), 3);
635        assert_eq!(s.num_elements(), Some(4 * 15 * 384));
636        assert_eq!(s.size_bytes(), Some(4 * 15 * 384 * 4));
637        assert!(s.is_static());
638        assert_eq!(format!("{s}"), "[4, 15, 384] f32");
639    }
640
641    // ── Shape inference tests ────────────────────────────────
642
643    #[test]
644    fn broadcast_same() {
645        let a = Shape::new(&[4, 15, 384], DType::F32);
646        let r = broadcast(&a, &a).unwrap();
647        assert_eq!(r.dims(), a.dims());
648    }
649
650    #[test]
651    fn broadcast_bias() {
652        let a = Shape::new(&[4, 15, 384], DType::F32);
653        let b = Shape::new(&[384], DType::F32);
654        let r = broadcast(&a, &b).unwrap();
655        assert_eq!(r, Shape::new(&[4, 15, 384], DType::F32));
656    }
657
658    #[test]
659    fn broadcast_scalar() {
660        let a = Shape::new(&[4, 15, 384], DType::F32);
661        let b = Shape::scalar(DType::F32);
662        let r = broadcast(&a, &b).unwrap();
663        assert_eq!(r, a);
664    }
665
666    #[test]
667    fn broadcast_mismatch() {
668        let a = Shape::new(&[4, 15, 384], DType::F32);
669        let b = Shape::new(&[4, 15, 256], DType::F32);
670        assert!(broadcast(&a, &b).is_err());
671    }
672
673    #[test]
674    fn matmul_basic() {
675        let a = Shape::new(&[4, 15, 384], DType::F32);
676        let b = Shape::new(&[384, 1536], DType::F32);
677        let r = matmul_shape(&a, &b).unwrap();
678        assert_eq!(r, Shape::new(&[4, 15, 1536], DType::F32));
679    }
680
681    #[test]
682    fn matmul_batched() {
683        let a = Shape::new(&[4, 15, 384], DType::F32);
684        let b = Shape::new(&[4, 384, 1536], DType::F32);
685        let r = matmul_shape(&a, &b).unwrap();
686        assert_eq!(r, Shape::new(&[4, 15, 1536], DType::F32));
687    }
688
689    #[test]
690    fn matmul_k_mismatch() {
691        let a = Shape::new(&[4, 15, 384], DType::F32);
692        let b = Shape::new(&[512, 1536], DType::F32);
693        assert!(matmul_shape(&a, &b).is_err());
694    }
695
696    #[test]
697    fn reduce_keepdim() {
698        let a = Shape::new(&[4, 15, 384], DType::F32);
699        let r = reduce_shape(&a, &[2], true).unwrap();
700        assert_eq!(r, Shape::new(&[4, 15, 1], DType::F32));
701    }
702
703    #[test]
704    fn reduce_no_keepdim() {
705        let a = Shape::new(&[4, 15, 384], DType::F32);
706        let r = reduce_shape(&a, &[2], false).unwrap();
707        assert_eq!(r, Shape::new(&[4, 15], DType::F32));
708    }
709
710    #[test]
711    fn concat_basic() {
712        let a = Shape::new(&[4, 15, 384], DType::F32);
713        let b = Shape::new(&[4, 15, 384], DType::F32);
714        let r = concat_shape(&[&a, &b], 2).unwrap();
715        assert_eq!(r, Shape::new(&[4, 15, 768], DType::F32));
716    }
717
718    #[test]
719    fn gather_embedding() {
720        let table = Shape::new(&[30522, 384], DType::F32);
721        let indices = Shape::new(&[4, 15], DType::I64);
722        let r = gather_shape(&table, &indices, 0).unwrap();
723        assert_eq!(
724            r,
725            Shape::from_dims(
726                &[Dim::Static(4), Dim::Static(15), Dim::Static(384)],
727                DType::F32
728            )
729        );
730    }
731
732    #[test]
733    fn reshape_with_neg1() {
734        let a = Shape::new(&[4, 15, 384], DType::F32);
735        let r = reshape_shape(&a, &[60, -1]).unwrap();
736        assert_eq!(r, Shape::new(&[60, 384], DType::F32));
737    }
738
739    #[test]
740    fn transpose_basic() {
741        let a = Shape::new(&[4, 15, 384], DType::F32);
742        let r = transpose_shape(&a, &[0, 2, 1]).unwrap();
743        assert_eq!(r, Shape::new(&[4, 384, 15], DType::F32));
744    }
745
746    #[test]
747    fn narrow_basic() {
748        let a = Shape::new(&[4, 15, 1152], DType::F32);
749        let r = narrow_shape(&a, 2, 384).unwrap();
750        assert_eq!(r, Shape::new(&[4, 15, 384], DType::F32));
751    }
752
753    #[test]
754    fn compare_bool_output() {
755        let a = Shape::new(&[4, 15], DType::F32);
756        let b = Shape::new(&[4, 15], DType::F32);
757        let r = compare_shape(&a, &b).unwrap();
758        assert_eq!(r.dtype(), DType::Bool);
759        assert_eq!(r.rank(), 2);
760    }
761
762    // ── Original tests ──────────────────────────────────────
763
764    #[test]
765    fn dynamic_shape() {
766        let s = Shape::from_dims(
767            &[Dim::Dynamic(0), Dim::Dynamic(1), Dim::Static(384)],
768            DType::F32,
769        );
770        assert_eq!(s.rank(), 3);
771        assert_eq!(s.num_elements(), None);
772        assert!(!s.is_static());
773        assert_eq!(format!("{s}"), "[?0, ?1, 384] f32");
774    }
775
776    #[test]
777    fn dynamic_symbols_lists_distinct_dims() {
778        let s = Shape::from_dims(
779            &[
780                Dim::Dynamic(1),
781                Dim::Static(384),
782                Dim::Dynamic(0),
783                Dim::Dynamic(1),
784            ],
785            DType::F32,
786        );
787        assert_eq!(s.dynamic_symbols(), vec![0, 1]);
788    }
789
790    #[test]
791    fn bind_specializes_known_symbols() {
792        let s = Shape::from_dims(
793            &[Dim::Dynamic(0), Dim::Dynamic(1), Dim::Static(384)],
794            DType::F32,
795        );
796        let mut b = DimBinding::new();
797        b.set(0, 8);
798        b.set(1, 64);
799        let s2 = s.bind(&b);
800        assert!(s2.is_static());
801        assert_eq!(s2.num_elements(), Some(8 * 64 * 384));
802    }
803
804    #[test]
805    fn bind_leaves_unknown_symbols_alone() {
806        let s = Shape::from_dims(&[Dim::Dynamic(0), Dim::Dynamic(99)], DType::F32);
807        let mut b = DimBinding::new();
808        b.set(0, 4);
809        let s2 = s.bind(&b);
810        assert!(!s2.is_static()); // ?99 still dynamic
811        assert_eq!(s2.dynamic_symbols(), vec![99]);
812    }
813}
814
815/// Mapping from a dynamic-dim symbol to its concrete size at
816/// runtime. Plan #54.
817#[derive(Debug, Clone, Default)]
818pub struct DimBinding {
819    map: std::collections::HashMap<u32, usize>,
820}
821
822impl DimBinding {
823    pub fn new() -> Self {
824        Self::default()
825    }
826    pub fn set(&mut self, symbol: u32, size: usize) -> Option<usize> {
827        self.map.insert(symbol, size)
828    }
829    pub fn get(&self, symbol: u32) -> Option<usize> {
830        self.map.get(&symbol).copied()
831    }
832    pub fn is_empty(&self) -> bool {
833        self.map.is_empty()
834    }
835    pub fn len(&self) -> usize {
836        self.map.len()
837    }
838    pub fn iter(&self) -> impl Iterator<Item = (u32, usize)> + '_ {
839        self.map.iter().map(|(&s, &n)| (s, n))
840    }
841}