1use std::collections::{BTreeSet, HashMap};
22
23use crate::shape::{Dim, DimBinding, Shape};
24use crate::{DType, Graph, Op};
25
26pub mod sym {
29 pub const BATCH: u32 = 0;
30 pub const SEQ: u32 = 1;
31 pub const PAST_SEQ: u32 = 3;
33 pub const ROWS: u32 = 2;
35}
36
37#[derive(Debug, Clone, Default)]
39pub struct DimEnv {
40 next: u32,
41 names: HashMap<String, u32>,
42}
43
44impl DimEnv {
45 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn sym(&mut self, name: &str) -> u32 {
51 if let Some(&id) = self.names.get(name) {
52 return id;
53 }
54 let id = self.next;
55 self.next += 1;
56 self.names.insert(name.into(), id);
57 id
58 }
59
60 pub fn name(&self, symbol: u32) -> Option<&str> {
61 self.names
62 .iter()
63 .find_map(|(n, &s)| (s == symbol).then_some(n.as_str()))
64 }
65}
66
67impl Shape {
68 pub fn batch_seq(batch: u32, seq: u32, hidden: usize, dtype: DType) -> Self {
70 Self::from_dims(
71 &[Dim::Dynamic(batch), Dim::Dynamic(seq), Dim::Static(hidden)],
72 dtype,
73 )
74 }
75
76 pub fn batch_seq_2d(batch: u32, seq: u32, dtype: DType) -> Self {
78 Self::from_dims(&[Dim::Dynamic(batch), Dim::Dynamic(seq)], dtype)
79 }
80
81 pub fn batch_seq_heads(
83 batch: u32,
84 seq: u32,
85 heads: usize,
86 head_dim: usize,
87 dtype: DType,
88 ) -> Self {
89 Self::from_dims(
90 &[
91 Dim::Dynamic(batch),
92 Dim::Dynamic(seq),
93 Dim::Static(heads),
94 Dim::Static(head_dim),
95 ],
96 dtype,
97 )
98 }
99}
100
101impl DimBinding {
102 pub fn from_pairs(pairs: &[(u32, usize)]) -> Self {
103 let mut b = Self::new();
104 for &(sym, size) in pairs {
105 b.set(sym, size);
106 }
107 b
108 }
109
110 pub fn batch_seq(batch: usize, seq: usize) -> Self {
111 let mut b = Self::from_pairs(&[(sym::BATCH, batch), (sym::SEQ, seq)]);
112 if batch > 1 {
113 b.set(sym::ROWS, batch * seq);
114 }
115 b
116 }
117
118 pub fn batch_past_seq(batch: usize, past_seq: usize) -> Self {
119 Self::from_pairs(&[(sym::BATCH, batch), (sym::PAST_SEQ, past_seq)])
120 }
121}
122
123pub fn has_dynamic_dims(graph: &Graph) -> bool {
125 graph
126 .nodes()
127 .iter()
128 .any(|n| n.shape.dims().iter().any(|d| matches!(d, Dim::Dynamic(_))))
129}
130
131pub fn collect_dynamic_symbols(graph: &Graph) -> Vec<u32> {
133 let mut syms = BTreeSet::new();
134 for node in graph.nodes() {
135 for s in node.shape.dynamic_symbols() {
136 syms.insert(s);
137 }
138 }
139 syms.into_iter().collect()
140}
141
142pub fn bind_graph(graph: &Graph, bindings: &DimBinding) -> Graph {
147 let mut out = Graph::new(&graph.name);
148 for node in graph.nodes() {
149 let bound = node.shape.bind(bindings);
150 out.push_ext(
151 node.op.clone(),
152 node.inputs.clone(),
153 bound,
154 node.name.clone(),
155 node.origin.clone(),
156 );
157 }
158 out.set_outputs(graph.outputs.clone());
159 out
160}
161
162pub fn sync_reshape_ops(graph: &mut Graph) {
164 use crate::Op;
165 for node in graph.nodes_mut() {
166 if let Op::Reshape { new_shape } = &mut node.op {
167 if node.shape.is_static() {
168 *new_shape = node
169 .shape
170 .dims()
171 .iter()
172 .map(|d| d.unwrap_static() as i64)
173 .collect();
174 }
175 }
176 }
177}
178
179pub fn sync_graph_shapes(graph: &mut Graph) {
181 let nodes = graph.nodes().to_vec();
182 for node in &nodes {
183 if let Some(shape) = crate::infer_shape::infer_output_shape(graph, node) {
184 graph.node_mut(node.id).shape = shape;
185 }
186 }
187}
188
189pub fn sync_concat_shapes(graph: &mut Graph) {
191 use crate::Op;
192 let nodes = graph.nodes().to_vec();
193 for node in &nodes {
194 let Op::Concat { axis } = &node.op else {
195 continue;
196 };
197 let shapes: Vec<Shape> = node
198 .inputs
199 .iter()
200 .map(|&id| graph.node(id).shape.clone())
201 .collect();
202 let refs: Vec<&Shape> = shapes.iter().collect();
203 if let Ok(out) = crate::shape::concat_shape(&refs, *axis) {
204 graph.node_mut(node.id).shape = out;
205 }
206 }
207}
208
209pub fn sync_narrow_ops(graph: &mut Graph) {
211 use crate::Op;
212 let nodes = graph.nodes().to_vec();
213 for node in &nodes {
214 let Op::Narrow { axis, start, len } = &node.op else {
215 continue;
216 };
217 let in_shape = graph.node(node.inputs[0]).shape.clone();
218 if *axis >= in_shape.rank() || !in_shape.is_static() {
219 continue;
220 }
221 let ax_len = in_shape.dims()[*axis].unwrap_static();
222 if *start + *len > ax_len {
223 graph.node_mut(node.id).op = Op::Narrow {
224 axis: *axis,
225 start: ax_len.saturating_sub(*len),
226 len: *len,
227 };
228 }
229 }
230}
231
232pub fn infer_bindings_from_inputs(
237 graph: &Graph,
238 inputs: &[(&str, usize)],
239) -> Result<DimBinding, String> {
240 let by_name: HashMap<&str, usize> = inputs.iter().copied().collect();
241 let mut binding = DimBinding::new();
242 for node in graph.nodes() {
243 let Op::Input { name } = &node.op else {
244 continue;
245 };
246 let Some(&n_elems) = by_name.get(name.as_str()) else {
247 continue;
248 };
249 let mut static_prod: usize = 1;
250 let mut dynamic_sym: Option<u32> = None;
251 for d in node.shape.dims() {
252 match d {
253 Dim::Static(n) => static_prod *= *n,
254 Dim::Dynamic(sym) => {
255 if dynamic_sym.is_some() {
256 return Err(format!(
257 "Input '{name}' has multiple dynamic dims; \
258 pass an explicit DimBinding"
259 ));
260 }
261 dynamic_sym = Some(*sym);
262 }
263 }
264 }
265 let Some(sym) = dynamic_sym else {
266 continue;
267 };
268 if static_prod == 0 {
269 return Err(format!("Input '{name}': static dim product is zero"));
270 }
271 if n_elems % static_prod != 0 {
272 return Err(format!(
273 "Input '{name}': len {n_elems} not divisible by static product {static_prod}"
274 ));
275 }
276 let size = n_elems / static_prod;
277 if let Some(prev) = binding.get(sym) {
278 if prev != size {
279 return Err(format!(
280 "symbol {sym} bound to {prev} and {size} from different inputs"
281 ));
282 }
283 } else {
284 binding.set(sym, size);
285 }
286 }
287 Ok(binding)
288}
289
290pub fn infer_bindings_from_f32_inputs(
292 graph: &Graph,
293 inputs: &[(&str, &[f32])],
294) -> Result<DimBinding, String> {
295 infer_bindings_from_inputs(
296 graph,
297 &inputs
298 .iter()
299 .map(|(n, d)| (*n, d.len()))
300 .collect::<Vec<_>>(),
301 )
302}
303
304pub fn same_binding(a: &DimBinding, b: &DimBinding) -> bool {
305 if a.len() != b.len() {
306 return false;
307 }
308 a.iter().all(|(sym, size)| b.get(sym) == Some(size))
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::infer::GraphExt;
315
316 #[test]
317 fn bind_graph_specializes_matmul() {
318 let batch = sym::BATCH;
319 let seq = sym::SEQ;
320 let mut g = Graph::new("dyn");
321 let x = g.input("x", Shape::batch_seq(batch, seq, 4, DType::F32));
322 let w = g.param("w", Shape::new(&[4, 8], DType::F32));
323 let y = g.mm(x, w);
324 g.set_outputs(vec![y]);
325
326 assert!(has_dynamic_dims(&g));
327 let binding = DimBinding::batch_seq(2, 16);
328 let bound = bind_graph(&g, &binding);
329 assert!(!has_dynamic_dims(&bound));
330 assert_eq!(
331 bound.node(bound.outputs[0]).shape,
332 Shape::new(&[2, 16, 8], DType::F32)
333 );
334 }
335
336 #[test]
337 fn infer_bindings_from_input_data() {
338 let mut g = Graph::new("dyn");
339 let x = g.input(
340 "x",
341 Shape::from_dims(
342 &[Dim::Static(3), Dim::Dynamic(sym::SEQ), Dim::Static(64)],
343 DType::F32,
344 ),
345 );
346 g.set_outputs(vec![x]);
347
348 let b = infer_bindings_from_f32_inputs(&g, &[("x", &vec![0.0f32; 3 * 128 * 64])])
349 .expect("infer");
350 assert_eq!(b.get(sym::SEQ), Some(128));
351 }
352}