1use std::collections::{BTreeSet, HashMap};
22
23use crate::shape::{Dim, DimBinding, Shape};
24use crate::{DType, Graph, Op};
25
26pub mod sym {
29 pub const BATCH: u32 = 0;
30 pub const SEQ: u32 = 1;
31 pub const PAST_SEQ: u32 = 3;
33 pub const ROWS: u32 = 2;
35}
36
37#[derive(Debug, Clone, Default)]
39pub struct DimEnv {
40 next: u32,
41 names: HashMap<String, u32>,
42}
43
44impl DimEnv {
45 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn sym(&mut self, name: &str) -> u32 {
51 if let Some(&id) = self.names.get(name) {
52 return id;
53 }
54 let id = self.next;
55 self.next += 1;
56 self.names.insert(name.into(), id);
57 id
58 }
59
60 pub fn name(&self, symbol: u32) -> Option<&str> {
61 self.names
62 .iter()
63 .find_map(|(n, &s)| (s == symbol).then_some(n.as_str()))
64 }
65}
66
67impl Shape {
68 pub fn batch_seq(batch: u32, seq: u32, hidden: usize, dtype: DType) -> Self {
70 Self::from_dims(
71 &[Dim::Dynamic(batch), Dim::Dynamic(seq), Dim::Static(hidden)],
72 dtype,
73 )
74 }
75
76 pub fn batch_seq_2d(batch: u32, seq: u32, dtype: DType) -> Self {
78 Self::from_dims(&[Dim::Dynamic(batch), Dim::Dynamic(seq)], dtype)
79 }
80
81 pub fn batch_seq_heads(
83 batch: u32,
84 seq: u32,
85 heads: usize,
86 head_dim: usize,
87 dtype: DType,
88 ) -> Self {
89 Self::from_dims(
90 &[
91 Dim::Dynamic(batch),
92 Dim::Dynamic(seq),
93 Dim::Static(heads),
94 Dim::Static(head_dim),
95 ],
96 dtype,
97 )
98 }
99}
100
101impl DimBinding {
102 pub fn from_pairs(pairs: &[(u32, usize)]) -> Self {
103 let mut b = Self::new();
104 for &(sym, size) in pairs {
105 b.set(sym, size);
106 }
107 b
108 }
109
110 pub fn batch_seq(batch: usize, seq: usize) -> Self {
111 let mut b = Self::from_pairs(&[(sym::BATCH, batch), (sym::SEQ, seq)]);
112 if batch > 1 {
113 b.set(sym::ROWS, batch * seq);
114 }
115 b
116 }
117
118 pub fn batch_past_seq(batch: usize, past_seq: usize) -> Self {
119 Self::from_pairs(&[(sym::BATCH, batch), (sym::PAST_SEQ, past_seq)])
120 }
121}
122
123pub fn has_dynamic_dims(graph: &Graph) -> bool {
125 graph
126 .nodes()
127 .iter()
128 .any(|n| n.shape.dims().iter().any(|d| matches!(d, Dim::Dynamic(_))))
129}
130
131pub fn collect_dynamic_symbols(graph: &Graph) -> Vec<u32> {
133 let mut syms = BTreeSet::new();
134 for node in graph.nodes() {
135 for s in node.shape.dynamic_symbols() {
136 syms.insert(s);
137 }
138 }
139 syms.into_iter().collect()
140}
141
142pub fn bind_graph(graph: &Graph, bindings: &DimBinding) -> Graph {
147 let mut out = Graph::new(&graph.name);
148 for node in graph.nodes() {
149 let bound = node.shape.bind(bindings);
150 out.push_ext(
151 node.op.clone(),
152 node.inputs.clone(),
153 bound,
154 node.name.clone(),
155 node.origin.clone(),
156 );
157 }
158 out.set_outputs(graph.outputs.clone());
159 out
160}
161
162pub fn sync_reshape_ops(graph: &mut Graph) {
164 use crate::Op;
165 for node in graph.nodes_mut() {
166 if let Op::Reshape { new_shape } = &mut node.op {
167 if node.shape.is_static() {
168 *new_shape = node
169 .shape
170 .dims()
171 .iter()
172 .map(|d| d.unwrap_static() as i64)
173 .collect();
174 }
175 }
176 }
177}
178
179pub fn sync_graph_shapes(graph: &mut Graph) {
181 let nodes = graph.nodes().to_vec();
182 for node in &nodes {
183 if let Some(shape) = crate::infer_shape::infer_output_shape(graph, node) {
184 graph.node_mut(node.id).shape = shape;
185 }
186 }
187}
188
189pub fn sync_concat_shapes(graph: &mut Graph) {
191 use crate::Op;
192 let nodes = graph.nodes().to_vec();
193 for node in &nodes {
194 let Op::Concat { axis } = &node.op else {
195 continue;
196 };
197 let shapes: Vec<Shape> = node
198 .inputs
199 .iter()
200 .map(|&id| graph.node(id).shape.clone())
201 .collect();
202 let refs: Vec<&Shape> = shapes.iter().collect();
203 if let Ok(out) = crate::shape::concat_shape(&refs, *axis) {
204 graph.node_mut(node.id).shape = out;
205 }
206 }
207}
208
209pub fn sync_narrow_ops(graph: &mut Graph) {
211 use crate::Op;
212 let nodes = graph.nodes().to_vec();
213 for node in &nodes {
214 let Op::Narrow { axis, start, len } = &node.op else {
215 continue;
216 };
217 let in_shape = graph.node(node.inputs[0]).shape.clone();
218 if *axis >= in_shape.rank() || !in_shape.is_static() {
219 continue;
220 }
221 let ax_len = in_shape.dims()[*axis].unwrap_static();
222 if *start + *len > ax_len {
223 graph.node_mut(node.id).op = Op::Narrow {
224 axis: *axis,
225 start: ax_len.saturating_sub(*len),
226 len: *len,
227 };
228 }
229 }
230}
231
232pub fn sync_expand_ops(graph: &mut Graph) {
234 use crate::Op;
235 let nodes = graph.nodes().to_vec();
236 for node in &nodes {
237 let Op::Expand { .. } = &node.op else {
238 continue;
239 };
240 if !node.shape.is_static() {
241 continue;
242 }
243 let target: Vec<i64> = node
244 .shape
245 .dims()
246 .iter()
247 .map(|d| d.unwrap_static() as i64)
248 .collect();
249 graph.node_mut(node.id).op = Op::Expand {
250 target_shape: target,
251 };
252 }
253}
254
255pub fn infer_bindings_from_inputs(
260 graph: &Graph,
261 inputs: &[(&str, usize)],
262) -> Result<DimBinding, String> {
263 let by_name: HashMap<&str, usize> = inputs.iter().copied().collect();
264 let mut binding = DimBinding::new();
265 for node in graph.nodes() {
266 let Op::Input { name } = &node.op else {
267 continue;
268 };
269 let Some(&n_elems) = by_name.get(name.as_str()) else {
270 continue;
271 };
272 let mut static_prod: usize = 1;
273 let mut dynamic_sym: Option<u32> = None;
274 for d in node.shape.dims() {
275 match d {
276 Dim::Static(n) => static_prod *= *n,
277 Dim::Dynamic(sym) => {
278 if dynamic_sym.is_some() {
279 return Err(format!(
280 "Input '{name}' has multiple dynamic dims; \
281 pass an explicit DimBinding"
282 ));
283 }
284 dynamic_sym = Some(*sym);
285 }
286 }
287 }
288 let Some(sym) = dynamic_sym else {
289 continue;
290 };
291 if static_prod == 0 {
292 return Err(format!("Input '{name}': static dim product is zero"));
293 }
294 if n_elems % static_prod != 0 {
295 return Err(format!(
296 "Input '{name}': len {n_elems} not divisible by static product {static_prod}"
297 ));
298 }
299 let size = n_elems / static_prod;
300 if let Some(prev) = binding.get(sym) {
301 if prev != size {
302 return Err(format!(
303 "symbol {sym} bound to {prev} and {size} from different inputs"
304 ));
305 }
306 } else {
307 binding.set(sym, size);
308 }
309 }
310 complete_im2col_row_bindings(graph, &mut binding);
311 Ok(binding)
312}
313
314pub fn complete_im2col_row_bindings(graph: &Graph, binding: &mut DimBinding) {
317 let Some(batch) = binding.get(sym::BATCH) else {
318 return;
319 };
320 if binding.get(sym::ROWS).is_some() {
321 return;
322 }
323 for node in graph.nodes() {
324 let Op::Im2Col {
325 kernel_size,
326 stride,
327 padding,
328 dilation,
329 } = &node.op
330 else {
331 continue;
332 };
333 let x_shape = &graph.node(node.inputs[0]).shape;
334 if x_shape.rank() != 4 {
335 continue;
336 }
337 if !x_shape.dim(2).is_static() || !x_shape.dim(3).is_static() {
338 continue;
339 }
340 let h = x_shape.dim(2).unwrap_static();
341 let w = x_shape.dim(3).unwrap_static();
342 let kh = kernel_size.first().copied().unwrap_or(1);
343 let kw = kernel_size.get(1).copied().unwrap_or(1);
344 let sh = stride.first().copied().unwrap_or(1);
345 let sw = stride.get(1).copied().unwrap_or(1);
346 let ph = padding.first().copied().unwrap_or(0);
347 let pw = padding.get(1).copied().unwrap_or(0);
348 let dh = dilation.first().copied().unwrap_or(1);
349 let dw = dilation.get(1).copied().unwrap_or(1);
350 let h_out = crate::shape::conv2d_spatial_output(h, kh, sh, ph, dh);
351 let w_out = crate::shape::conv2d_spatial_output(w, kw, sw, pw, dw);
352 binding.set(sym::ROWS, batch * h_out * w_out);
353 return;
354 }
355}
356
357pub fn infer_bindings_from_f32_inputs(
359 graph: &Graph,
360 inputs: &[(&str, &[f32])],
361) -> Result<DimBinding, String> {
362 infer_bindings_from_inputs(
363 graph,
364 &inputs
365 .iter()
366 .map(|(n, d)| (*n, d.len()))
367 .collect::<Vec<_>>(),
368 )
369}
370
371pub fn same_binding(a: &DimBinding, b: &DimBinding) -> bool {
372 if a.len() != b.len() {
373 return false;
374 }
375 a.iter().all(|(sym, size)| b.get(sym) == Some(size))
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::infer::GraphExt;
382
383 #[test]
384 fn bind_graph_specializes_matmul() {
385 let batch = sym::BATCH;
386 let seq = sym::SEQ;
387 let mut g = Graph::new("dyn");
388 let x = g.input("x", Shape::batch_seq(batch, seq, 4, DType::F32));
389 let w = g.param("w", Shape::new(&[4, 8], DType::F32));
390 let y = g.mm(x, w);
391 g.set_outputs(vec![y]);
392
393 assert!(has_dynamic_dims(&g));
394 let binding = DimBinding::batch_seq(2, 16);
395 let bound = bind_graph(&g, &binding);
396 assert!(!has_dynamic_dims(&bound));
397 assert_eq!(
398 bound.node(bound.outputs[0]).shape,
399 Shape::new(&[2, 16, 8], DType::F32)
400 );
401 }
402
403 #[test]
404 fn infer_bindings_from_input_data() {
405 let mut g = Graph::new("dyn");
406 let x = g.input(
407 "x",
408 Shape::from_dims(
409 &[Dim::Static(3), Dim::Dynamic(sym::SEQ), Dim::Static(64)],
410 DType::F32,
411 ),
412 );
413 g.set_outputs(vec![x]);
414
415 let b = infer_bindings_from_f32_inputs(&g, &[("x", &vec![0.0f32; 3 * 128 * 64])])
416 .expect("infer");
417 assert_eq!(b.get(sym::SEQ), Some(128));
418 }
419
420 #[test]
421 fn infer_bindings_sets_im2col_rows_from_batch() {
422 let mut g = Graph::new("im2col_rows");
423 let x = g.input(
424 "x",
425 Shape::from_dims(
426 &[
427 Dim::Dynamic(sym::BATCH),
428 Dim::Static(1),
429 Dim::Static(4),
430 Dim::Static(4),
431 ],
432 DType::F32,
433 ),
434 );
435 let _col = g.im2col(x, [3, 3], [1, 1], [1, 1], [1, 1]);
436 g.set_outputs(vec![x]);
437 let b = infer_bindings_from_f32_inputs(&g, &[("x", &[0.0f32; 2 * 16])]).expect("infer");
438 assert_eq!(b.get(sym::BATCH), Some(2));
439 assert_eq!(b.get(sym::ROWS), Some(2 * 4 * 4));
440 }
441}