Skip to main content

rlx_ir/
dynamic.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Dynamic / symbolic dimensions — compile once, specialize at runtime.
17//!
18//! Plan #54: graphs built with [`Dim::Dynamic`] symbols specialize via
19//! [`DimBinding`] before buffer planning and backend lowering.
20
21use std::collections::{BTreeSet, HashMap};
22
23use crate::shape::{Dim, DimBinding, Shape};
24use crate::{DType, Graph, Op};
25
26/// Well-known dynamic dimension symbols. Reuse the same id across shapes
27/// so `[?0, ?1, H]` and `[?0, ?1, 4]` share batch/seq bindings.
28pub mod sym {
29    pub const BATCH: u32 = 0;
30    pub const SEQ: u32 = 1;
31    /// Cached prefix length for decode KV (past_k / past_v axis 1).
32    pub const PAST_SEQ: u32 = 3;
33    /// Product of leading axes (e.g. `batch * seq` flatten).
34    pub const ROWS: u32 = 2;
35}
36
37/// Allocate named dynamic symbols for model builders.
38#[derive(Debug, Clone, Default)]
39pub struct DimEnv {
40    next: u32,
41    names: HashMap<String, u32>,
42}
43
44impl DimEnv {
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Return the symbol id for `name`, allocating on first use.
50    pub fn sym(&mut self, name: &str) -> u32 {
51        if let Some(&id) = self.names.get(name) {
52            return id;
53        }
54        let id = self.next;
55        self.next += 1;
56        self.names.insert(name.into(), id);
57        id
58    }
59
60    pub fn name(&self, symbol: u32) -> Option<&str> {
61        self.names
62            .iter()
63            .find_map(|(n, &s)| (s == symbol).then_some(n.as_str()))
64    }
65}
66
67impl Shape {
68    /// `[batch, seq, hidden]` with symbolic leading axes.
69    pub fn batch_seq(batch: u32, seq: u32, hidden: usize, dtype: DType) -> Self {
70        Self::from_dims(
71            &[Dim::Dynamic(batch), Dim::Dynamic(seq), Dim::Static(hidden)],
72            dtype,
73        )
74    }
75
76    /// `[batch, seq]` matrix.
77    pub fn batch_seq_2d(batch: u32, seq: u32, dtype: DType) -> Self {
78        Self::from_dims(&[Dim::Dynamic(batch), Dim::Dynamic(seq)], dtype)
79    }
80
81    /// `[batch, seq, heads, head_dim]` attention layout.
82    pub fn batch_seq_heads(
83        batch: u32,
84        seq: u32,
85        heads: usize,
86        head_dim: usize,
87        dtype: DType,
88    ) -> Self {
89        Self::from_dims(
90            &[
91                Dim::Dynamic(batch),
92                Dim::Dynamic(seq),
93                Dim::Static(heads),
94                Dim::Static(head_dim),
95            ],
96            dtype,
97        )
98    }
99}
100
101impl DimBinding {
102    pub fn from_pairs(pairs: &[(u32, usize)]) -> Self {
103        let mut b = Self::new();
104        for &(sym, size) in pairs {
105            b.set(sym, size);
106        }
107        b
108    }
109
110    pub fn batch_seq(batch: usize, seq: usize) -> Self {
111        let mut b = Self::from_pairs(&[(sym::BATCH, batch), (sym::SEQ, seq)]);
112        if batch > 1 {
113            b.set(sym::ROWS, batch * seq);
114        }
115        b
116    }
117
118    pub fn batch_past_seq(batch: usize, past_seq: usize) -> Self {
119        Self::from_pairs(&[(sym::BATCH, batch), (sym::PAST_SEQ, past_seq)])
120    }
121}
122
123/// True if any node shape references a dynamic dimension.
124pub fn has_dynamic_dims(graph: &Graph) -> bool {
125    graph
126        .nodes()
127        .iter()
128        .any(|n| n.shape.dims().iter().any(|d| matches!(d, Dim::Dynamic(_))))
129}
130
131/// Collect all dynamic symbols referenced anywhere in the graph.
132pub fn collect_dynamic_symbols(graph: &Graph) -> Vec<u32> {
133    let mut syms = BTreeSet::new();
134    for node in graph.nodes() {
135        for s in node.shape.dynamic_symbols() {
136            syms.insert(s);
137        }
138    }
139    syms.into_iter().collect()
140}
141
142/// Specialize every node's shape against `bindings`.
143///
144/// Node ids are preserved (nodes are cloned in insertion order), so
145/// edges and outputs remain valid without remapping.
146pub fn bind_graph(graph: &Graph, bindings: &DimBinding) -> Graph {
147    let mut out = Graph::new(&graph.name);
148    for node in graph.nodes() {
149        let bound = node.shape.bind(bindings);
150        out.push_ext(
151            node.op.clone(),
152            node.inputs.clone(),
153            bound,
154            node.name.clone(),
155            node.origin.clone(),
156        );
157    }
158    out.set_outputs(graph.outputs.clone());
159    out
160}
161
162/// After [`bind_graph`], sync `Op::Reshape { new_shape }` with bound node shapes.
163pub fn sync_reshape_ops(graph: &mut Graph) {
164    use crate::Op;
165    for node in graph.nodes_mut() {
166        if let Op::Reshape { new_shape } = &mut node.op {
167            if node.shape.is_static() {
168                *new_shape = node
169                    .shape
170                    .dims()
171                    .iter()
172                    .map(|d| d.unwrap_static() as i64)
173                    .collect();
174            }
175        }
176    }
177}
178
179/// Recompute all inferrable output shapes after binding (propagates concat fixes).
180pub fn sync_graph_shapes(graph: &mut Graph) {
181    let nodes = graph.nodes().to_vec();
182    for node in &nodes {
183        if let Some(shape) = crate::infer_shape::infer_output_shape(graph, node) {
184            graph.node_mut(node.id).shape = shape;
185        }
186    }
187}
188
189/// Recompute `Op::Concat` output shapes from bound inputs (fixes mixed static+dynamic axes).
190pub fn sync_concat_shapes(graph: &mut Graph) {
191    use crate::Op;
192    let nodes = graph.nodes().to_vec();
193    for node in &nodes {
194        let Op::Concat { axis } = &node.op else {
195            continue;
196        };
197        let shapes: Vec<Shape> = node
198            .inputs
199            .iter()
200            .map(|&id| graph.node(id).shape.clone())
201            .collect();
202        let refs: Vec<&Shape> = shapes.iter().collect();
203        if let Ok(out) = crate::shape::concat_shape(&refs, *axis) {
204            graph.node_mut(node.id).shape = out;
205        }
206    }
207}
208
209/// Clamp `Op::Narrow` start indices after bind (template may bake in max_seq placeholders).
210pub fn sync_narrow_ops(graph: &mut Graph) {
211    use crate::Op;
212    let nodes = graph.nodes().to_vec();
213    for node in &nodes {
214        let Op::Narrow { axis, start, len } = &node.op else {
215            continue;
216        };
217        let in_shape = graph.node(node.inputs[0]).shape.clone();
218        if *axis >= in_shape.rank() || !in_shape.is_static() {
219            continue;
220        }
221        let ax_len = in_shape.dims()[*axis].unwrap_static();
222        if *start + *len > ax_len {
223            graph.node_mut(node.id).op = Op::Narrow {
224                axis: *axis,
225                start: ax_len.saturating_sub(*len),
226                len: *len,
227            };
228        }
229    }
230}
231
232/// After [`bind_graph`], sync `Op::Expand { target_shape }` with bound output shapes.
233pub fn sync_expand_ops(graph: &mut Graph) {
234    use crate::Op;
235    let nodes = graph.nodes().to_vec();
236    for node in &nodes {
237        let Op::Expand { .. } = &node.op else {
238            continue;
239        };
240        if !node.shape.is_static() {
241            continue;
242        }
243        let target: Vec<i64> = node
244            .shape
245            .dims()
246            .iter()
247            .map(|d| d.unwrap_static() as i64)
248            .collect();
249        graph.node_mut(node.id).op = Op::Expand {
250            target_shape: target,
251        };
252    }
253}
254
255/// Infer symbol sizes from runtime input element counts.
256///
257/// Each `Op::Input` may have at most one dynamic dimension; its size is
258/// `data_len / product(static_dims)`.
259pub fn infer_bindings_from_inputs(
260    graph: &Graph,
261    inputs: &[(&str, usize)],
262) -> Result<DimBinding, String> {
263    let by_name: HashMap<&str, usize> = inputs.iter().copied().collect();
264    let mut binding = DimBinding::new();
265    for node in graph.nodes() {
266        let Op::Input { name } = &node.op else {
267            continue;
268        };
269        let Some(&n_elems) = by_name.get(name.as_str()) else {
270            continue;
271        };
272        let mut static_prod: usize = 1;
273        let mut dynamic_sym: Option<u32> = None;
274        for d in node.shape.dims() {
275            match d {
276                Dim::Static(n) => static_prod *= *n,
277                Dim::Dynamic(sym) => {
278                    if dynamic_sym.is_some() {
279                        return Err(format!(
280                            "Input '{name}' has multiple dynamic dims; \
281                             pass an explicit DimBinding"
282                        ));
283                    }
284                    dynamic_sym = Some(*sym);
285                }
286            }
287        }
288        let Some(sym) = dynamic_sym else {
289            continue;
290        };
291        if static_prod == 0 {
292            return Err(format!("Input '{name}': static dim product is zero"));
293        }
294        if n_elems % static_prod != 0 {
295            return Err(format!(
296                "Input '{name}': len {n_elems} not divisible by static product {static_prod}"
297            ));
298        }
299        let size = n_elems / static_prod;
300        if let Some(prev) = binding.get(sym) {
301            if prev != size {
302                return Err(format!(
303                    "symbol {sym} bound to {prev} and {size} from different inputs"
304                ));
305            }
306        } else {
307            binding.set(sym, size);
308        }
309    }
310    complete_im2col_row_bindings(graph, &mut binding);
311    Ok(binding)
312}
313
314/// When `sym::BATCH` is bound on NCHW inputs feeding `Op::Im2Col`, derive
315/// `sym::ROWS = batch · H_out · W_out` for im2col/matmul reshape consumers.
316pub fn complete_im2col_row_bindings(graph: &Graph, binding: &mut DimBinding) {
317    let Some(batch) = binding.get(sym::BATCH) else {
318        return;
319    };
320    if binding.get(sym::ROWS).is_some() {
321        return;
322    }
323    for node in graph.nodes() {
324        let Op::Im2Col {
325            kernel_size,
326            stride,
327            padding,
328            dilation,
329        } = &node.op
330        else {
331            continue;
332        };
333        let x_shape = &graph.node(node.inputs[0]).shape;
334        if x_shape.rank() != 4 {
335            continue;
336        }
337        if !x_shape.dim(2).is_static() || !x_shape.dim(3).is_static() {
338            continue;
339        }
340        let h = x_shape.dim(2).unwrap_static();
341        let w = x_shape.dim(3).unwrap_static();
342        let kh = kernel_size.first().copied().unwrap_or(1);
343        let kw = kernel_size.get(1).copied().unwrap_or(1);
344        let sh = stride.first().copied().unwrap_or(1);
345        let sw = stride.get(1).copied().unwrap_or(1);
346        let ph = padding.first().copied().unwrap_or(0);
347        let pw = padding.get(1).copied().unwrap_or(0);
348        let dh = dilation.first().copied().unwrap_or(1);
349        let dw = dilation.get(1).copied().unwrap_or(1);
350        let h_out = crate::shape::conv2d_spatial_output(h, kh, sh, ph, dh);
351        let w_out = crate::shape::conv2d_spatial_output(w, kw, sw, pw, dw);
352        binding.set(sym::ROWS, batch * h_out * w_out);
353        return;
354    }
355}
356
357/// Infer bindings from f32 slice lengths (convenience for tests/runtime).
358pub fn infer_bindings_from_f32_inputs(
359    graph: &Graph,
360    inputs: &[(&str, &[f32])],
361) -> Result<DimBinding, String> {
362    infer_bindings_from_inputs(
363        graph,
364        &inputs
365            .iter()
366            .map(|(n, d)| (*n, d.len()))
367            .collect::<Vec<_>>(),
368    )
369}
370
371pub fn same_binding(a: &DimBinding, b: &DimBinding) -> bool {
372    if a.len() != b.len() {
373        return false;
374    }
375    a.iter().all(|(sym, size)| b.get(sym) == Some(size))
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use crate::infer::GraphExt;
382
383    #[test]
384    fn bind_graph_specializes_matmul() {
385        let batch = sym::BATCH;
386        let seq = sym::SEQ;
387        let mut g = Graph::new("dyn");
388        let x = g.input("x", Shape::batch_seq(batch, seq, 4, DType::F32));
389        let w = g.param("w", Shape::new(&[4, 8], DType::F32));
390        let y = g.mm(x, w);
391        g.set_outputs(vec![y]);
392
393        assert!(has_dynamic_dims(&g));
394        let binding = DimBinding::batch_seq(2, 16);
395        let bound = bind_graph(&g, &binding);
396        assert!(!has_dynamic_dims(&bound));
397        assert_eq!(
398            bound.node(bound.outputs[0]).shape,
399            Shape::new(&[2, 16, 8], DType::F32)
400        );
401    }
402
403    #[test]
404    fn infer_bindings_from_input_data() {
405        let mut g = Graph::new("dyn");
406        let x = g.input(
407            "x",
408            Shape::from_dims(
409                &[Dim::Static(3), Dim::Dynamic(sym::SEQ), Dim::Static(64)],
410                DType::F32,
411            ),
412        );
413        g.set_outputs(vec![x]);
414
415        let b = infer_bindings_from_f32_inputs(&g, &[("x", &vec![0.0f32; 3 * 128 * 64])])
416            .expect("infer");
417        assert_eq!(b.get(sym::SEQ), Some(128));
418    }
419
420    #[test]
421    fn infer_bindings_sets_im2col_rows_from_batch() {
422        let mut g = Graph::new("im2col_rows");
423        let x = g.input(
424            "x",
425            Shape::from_dims(
426                &[
427                    Dim::Dynamic(sym::BATCH),
428                    Dim::Static(1),
429                    Dim::Static(4),
430                    Dim::Static(4),
431                ],
432                DType::F32,
433            ),
434        );
435        let _col = g.im2col(x, [3, 3], [1, 1], [1, 1], [1, 1]);
436        g.set_outputs(vec![x]);
437        let b = infer_bindings_from_f32_inputs(&g, &[("x", &[0.0f32; 2 * 16])]).expect("infer");
438        assert_eq!(b.get(sym::BATCH), Some(2));
439        assert_eq!(b.get(sym::ROWS), Some(2 * 4 * 4));
440    }
441}