Skip to main content

rlx_ir/
dynamic.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Dynamic / symbolic dimensions — compile once, specialize at runtime.
17//!
18//! Plan #54: graphs built with [`Dim::Dynamic`] symbols specialize via
19//! [`DimBinding`] before buffer planning and backend lowering.
20
21use std::collections::{BTreeSet, HashMap};
22
23use crate::shape::{Dim, DimBinding, Shape};
24use crate::{DType, Graph, Op};
25
26/// Well-known dynamic dimension symbols. Reuse the same id across shapes
27/// so `[?0, ?1, H]` and `[?0, ?1, 4]` share batch/seq bindings.
28pub mod sym {
29    pub const BATCH: u32 = 0;
30    pub const SEQ: u32 = 1;
31    /// Cached prefix length for decode KV (past_k / past_v axis 1).
32    pub const PAST_SEQ: u32 = 3;
33    /// Product of leading axes (e.g. `batch * seq` flatten).
34    pub const ROWS: u32 = 2;
35}
36
37/// Allocate named dynamic symbols for model builders.
38#[derive(Debug, Clone, Default)]
39pub struct DimEnv {
40    next: u32,
41    names: HashMap<String, u32>,
42}
43
44impl DimEnv {
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Return the symbol id for `name`, allocating on first use.
50    pub fn sym(&mut self, name: &str) -> u32 {
51        if let Some(&id) = self.names.get(name) {
52            return id;
53        }
54        let id = self.next;
55        self.next += 1;
56        self.names.insert(name.into(), id);
57        id
58    }
59
60    pub fn name(&self, symbol: u32) -> Option<&str> {
61        self.names
62            .iter()
63            .find_map(|(n, &s)| (s == symbol).then_some(n.as_str()))
64    }
65}
66
67impl Shape {
68    /// `[batch, seq, hidden]` with symbolic leading axes.
69    pub fn batch_seq(batch: u32, seq: u32, hidden: usize, dtype: DType) -> Self {
70        Self::from_dims(
71            &[Dim::Dynamic(batch), Dim::Dynamic(seq), Dim::Static(hidden)],
72            dtype,
73        )
74    }
75
76    /// `[batch, seq]` matrix.
77    pub fn batch_seq_2d(batch: u32, seq: u32, dtype: DType) -> Self {
78        Self::from_dims(&[Dim::Dynamic(batch), Dim::Dynamic(seq)], dtype)
79    }
80
81    /// `[batch, seq, heads, head_dim]` attention layout.
82    pub fn batch_seq_heads(
83        batch: u32,
84        seq: u32,
85        heads: usize,
86        head_dim: usize,
87        dtype: DType,
88    ) -> Self {
89        Self::from_dims(
90            &[
91                Dim::Dynamic(batch),
92                Dim::Dynamic(seq),
93                Dim::Static(heads),
94                Dim::Static(head_dim),
95            ],
96            dtype,
97        )
98    }
99}
100
101impl DimBinding {
102    pub fn from_pairs(pairs: &[(u32, usize)]) -> Self {
103        let mut b = Self::new();
104        for &(sym, size) in pairs {
105            b.set(sym, size);
106        }
107        b
108    }
109
110    pub fn batch_seq(batch: usize, seq: usize) -> Self {
111        let mut b = Self::from_pairs(&[(sym::BATCH, batch), (sym::SEQ, seq)]);
112        if batch > 1 {
113            b.set(sym::ROWS, batch * seq);
114        }
115        b
116    }
117
118    pub fn batch_past_seq(batch: usize, past_seq: usize) -> Self {
119        Self::from_pairs(&[(sym::BATCH, batch), (sym::PAST_SEQ, past_seq)])
120    }
121}
122
123/// True if any node shape references a dynamic dimension.
124pub fn has_dynamic_dims(graph: &Graph) -> bool {
125    graph
126        .nodes()
127        .iter()
128        .any(|n| n.shape.dims().iter().any(|d| matches!(d, Dim::Dynamic(_))))
129}
130
131/// Collect all dynamic symbols referenced anywhere in the graph.
132pub fn collect_dynamic_symbols(graph: &Graph) -> Vec<u32> {
133    let mut syms = BTreeSet::new();
134    for node in graph.nodes() {
135        for s in node.shape.dynamic_symbols() {
136            syms.insert(s);
137        }
138    }
139    syms.into_iter().collect()
140}
141
142/// Specialize every node's shape against `bindings`.
143///
144/// Node ids are preserved (nodes are cloned in insertion order), so
145/// edges and outputs remain valid without remapping.
146pub fn bind_graph(graph: &Graph, bindings: &DimBinding) -> Graph {
147    let mut out = Graph::new(&graph.name);
148    for node in graph.nodes() {
149        let bound = node.shape.bind(bindings);
150        out.push_ext(
151            node.op.clone(),
152            node.inputs.clone(),
153            bound,
154            node.name.clone(),
155            node.origin.clone(),
156        );
157    }
158    out.set_outputs(graph.outputs.clone());
159    out
160}
161
162/// After [`bind_graph`], sync `Op::Reshape { new_shape }` with bound node shapes.
163pub fn sync_reshape_ops(graph: &mut Graph) {
164    use crate::Op;
165    for node in graph.nodes_mut() {
166        if let Op::Reshape { new_shape } = &mut node.op {
167            if node.shape.is_static() {
168                *new_shape = node
169                    .shape
170                    .dims()
171                    .iter()
172                    .map(|d| d.unwrap_static() as i64)
173                    .collect();
174            }
175        }
176    }
177}
178
179/// Recompute all inferrable output shapes after binding (propagates concat fixes).
180pub fn sync_graph_shapes(graph: &mut Graph) {
181    let nodes = graph.nodes().to_vec();
182    for node in &nodes {
183        if let Some(shape) = crate::infer_shape::infer_output_shape(graph, node) {
184            graph.node_mut(node.id).shape = shape;
185        }
186    }
187}
188
189/// Recompute `Op::Concat` output shapes from bound inputs (fixes mixed static+dynamic axes).
190pub fn sync_concat_shapes(graph: &mut Graph) {
191    use crate::Op;
192    let nodes = graph.nodes().to_vec();
193    for node in &nodes {
194        let Op::Concat { axis } = &node.op else {
195            continue;
196        };
197        let shapes: Vec<Shape> = node
198            .inputs
199            .iter()
200            .map(|&id| graph.node(id).shape.clone())
201            .collect();
202        let refs: Vec<&Shape> = shapes.iter().collect();
203        if let Ok(out) = crate::shape::concat_shape(&refs, *axis) {
204            graph.node_mut(node.id).shape = out;
205        }
206    }
207}
208
209/// Clamp `Op::Narrow` start indices after bind (template may bake in max_seq placeholders).
210pub fn sync_narrow_ops(graph: &mut Graph) {
211    use crate::Op;
212    let nodes = graph.nodes().to_vec();
213    for node in &nodes {
214        let Op::Narrow { axis, start, len } = &node.op else {
215            continue;
216        };
217        let in_shape = graph.node(node.inputs[0]).shape.clone();
218        if *axis >= in_shape.rank() || !in_shape.is_static() {
219            continue;
220        }
221        let ax_len = in_shape.dims()[*axis].unwrap_static();
222        if *start + *len > ax_len {
223            graph.node_mut(node.id).op = Op::Narrow {
224                axis: *axis,
225                start: ax_len.saturating_sub(*len),
226                len: *len,
227            };
228        }
229    }
230}
231
232/// Infer symbol sizes from runtime input element counts.
233///
234/// Each `Op::Input` may have at most one dynamic dimension; its size is
235/// `data_len / product(static_dims)`.
236pub fn infer_bindings_from_inputs(
237    graph: &Graph,
238    inputs: &[(&str, usize)],
239) -> Result<DimBinding, String> {
240    let by_name: HashMap<&str, usize> = inputs.iter().copied().collect();
241    let mut binding = DimBinding::new();
242    for node in graph.nodes() {
243        let Op::Input { name } = &node.op else {
244            continue;
245        };
246        let Some(&n_elems) = by_name.get(name.as_str()) else {
247            continue;
248        };
249        let mut static_prod: usize = 1;
250        let mut dynamic_sym: Option<u32> = None;
251        for d in node.shape.dims() {
252            match d {
253                Dim::Static(n) => static_prod *= *n,
254                Dim::Dynamic(sym) => {
255                    if dynamic_sym.is_some() {
256                        return Err(format!(
257                            "Input '{name}' has multiple dynamic dims; \
258                             pass an explicit DimBinding"
259                        ));
260                    }
261                    dynamic_sym = Some(*sym);
262                }
263            }
264        }
265        let Some(sym) = dynamic_sym else {
266            continue;
267        };
268        if static_prod == 0 {
269            return Err(format!("Input '{name}': static dim product is zero"));
270        }
271        if n_elems % static_prod != 0 {
272            return Err(format!(
273                "Input '{name}': len {n_elems} not divisible by static product {static_prod}"
274            ));
275        }
276        let size = n_elems / static_prod;
277        if let Some(prev) = binding.get(sym) {
278            if prev != size {
279                return Err(format!(
280                    "symbol {sym} bound to {prev} and {size} from different inputs"
281                ));
282            }
283        } else {
284            binding.set(sym, size);
285        }
286    }
287    Ok(binding)
288}
289
290/// Infer bindings from f32 slice lengths (convenience for tests/runtime).
291pub fn infer_bindings_from_f32_inputs(
292    graph: &Graph,
293    inputs: &[(&str, &[f32])],
294) -> Result<DimBinding, String> {
295    infer_bindings_from_inputs(
296        graph,
297        &inputs
298            .iter()
299            .map(|(n, d)| (*n, d.len()))
300            .collect::<Vec<_>>(),
301    )
302}
303
304pub fn same_binding(a: &DimBinding, b: &DimBinding) -> bool {
305    if a.len() != b.len() {
306        return false;
307    }
308    a.iter().all(|(sym, size)| b.get(sym) == Some(size))
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::infer::GraphExt;
315
316    #[test]
317    fn bind_graph_specializes_matmul() {
318        let batch = sym::BATCH;
319        let seq = sym::SEQ;
320        let mut g = Graph::new("dyn");
321        let x = g.input("x", Shape::batch_seq(batch, seq, 4, DType::F32));
322        let w = g.param("w", Shape::new(&[4, 8], DType::F32));
323        let y = g.mm(x, w);
324        g.set_outputs(vec![y]);
325
326        assert!(has_dynamic_dims(&g));
327        let binding = DimBinding::batch_seq(2, 16);
328        let bound = bind_graph(&g, &binding);
329        assert!(!has_dynamic_dims(&bound));
330        assert_eq!(
331            bound.node(bound.outputs[0]).shape,
332            Shape::new(&[2, 16, 8], DType::F32)
333        );
334    }
335
336    #[test]
337    fn infer_bindings_from_input_data() {
338        let mut g = Graph::new("dyn");
339        let x = g.input(
340            "x",
341            Shape::from_dims(
342                &[Dim::Static(3), Dim::Dynamic(sym::SEQ), Dim::Static(64)],
343                DType::F32,
344            ),
345        );
346        g.set_outputs(vec![x]);
347
348        let b = infer_bindings_from_f32_inputs(&g, &[("x", &vec![0.0f32; 3 * 128 * 64])])
349            .expect("infer");
350        assert_eq!(b.get(sym::SEQ), Some(128));
351    }
352}