Skip to main content

rlx_ir/
shape.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tensor shapes with static and dynamic dimensions.
17//!
18//! Shapes are first-class in RLX IR — every node's output shape is known
19//! (or symbolically bounded) at graph construction time. This enables
20//! buffer size computation for memory planning.
21
22use crate::DType;
23use smallvec::SmallVec;
24
25/// A single dimension — either a concrete size or a symbolic dynamic dim.
26#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum Dim {
29    /// Known at graph construction time.
30    Static(usize),
31    /// Unknown until runtime. Identified by a symbol index so that
32    /// `Dim::Dynamic(0)` in two shapes means "same unknown size".
33    Dynamic(u32),
34}
35
36impl Dim {
37    pub fn unwrap_static(self) -> usize {
38        match self {
39            Self::Static(n) => n,
40            Self::Dynamic(s) => panic!("expected static dim, got dynamic symbol {s}"),
41        }
42    }
43
44    pub fn is_static(self) -> bool {
45        matches!(self, Self::Static(_))
46    }
47}
48
49impl From<usize> for Dim {
50    fn from(n: usize) -> Self {
51        Self::Static(n)
52    }
53}
54
55impl std::fmt::Display for Dim {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            Self::Static(n) => write!(f, "{n}"),
59            Self::Dynamic(s) => write!(f, "?{s}"),
60        }
61    }
62}
63
64/// Tensor shape: ordered list of dimensions + element type.
65///
66/// SmallVec<[Dim; 4]> avoids heap allocation for up to 4D tensors (the common case).
67#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69pub struct Shape {
70    dims: SmallVec<[Dim; 4]>,
71    dtype: DType,
72}
73
74impl Shape {
75    /// Create a shape from static dimensions.
76    pub fn new(dims: &[usize], dtype: DType) -> Self {
77        Self {
78            dims: dims.iter().map(|&d| Dim::Static(d)).collect(),
79            dtype,
80        }
81    }
82
83    /// Create a shape with mixed static/dynamic dimensions.
84    pub fn from_dims(dims: &[Dim], dtype: DType) -> Self {
85        Self {
86            dims: dims.into(),
87            dtype,
88        }
89    }
90
91    /// Scalar (0-dimensional).
92    pub fn scalar(dtype: DType) -> Self {
93        Self {
94            dims: SmallVec::new(),
95            dtype,
96        }
97    }
98
99    pub fn rank(&self) -> usize {
100        self.dims.len()
101    }
102    pub fn dtype(&self) -> DType {
103        self.dtype
104    }
105    pub fn dims(&self) -> &[Dim] {
106        &self.dims
107    }
108    pub fn dim(&self, i: usize) -> Dim {
109        self.dims.get(i).copied().unwrap_or_else(|| {
110            let dims: Vec<_> = self.dims.iter().map(|d| d.unwrap_static()).collect();
111            panic!(
112                "Shape::dim({i}) out of bounds for rank {} dims={dims:?}",
113                self.rank()
114            );
115        })
116    }
117
118    /// Set of dynamic dim symbols this shape references. Useful for
119    /// "what bindings does this graph need?" queries on inputs.
120    pub fn dynamic_symbols(&self) -> Vec<u32> {
121        let mut syms: Vec<u32> = self
122            .dims
123            .iter()
124            .filter_map(|d| match d {
125                Dim::Dynamic(s) => Some(*s),
126                _ => None,
127            })
128            .collect();
129        syms.sort();
130        syms.dedup();
131        syms
132    }
133
134    /// Specialize the shape against a binding (`symbol → static
135    /// size`). Unknown symbols stay [`Dim::Dynamic`]. Plan #54: the
136    /// step that takes a "compile once, run at any seq length" graph
137    /// and produces the runtime-specific concrete shape.
138    pub fn bind(&self, bindings: &DimBinding) -> Self {
139        let dims = self
140            .dims
141            .iter()
142            .map(|d| match d {
143                Dim::Dynamic(s) => match bindings.get(*s) {
144                    Some(n) => Dim::Static(n),
145                    None => *d,
146                },
147                _ => *d,
148            })
149            .collect();
150        Self {
151            dims,
152            dtype: self.dtype,
153        }
154    }
155
156    /// Total number of elements (only if all dims are static).
157    pub fn num_elements(&self) -> Option<usize> {
158        let mut total = 1usize;
159        for d in &self.dims {
160            match d {
161                Dim::Static(n) => total = total.checked_mul(*n)?,
162                Dim::Dynamic(_) => return None,
163            }
164        }
165        Some(total)
166    }
167
168    /// Total size in bytes (only if all dims are static).
169    pub fn size_bytes(&self) -> Option<usize> {
170        self.num_elements().map(|n| n * self.dtype.size_bytes())
171    }
172
173    /// True if all dimensions are statically known.
174    pub fn is_static(&self) -> bool {
175        self.dims.iter().all(|d| d.is_static())
176    }
177
178    /// Replace a dimension.
179    pub fn with_dim(mut self, axis: usize, dim: Dim) -> Self {
180        self.dims[axis] = dim;
181        self
182    }
183
184    /// Change dtype (for cast operations).
185    pub fn with_dtype(mut self, dtype: DType) -> Self {
186        self.dtype = dtype;
187        self
188    }
189
190    /// Numpy-style broadcast with another shape (fusion / lowering).
191    pub fn broadcast_with(&self, other: &Shape) -> Result<Shape, String> {
192        broadcast(self, other)
193    }
194}
195
196// ── Shape inference functions ────────────────────────────────────────────
197
198/// Numpy-style broadcast of two shapes. Returns the broadcast result.
199pub fn broadcast(a: &Shape, b: &Shape) -> Result<Shape, String> {
200    let max_rank = a.rank().max(b.rank());
201    let mut dims = SmallVec::new();
202    for i in 0..max_rank {
203        let ad = if i < max_rank - a.rank() {
204            Dim::Static(1)
205        } else {
206            a.dims[i - (max_rank - a.rank())]
207        };
208        let bd = if i < max_rank - b.rank() {
209            Dim::Static(1)
210        } else {
211            b.dims[i - (max_rank - b.rank())]
212        };
213        let d = broadcast_dim(ad, bd)?;
214        dims.push(d);
215    }
216    Ok(Shape {
217        dims,
218        dtype: a.dtype,
219    })
220}
221
222fn broadcast_dim(a: Dim, b: Dim) -> Result<Dim, String> {
223    match (a, b) {
224        (Dim::Static(1), d) | (d, Dim::Static(1)) => Ok(d),
225        (Dim::Static(x), Dim::Static(y)) if x == y => Ok(Dim::Static(x)),
226        (Dim::Static(x), Dim::Static(y)) => Err(format!("cannot broadcast {x} with {y}")),
227        (Dim::Dynamic(s), Dim::Dynamic(t)) if s == t => Ok(Dim::Dynamic(s)),
228        (Dim::Dynamic(_), _) | (_, Dim::Dynamic(_)) => Ok(a), // keep first dynamic
229    }
230}
231
232/// MatMul output shape: `[..,M,K] × [..,K,N] → [..,M,N]`.
233pub fn matmul_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape, String> {
234    if lhs.rank() < 2 || rhs.rank() < 2 {
235        return Err(format!(
236            "matmul requires rank >= 2, got {} and {}",
237            lhs.rank(),
238            rhs.rank()
239        ));
240    }
241    let m = lhs.dims[lhs.rank() - 2];
242    let k1 = lhs.dims[lhs.rank() - 1];
243    let k2 = rhs.dims[rhs.rank() - 2];
244    let n = rhs.dims[rhs.rank() - 1];
245
246    // Verify K dimensions match
247    match (k1, k2) {
248        (Dim::Static(a), Dim::Static(b)) if a != b => {
249            return Err(format!("matmul K mismatch: {a} vs {b}"));
250        }
251        (Dim::Dynamic(s), Dim::Dynamic(t)) if s != t => {
252            return Err(format!("matmul K mismatch: ?{s} vs ?{t}"));
253        }
254        _ => {}
255    }
256
257    // Broadcast batch dimensions
258    let lhs_batch = &lhs.dims[..lhs.rank() - 2];
259    let rhs_batch = &rhs.dims[..rhs.rank() - 2];
260    let batch_a = Shape::from_dims(lhs_batch, lhs.dtype);
261    let batch_b = Shape::from_dims(rhs_batch, rhs.dtype);
262    let batch = if lhs_batch.is_empty() && rhs_batch.is_empty() {
263        SmallVec::new()
264    } else if lhs_batch.is_empty() {
265        rhs_batch.into()
266    } else if rhs_batch.is_empty() {
267        lhs_batch.into()
268    } else {
269        broadcast(&batch_a, &batch_b)?.dims.clone()
270    };
271
272    let mut dims = batch;
273    dims.push(m);
274    dims.push(n);
275    Ok(Shape {
276        dims,
277        dtype: lhs.dtype,
278    })
279}
280
281/// ONNX Expand: broadcast `input` to `target` (numpy-style).
282pub fn expand_shape(input: &Shape, target: &[i64]) -> Result<Shape, String> {
283    if target.iter().any(|&d| d < 0) {
284        return Err("expand target has negative dim".into());
285    }
286    let target_s = Shape::new(
287        &target.iter().map(|&d| d as usize).collect::<Vec<_>>(),
288        input.dtype(),
289    );
290    broadcast(input, &target_s)
291}
292
293/// Binary element-wise shape (broadcast).
294pub fn binary_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape, String> {
295    broadcast(lhs, rhs)
296}
297
298/// Unary op: output = input shape.
299pub fn unary_shape(input: &Shape) -> Shape {
300    input.clone()
301}
302
303/// Cast: change dtype, keep shape.
304pub fn cast_shape(input: &Shape, to: DType) -> Shape {
305    input.clone().with_dtype(to)
306}
307
308/// Compare: broadcast + Bool dtype.
309pub fn compare_shape(lhs: &Shape, rhs: &Shape) -> Result<Shape, String> {
310    Ok(broadcast(lhs, rhs)?.with_dtype(DType::Bool))
311}
312
313/// Reduce along axes.
314pub fn reduce_shape(input: &Shape, axes: &[usize], keep_dim: bool) -> Result<Shape, String> {
315    let mut dims = SmallVec::new();
316    for (i, &d) in input.dims.iter().enumerate() {
317        if axes.contains(&i) {
318            if keep_dim {
319                dims.push(Dim::Static(1));
320            }
321        } else {
322            dims.push(d);
323        }
324    }
325    Ok(Shape {
326        dims,
327        dtype: input.dtype,
328    })
329}
330
331/// Softmax: preserves shape.
332pub fn softmax_shape(input: &Shape) -> Shape {
333    input.clone()
334}
335
336/// Transpose: permute dims.
337pub fn transpose_shape(input: &Shape, perm: &[usize]) -> Result<Shape, String> {
338    if perm.len() != input.rank() {
339        return Err(format!("perm len {} != rank {}", perm.len(), input.rank()));
340    }
341    let dims: SmallVec<[Dim; 4]> = perm.iter().map(|&i| input.dims[i]).collect();
342    Ok(Shape {
343        dims,
344        dtype: input.dtype,
345    })
346}
347
348/// Narrow: slice along one axis.
349pub fn narrow_shape(input: &Shape, axis: usize, len: usize) -> Result<Shape, String> {
350    if axis >= input.rank() {
351        return Err(format!("axis {axis} >= rank {}", input.rank()));
352    }
353    Ok(input.clone().with_dim(axis, Dim::Static(len)))
354}
355
356/// Concat along axis.
357pub fn concat_shape(inputs: &[&Shape], axis: usize) -> Result<Shape, String> {
358    if inputs.is_empty() {
359        return Err("concat: no inputs".into());
360    }
361    let base = inputs[0];
362    let mut static_sum = 0usize;
363    let mut dyn_sym: Option<u32> = None;
364    for s in inputs {
365        if s.rank() == 0 {
366            return Err("concat: input has rank 0".into());
367        }
368        if s.rank() != base.rank() {
369            return Err(format!(
370                "concat: rank mismatch {} vs {}",
371                s.rank(),
372                base.rank()
373            ));
374        }
375        let ax = axis.min(s.rank().saturating_sub(1));
376        match s.dims[ax] {
377            Dim::Static(n) => static_sum += n,
378            Dim::Dynamic(sym) => {
379                if let Some(prev) = dyn_sym {
380                    if prev != sym {
381                        return Err(format!(
382                            "concat: mismatched dynamic symbols {prev} vs {sym} on axis {axis}"
383                        ));
384                    }
385                }
386                dyn_sym = Some(sym);
387            }
388        }
389    }
390    let out_dim = match dyn_sym {
391        None => Dim::Static(static_sum),
392        Some(sym) if static_sum == 0 => Dim::Dynamic(sym),
393        Some(sym) => {
394            // Mixed static + dynamic (e.g. conv_state || qkv). After `bind_graph`,
395            // `sync_concat_shapes` recomputes from concrete input shapes.
396            let _ = static_sum;
397            Dim::Dynamic(sym)
398        }
399    };
400    let out_axis = axis.min(base.rank().saturating_sub(1));
401    Ok(base.clone().with_dim(out_axis, out_dim))
402}
403
404/// Gather (embedding lookup): table\[V,D\] + indices\[B,S\] → \[B,S,D\].
405pub fn gather_shape(table: &Shape, indices: &Shape, axis: usize) -> Result<Shape, String> {
406    if axis >= table.rank() {
407        return Err(format!("gather: axis {axis} >= rank {}", table.rank()));
408    }
409    let mut dims: SmallVec<[Dim; 4]> = indices.dims.clone();
410    for i in (axis + 1)..table.rank() {
411        dims.push(table.dims[i]);
412    }
413    Ok(Shape {
414        dims,
415        dtype: table.dtype,
416    })
417}
418
419/// Reshape with -1 wildcard support.
420pub fn reshape_shape(input: &Shape, new_shape: &[i64]) -> Result<Shape, String> {
421    let neg_count = new_shape.iter().filter(|&&d| d == -1).count();
422    if neg_count > 1 {
423        return Err("reshape: at most one -1".into());
424    }
425
426    if input.is_static() {
427        let total = input
428            .num_elements()
429            .ok_or_else(|| "reshape: input has dynamic dims".to_string())?;
430        let known_product: i64 = new_shape.iter().filter(|&&d| d != -1).product();
431        let mut dims = SmallVec::new();
432        for &d in new_shape {
433            if d == -1 {
434                let inferred = total as i64 / known_product;
435                dims.push(Dim::Static(inferred as usize));
436            } else if d < 0 {
437                return Err(format!("reshape: invalid dim {d}"));
438            } else {
439                dims.push(Dim::Static(d as usize));
440            }
441        }
442        return Ok(Shape {
443            dims,
444            dtype: input.dtype,
445        });
446    }
447
448    // Symbolic input: map `-1` to the sole dynamic symbol when unambiguous
449    // (qwen35 prefill with batch=1 and `sym::SEQ`), otherwise keep dynamic.
450    let dyn_syms = input.dynamic_symbols();
451    let neg_idx = new_shape.iter().position(|&d| d == -1);
452    let mut out_dims: SmallVec<[Dim; 4]> = SmallVec::new();
453    for (i, &d) in new_shape.iter().enumerate() {
454        if Some(i) == neg_idx {
455            continue;
456        }
457        if d < 0 {
458            return Err(format!("reshape: invalid dim {d}"));
459        }
460        out_dims.push(Dim::Static(d as usize));
461    }
462    if let Some(ni) = neg_idx {
463        let inferred = if dyn_syms.len() == 1 {
464            Dim::Dynamic(dyn_syms[0])
465        } else if dyn_syms.is_empty() {
466            return Err("reshape: cannot infer -1 on static input".into());
467        } else {
468            Dim::Dynamic(crate::dynamic::sym::ROWS)
469        };
470        out_dims.insert(ni, inferred);
471    }
472    Ok(Shape {
473        dims: out_dims,
474        dtype: input.dtype,
475    })
476}
477
478/// Flatten leading axes to `[∏leading, H]` — used by `FuseRmsNormReshape` and shape verify.
479pub fn leading_flatten_fused_shape(input: &Shape) -> Option<Shape> {
480    if input.rank() < 2 {
481        return None;
482    }
483    let Dim::Static(h) = input.dim(input.rank() - 1) else {
484        return None;
485    };
486    let leading = &input.dims()[..input.rank() - 1];
487    let lead_dim = if leading.iter().all(|d| d.is_static()) {
488        Dim::Static(leading.iter().map(|d| d.unwrap_static()).product::<usize>())
489    } else {
490        let mut syms: Vec<u32> = leading
491            .iter()
492            .filter_map(|d| match d {
493                Dim::Dynamic(s) => Some(*s),
494                _ => None,
495            })
496            .collect();
497        syms.sort();
498        syms.dedup();
499        match syms.len() {
500            0 => Dim::Static(leading.iter().map(|d| d.unwrap_static()).product::<usize>()),
501            1 => Dim::Dynamic(syms[0]),
502            _ => Dim::Dynamic(crate::dynamic::sym::ROWS),
503        }
504    };
505    Some(Shape::from_dims(&[lead_dim, Dim::Static(h)], input.dtype()))
506}
507
508/// Match `Reshape { new_shape }` after RmsNorm when fusing to a single op.
509pub fn leading_flatten_shape(input: &Shape, new_shape: &[i64]) -> Option<Shape> {
510    if new_shape.len() != 2 {
511        return None;
512    }
513    let flat = leading_flatten_fused_shape(input)?;
514    let Dim::Static(h) = input.dim(input.rank() - 1) else {
515        return None;
516    };
517    if new_shape[1] as usize != h {
518        return None;
519    }
520    match flat.dim(0) {
521        Dim::Static(lead) if new_shape[0] as usize == lead => Some(flat),
522        Dim::Dynamic(_) if new_shape[0] == -1 => Some(flat),
523        _ => None,
524    }
525}
526
527/// Attention: output shape = Q shape.
528pub fn attention_shape(q: &Shape) -> Shape {
529    q.clone()
530}
531
532impl std::fmt::Display for Shape {
533    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
534        write!(f, "[")?;
535        for (i, d) in self.dims.iter().enumerate() {
536            if i > 0 {
537                write!(f, ", ")?;
538            }
539            write!(f, "{d}")?;
540        }
541        write!(f, "] {}", self.dtype)
542    }
543}
544
545/// Spatial output size for NCHW `Op::Conv` / `conv2d`.
546pub fn conv2d_spatial_output(
547    in_size: usize,
548    kernel: usize,
549    stride: usize,
550    padding: usize,
551    dilation: usize,
552) -> usize {
553    let dil_k = dilation.saturating_mul(kernel.saturating_sub(1));
554    (in_size + 2 * padding)
555        .saturating_sub(dil_k)
556        .saturating_sub(1)
557        / stride
558        + 1
559}
560
561/// Spatial output size for NCHW `Op::ConvTranspose2d`.
562pub fn conv_transpose2d_spatial_output(
563    in_size: usize,
564    kernel: usize,
565    stride: usize,
566    padding: usize,
567    dilation: usize,
568    output_padding: usize,
569) -> usize {
570    let dil_k = dilation.saturating_mul(kernel.saturating_sub(1));
571    (in_size - 1) * stride + output_padding + dil_k - 2 * padding + 1
572}
573
574/// Output shape for `conv2d` given NCHW `input` and weight `[C_out, C_in/g, kH, kW]`.
575pub fn conv2d_output_shape(
576    input: &Shape,
577    weight: &Shape,
578    kernel_size: [usize; 2],
579    stride: [usize; 2],
580    padding: [usize; 2],
581    dilation: [usize; 2],
582    groups: usize,
583) -> Result<Shape, String> {
584    if input.rank() != 4 || weight.rank() != 4 {
585        return Err("conv2d requires NCHW input and 4-D weight".into());
586    }
587    let n = input.dim(0);
588    let c_in = input.dim(1).unwrap_static();
589    let h = input.dim(2).unwrap_static();
590    let w = input.dim(3).unwrap_static();
591    let c_out = weight.dim(0).unwrap_static();
592    let w_cin = weight.dim(1).unwrap_static();
593    if w_cin * groups != c_in {
594        return Err(format!(
595            "conv2d weight C_in/g={w_cin} * groups={groups} != input C={c_in}"
596        ));
597    }
598    let h_out = conv2d_spatial_output(h, kernel_size[0], stride[0], padding[0], dilation[0]);
599    let w_out = conv2d_spatial_output(w, kernel_size[1], stride[1], padding[1], dilation[1]);
600    Ok(Shape::from_dims(
601        &[
602            n,
603            Dim::Static(c_out),
604            Dim::Static(h_out),
605            Dim::Static(w_out),
606        ],
607        input.dtype(),
608    ))
609}
610
611/// Output shape for NCHW `Op::Im2Col`: `[M, C·kH·kW]` with
612/// `M = N · H_out · W_out`. Dynamic batch maps to [`dynamic::sym::ROWS`].
613pub fn im2col_output_shape(
614    input: &Shape,
615    kernel_size: [usize; 2],
616    stride: [usize; 2],
617    padding: [usize; 2],
618    dilation: [usize; 2],
619) -> Result<Shape, String> {
620    if input.rank() != 4 {
621        return Err("im2col requires NCHW input".into());
622    }
623    let c_in = input.dim(1).unwrap_static();
624    let h = input.dim(2).unwrap_static();
625    let w = input.dim(3).unwrap_static();
626    let kh = kernel_size[0];
627    let kw = kernel_size[1];
628    let h_out = conv2d_spatial_output(h, kh, stride[0], padding[0], dilation[0]);
629    let w_out = conv2d_spatial_output(w, kw, stride[1], padding[1], dilation[1]);
630    let k = c_in * kh * kw;
631    let spatial = h_out * w_out;
632    let m = match input.dim(0) {
633        Dim::Static(n) => Dim::Static(n * spatial),
634        Dim::Dynamic(crate::dynamic::sym::BATCH) | Dim::Dynamic(crate::dynamic::sym::ROWS) => {
635            Dim::Dynamic(crate::dynamic::sym::ROWS)
636        }
637        Dim::Dynamic(_) => Dim::Dynamic(crate::dynamic::sym::ROWS),
638    };
639    Ok(Shape::from_dims(&[m, Dim::Static(k)], input.dtype()))
640}
641
642/// Output shape for `conv_transpose2d` (weight `[C_in, C_out/g, kH, kW]`).
643pub fn conv_transpose2d_output_shape(
644    input: &Shape,
645    weight: &Shape,
646    kernel_size: [usize; 2],
647    stride: [usize; 2],
648    padding: [usize; 2],
649    dilation: [usize; 2],
650    output_padding: [usize; 2],
651    groups: usize,
652) -> Result<Shape, String> {
653    if input.rank() != 4 || weight.rank() != 4 {
654        return Err("conv_transpose2d requires NCHW input and 4-D weight".into());
655    }
656    let n = input.dim(0).unwrap_static();
657    let c_in = input.dim(1).unwrap_static();
658    let h = input.dim(2).unwrap_static();
659    let w = input.dim(3).unwrap_static();
660    let w_cin = weight.dim(0).unwrap_static();
661    let c_out_per_g = weight.dim(1).unwrap_static();
662    if w_cin != c_in {
663        return Err(format!(
664            "conv_transpose2d weight C_in={w_cin} != input C={c_in}"
665        ));
666    }
667    let h_out = conv_transpose2d_spatial_output(
668        h,
669        kernel_size[0],
670        stride[0],
671        padding[0],
672        dilation[0],
673        output_padding[0],
674    );
675    let w_out = conv_transpose2d_spatial_output(
676        w,
677        kernel_size[1],
678        stride[1],
679        padding[1],
680        dilation[1],
681        output_padding[1],
682    );
683    Ok(Shape::new(
684        &[n, c_out_per_g * groups, h_out, w_out],
685        input.dtype(),
686    ))
687}
688
689#[cfg(test)]
690mod tests {
691    use super::*;
692
693    #[test]
694    fn static_shape() {
695        let s = Shape::new(&[4, 15, 384], DType::F32);
696        assert_eq!(s.rank(), 3);
697        assert_eq!(s.num_elements(), Some(4 * 15 * 384));
698        assert_eq!(s.size_bytes(), Some(4 * 15 * 384 * 4));
699        assert!(s.is_static());
700        assert_eq!(format!("{s}"), "[4, 15, 384] f32");
701    }
702
703    // ── Shape inference tests ────────────────────────────────
704
705    #[test]
706    fn broadcast_same() {
707        let a = Shape::new(&[4, 15, 384], DType::F32);
708        let r = broadcast(&a, &a).unwrap();
709        assert_eq!(r.dims(), a.dims());
710    }
711
712    #[test]
713    fn broadcast_bias() {
714        let a = Shape::new(&[4, 15, 384], DType::F32);
715        let b = Shape::new(&[384], DType::F32);
716        let r = broadcast(&a, &b).unwrap();
717        assert_eq!(r, Shape::new(&[4, 15, 384], DType::F32));
718    }
719
720    #[test]
721    fn broadcast_scalar() {
722        let a = Shape::new(&[4, 15, 384], DType::F32);
723        let b = Shape::scalar(DType::F32);
724        let r = broadcast(&a, &b).unwrap();
725        assert_eq!(r, a);
726    }
727
728    #[test]
729    fn broadcast_mismatch() {
730        let a = Shape::new(&[4, 15, 384], DType::F32);
731        let b = Shape::new(&[4, 15, 256], DType::F32);
732        assert!(broadcast(&a, &b).is_err());
733    }
734
735    #[test]
736    fn matmul_basic() {
737        let a = Shape::new(&[4, 15, 384], DType::F32);
738        let b = Shape::new(&[384, 1536], DType::F32);
739        let r = matmul_shape(&a, &b).unwrap();
740        assert_eq!(r, Shape::new(&[4, 15, 1536], DType::F32));
741    }
742
743    #[test]
744    fn matmul_batched() {
745        let a = Shape::new(&[4, 15, 384], DType::F32);
746        let b = Shape::new(&[4, 384, 1536], DType::F32);
747        let r = matmul_shape(&a, &b).unwrap();
748        assert_eq!(r, Shape::new(&[4, 15, 1536], DType::F32));
749    }
750
751    #[test]
752    fn matmul_k_mismatch() {
753        let a = Shape::new(&[4, 15, 384], DType::F32);
754        let b = Shape::new(&[512, 1536], DType::F32);
755        assert!(matmul_shape(&a, &b).is_err());
756    }
757
758    #[test]
759    fn reduce_keepdim() {
760        let a = Shape::new(&[4, 15, 384], DType::F32);
761        let r = reduce_shape(&a, &[2], true).unwrap();
762        assert_eq!(r, Shape::new(&[4, 15, 1], DType::F32));
763    }
764
765    #[test]
766    fn reduce_no_keepdim() {
767        let a = Shape::new(&[4, 15, 384], DType::F32);
768        let r = reduce_shape(&a, &[2], false).unwrap();
769        assert_eq!(r, Shape::new(&[4, 15], DType::F32));
770    }
771
772    #[test]
773    fn concat_basic() {
774        let a = Shape::new(&[4, 15, 384], DType::F32);
775        let b = Shape::new(&[4, 15, 384], DType::F32);
776        let r = concat_shape(&[&a, &b], 2).unwrap();
777        assert_eq!(r, Shape::new(&[4, 15, 768], DType::F32));
778    }
779
780    #[test]
781    fn gather_embedding() {
782        let table = Shape::new(&[30522, 384], DType::F32);
783        let indices = Shape::new(&[4, 15], DType::I64);
784        let r = gather_shape(&table, &indices, 0).unwrap();
785        assert_eq!(
786            r,
787            Shape::from_dims(
788                &[Dim::Static(4), Dim::Static(15), Dim::Static(384)],
789                DType::F32
790            )
791        );
792    }
793
794    #[test]
795    fn reshape_with_neg1() {
796        let a = Shape::new(&[4, 15, 384], DType::F32);
797        let r = reshape_shape(&a, &[60, -1]).unwrap();
798        assert_eq!(r, Shape::new(&[60, 384], DType::F32));
799    }
800
801    #[test]
802    fn transpose_basic() {
803        let a = Shape::new(&[4, 15, 384], DType::F32);
804        let r = transpose_shape(&a, &[0, 2, 1]).unwrap();
805        assert_eq!(r, Shape::new(&[4, 384, 15], DType::F32));
806    }
807
808    #[test]
809    fn narrow_basic() {
810        let a = Shape::new(&[4, 15, 1152], DType::F32);
811        let r = narrow_shape(&a, 2, 384).unwrap();
812        assert_eq!(r, Shape::new(&[4, 15, 384], DType::F32));
813    }
814
815    #[test]
816    fn compare_bool_output() {
817        let a = Shape::new(&[4, 15], DType::F32);
818        let b = Shape::new(&[4, 15], DType::F32);
819        let r = compare_shape(&a, &b).unwrap();
820        assert_eq!(r.dtype(), DType::Bool);
821        assert_eq!(r.rank(), 2);
822    }
823
824    // ── Original tests ──────────────────────────────────────
825
826    #[test]
827    fn dynamic_shape() {
828        let s = Shape::from_dims(
829            &[Dim::Dynamic(0), Dim::Dynamic(1), Dim::Static(384)],
830            DType::F32,
831        );
832        assert_eq!(s.rank(), 3);
833        assert_eq!(s.num_elements(), None);
834        assert!(!s.is_static());
835        assert_eq!(format!("{s}"), "[?0, ?1, 384] f32");
836    }
837
838    #[test]
839    fn dynamic_symbols_lists_distinct_dims() {
840        let s = Shape::from_dims(
841            &[
842                Dim::Dynamic(1),
843                Dim::Static(384),
844                Dim::Dynamic(0),
845                Dim::Dynamic(1),
846            ],
847            DType::F32,
848        );
849        assert_eq!(s.dynamic_symbols(), vec![0, 1]);
850    }
851
852    #[test]
853    fn bind_specializes_known_symbols() {
854        let s = Shape::from_dims(
855            &[Dim::Dynamic(0), Dim::Dynamic(1), Dim::Static(384)],
856            DType::F32,
857        );
858        let mut b = DimBinding::new();
859        b.set(0, 8);
860        b.set(1, 64);
861        let s2 = s.bind(&b);
862        assert!(s2.is_static());
863        assert_eq!(s2.num_elements(), Some(8 * 64 * 384));
864    }
865
866    #[test]
867    fn bind_leaves_unknown_symbols_alone() {
868        let s = Shape::from_dims(&[Dim::Dynamic(0), Dim::Dynamic(99)], DType::F32);
869        let mut b = DimBinding::new();
870        b.set(0, 4);
871        let s2 = s.bind(&b);
872        assert!(!s2.is_static()); // ?99 still dynamic
873        assert_eq!(s2.dynamic_symbols(), vec![99]);
874    }
875}
876
877/// Mapping from a dynamic-dim symbol to its concrete size at
878/// runtime. Plan #54.
879#[derive(Debug, Clone, Default)]
880pub struct DimBinding {
881    map: std::collections::HashMap<u32, usize>,
882}
883
884impl DimBinding {
885    pub fn new() -> Self {
886        Self::default()
887    }
888    pub fn set(&mut self, symbol: u32, size: usize) -> Option<usize> {
889        self.map.insert(symbol, size)
890    }
891    pub fn get(&self, symbol: u32) -> Option<usize> {
892        self.map.get(&symbol).copied()
893    }
894    pub fn is_empty(&self) -> bool {
895        self.map.is_empty()
896    }
897    pub fn len(&self) -> usize {
898        self.map.len()
899    }
900    pub fn iter(&self) -> impl Iterator<Item = (u32, usize)> + '_ {
901        self.map.iter().map(|(&s, &n)| (s, n))
902    }
903}