1use crate::ast::{DataType, Dimension, DynamicDimension, GraphJson};
4use crate::protos::onnx::{
5 tensor_shape_proto::dimension::Value as DimensionValue, type_proto::Value as TypeProtoValue,
6 ModelProto, TensorProto, TensorProto_DataType,
7};
8use prost::Message;
9use serde_json::Value as JsonValue;
10use std::collections::{BTreeMap, HashMap, HashSet};
11use std::fs;
12use std::path::Path;
13use thiserror::Error;
14use webnn_onnx_utils::{data_types as utils_data_types, identifiers};
15
16const MIN_SUPPORTED_OPSET: i64 = 11;
17const MAX_SUPPORTED_OPSET: i64 = 18;
18
19#[derive(Debug, Error)]
20pub enum OnnxError {
21 #[error("failed to read ONNX file: {0}")]
22 IoError(#[from] std::io::Error),
23
24 #[error("failed to parse ONNX protobuf: {0}")]
25 ProtobufError(String),
26
27 #[error("unsupported ONNX opset version {version} for domain '{domain}'")]
28 UnsupportedOpset { domain: String, version: i64 },
29
30 #[error("unsupported operator: {op} (node: {node})")]
31 UnsupportedOp { op: String, node: String },
32
33 #[error("missing required attribute: {attr} in {op}")]
34 MissingAttribute { attr: String, op: String },
35
36 #[error("invalid tensor shape: {0}")]
37 InvalidShape(String),
38
39 #[error("type conversion error: {0}")]
40 TypeConversion(#[from] webnn_onnx_utils::error::ConversionError),
41
42 #[error("shape inference failed for node: {0}")]
43 ShapeInference(String),
44}
45
46pub fn sanitize_identifier(name: &str) -> String {
49 identifiers::sanitize_for_webnn(name)
50}
51
52pub(crate) fn map_onnx_data_type(onnx_type: i32) -> Result<DataType, OnnxError> {
54 if onnx_type == TensorProto_DataType::Bool as i32 {
55 return Ok(DataType::Uint8);
56 }
57
58 let utils_dtype = utils_data_types::onnx_to_webnn(onnx_type)?;
59 Ok(match utils_dtype {
60 utils_data_types::DataType::Float32 => DataType::Float32,
61 utils_data_types::DataType::Float16 => DataType::Float16,
62 utils_data_types::DataType::Int32 => DataType::Int32,
63 utils_data_types::DataType::Uint32 => DataType::Uint32,
64 utils_data_types::DataType::Int64 => DataType::Int64,
65 utils_data_types::DataType::Uint64 => DataType::Uint64,
66 utils_data_types::DataType::Int8 => DataType::Int8,
67 utils_data_types::DataType::Uint8 => DataType::Uint8,
68 })
69}
70
71fn infer_shape(
73 node: &crate::protos::onnx::NodeProto,
74 value_shapes: &HashMap<String, Vec<i64>>,
75 initializers: &HashMap<String, &TensorProto>,
76 const_values: &HashMap<String, Vec<i64>>,
77) -> Option<Vec<i64>> {
78 let op = node.op_type.as_str();
79
80 match op {
81 "Cast" | "Relu" | "Tanh" | "Sigmoid" | "Erf" | "Softmax" | "Gelu" | "Exp" | "Log"
83 | "Abs" | "Neg" | "Sqrt" | "LayerNormalization" | "Trilu" => {
84 let ins = node.input.as_slice();
85 if ins.is_empty() {
86 return None;
87 }
88 value_shapes.get(ins[0].as_str()).cloned()
89 }
90
91 "Add" | "Sub" | "Mul" | "Div" | "Pow" => {
93 let ins = node.input.as_slice();
94 if ins.len() < 2 {
95 return None;
96 }
97
98 let shape_a = value_shapes.get(ins[0].as_str());
99 let shape_b = value_shapes.get(ins[1].as_str());
100
101 match (shape_a, shape_b) {
102 (Some(a), Some(b)) => {
103 let rank = a.len().max(b.len());
104 let mut out_rev = Vec::with_capacity(rank);
105 for i in 0..rank {
106 let da = a.get(a.len().wrapping_sub(1 + i)).copied().unwrap_or(1);
107 let db = b.get(b.len().wrapping_sub(1 + i)).copied().unwrap_or(1);
108 if da == db || da == 1 {
109 out_rev.push(db);
110 } else if db == 1 {
111 out_rev.push(da);
112 } else {
113 return None;
114 }
115 }
116 out_rev.reverse();
117 Some(out_rev)
118 }
119 (Some(a), None) => Some(a.clone()),
120 (None, Some(b)) => Some(b.clone()),
121 (None, None) => None,
122 }
123 }
124
125 "MatMul" => {
127 let ins = node.input.as_slice();
128 if ins.len() < 2 {
129 return None;
130 }
131
132 let a_shape = value_shapes.get(ins[0].as_str())?;
133 let b_shape = value_shapes.get(ins[1].as_str())?;
134
135 if a_shape.len() >= 2 && b_shape.len() >= 2 {
137 let m = a_shape[a_shape.len() - 2];
138 let n = b_shape[b_shape.len() - 1];
139
140 if a_shape.len() == 2 && b_shape.len() == 2 {
142 return Some(vec![m, n]);
143 } else if a_shape.len() > 2 {
144 let mut result = a_shape[..a_shape.len() - 2].to_vec();
145 result.push(m);
146 result.push(n);
147 return Some(result);
148 }
149 }
150 None
151 }
152
153 "Transpose" => {
155 let ins = node.input.as_slice();
156 if ins.is_empty() {
157 return None;
158 }
159 let input_shape = value_shapes.get(ins[0].as_str())?;
160
161 let perm: Vec<usize> = node
163 .attribute
164 .as_slice()
165 .iter()
166 .find(|a| a.name.as_str() == "perm")
167 .map(|a| a.ints.iter().map(|&i| i as usize).collect::<Vec<usize>>())
168 .unwrap_or_else(|| (0..input_shape.len()).rev().collect());
169
170 Some(perm.iter().map(|&i| input_shape[i]).collect())
172 }
173
174 "ReduceMean" | "ReduceSum" | "ReduceMax" | "ReduceMin" => {
176 let ins = node.input.as_slice();
177 if ins.is_empty() {
178 return None;
179 }
180 let input_shape = value_shapes.get(ins[0].as_str())?;
181
182 let keepdims = node
184 .attribute
185 .as_slice()
186 .iter()
187 .find(|a| a.name.as_str() == "keepdims")
188 .and_then(|a| if a.i != 0 { Some(a.i != 0) } else { None })
189 .unwrap_or(true);
190
191 let axes: Vec<i64> = node
193 .attribute
194 .as_slice()
195 .iter()
196 .find(|a| a.name.as_str() == "axes")
197 .map(|a| a.ints.clone())
198 .unwrap_or_default();
199
200 if axes.is_empty() {
201 if keepdims {
203 Some(vec![1; input_shape.len()])
204 } else {
205 Some(vec![])
206 }
207 } else {
208 let mut output_shape = input_shape.clone();
210 for &axis in &axes {
211 let idx = if axis < 0 {
212 (input_shape.len() as i64 + axis) as usize
213 } else {
214 axis as usize
215 };
216 if idx < output_shape.len() {
217 if keepdims {
218 output_shape[idx] = 1;
219 } else {
220 output_shape[idx] = -1; }
222 }
223 }
224 if !keepdims {
225 output_shape.retain(|&d| d != -1);
226 }
227 Some(output_shape)
228 }
229 }
230
231 "Gemm" => {
233 let ins = node.input.as_slice();
234 if ins.len() < 2 {
235 return None;
236 }
237
238 let a_shape = value_shapes.get(ins[0].as_str())?;
239 let b_shape = value_shapes.get(ins[1].as_str())?;
240
241 if a_shape.len() != 2 || b_shape.len() != 2 {
242 return None;
243 }
244
245 let trans_a = node
247 .attribute
248 .as_slice()
249 .iter()
250 .find(|a| a.name.as_str() == "transA")
251 .and_then(|a| if a.i != 0 { Some(a.i != 0) } else { None })
252 .unwrap_or(false);
253
254 let trans_b = node
255 .attribute
256 .as_slice()
257 .iter()
258 .find(|a| a.name.as_str() == "transB")
259 .and_then(|a| if a.i != 0 { Some(a.i != 0) } else { None })
260 .unwrap_or(false);
261
262 let m = if trans_a { a_shape[1] } else { a_shape[0] };
263 let n = if trans_b { b_shape[0] } else { b_shape[1] };
264
265 Some(vec![m, n])
266 }
267
268 "Gather" => {
269 let ins = node.input.as_slice();
270 if ins.len() < 2 {
271 return None;
272 }
273
274 let data_shape = value_shapes.get(ins[0].as_str())?;
275 let indices_shape = value_shapes.get(ins[1].as_str())?;
276
277 let mut axis = node
278 .attribute
279 .as_slice()
280 .iter()
281 .find(|a| a.name.as_str() == "axis")
282 .and_then(|a| if a.i != 0 { Some(a.i) } else { None })
283 .unwrap_or(0);
284
285 if axis < 0 {
286 axis += data_shape.len() as i64;
287 }
288
289 let axis_usize = axis as usize;
290 if axis_usize > data_shape.len() {
291 return None;
292 }
293
294 let mut output = Vec::new();
295 output.extend_from_slice(&data_shape[..axis_usize]);
296 output.extend(indices_shape.iter().cloned());
297 if axis_usize < data_shape.len() {
298 output.extend_from_slice(&data_shape[axis_usize + 1..]);
299 }
300 Some(output)
301 }
302
303 "Unsqueeze" => {
304 let ins = node.input.as_slice();
305 if ins.is_empty() {
306 return None;
307 }
308
309 let input_shape = value_shapes.get(ins[0].as_str())?.clone();
310 let mut axes: Vec<i64> = node
311 .attribute
312 .as_slice()
313 .iter()
314 .find(|a| a.name.as_str() == "axes")
315 .map(|a| a.ints.clone())
316 .unwrap_or_default();
317
318 if axes.is_empty() {
319 return None;
320 }
321
322 axes.sort();
323 let mut output_shape = input_shape;
324 for axis in axes {
325 let idx = if axis < 0 {
326 (output_shape.len() as i64 + axis + 1) as usize
327 } else {
328 axis as usize
329 };
330 if idx <= output_shape.len() {
331 output_shape.insert(idx, 1);
332 }
333 }
334 Some(output_shape)
335 }
336
337 "Concat" => {
338 let mut shapes = Vec::new();
339 for inp in node.input.as_slice() {
340 let shape = value_shapes.get(inp.as_str())?;
341 shapes.push(shape.clone());
342 }
343
344 if shapes.is_empty() {
345 return None;
346 }
347
348 let mut axis = node
349 .attribute
350 .as_slice()
351 .iter()
352 .find(|a| a.name.as_str() == "axis")
353 .and_then(|a| if a.i != 0 { Some(a.i) } else { None })
354 .unwrap_or(0);
355
356 if axis < 0 {
357 axis += shapes[0].len() as i64;
358 }
359 let axis_usize = axis as usize;
360
361 let mut output = shapes[0].clone();
362 for shape in shapes.iter().skip(1) {
363 if shape.len() != output.len() || axis_usize >= shape.len() {
364 return None;
365 }
366 output[axis_usize] += shape[axis_usize];
367 }
368 Some(output)
369 }
370
371 "Reshape" => {
372 let ins = node.input.as_slice();
373 if ins.len() < 2 {
374 return None;
375 }
376
377 let input_shape = value_shapes.get(ins[0].as_str())?;
378 let shape_input = ins[1].as_str();
379 let mut target: Vec<i64> = if let Some(values) = const_values.get(shape_input) {
380 values.clone()
381 } else if let Some(shape_tensor) = initializers.get(shape_input) {
382 if !shape_tensor.raw_data.as_slice().is_empty() {
383 if shape_tensor.data_type == TensorProto_DataType::Int32 as i32 {
384 shape_tensor
385 .raw_data
386 .as_slice()
387 .chunks_exact(4)
388 .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as i64)
389 .collect()
390 } else {
391 shape_tensor
392 .raw_data
393 .as_slice()
394 .chunks_exact(8)
395 .map(|c| {
396 i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
397 })
398 .collect()
399 }
400 } else if !shape_tensor.int64_data.as_slice().is_empty() {
401 shape_tensor.int64_data.as_slice().to_vec()
402 } else if !shape_tensor.int32_data.as_slice().is_empty() {
403 shape_tensor
404 .int32_data
405 .as_slice()
406 .iter()
407 .map(|&v| v as i64)
408 .collect()
409 } else {
410 Vec::new()
411 }
412 } else {
413 Vec::new()
414 };
415
416 if target.is_empty() {
417 return None;
418 }
419
420 if target.contains(&-1) {
421 let total_input: i64 = input_shape.iter().product();
422 let known: i64 = target.iter().filter(|&&d| d != -1).product();
423 if known == 0 || total_input % known != 0 {
424 return None;
425 }
426 if let Some(idx) = target.iter().position(|&d| d == -1) {
427 target[idx] = total_input / known;
428 }
429 }
430
431 Some(target)
432 }
433
434 "Slice" => {
435 let ins = node.input.as_slice();
436 if ins.is_empty() {
437 return None;
438 }
439
440 let input_shape = value_shapes.get(ins[0].as_str())?;
441
442 let read_ints = |name: Option<&String>| -> Option<Vec<i64>> {
443 if let Some(n) = name {
444 if let Some(v) = const_values.get(n) {
445 return Some(v.clone());
446 }
447 if let Some(t) = initializers.get(n) {
448 let raw = t.raw_data.as_slice();
449 if !raw.is_empty() {
450 if t.data_type == TensorProto_DataType::Int32 as i32 {
451 return Some(
452 raw.chunks_exact(4)
453 .map(|c| {
454 i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as i64
455 })
456 .collect(),
457 );
458 } else {
459 return Some(
460 raw.chunks_exact(8)
461 .map(|c| {
462 i64::from_le_bytes([
463 c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7],
464 ])
465 })
466 .collect(),
467 );
468 }
469 } else if !t.int64_data.as_slice().is_empty() {
470 return Some(t.int64_data.as_slice().to_vec());
471 } else if !t.int32_data.as_slice().is_empty() {
472 return Some(
473 t.int32_data.as_slice().iter().map(|&v| v as i64).collect(),
474 );
475 }
476 }
477 }
478 None
479 };
480
481 let starts = read_ints(ins.get(1))?;
482 let ends = read_ints(ins.get(2))?;
483 let axes =
484 read_ints(ins.get(3)).unwrap_or_else(|| (0..input_shape.len() as i64).collect());
485 let steps = read_ints(ins.get(4)).unwrap_or_else(|| vec![1; axes.len()]);
486
487 if axes.len() != starts.len() || axes.len() != ends.len() || axes.len() != steps.len() {
488 return None;
489 }
490
491 let mut output = input_shape.clone();
492 for i in 0..axes.len() {
493 let axis = if axes[i] < 0 {
494 (input_shape.len() as i64 + axes[i]) as usize
495 } else {
496 axes[i] as usize
497 };
498 if axis >= output.len() {
499 return None;
500 }
501
502 let step = steps[i];
503 if step != 1 {
504 return None;
505 }
506
507 let dim = input_shape[axis];
508 let mut start = starts[i];
509 let mut end = ends[i];
510
511 if start < 0 {
512 start += dim;
513 }
514 if end < 0 {
515 end += dim;
516 }
517
518 start = start.max(0);
519 end = end.min(dim);
520
521 if end < start {
522 output[axis] = 0;
523 } else {
524 output[axis] = end - start;
525 }
526 }
527
528 Some(output)
529 }
530
531 _ => None,
532 }
533}
534
535fn shape_numel(shape: &[i64]) -> Option<usize> {
536 shape.iter().try_fold(1usize, |acc, &d| {
537 if d < 0 {
538 return None;
539 }
540 usize::try_from(d).ok().map(|v| acc.saturating_mul(v))
541 })
542}
543
544fn const_shape_for_folding(
545 name: &str,
546 values: &[i64],
547 value_shapes: &HashMap<String, Vec<i64>>,
548) -> Vec<i64> {
549 if let Some(shape) = value_shapes.get(name) {
550 if shape_numel(shape) == Some(values.len()) {
551 return shape.clone();
552 }
553 }
554
555 if values.len() == 1 {
556 Vec::new()
557 } else {
558 vec![values.len() as i64]
559 }
560}
561
562fn broadcast_shape(shape_a: &[i64], shape_b: &[i64]) -> Option<Vec<i64>> {
563 let rank = shape_a.len().max(shape_b.len());
564 let mut out_rev = Vec::with_capacity(rank);
565 for i in 0..rank {
566 let da = shape_a
567 .get(shape_a.len().wrapping_sub(1 + i))
568 .copied()
569 .unwrap_or(1);
570 let db = shape_b
571 .get(shape_b.len().wrapping_sub(1 + i))
572 .copied()
573 .unwrap_or(1);
574 if da <= 0 || db <= 0 {
575 return None;
576 }
577 if da == db || da == 1 {
578 out_rev.push(db);
579 } else if db == 1 {
580 out_rev.push(da);
581 } else {
582 return None;
583 }
584 }
585 out_rev.reverse();
586 Some(out_rev)
587}
588
589fn linear_index_for_broadcast_operand(
590 out_linear_idx: usize,
591 out_shape: &[i64],
592 in_shape: &[i64],
593) -> Option<usize> {
594 if in_shape.is_empty() {
595 return Some(0);
596 }
597
598 let in_rank = in_shape.len();
599 let out_rank = out_shape.len();
600 if in_rank > out_rank {
601 return None;
602 }
603
604 let mut in_linear_idx = 0usize;
605 let mut in_stride = 1usize;
606 let mut rem = out_linear_idx;
607
608 for out_axis_rev in 0..out_rank {
609 let out_axis = out_rank - 1 - out_axis_rev;
610 let out_dim = usize::try_from(out_shape[out_axis]).ok()?;
611 if out_dim == 0 {
612 return None;
613 }
614 let out_coord = rem % out_dim;
615 rem /= out_dim;
616
617 if out_axis_rev < in_rank {
618 let in_axis = in_rank - 1 - out_axis_rev;
619 let in_dim = usize::try_from(in_shape[in_axis]).ok()?;
620 if in_dim == 0 {
621 return None;
622 }
623 let in_coord = if in_dim == 1 { 0 } else { out_coord };
624 in_linear_idx = in_linear_idx.saturating_add(in_coord.saturating_mul(in_stride));
625 in_stride = in_stride.saturating_mul(in_dim);
626 }
627 }
628
629 Some(in_linear_idx)
630}
631
632fn fold_binary_const_i64(
633 op_type: &str,
634 a_values: &[i64],
635 b_values: &[i64],
636 a_shape: &[i64],
637 b_shape: &[i64],
638) -> Option<(Vec<i64>, Vec<i64>)> {
639 let out_shape = broadcast_shape(a_shape, b_shape)?;
640 let out_numel = shape_numel(&out_shape)?;
641
642 let mut out_values = Vec::with_capacity(out_numel);
643 for out_idx in 0..out_numel {
644 let a_idx = linear_index_for_broadcast_operand(out_idx, &out_shape, a_shape)?;
645 let b_idx = linear_index_for_broadcast_operand(out_idx, &out_shape, b_shape)?;
646 let av = *a_values.get(a_idx)?;
647 let bv = *b_values.get(b_idx)?;
648 let v = match op_type {
649 "Add" => av + bv,
650 "Sub" => av - bv,
651 "Mul" => av * bv,
652 "Div" => {
653 if bv == 0 {
654 return None;
655 }
656 av / bv
657 }
658 "Equal" => {
659 if av == bv {
660 1
661 } else {
662 0
663 }
664 }
665 _ => return None,
666 };
667 out_values.push(v);
668 }
669
670 Some((out_values, out_shape))
671}
672
673fn value_shape_dims_for<'a>(
674 name: &str,
675 value_shape_dims: &'a HashMap<String, Vec<Dimension>>,
676) -> Option<&'a [Dimension]> {
677 let sanitized = sanitize_identifier(name);
678 let trimmed = name.trim_start_matches('/');
679 value_shape_dims
680 .get(name)
681 .or_else(|| value_shape_dims.get(&sanitized))
682 .or_else(|| value_shape_dims.get(trimmed))
683 .map(Vec::as_slice)
684}
685
686fn dims_contain_dynamic(dims: &[Dimension]) -> bool {
687 dims.iter().any(|d| matches!(d, Dimension::Dynamic(_)))
688}
689
690pub(crate) fn parse_dynamic_dim_expr(dim_name: &str) -> (String, i64) {
691 let s = dim_name.trim();
692 if let Some((lhs, rhs)) = s.rsplit_once('+') {
693 if let Ok(offset) = rhs.trim().parse::<i64>() {
694 return (lhs.trim().to_string(), offset);
695 }
696 }
697 if let Some((lhs, rhs)) = s.rsplit_once('-') {
698 if let Ok(offset) = rhs.trim().parse::<i64>() {
699 return (lhs.trim().to_string(), -offset);
700 }
701 }
702 (s.to_string(), 0)
703}
704
705pub(crate) fn format_dynamic_dim_expr(base: &str, offset: i64) -> String {
706 if offset > 0 {
707 format!("{base} + {offset}")
708 } else if offset < 0 {
709 format!("{base} - {}", offset.abs())
710 } else {
711 base.to_string()
712 }
713}
714
715fn parse_additive_dynamic_dim_expr(dim_name: &str) -> Option<(BTreeMap<String, i64>, i64)> {
716 let expr = dim_name.trim();
717 if expr.is_empty() {
718 return None;
719 }
720
721 let normalized = expr.replace('+', " + ").replace('-', " - ");
722 let mut terms = BTreeMap::new();
723 let mut constant = 0i64;
724 let mut sign = 1i64;
725 let mut saw_term = false;
726
727 for token in normalized.split_whitespace() {
728 match token {
729 "+" => sign = 1,
730 "-" => sign = -1,
731 _ => {
732 saw_term = true;
733 if let Ok(value) = token.parse::<i64>() {
734 constant += sign * value;
735 } else {
736 *terms.entry(token.to_string()).or_insert(0) += sign;
737 }
738 sign = 1;
739 }
740 }
741 }
742
743 if !saw_term {
744 return None;
745 }
746
747 terms.retain(|_, coeff| *coeff != 0);
748 Some((terms, constant))
749}
750
751fn format_additive_dynamic_dim_expr(
752 terms: &BTreeMap<String, i64>,
753 constant: i64,
754) -> Option<String> {
755 if terms.is_empty() && constant == 0 {
756 return None;
757 }
758
759 let mut out = String::new();
760 for (name, coeff) in terms {
761 for _ in 0..coeff.abs() {
762 if out.is_empty() {
763 if *coeff < 0 {
764 out.push_str("- ");
765 }
766 out.push_str(name);
767 } else if *coeff < 0 {
768 out.push_str(" - ");
769 out.push_str(name);
770 } else {
771 out.push_str(" + ");
772 out.push_str(name);
773 }
774 }
775 }
776
777 if constant != 0 {
778 if out.is_empty() {
779 out.push_str(&constant.to_string());
780 } else if constant < 0 {
781 out.push_str(" - ");
782 out.push_str(&constant.abs().to_string());
783 } else {
784 out.push_str(" + ");
785 out.push_str(&constant.to_string());
786 }
787 }
788
789 Some(out)
790}
791
792fn is_runtime_resolvable_dynamic_dim_expr(dim_name: &str) -> bool {
793 let s = dim_name.trim();
794 if s.is_empty() || s.contains('*') || s.contains('/') {
795 return false;
796 }
797 if let Some((lhs, rhs)) = s.rsplit_once('+') {
798 return !lhs.trim().is_empty() && rhs.trim().parse::<i64>().is_ok();
799 }
800 if let Some((lhs, rhs)) = s.rsplit_once('-') {
801 return !lhs.trim().is_empty() && rhs.trim().parse::<i64>().is_ok();
802 }
803 true
804}
805
806fn shift_dynamic_dimension(dim: &DynamicDimension, delta: i64) -> Option<DynamicDimension> {
807 let (base, offset) = parse_dynamic_dim_expr(&dim.name);
808 let name = format_dynamic_dim_expr(&base, offset.checked_add(delta)?);
809 let shifted_max = (dim.max_size as i64).checked_add(delta)?.max(0);
810 let max_size = u32::try_from(shifted_max).ok()?;
811 Some(DynamicDimension { name, max_size })
812}
813
814pub(crate) fn dynamic_scalar_dimension_for_value(
815 name: &str,
816 value_shape_dims: &HashMap<String, Vec<Dimension>>,
817) -> Option<DynamicDimension> {
818 let dims = value_shape_dims_for(name, value_shape_dims)?;
819 if dims.len() != 1 {
820 return None;
821 }
822 match &dims[0] {
823 Dimension::Dynamic(dim) => Some(dim.clone()),
824 Dimension::Static(_) => None,
825 }
826}
827
828fn dimension_vector_for_value(
829 name: &str,
830 const_values: &HashMap<String, Vec<i64>>,
831 value_shape_dims: &HashMap<String, Vec<Dimension>>,
832) -> Option<Vec<Dimension>> {
833 if let Some(dims) = value_shape_dims_for(name, value_shape_dims) {
834 return Some(dims.to_vec());
835 }
836 let values = const_values.get(name)?;
837 values
838 .iter()
839 .map(|&v| u32::try_from(v).ok().map(Dimension::Static))
840 .collect()
841}
842
843fn is_trivial_static_dimension_vector(dims: &[Dimension]) -> bool {
844 dims.len() <= 3 && dims.iter().all(|d| matches!(d, Dimension::Static(1)))
845}
846
847fn combine_binary_dimension(
848 op_type: &str,
849 dynamic: &DynamicDimension,
850 static_value: i64,
851 dynamic_on_lhs: bool,
852) -> Option<Dimension> {
853 match op_type {
854 "Add" => shift_dynamic_dimension(dynamic, static_value).map(Dimension::Dynamic),
855 "Sub" if dynamic_on_lhs => {
856 shift_dynamic_dimension(dynamic, -static_value).map(Dimension::Dynamic)
857 }
858 "Mul" if static_value == 0 => Some(Dimension::Static(0)),
859 "Mul" if static_value == 1 => Some(Dimension::Dynamic(dynamic.clone())),
860 "Mul" if static_value > 1 => Some(Dimension::Dynamic(DynamicDimension {
861 name: if dynamic_on_lhs {
862 format!("{} * {}", dynamic.name, static_value)
863 } else {
864 format!("{} * {}", static_value, dynamic.name)
865 },
866 max_size: dynamic.max_size.saturating_mul(static_value as u32),
867 })),
868 "Div" if dynamic_on_lhs && static_value == 1 => Some(Dimension::Dynamic(dynamic.clone())),
869 "Div" if dynamic_on_lhs && static_value > 1 => Some(Dimension::Dynamic(DynamicDimension {
870 name: format!("{} / {}", dynamic.name, static_value),
871 max_size: dynamic.max_size / (static_value as u32),
872 })),
873 _ => None,
874 }
875}
876
877fn combine_dynamic_dimensions(
878 op_type: &str,
879 lhs: &DynamicDimension,
880 rhs: &DynamicDimension,
881 lhs_value: i64,
882 rhs_value: i64,
883) -> Option<Dimension> {
884 match op_type {
885 "Add" | "Sub" => {
886 let (mut terms, mut constant) = parse_additive_dynamic_dim_expr(&lhs.name)?;
887 let (rhs_terms, rhs_constant) = parse_additive_dynamic_dim_expr(&rhs.name)?;
888 let rhs_sign = if op_type == "Add" { 1 } else { -1 };
889
890 for (name, coeff) in rhs_terms {
891 *terms.entry(name).or_insert(0) += rhs_sign * coeff;
892 }
893 constant += rhs_sign * rhs_constant;
894 terms.retain(|_, coeff| *coeff != 0);
895
896 let value = if op_type == "Add" {
897 lhs_value.checked_add(rhs_value)?
898 } else {
899 lhs_value.checked_sub(rhs_value)?
900 };
901 if terms.is_empty() {
902 return u32::try_from(value).ok().map(Dimension::Static);
903 }
904
905 let name = format_additive_dynamic_dim_expr(&terms, constant)?;
906 let max_size = u32::try_from(value).ok()?;
907 Some(Dimension::Dynamic(DynamicDimension { name, max_size }))
908 }
909 _ => None,
910 }
911}
912
913fn fold_binary_dynamic_dims(
914 op_type: &str,
915 a_values: &[i64],
916 b_values: &[i64],
917 a_shape: &[i64],
918 b_shape: &[i64],
919 a_dims: Option<&[Dimension]>,
920 b_dims: Option<&[Dimension]>,
921) -> Option<Vec<Dimension>> {
922 let out_shape = broadcast_shape(a_shape, b_shape)?;
923 let out_numel = shape_numel(&out_shape)?;
924 let mut out_dims = Vec::with_capacity(out_numel);
925 let mut has_dynamic = false;
926
927 for out_idx in 0..out_numel {
928 let a_idx = linear_index_for_broadcast_operand(out_idx, &out_shape, a_shape)?;
929 let b_idx = linear_index_for_broadcast_operand(out_idx, &out_shape, b_shape)?;
930 let av = *a_values.get(a_idx)?;
931 let bv = *b_values.get(b_idx)?;
932 let a_dim = a_dims.and_then(|dims| dims.get(a_idx));
933 let b_dim = b_dims.and_then(|dims| dims.get(b_idx));
934
935 let out_dim = match (a_dim, b_dim) {
936 (Some(Dimension::Dynamic(dynamic)), Some(Dimension::Static(_)))
937 | (Some(Dimension::Dynamic(dynamic)), None) => {
938 let dim = combine_binary_dimension(op_type, dynamic, bv, true)?;
939 has_dynamic |= matches!(dim, Dimension::Dynamic(_));
940 dim
941 }
942 (Some(Dimension::Static(_)), Some(Dimension::Dynamic(dynamic)))
943 | (None, Some(Dimension::Dynamic(dynamic))) => {
944 let dim = combine_binary_dimension(op_type, dynamic, av, false)?;
945 has_dynamic |= matches!(dim, Dimension::Dynamic(_));
946 dim
947 }
948 (Some(Dimension::Dynamic(a_dynamic)), Some(Dimension::Dynamic(b_dynamic))) => {
949 let dim = combine_dynamic_dimensions(op_type, a_dynamic, b_dynamic, av, bv)?;
950 has_dynamic |= matches!(dim, Dimension::Dynamic(_));
951 dim
952 }
953 _ => {
954 let value = match op_type {
955 "Add" => av + bv,
956 "Sub" => av - bv,
957 "Mul" => av * bv,
958 "Div" => {
959 if bv == 0 {
960 return None;
961 }
962 av / bv
963 }
964 _ => return None,
965 };
966 Dimension::Static(u32::try_from(value).ok()?)
967 }
968 };
969
970 out_dims.push(out_dim);
971 }
972
973 has_dynamic.then_some(out_dims)
974}
975
976pub(crate) fn dynamic_range_length_dimension(
977 start: i64,
978 delta: i64,
979 start_dim: Option<&DynamicDimension>,
980 limit: &DynamicDimension,
981) -> Option<DynamicDimension> {
982 if delta != 1 {
983 return None;
984 }
985
986 let (mut terms, mut constant) = parse_additive_dynamic_dim_expr(&limit.name)?;
987 if let Some(start_dim) = start_dim {
988 let (start_terms, start_constant) = parse_additive_dynamic_dim_expr(&start_dim.name)?;
989 for (name, coeff) in start_terms {
990 *terms.entry(name).or_insert(0) -= coeff;
991 }
992 constant -= start_constant;
993 } else {
994 constant -= start;
995 }
996 terms.retain(|_, coeff| *coeff != 0);
997 if terms.is_empty() {
998 return None;
999 }
1000
1001 let name = format_additive_dynamic_dim_expr(&terms, constant)?;
1002 if !is_runtime_resolvable_dynamic_dim_expr(&name) {
1003 return None;
1004 }
1005
1006 let max_size = u32::try_from((limit.max_size as i64).checked_sub(start)?).ok()?;
1007 Some(DynamicDimension { name, max_size })
1008}
1009
1010#[derive(Debug, Clone)]
1012pub struct ConvertOptions {
1013 pub extract_weights: bool,
1015 pub output_path: String,
1017 pub weights_path: Option<String>,
1019 pub manifest_path: Option<String>,
1021 pub free_dim_overrides: HashMap<String, u32>,
1023 pub optimize: bool,
1025 pub experimental_dynamic_inputs: bool,
1027}
1028
1029impl Default for ConvertOptions {
1030 fn default() -> Self {
1031 Self {
1032 extract_weights: true,
1033 output_path: "output.webnn".to_string(),
1034 weights_path: Some("output.weights".to_string()),
1035 manifest_path: Some("output.manifest.json".to_string()),
1036 free_dim_overrides: HashMap::new(),
1037 optimize: false,
1038 experimental_dynamic_inputs: false,
1039 }
1040 }
1041}
1042
1043struct TensorInfo {
1044 _data_type: DataType,
1045 _shape: Vec<i64>,
1046}
1047
1048pub struct OnnxConverter {
1050 model: ModelProto,
1051 graph: GraphJson,
1052 _value_info: HashMap<String, TensorInfo>,
1053}
1054
1055impl OnnxConverter {
1056 pub fn new(model: ModelProto) -> Result<Self, OnnxError> {
1058 let graph_name = if let Some(graph) = &model.graph {
1059 if !graph.name.is_empty() {
1060 graph.name.as_str().to_string()
1061 } else {
1062 "graph".to_string()
1063 }
1064 } else {
1065 "graph".to_string()
1066 };
1067
1068 let graph = GraphJson {
1069 format: "webnn-graph-json".to_string(),
1070 version: 1,
1071 name: Some(graph_name),
1072 quantized: false,
1073 inputs: BTreeMap::new(),
1074 consts: BTreeMap::new(),
1075 nodes: Vec::new(),
1076 outputs: BTreeMap::new(),
1077 };
1078
1079 Ok(Self {
1080 model,
1081 graph,
1082 _value_info: HashMap::new(),
1083 })
1084 }
1085
1086 pub fn extract_metadata(&self) -> Result<(), OnnxError> {
1088 if self.model.graph.is_none() {
1089 return Err(OnnxError::ProtobufError(
1090 "Missing graph in model".to_string(),
1091 ));
1092 }
1093
1094 let graph = self.model.graph.as_ref().unwrap();
1095
1096 println!("Model name: {}", self.graph.name.as_ref().unwrap());
1098 println!("Inputs: {}", graph.input.as_slice().len());
1099 println!("Outputs: {}", graph.output.as_slice().len());
1100 println!("Nodes: {}", graph.node.as_slice().len());
1101 println!("Initializers: {}", graph.initializer.as_slice().len());
1102
1103 Ok(())
1104 }
1105
1106 pub fn convert(mut self, options: &ConvertOptions) -> Result<GraphJson, OnnxError> {
1108 if self.model.graph.is_none() {
1109 return Err(OnnxError::ProtobufError(
1110 "Missing graph in model".to_string(),
1111 ));
1112 }
1113
1114 for import in self.model.opset_import.as_slice() {
1116 let domain = import.domain.as_str();
1117 let version = import.version;
1118 let domain_name = if domain.is_empty() {
1119 "ai.onnx".to_string()
1120 } else {
1121 domain.to_string()
1122 };
1123
1124 if (domain.is_empty() || domain == "ai.onnx")
1125 && !(MIN_SUPPORTED_OPSET..=MAX_SUPPORTED_OPSET).contains(&version)
1126 {
1127 return Err(OnnxError::UnsupportedOpset {
1128 domain: domain_name,
1129 version,
1130 });
1131 }
1132 }
1133
1134 let onnx_graph = self.model.graph.as_ref().unwrap();
1135 let mut value_name_map: HashMap<String, String> = HashMap::new();
1136 let mut effective_overrides = options.free_dim_overrides.clone();
1137 let mut inference_overrides = effective_overrides.clone();
1138 let mut value_types: HashMap<String, DataType> = HashMap::new();
1139
1140 for meta in self.model.metadata_props.as_slice() {
1142 if meta
1143 .key
1144 .as_str()
1145 .eq_ignore_ascii_case("freedimensionoverrides")
1146 {
1147 if let Ok(json) = serde_json::from_str::<JsonValue>(meta.value.as_str()) {
1148 let obj = json
1149 .get("freeDimensionOverrides")
1150 .unwrap_or(&json)
1151 .as_object()
1152 .cloned();
1153 if let Some(map) = obj {
1154 for (name, value) in map {
1155 if let Some(v) = value.as_u64() {
1156 effective_overrides.entry(name.clone()).or_insert(v as u32);
1157 }
1158 }
1159 }
1160 }
1161 }
1162 }
1163
1164 let initializer_names: HashSet<String> = onnx_graph
1166 .initializer
1167 .as_slice()
1168 .iter()
1169 .map(|init| init.name.as_str().to_string())
1170 .collect();
1171
1172 let default_dynamic_max_size: u32 = 65_535;
1173 let default_inference_dim_values: HashMap<&str, u32> =
1174 HashMap::from([("batch_size", 1), ("batch", 1), ("n", 1), ("b", 1)]);
1175 let dynamic_max_for_dim = |name: &str| -> u32 {
1176 let lower = name.to_ascii_lowercase();
1177 if lower.contains("past")
1178 || lower.contains("seq")
1179 || lower.contains("length")
1180 || lower == "s"
1181 || lower == "t"
1182 {
1183 4096
1184 } else if lower.contains("batch") || lower == "b" || lower == "n" {
1185 8
1186 } else {
1187 default_dynamic_max_size
1188 }
1189 };
1190
1191 let resolve_dim_override =
1192 |dim_param: &str, overrides: &mut HashMap<String, u32>| -> Option<u32> {
1193 if let Some(v) = overrides.get(dim_param) {
1194 return Some(*v);
1195 }
1196
1197 let lower = dim_param.to_ascii_lowercase();
1198 if let Some(v) = overrides.get(&lower) {
1199 return Some(*v);
1200 }
1201 None
1202 };
1203 let resolve_dim_for_inference =
1204 |dim_param: &str, overrides: &mut HashMap<String, u32>| -> Option<u32> {
1205 if let Some(v) = resolve_dim_override(dim_param, overrides) {
1206 return Some(v);
1207 }
1208 let lower = dim_param.to_ascii_lowercase();
1209 if let Some(v) = default_inference_dim_values.get(lower.as_str()) {
1210 overrides.insert(dim_param.to_string(), *v);
1211 return Some(*v);
1212 }
1213 None
1214 };
1215
1216 for input in onnx_graph.input.as_slice() {
1217 let raw_name = input.name.as_str().to_string();
1218 let name = sanitize_identifier(&raw_name);
1219
1220 if initializer_names.contains(&raw_name) {
1222 continue;
1223 }
1224
1225 if let Some(type_proto) = &input.r#type {
1227 if let Some(TypeProtoValue::TensorType(tensor_type)) = &type_proto.value {
1228 let data_type = if tensor_type.elem_type != 0 {
1229 let onnx_type = tensor_type.elem_type;
1230 map_onnx_data_type(onnx_type)?
1231 } else {
1232 DataType::Float32 };
1234
1235 let shape = if let Some(shape_proto) = &tensor_type.shape {
1236 let mut resolved: Vec<Dimension> = Vec::new();
1237 for (idx, dim) in shape_proto.dim.iter().enumerate() {
1238 if let Some(dim_value) = &dim.value {
1239 match dim_value {
1240 DimensionValue::DimValue(v) => {
1241 if *v > 0 {
1242 resolved.push(Dimension::Static(*v as u32));
1243 } else if options.experimental_dynamic_inputs {
1244 resolved.push(Dimension::Dynamic(DynamicDimension {
1245 name: format!("{}_dim{}", name, idx),
1246 max_size: default_dynamic_max_size,
1247 }));
1248 } else {
1249 let dim_hint = format!("{}_dim{}", name, idx);
1250 return Err(OnnxError::InvalidShape(format!(
1251 "Input '{}' has non-positive dim value ({}) at index {}. \
1252Provide --override-dim {}=<value> or enable --experimental-dynamic-inputs.",
1253 raw_name,
1254 v,
1255 idx,
1256 dim_hint
1257 )));
1258 }
1259 }
1260 DimensionValue::DimParam(dim_param) => {
1261 if let Some(v) = resolve_dim_override(
1262 dim_param,
1263 &mut effective_overrides,
1264 ) {
1265 resolved.push(Dimension::Static(v));
1266 } else if options.experimental_dynamic_inputs {
1267 let max_size = dynamic_max_for_dim(dim_param);
1268 resolved.push(Dimension::Dynamic(DynamicDimension {
1269 name: dim_param.to_string(),
1270 max_size,
1271 }));
1272 } else if let Some(v) = resolve_dim_for_inference(
1273 dim_param,
1274 &mut inference_overrides,
1275 ) {
1276 effective_overrides
1277 .entry(dim_param.clone())
1278 .or_insert(v);
1279 resolved.push(Dimension::Static(v));
1280 } else {
1281 return Err(OnnxError::InvalidShape(format!(
1282 "Input '{}' has unresolved dynamic dimension '{}'. \
1283Provide --override-dim {}=<value> or enable --experimental-dynamic-inputs.",
1284 raw_name, dim_param, dim_param
1285 )));
1286 }
1287 }
1288 }
1289 } else if options.experimental_dynamic_inputs {
1290 resolved.push(Dimension::Dynamic(DynamicDimension {
1291 name: format!("{}_dim{}", name, idx),
1292 max_size: default_dynamic_max_size,
1293 }));
1294 } else {
1295 let dim_hint = format!("{}_dim{}", name, idx);
1296 return Err(OnnxError::InvalidShape(format!(
1297 "Input '{}' has unknown dimension at index {}. \
1298Provide --override-dim {}=<value> or enable --experimental-dynamic-inputs.",
1299 raw_name, idx, dim_hint
1300 )));
1301 }
1302 }
1303 resolved
1304 } else {
1305 return Err(OnnxError::InvalidShape(format!(
1306 "Input '{}' is missing shape information",
1307 raw_name
1308 )));
1309 };
1310
1311 if shape.is_empty() {
1312 continue;
1313 }
1314
1315 self.graph.inputs.insert(
1316 name.clone(),
1317 crate::ast::OperandDesc {
1318 data_type: data_type.clone(),
1319 shape,
1320 },
1321 );
1322
1323 value_name_map.insert(raw_name.clone(), name.clone());
1324 value_name_map.insert(name.clone(), name.clone());
1325 value_types.insert(raw_name.clone(), data_type.clone());
1326 value_types.insert(name.clone(), data_type);
1327 }
1328 }
1329 }
1330
1331 for initializer in onnx_graph.initializer.as_slice() {
1333 let name = sanitize_identifier(initializer.name.as_str());
1334 let raw_data = initializer.raw_data.as_slice();
1335
1336 let has_data = !raw_data.is_empty()
1338 || !initializer.float_data.as_slice().is_empty()
1339 || !initializer.int32_data.as_slice().is_empty()
1340 || !initializer.int64_data.as_slice().is_empty()
1341 || !initializer.double_data.as_slice().is_empty();
1342
1343 if !has_data {
1344 crate::debug_println!("Warning: Skipping initializer '{}' with no data", name);
1345 continue;
1346 }
1347
1348 let onnx_type = initializer.data_type;
1349 let data_type = map_onnx_data_type(onnx_type)?;
1350 let shape: Vec<u32> = initializer
1351 .dims
1352 .as_slice()
1353 .iter()
1354 .map(|d| *d as u32)
1355 .collect();
1356
1357 let init = if options.extract_weights {
1358 crate::ast::ConstInit::Weights {
1360 r#ref: sanitize_identifier(initializer.name.as_str()),
1361 }
1362 } else {
1363 let bytes = raw_data.to_vec();
1365 crate::ast::ConstInit::InlineBytes { bytes }
1366 };
1367
1368 self.graph
1369 .consts
1370 .entry(name.clone())
1371 .or_insert(crate::ast::ConstDecl {
1372 data_type: data_type.clone(),
1373 shape,
1374 init,
1375 });
1376
1377 value_name_map.insert(initializer.name.as_str().to_string(), name.clone());
1378 value_name_map.insert(name.clone(), name.clone());
1379 value_types.insert(initializer.name.as_str().to_string(), data_type.clone());
1380 value_types.insert(name, data_type);
1381 }
1382
1383 let registry = crate::onnx::ops::OpRegistry::new();
1385
1386 let mut initializers_map = std::collections::HashMap::new();
1388 for initializer in onnx_graph.initializer.as_slice() {
1389 let has_data = !initializer.raw_data.as_slice().is_empty()
1391 || !initializer.float_data.as_slice().is_empty()
1392 || !initializer.int32_data.as_slice().is_empty()
1393 || !initializer.int64_data.as_slice().is_empty()
1394 || !initializer.double_data.as_slice().is_empty();
1395
1396 if !has_data {
1397 continue;
1398 }
1399 initializers_map.insert(initializer.name.as_str().to_string(), initializer);
1400 }
1401
1402 let mut value_shapes = std::collections::HashMap::new();
1404 let mut value_shape_dims = std::collections::HashMap::new();
1405
1406 for (raw_name, mapped_name) in value_name_map.clone() {
1408 if initializer_names.contains(&raw_name) {
1409 continue;
1410 }
1411 if let Some(input) = onnx_graph
1412 .input
1413 .as_slice()
1414 .iter()
1415 .find(|i| i.name.as_str() == raw_name)
1416 {
1417 if let Some(type_proto) = &input.r#type {
1418 if let Some(TypeProtoValue::TensorType(tensor_type)) = &type_proto.value {
1419 if let Some(shape_proto) = &tensor_type.shape {
1420 let mut shape: Vec<i64> = Vec::new();
1421 let mut unknown = false;
1422 for dim in &shape_proto.dim {
1423 if let Some(dim_value) = &dim.value {
1424 match dim_value {
1425 DimensionValue::DimValue(v) => {
1426 if *v > 0 {
1427 shape.push(*v);
1428 } else if options.experimental_dynamic_inputs {
1429 shape.push(default_dynamic_max_size as i64);
1430 } else {
1431 unknown = true;
1432 break;
1433 }
1434 }
1435 DimensionValue::DimParam(dim_param) => {
1436 if let Some(v) = resolve_dim_for_inference(
1437 dim_param,
1438 &mut inference_overrides,
1439 ) {
1440 shape.push(v as i64);
1441 } else if options.experimental_dynamic_inputs {
1442 shape.push(dynamic_max_for_dim(dim_param) as i64);
1443 } else {
1444 unknown = true;
1445 break;
1446 }
1447 }
1448 }
1449 } else if options.experimental_dynamic_inputs {
1450 shape.push(default_dynamic_max_size as i64);
1451 } else {
1452 unknown = true;
1453 break;
1454 }
1455 }
1456 if !unknown && !shape.is_empty() {
1457 value_shapes.insert(raw_name.clone(), shape.clone());
1458 value_shapes.insert(mapped_name.clone(), shape);
1459 }
1460 let mut dims = Vec::new();
1461 for dim in &shape_proto.dim {
1462 if let Some(dim_value) = &dim.value {
1463 match dim_value {
1464 DimensionValue::DimValue(v) => {
1465 if *v > 0 {
1466 dims.push(crate::ast::Dimension::Static(*v as u32));
1467 }
1468 }
1469 DimensionValue::DimParam(dim_param) => {
1470 dims.push(crate::ast::Dimension::Dynamic(
1471 crate::ast::DynamicDimension {
1472 name: dim_param.clone(),
1473 max_size: dynamic_max_for_dim(dim_param),
1474 },
1475 ));
1476 }
1477 }
1478 }
1479 }
1480 if !dims.is_empty() {
1481 value_shape_dims.insert(raw_name.clone(), dims.clone());
1482 value_shape_dims.insert(mapped_name.clone(), dims);
1483 }
1484 }
1485 }
1486 }
1487 }
1488 }
1489
1490 for initializer in onnx_graph.initializer.as_slice() {
1492 let has_data = !initializer.raw_data.as_slice().is_empty()
1494 || !initializer.float_data.as_slice().is_empty()
1495 || !initializer.int32_data.as_slice().is_empty()
1496 || !initializer.int64_data.as_slice().is_empty()
1497 || !initializer.double_data.as_slice().is_empty();
1498
1499 if !has_data {
1500 continue;
1501 }
1502 let shape: Vec<i64> = initializer.dims.as_slice().to_vec();
1503 value_shapes.insert(initializer.name.as_str().to_string(), shape);
1504 let dims: Vec<crate::ast::Dimension> = initializer
1505 .dims
1506 .iter()
1507 .copied()
1508 .filter(|d| *d > 0)
1509 .map(|d| crate::ast::Dimension::Static(d as u32))
1510 .collect();
1511 if !dims.is_empty() {
1512 value_shape_dims.insert(initializer.name.as_str().to_string(), dims);
1513 }
1514 }
1515
1516 for value_info in onnx_graph.value_info.as_slice() {
1519 if let Some(type_proto) = &value_info.r#type {
1520 if let Some(TypeProtoValue::TensorType(tensor_type)) = &type_proto.value {
1521 if let Some(shape_proto) = &tensor_type.shape {
1522 let mut shape: Vec<i64> = Vec::new();
1523 let mut unknown = false;
1524
1525 for dim in &shape_proto.dim {
1526 if let Some(dim_value) = &dim.value {
1527 match dim_value {
1528 DimensionValue::DimValue(v) => {
1529 if *v > 0 {
1530 shape.push(*v);
1531 } else if options.experimental_dynamic_inputs {
1532 shape.push(default_dynamic_max_size as i64);
1533 } else {
1534 unknown = true;
1535 break;
1536 }
1537 }
1538 DimensionValue::DimParam(dim_param) => {
1539 if let Some(v) = resolve_dim_for_inference(
1540 dim_param,
1541 &mut inference_overrides,
1542 ) {
1543 shape.push(v as i64);
1544 } else if options.experimental_dynamic_inputs {
1545 shape.push(dynamic_max_for_dim(dim_param) as i64);
1546 } else {
1547 unknown = true;
1548 break;
1549 }
1550 }
1551 }
1552 } else if options.experimental_dynamic_inputs {
1553 shape.push(default_dynamic_max_size as i64);
1554 } else {
1555 unknown = true;
1556 break;
1557 }
1558 }
1559
1560 if !unknown && !shape.is_empty() && shape.iter().all(|&d| d > 0) {
1561 value_shapes.insert(value_info.name.as_str().to_string(), shape);
1562 }
1563 let mut dims = Vec::new();
1564 for dim in &shape_proto.dim {
1565 if let Some(dim_value) = &dim.value {
1566 match dim_value {
1567 DimensionValue::DimValue(v) => {
1568 if *v > 0 {
1569 dims.push(crate::ast::Dimension::Static(*v as u32));
1570 }
1571 }
1572 DimensionValue::DimParam(dim_param) => {
1573 dims.push(crate::ast::Dimension::Dynamic(
1574 crate::ast::DynamicDimension {
1575 name: dim_param.clone(),
1576 max_size: dynamic_max_for_dim(dim_param),
1577 },
1578 ));
1579 }
1580 }
1581 }
1582 }
1583 if !dims.is_empty() {
1584 value_shape_dims.insert(value_info.name.as_str().to_string(), dims);
1585 }
1586 }
1587 }
1588 }
1589 }
1590
1591 let mut const_values: HashMap<String, Vec<i64>> = HashMap::new();
1593 for (name, initializer) in &initializers_map {
1594 if initializer.data_type == TensorProto_DataType::Int64 as i32
1595 || initializer.data_type == TensorProto_DataType::Int32 as i32
1596 {
1597 let raw = initializer.raw_data.as_slice();
1598 let values = if !raw.is_empty() {
1599 if initializer.data_type == TensorProto_DataType::Int32 as i32 {
1600 raw.chunks_exact(4)
1601 .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as i64)
1602 .collect()
1603 } else {
1604 raw.chunks_exact(8)
1605 .map(|c| {
1606 i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
1607 })
1608 .collect()
1609 }
1610 } else if !initializer.int64_data.as_slice().is_empty() {
1611 initializer.int64_data.as_slice().to_vec()
1612 } else if !initializer.int32_data.as_slice().is_empty() {
1613 initializer
1614 .int32_data
1615 .as_slice()
1616 .iter()
1617 .map(|&v| v as i64)
1618 .collect()
1619 } else {
1620 Vec::new()
1621 };
1622
1623 if !values.is_empty() {
1624 const_values.insert(name.clone(), values);
1625 }
1626 }
1627 }
1628
1629 for node in onnx_graph.node.as_slice() {
1630 if node.op_type.as_str() == "Constant" {
1631 if let Some(attr) = node
1632 .attribute
1633 .as_slice()
1634 .iter()
1635 .find(|a| a.name.as_str() == "value" && a.t.is_some())
1636 {
1637 let tensor = attr.t.as_ref().unwrap();
1638 if tensor.data_type == TensorProto_DataType::Int64 as i32
1639 || tensor.data_type == TensorProto_DataType::Int32 as i32
1640 {
1641 let raw = tensor.raw_data.as_slice();
1642 let values = if !raw.is_empty() {
1643 if tensor.data_type == TensorProto_DataType::Int32 as i32 {
1644 raw.chunks_exact(4)
1645 .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as i64)
1646 .collect()
1647 } else {
1648 raw.chunks_exact(8)
1649 .map(|c| {
1650 i64::from_le_bytes([
1651 c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7],
1652 ])
1653 })
1654 .collect()
1655 }
1656 } else if !tensor.int64_data.as_slice().is_empty() {
1657 tensor.int64_data.as_slice().to_vec()
1658 } else if !tensor.int32_data.as_slice().is_empty() {
1659 tensor
1660 .int32_data
1661 .as_slice()
1662 .iter()
1663 .map(|&v| v as i64)
1664 .collect()
1665 } else {
1666 Vec::new()
1667 };
1668
1669 if let Some(out) = node.output.as_slice().first() {
1670 if !values.is_empty() {
1671 const_values.insert(out.to_string(), values);
1672 value_types.insert(out.to_string(), DataType::Int64);
1673 }
1674 }
1675 }
1676 }
1677 }
1678 }
1679
1680 let mut dynamic_inference_attempts: HashSet<String> = HashSet::new();
1683 loop {
1684 match crate::onnx::shape_inference::infer_static_shapes(
1685 &self.model,
1686 &inference_overrides,
1687 ) {
1688 Ok(inferred) => {
1689 for (k, v) in inferred.value_shapes {
1692 value_shapes.entry(k).or_insert(v);
1693 }
1694 for (k, v) in inferred.value_types {
1695 value_types.entry(k).or_insert(v);
1696 }
1697 for (k, v) in inferred.const_values {
1698 if k.contains("rotary") && k.contains("Where") {
1701 if let Some(old_val) = const_values.get(&k) {
1702 crate::debug_println!(
1703 "[CONVERT] Overwriting {} from {:?} to {:?}",
1704 k,
1705 old_val,
1706 v
1707 );
1708 } else {
1709 crate::debug_println!("[CONVERT] Inserting new {} = {:?}", k, v);
1710 }
1711 }
1712 const_values.insert(k, v);
1713 }
1714 break;
1715 }
1716 Err(crate::onnx::shape_inference::ShapeInferenceError::DynamicDim {
1717 input,
1718 dim,
1719 }) => {
1720 if options.experimental_dynamic_inputs
1721 && !dynamic_inference_attempts.contains(dim.as_str())
1722 {
1723 let fallback = dynamic_max_for_dim(&dim);
1724 inference_overrides.insert(dim.clone(), fallback);
1725 dynamic_inference_attempts.insert(dim.clone());
1726 crate::debug_println!(
1727 "[CONVERT] Retrying static shape inference with inferred override {}={} \
1728 (required by input '{}')",
1729 dim,
1730 fallback,
1731 input
1732 );
1733 continue;
1734 }
1735 crate::debug_println!(
1736 "[CONVERT] Skipping static shape inference due to unresolved dynamic dim '{}' on input '{}'",
1737 dim,
1738 input
1739 );
1740 break;
1741 }
1742 Err(e) => return Err(OnnxError::ShapeInference(e.to_string())),
1743 }
1744 }
1745
1746 for _ in 0..3 {
1748 if options.optimize {
1749 let max_iterations = 10;
1750 for iteration in 0..max_iterations {
1751 let initial_count = value_shapes.len();
1752
1753 for onnx_node in onnx_graph.node.as_slice() {
1754 let all_outputs_known = onnx_node
1755 .output
1756 .as_slice()
1757 .iter()
1758 .all(|out| value_shapes.contains_key(out.as_str()));
1759 if all_outputs_known {
1760 continue;
1761 }
1762
1763 if let Some(inferred) =
1764 infer_shape(onnx_node, &value_shapes, &initializers_map, &const_values)
1765 {
1766 if let Some(output_name) = onnx_node.output.as_slice().first() {
1767 if output_name.contains("layers_15_self_attn")
1769 && (output_name.contains("Reshape")
1770 || output_name.contains("Transpose"))
1771 {
1772 crate::debug_println!(
1773 "[SHAPE DEBUG] {} {} -> {:?}",
1774 onnx_node.op_type.as_str(),
1775 output_name,
1776 inferred
1777 );
1778 }
1779 value_shapes.insert(output_name.to_string(), inferred);
1781 }
1782 }
1783 }
1784
1785 if value_shapes.len() == initial_count {
1786 break;
1787 }
1788
1789 if iteration == max_iterations - 1 {
1790 crate::debug_println!(
1791 "Warning: Shape propagation reached max iterations ({}/{})",
1792 value_shapes.len(),
1793 onnx_graph.node.as_slice().len()
1794 );
1795 }
1796 }
1797 }
1798
1799 if let Some(ids_shape) = value_shapes.get("input_ids") {
1803 if ids_shape.len() == 2 {
1804 let (batch, seq) = (ids_shape[0], ids_shape[1]);
1805 let upgrades: Vec<(String, Vec<i64>)> = value_shapes
1806 .iter()
1807 .filter_map(|(k, v)| {
1808 if v.len() == 1 && v[0] > 1 {
1809 Some((k.clone(), vec![batch, seq, v[0]]))
1810 } else {
1811 None
1812 }
1813 })
1814 .collect();
1815 for (k, v) in upgrades {
1816 value_shapes.insert(k, v);
1817 }
1818 }
1819 }
1820
1821 crate::debug_println!(
1822 "[debug] layer_norm shape {:?}",
1823 value_shapes.get("/decoder/block.0/layer.0/layer_norm/Mul_1_output_0")
1824 );
1825 crate::debug_println!(
1826 "[debug] matmul q shape {:?}",
1827 value_shapes.get("/decoder/block.0/layer.0/SelfAttention/q/MatMul_output_0")
1828 );
1829 crate::debug_println!(
1830 "[debug] input_ids shape {:?}",
1831 value_shapes.get("input_ids")
1832 );
1833 crate::debug_println!(
1834 "[debug] ln div shape {:?}",
1835 value_shapes.get("/decoder/block.0/layer.0/layer_norm/Div_output_0")
1836 );
1837
1838 let consts_before = const_values.len();
1839
1840 if let Some(val) = const_values.get("/model/rotary_emb/Where_output_0") {
1842 crate::debug_println!("[PROP BEFORE] /model/rotary_emb/Where_output_0 = {:?}", val);
1843 }
1844
1845 for node in onnx_graph.node.as_slice() {
1847 let op_type = node.op_type.as_str();
1848 if op_type == "Shape" {
1849 if let (Some(inp), Some(out)) = (
1850 node.input.as_slice().first(),
1851 node.output.as_slice().first(),
1852 ) {
1853 let out = out.to_string();
1854 if let Some(shape) = value_shapes.get(inp).cloned() {
1855 if shape.iter().all(|d| *d > 0) {
1856 if options.experimental_dynamic_inputs {
1859 let inp_s = inp.to_string();
1860 if let Some(dims) = value_shape_dims.get(&inp_s).or_else(|| {
1861 value_shape_dims.get(&sanitize_identifier(&inp_s))
1862 }) {
1863 let out_dims: Vec<crate::ast::Dimension> = dims
1867 .iter()
1868 .map(|d| match d {
1869 crate::ast::Dimension::Dynamic(dd) => {
1870 crate::ast::Dimension::Dynamic(dd.clone())
1871 }
1872 crate::ast::Dimension::Static(v) => {
1873 crate::ast::Dimension::Static(*v)
1874 }
1875 })
1876 .collect();
1877 value_shape_dims.insert(out.clone(), out_dims);
1878 }
1879 }
1880 const_values.insert(out.clone(), shape.clone());
1881 let inferred_shape = vec![shape.len() as i64];
1882 value_shapes.insert(out.clone(), inferred_shape.clone());
1884 value_shapes.insert(sanitize_identifier(&out), inferred_shape);
1885 value_types.insert(out, DataType::Int64);
1886 }
1887 }
1888 }
1889 } else if op_type == "Gather" {
1890 if let (Some(data_name), Some(indices_name), Some(out)) = (
1891 node.input.as_slice().first(),
1892 node.input.as_slice().get(1),
1893 node.output.as_slice().first(),
1894 ) {
1895 if let (Some(data), Some(indices)) =
1896 (const_values.get(data_name), const_values.get(indices_name))
1897 {
1898 let axis = node
1899 .attribute
1900 .as_slice()
1901 .iter()
1902 .find(|a| a.name.as_str() == "axis" && a.i != 0)
1903 .map(|a| a.i)
1904 .unwrap_or(0);
1905
1906 if axis == 0 {
1907 let mut gathered = Vec::new();
1908 let mut gathered_dims = Vec::new();
1909 let data_dims = if options.experimental_dynamic_inputs {
1910 value_shape_dims
1911 .get(data_name)
1912 .or_else(|| {
1913 value_shape_dims.get(&sanitize_identifier(data_name))
1914 })
1915 .cloned()
1916 } else {
1917 None
1918 };
1919 for &idx in indices {
1920 let i = if idx < 0 {
1921 (data.len() as i64 + idx) as usize
1922 } else {
1923 idx as usize
1924 };
1925 if let Some(v) = data.get(i) {
1926 gathered.push(*v);
1927 if let Some(ref dd) = data_dims {
1928 if let Some(dim) = dd.get(i) {
1929 gathered_dims.push(dim.clone());
1930 }
1931 }
1932 }
1933 }
1934 if !gathered.is_empty() {
1935 if options.experimental_dynamic_inputs
1936 && gathered_dims.len() == gathered.len()
1937 && gathered_dims
1938 .iter()
1939 .any(|d| matches!(d, crate::ast::Dimension::Dynamic(_)))
1940 {
1941 value_shape_dims.insert(out.to_string(), gathered_dims);
1942 }
1943 const_values.insert(out.to_string(), gathered.clone());
1944 let out_shape = if gathered.len() == 1 {
1945 Vec::new()
1946 } else {
1947 vec![gathered.len() as i64]
1948 };
1949 value_shapes.insert(out.to_string(), out_shape.clone());
1951 value_shapes.insert(sanitize_identifier(out), out_shape);
1952 value_types.insert(out.to_string(), DataType::Int64);
1953 }
1954 }
1955 }
1956 }
1957 } else if matches!(op_type, "Add" | "Sub" | "Mul" | "Div") {
1958 if node.input.as_slice().len() >= 2 {
1959 if let (Some(a_name), Some(b_name), Some(out)) = (
1960 node.input.as_slice().first(),
1961 node.input.as_slice().get(1),
1962 node.output.as_slice().first(),
1963 ) {
1964 let a = const_values.get(a_name);
1965 let b = const_values.get(b_name);
1966 if let (Some(a), Some(b)) = (a, b) {
1967 let a_shape = const_shape_for_folding(a_name, a, &value_shapes);
1968 let b_shape = const_shape_for_folding(b_name, b, &value_shapes);
1969 if let Some((result_vals, out_shape)) =
1970 fold_binary_const_i64(op_type, a, b, &a_shape, &b_shape)
1971 {
1972 if options.experimental_dynamic_inputs {
1973 let a_dims =
1974 value_shape_dims_for(a_name, &value_shape_dims);
1975 let b_dims =
1976 value_shape_dims_for(b_name, &value_shape_dims);
1977 if let Some(out_dims) = fold_binary_dynamic_dims(
1978 op_type, a, b, &a_shape, &b_shape, a_dims, b_dims,
1979 ) {
1980 value_shape_dims.insert(out.to_string(), out_dims);
1981 }
1982 }
1983 const_values.insert(out.to_string(), result_vals.clone());
1984 value_shapes.insert(out.to_string(), out_shape.clone());
1986 value_shapes.insert(sanitize_identifier(out), out_shape);
1987 if let Some(dtype) = node
1988 .input
1989 .as_slice()
1990 .iter()
1991 .find_map(|i| value_types.get(i).cloned())
1992 {
1993 value_types.insert(out.to_string(), dtype);
1994 }
1995 }
1996 }
1997 }
1998 }
1999 } else if op_type == "Cast" || op_type == "Unsqueeze" || op_type == "Squeeze" {
2000 if let (Some(inp), Some(out)) = (
2001 node.input.as_slice().first(),
2002 node.output.as_slice().first(),
2003 ) {
2004 if let Some(vals) = const_values.get(inp).cloned() {
2005 if options.experimental_dynamic_inputs {
2007 if let Some(dims) = value_shape_dims
2008 .get(inp)
2009 .or_else(|| value_shape_dims.get(&sanitize_identifier(inp)))
2010 .cloned()
2011 {
2012 value_shape_dims.insert(out.to_string(), dims);
2013 }
2014 }
2015 const_values.insert(out.to_string(), vals.clone());
2016 let out_shape = if vals.len() == 1 {
2017 Vec::new()
2018 } else {
2019 vec![vals.len() as i64]
2020 };
2021 value_shapes.insert(out.to_string(), out_shape);
2023 if let Some(dtype) = value_types.get(inp).cloned() {
2024 value_types.insert(out.to_string(), dtype);
2025 }
2026 }
2027 }
2028 } else if op_type == "Range" {
2029 if node.input.as_slice().len() == 3 {
2030 if let (Some(start_name), Some(limit_name), Some(delta_name)) = (
2031 node.input.as_slice().first(),
2032 node.input.as_slice().get(1),
2033 node.input.as_slice().get(2),
2034 ) {
2035 if options.experimental_dynamic_inputs {
2036 let start_dim = dynamic_scalar_dimension_for_value(
2037 start_name,
2038 &value_shape_dims,
2039 );
2040 if let Some(limit_dim) = dynamic_scalar_dimension_for_value(
2041 limit_name,
2042 &value_shape_dims,
2043 ) {
2044 if let (Some(start_vals), Some(delta_vals), Some(out)) = (
2045 const_values.get(start_name),
2046 const_values.get(delta_name),
2047 node.output.as_slice().first(),
2048 ) {
2049 if !start_vals.is_empty() && !delta_vals.is_empty() {
2050 let start = start_vals[0];
2051 let delta = delta_vals[0];
2052 if let Some(range_dim) = dynamic_range_length_dimension(
2053 start,
2054 delta,
2055 start_dim.as_ref(),
2056 &limit_dim,
2057 ) {
2058 let out_shape = vec![range_dim.max_size as i64];
2059 value_shape_dims.insert(
2060 out.to_string(),
2061 vec![Dimension::Dynamic(range_dim.clone())],
2062 );
2063 value_shapes
2064 .insert(out.to_string(), out_shape.clone());
2065 value_shapes
2066 .insert(sanitize_identifier(out), out_shape);
2067 value_types
2068 .insert(out.to_string(), DataType::Int64);
2069 }
2070 }
2071 }
2072 continue;
2073 }
2074 }
2075
2076 if let (Some(start_vals), Some(limit_vals), Some(delta_vals)) = (
2078 const_values.get(start_name),
2079 const_values.get(limit_name),
2080 const_values.get(delta_name),
2081 ) {
2082 if !start_vals.is_empty()
2083 && !limit_vals.is_empty()
2084 && !delta_vals.is_empty()
2085 {
2086 let start = start_vals[0];
2087 let limit = limit_vals[0];
2088 let delta = delta_vals[0];
2089
2090 let mut range_vals = Vec::new();
2091 if delta > 0 {
2092 let mut current = start;
2093 while current < limit {
2094 range_vals.push(current);
2095 current += delta;
2096 }
2097 } else if delta < 0 {
2098 let mut current = start;
2099 while current > limit {
2100 range_vals.push(current);
2101 current += delta;
2102 }
2103 }
2104
2105 if let Some(out) = node.output.as_slice().first() {
2106 const_values.insert(out.to_string(), range_vals.clone());
2107 let out_shape = vec![range_vals.len() as i64];
2108 value_shapes.insert(out.to_string(), out_shape.clone());
2110 value_shapes.insert(sanitize_identifier(out), out_shape);
2111 value_types.insert(out.to_string(), DataType::Int64);
2112 }
2113 }
2114 }
2115 }
2116 }
2117 } else if op_type == "Concat" {
2118 if let Some(out) = node.output.as_slice().first() {
2120 let mut concatenated: Vec<i64> = Vec::new();
2121 let mut all_const = true;
2122 for inp in node.input.as_slice() {
2123 if let Some(vals) = const_values.get(inp) {
2124 concatenated.extend_from_slice(vals);
2125 } else {
2126 all_const = false;
2127 break;
2128 }
2129 }
2130
2131 let axis = node
2133 .attribute
2134 .as_slice()
2135 .iter()
2136 .find(|a| a.name.as_str() == "axis" && a.i != 0)
2137 .map(|a| a.i)
2138 .unwrap_or(0);
2139
2140 if all_const && (axis == 0 || axis == -1) {
2141 if out.contains("rotary") && out.contains("Where") {
2142 crate::debug_println!(
2143 "[CONCAT WRITE] Writing {} = {:?}",
2144 out,
2145 concatenated
2146 );
2147 }
2148 if options.experimental_dynamic_inputs {
2150 let mut concat_dims: Vec<crate::ast::Dimension> = Vec::new();
2151 let mut has_dynamic = false;
2152 for inp in node.input.as_slice() {
2153 let inp_s = inp.to_string();
2154 if let Some(dims) = value_shape_dims.get(&inp_s).or_else(|| {
2155 value_shape_dims.get(&sanitize_identifier(&inp_s))
2156 }) {
2157 for d in dims {
2158 if matches!(d, crate::ast::Dimension::Dynamic(_)) {
2159 has_dynamic = true;
2160 }
2161 concat_dims.push(d.clone());
2162 }
2163 } else if let Some(vals) = const_values.get(inp) {
2164 for v in vals {
2165 concat_dims
2166 .push(crate::ast::Dimension::Static(*v as u32));
2167 }
2168 }
2169 }
2170 if has_dynamic && concat_dims.len() == concatenated.len() {
2171 value_shape_dims.insert(out.to_string(), concat_dims);
2172 }
2173 }
2174 const_values.insert(out.to_string(), concatenated.clone());
2175 let out_shape = vec![concatenated.len() as i64];
2176 value_shapes.insert(out.to_string(), out_shape.clone());
2178 value_shapes.insert(sanitize_identifier(out), out_shape);
2179 value_types.insert(out.to_string(), DataType::Int64);
2180 }
2181 }
2182 } else if op_type == "ConstantOfShape" {
2183 if let Some(shape_name) = node.input.as_slice().first() {
2185 let dynamic_output_dims = if options.experimental_dynamic_inputs {
2186 value_shape_dims_for(shape_name, &value_shape_dims)
2187 .map(|dims| dims.to_vec())
2188 .filter(|dims| dims_contain_dynamic(dims))
2189 } else {
2190 None
2191 };
2192
2193 if let (Some(out), Some(dims)) =
2194 (node.output.as_slice().first(), dynamic_output_dims.as_ref())
2195 {
2196 value_shape_dims.insert(out.to_string(), dims.to_vec());
2197 const_values.remove(out.as_str());
2198 }
2199
2200 if let Some(shape_vals) = const_values.get(shape_name).cloned() {
2201 let mut fill_value = 0i64;
2203 for attr in node.attribute.as_slice() {
2204 if attr.name.as_str() == "value" {
2205 if let Some(value_tensor) = attr.t.as_ref() {
2206 if value_tensor.data_type
2207 == crate::protos::onnx::TensorProto_DataType::Int64
2208 as i32
2209 {
2210 let raw = value_tensor.raw_data.as_slice();
2211 if !raw.is_empty() && raw.len() >= 8 {
2212 fill_value = i64::from_le_bytes([
2213 raw[0], raw[1], raw[2], raw[3], raw[4], raw[5],
2214 raw[6], raw[7],
2215 ]);
2216 } else if !value_tensor.int64_data.as_slice().is_empty()
2217 {
2218 fill_value = value_tensor.int64_data.as_slice()[0];
2219 }
2220 }
2221 }
2222 }
2223 }
2224
2225 let numel = if shape_vals.is_empty() {
2227 1
2228 } else {
2229 shape_vals.iter().product::<i64>()
2230 };
2231
2232 if numel > 0 && numel < 1_000_000 {
2233 let filled_tensor = vec![fill_value; numel as usize];
2235 if let Some(out) = node.output.as_slice().first() {
2236 let should_keep_const = dynamic_output_dims
2237 .as_ref()
2238 .is_none_or(|dims| !dims_contain_dynamic(dims));
2239 if should_keep_const {
2240 const_values.insert(out.to_string(), filled_tensor);
2241 } else {
2242 const_values.remove(out.as_str());
2243 }
2244 value_shapes.insert(out.to_string(), shape_vals.clone());
2246 value_shapes
2247 .insert(sanitize_identifier(out), shape_vals.clone());
2248 value_types.insert(out.to_string(), DataType::Int64);
2249 }
2250 }
2251 }
2252 }
2253 } else if op_type == "Equal" {
2254 if node.input.as_slice().len() >= 2 {
2256 if let (Some(a_name), Some(b_name), Some(out)) = (
2257 node.input.as_slice().first(),
2258 node.input.as_slice().get(1),
2259 node.output.as_slice().first(),
2260 ) {
2261 let a = const_values.get(a_name);
2262 let b = const_values.get(b_name);
2263 if let (Some(a), Some(b)) = (a, b) {
2264 let a_shape = const_shape_for_folding(a_name, a, &value_shapes);
2265 let b_shape = const_shape_for_folding(b_name, b, &value_shapes);
2266 if let Some((result_vals, out_shape)) =
2267 fold_binary_const_i64("Equal", a, b, &a_shape, &b_shape)
2268 {
2269 const_values.insert(out.to_string(), result_vals.clone());
2270 value_shapes.insert(out.to_string(), out_shape.clone());
2272 value_shapes.insert(sanitize_identifier(out), out_shape);
2273 value_types.insert(out.to_string(), DataType::Int64);
2274 }
2275 }
2276 }
2277 }
2278 } else if op_type == "Where" {
2279 if options.experimental_dynamic_inputs && node.input.as_slice().len() >= 3 {
2280 if let Some(out) = node.output.as_slice().first() {
2281 let cond = const_values.get(node.input.as_slice()[0].as_str());
2282 let a_dims = dimension_vector_for_value(
2283 node.input.as_slice()[1].as_str(),
2284 &const_values,
2285 &value_shape_dims,
2286 );
2287 let b_dims = dimension_vector_for_value(
2288 node.input.as_slice()[2].as_str(),
2289 &const_values,
2290 &value_shape_dims,
2291 );
2292 let out_dims = if let (Some(cond), Some(a_dims), Some(b_dims)) =
2293 (cond, a_dims.as_ref(), b_dims.as_ref())
2294 {
2295 if cond.len() == 1 && a_dims.len() == b_dims.len() {
2296 Some(if cond[0] != 0 {
2297 a_dims.clone()
2298 } else {
2299 b_dims.clone()
2300 })
2301 } else if cond.len() == a_dims.len() && cond.len() == b_dims.len() {
2302 Some(
2303 cond.iter()
2304 .enumerate()
2305 .map(|(idx, c)| {
2306 if *c != 0 {
2307 a_dims[idx].clone()
2308 } else {
2309 b_dims[idx].clone()
2310 }
2311 })
2312 .collect(),
2313 )
2314 } else {
2315 None
2316 }
2317 } else if let (Some(a_dims), Some(b_dims)) =
2318 (a_dims.as_ref(), b_dims.as_ref())
2319 {
2320 let a_has_dynamic =
2321 a_dims.iter().any(|d| matches!(d, Dimension::Dynamic(_)));
2322 let b_has_dynamic =
2323 b_dims.iter().any(|d| matches!(d, Dimension::Dynamic(_)));
2324 if a_has_dynamic && !b_has_dynamic {
2325 Some(a_dims.clone())
2326 } else if b_has_dynamic && !a_has_dynamic {
2327 Some(b_dims.clone())
2328 } else if a_has_dynamic
2329 && b_has_dynamic
2330 && a_dims.len() == b_dims.len()
2331 {
2332 Some(
2333 a_dims
2334 .iter()
2335 .zip(b_dims.iter())
2336 .map(|(a_dim, b_dim)| match (a_dim, b_dim) {
2337 (Dimension::Dynamic(dim), _) => {
2338 Dimension::Dynamic(dim.clone())
2339 }
2340 (_, Dimension::Dynamic(dim)) => {
2341 Dimension::Dynamic(dim.clone())
2342 }
2343 (Dimension::Static(v), _) => Dimension::Static(*v),
2344 })
2345 .collect(),
2346 )
2347 } else {
2348 None
2349 }
2350 } else if let Some(a_dims) = a_dims.as_ref() {
2351 if a_dims.iter().any(|d| matches!(d, Dimension::Dynamic(_)))
2352 && !is_trivial_static_dimension_vector(a_dims)
2353 {
2354 Some(a_dims.clone())
2355 } else {
2356 None
2357 }
2358 } else if let Some(b_dims) = b_dims.as_ref() {
2359 if b_dims.iter().any(|d| matches!(d, Dimension::Dynamic(_)))
2360 && !is_trivial_static_dimension_vector(b_dims)
2361 {
2362 Some(b_dims.clone())
2363 } else {
2364 None
2365 }
2366 } else {
2367 None
2368 };
2369
2370 if let Some(out_dims) = out_dims {
2371 if out_dims.iter().any(|d| matches!(d, Dimension::Dynamic(_))) {
2372 value_shape_dims.insert(out.to_string(), out_dims);
2373 }
2374 }
2375 }
2376 }
2377 continue;
2380 }
2381 }
2382
2383 if const_values.len() == consts_before {
2384 break;
2385 }
2386
2387 if let Some(val) = const_values.get("/model/rotary_emb/Where_output_0") {
2389 crate::debug_println!("[PROP AFTER] /model/rotary_emb/Where_output_0 = {:?}", val);
2390 }
2391 }
2392
2393 if let Some(val) = const_values.get("/model/rotary_emb/Where_output_0") {
2395 crate::debug_println!("[NODE CONV] /model/rotary_emb/Where_output_0 = {:?}", val);
2396 }
2397 for onnx_node in onnx_graph.node.as_slice() {
2398 let outputs = onnx_node.output.as_slice();
2400 let has_dynamic_output_metadata = outputs.iter().any(|o| {
2401 value_shape_dims_for(o.as_str(), &value_shape_dims)
2402 .map(|dims| dims.iter().any(|d| matches!(d, Dimension::Dynamic(_))))
2403 .unwrap_or(false)
2404 });
2405 if !outputs.is_empty()
2406 && !has_dynamic_output_metadata
2407 && outputs
2408 .iter()
2409 .all(|o| const_values.contains_key(o.as_str()))
2410 {
2411 let all_scalar = outputs.iter().all(|o| {
2413 value_shapes
2414 .get(o.as_str())
2415 .map(|s| s.is_empty()) .unwrap_or_else(|| {
2417 const_values
2419 .get(o.as_str())
2420 .map(|v| v.len() == 1)
2421 .unwrap_or(false)
2422 })
2423 });
2424
2425 if all_scalar {
2427 for out in outputs {
2428 if let Some(values) = const_values.get(out) {
2429 let const_name = sanitize_identifier(out);
2430 let shape = value_shapes
2432 .get(out.as_str())
2433 .map(|s| s.iter().map(|&d| d as u32).collect())
2434 .unwrap_or_else(Vec::new);
2435
2436 let decl = crate::ast::ConstDecl {
2437 data_type: DataType::Int64,
2438 shape,
2439 init: crate::ast::ConstInit::InlineBytes {
2440 bytes: values[0].to_le_bytes().to_vec(),
2441 },
2442 };
2443
2444 if let Some(existing) = self.graph.consts.get(&const_name) {
2445 if existing != &decl {
2446 return Err(OnnxError::InvalidShape(format!(
2447 "Conflicting constant definitions for '{}'",
2448 const_name
2449 )));
2450 }
2451 } else {
2452 self.graph.consts.insert(const_name.clone(), decl);
2453 }
2454
2455 value_name_map.insert(out.to_string(), const_name.clone());
2456 value_name_map.insert(const_name.clone(), const_name.clone());
2457 value_types.insert(out.to_string(), DataType::Int64);
2458 value_types.insert(const_name, DataType::Int64);
2459 }
2460 }
2461 }
2462 for out in outputs {
2465 if let Some(values) = const_values.get(out) {
2466 let const_name = sanitize_identifier(out);
2467 let mut shape = value_shapes
2468 .get(out.as_str())
2469 .cloned()
2470 .unwrap_or_else(|| vec![values.len() as i64]);
2471 let declared_numel = shape
2472 .iter()
2473 .try_fold(1usize, |acc, d| usize::try_from(*d).ok().map(|v| acc * v));
2474 if declared_numel != Some(values.len()) {
2475 shape = vec![values.len() as i64];
2479 }
2480 let dtype = value_types
2481 .get(out.as_str())
2482 .cloned()
2483 .unwrap_or(DataType::Int64);
2484
2485 let mut bytes = Vec::with_capacity(values.len() * 8);
2487 for v in values {
2488 bytes.extend_from_slice(&v.to_le_bytes());
2489 }
2490
2491 let decl = crate::ast::ConstDecl {
2492 data_type: dtype.clone(),
2493 shape: shape.iter().map(|d| *d as u32).collect(),
2494 init: crate::ast::ConstInit::InlineBytes { bytes },
2495 };
2496
2497 let existing = self.graph.consts.get(&const_name).cloned();
2498 if existing.is_none() {
2499 self.graph.consts.insert(const_name.clone(), decl);
2500 }
2501
2502 value_name_map.insert(out.to_string(), const_name.clone());
2503 value_name_map.insert(const_name.clone(), const_name.clone());
2504 value_types.insert(out.to_string(), dtype.clone());
2505 value_types.insert(const_name, dtype);
2506 }
2507 }
2508 continue;
2509 }
2510
2511 let context = crate::onnx::ops::ConversionContext {
2512 initializers: &initializers_map,
2513 value_shapes: &value_shapes,
2514 value_shape_dims: &value_shape_dims,
2515 const_values: &const_values,
2516 value_ids: &value_name_map,
2517 value_types: &value_types,
2518 };
2519
2520 let converted = registry.convert_node(onnx_node, &context)?;
2521
2522 for (name, mut decl) in converted.consts {
2523 if let crate::ast::ConstInit::InlineBytes { bytes } = &decl.init {
2524 let elem_size = match decl.data_type {
2525 DataType::Float32 => 4,
2526 DataType::Float16 => 2,
2527 DataType::Int64 => 8,
2528 DataType::Uint64 => 8,
2529 DataType::Int32 => 4,
2530 DataType::Uint32 => 4,
2531 DataType::Int8 => 1,
2532 DataType::Uint8 => 1,
2533 DataType::Int4 | DataType::Uint4 => 0,
2534 };
2535 if elem_size > 0 {
2536 let declared_numel = decl
2537 .shape
2538 .iter()
2539 .try_fold(1usize, |acc, d| usize::try_from(*d).ok().map(|v| acc * v));
2540 let declared_bytes = declared_numel.map(|n| n * elem_size);
2541 if declared_bytes != Some(bytes.len()) && bytes.len() % elem_size == 0 {
2542 decl.shape = vec![(bytes.len() / elem_size) as u32];
2545 }
2546 }
2547 }
2548 let decl_dtype = decl.data_type.clone();
2549 if let Some(existing) = self.graph.consts.get(&name) {
2550 if existing != &decl {
2551 return Err(OnnxError::InvalidShape(format!(
2552 "Conflicting constant definitions for '{}'",
2553 name
2554 )));
2555 }
2556 } else {
2557 self.graph.consts.insert(name.clone(), decl);
2558 }
2559 value_name_map.insert(name.clone(), name.clone());
2560 value_types.insert(name.clone(), decl_dtype);
2561 }
2562
2563 for (onnx_out, webnn_id) in converted.output_mappings {
2564 value_name_map.insert(onnx_out.clone(), webnn_id.clone());
2565 value_name_map.insert(sanitize_identifier(&onnx_out), webnn_id.clone());
2566 }
2567
2568 for (onnx_out, dtype) in converted.output_types {
2569 if let Some(webnn_id) = value_name_map.get(&onnx_out).cloned() {
2570 value_types.insert(webnn_id, dtype);
2571 }
2572 }
2573
2574 if let Some(inferred_shape) =
2577 infer_shape(onnx_node, &value_shapes, &initializers_map, &const_values)
2578 {
2579 for output_name in onnx_node.output.as_slice() {
2580 value_shapes.insert(output_name.to_string(), inferred_shape.clone());
2582 value_shapes.insert(sanitize_identifier(output_name), inferred_shape.clone());
2583 }
2584 }
2585
2586 self.graph.nodes.extend(converted.nodes);
2587 }
2588
2589 for output in onnx_graph.output.as_slice() {
2591 let onnx_name = output.name.as_str();
2592 if let Some(mapped) = value_name_map.get(onnx_name) {
2593 self.graph
2594 .outputs
2595 .insert(sanitize_identifier(onnx_name), mapped.clone());
2596 } else {
2597 return Err(OnnxError::InvalidShape(format!(
2598 "No WebNN value found for ONNX output '{}'",
2599 onnx_name
2600 )));
2601 }
2602 }
2603
2604 let has_dynamic_inputs = self.graph.inputs.values().any(|operand| {
2605 operand
2606 .shape
2607 .iter()
2608 .any(|dim| matches!(dim, Dimension::Dynamic(_)))
2609 });
2610 self.graph.version = if has_dynamic_inputs { 2 } else { 1 };
2611
2612 Ok(self.graph)
2613 }
2614}
2615
2616pub fn convert_onnx<P: AsRef<Path>>(
2618 onnx_path: P,
2619 mut options: ConvertOptions,
2620) -> Result<GraphJson, OnnxError> {
2621 let onnx_path_ref = onnx_path.as_ref();
2623 let onnx_bytes = fs::read(onnx_path_ref)?;
2624
2625 let mut model: ModelProto =
2627 ModelProto::decode(&onnx_bytes[..]).map_err(|e| OnnxError::ProtobufError(e.to_string()))?;
2628
2629 if options.optimize {
2631 crate::debug_println!("Running constant folding...");
2632 let evaluators = crate::onnx::constant_folding::evaluators::get_evaluators();
2633 let nodes_folded =
2634 crate::onnx::constant_folding::fold_constants_in_model(&mut model, &evaluators)?;
2635 crate::debug_println!("Constant folding: {} nodes folded", nodes_folded);
2636 }
2637
2638 if options.free_dim_overrides.is_empty() {
2640 let mut sidecar = onnx_path_ref.to_path_buf();
2641 sidecar.set_extension("dims.json");
2642 if sidecar.exists() {
2643 let content = fs::read_to_string(&sidecar)?;
2644 if let Ok(json) = serde_json::from_str::<JsonValue>(&content) {
2645 if let Some(obj) = json
2646 .get("freeDimensionOverrides")
2647 .unwrap_or(&json)
2648 .as_object()
2649 {
2650 for (name, value) in obj {
2651 if let Some(v) = value.as_u64() {
2652 options
2653 .free_dim_overrides
2654 .entry(name.clone())
2655 .or_insert(v as u32);
2656 }
2657 }
2658 }
2659 }
2660 }
2661 }
2662
2663 let converter = OnnxConverter::new(model.clone())?;
2665
2666 converter.extract_metadata()?;
2668
2669 let mut graph = converter.convert(&options)?;
2671
2672 if options.extract_weights {
2674 if let (Some(weights_path), Some(manifest_path)) =
2675 (&options.weights_path, &options.manifest_path)
2676 {
2677 extract_weights_from_onnx(&model, &mut graph, weights_path, manifest_path)?;
2678 }
2679 }
2680
2681 Ok(graph)
2682}
2683
2684fn extract_weights_from_onnx(
2687 model: &ModelProto,
2688 graph: &mut GraphJson,
2689 weights_path: &str,
2690 manifest_path: &str,
2691) -> Result<(), OnnxError> {
2692 use crate::weights::{TensorEntry, WeightsManifest};
2693
2694 if model.graph.is_none() {
2695 return Err(OnnxError::ProtobufError(
2696 "Missing graph in model".to_string(),
2697 ));
2698 }
2699
2700 let onnx_graph = model.graph.as_ref().unwrap();
2701 let mut manifest = WeightsManifest {
2702 format: "wg-weights-manifest".to_string(),
2703 version: 1,
2704 endianness: "little".to_string(),
2705 tensors: BTreeMap::new(),
2706 };
2707
2708 let mut weights_data = Vec::new();
2709 let mut current_offset = 0u64;
2710
2711 for initializer in onnx_graph.initializer.as_slice() {
2713 let name = sanitize_identifier(initializer.name.as_str());
2714
2715 let onnx_type = initializer.data_type;
2717 let data_type = map_onnx_data_type(onnx_type)?;
2718
2719 let shape: Vec<u32> = initializer
2720 .dims
2721 .as_slice()
2722 .iter()
2723 .map(|d| *d as u32)
2724 .collect();
2725 let raw_data = initializer.raw_data.as_slice();
2726
2727 let bytes_to_write: Vec<u8> = if raw_data.is_empty() {
2729 let int64_data = initializer.int64_data.as_slice();
2731 let float_data = initializer.float_data.as_slice();
2732 let int32_data = initializer.int32_data.as_slice();
2733 let double_data = initializer.double_data.as_slice();
2734
2735 if !int64_data.is_empty() {
2736 int64_data.iter().flat_map(|&v| v.to_le_bytes()).collect()
2738 } else if !float_data.is_empty() {
2739 float_data.iter().flat_map(|&v| v.to_le_bytes()).collect()
2741 } else if !int32_data.is_empty() {
2742 int32_data.iter().flat_map(|&v| v.to_le_bytes()).collect()
2744 } else if !double_data.is_empty() {
2745 double_data.iter().flat_map(|&v| v.to_le_bytes()).collect()
2747 } else {
2748 crate::debug_println!("Warning: Skipping initializer '{}' with no data", name);
2750 continue;
2751 }
2752 } else {
2753 raw_data.to_vec()
2754 };
2755
2756 let byte_length = bytes_to_write.len() as u64;
2757
2758 manifest.tensors.insert(
2760 name,
2761 TensorEntry {
2762 data_type,
2763 shape,
2764 byte_offset: current_offset,
2765 byte_length,
2766 layout: None,
2767 },
2768 );
2769
2770 weights_data.extend_from_slice(&bytes_to_write);
2772 current_offset += byte_length;
2773 }
2774
2775 const INLINE_THRESHOLD: usize = 1024;
2778 for (name, decl) in graph.consts.iter_mut() {
2779 if let crate::ast::ConstInit::InlineBytes { bytes } = &decl.init {
2780 if bytes.len() > INLINE_THRESHOLD && !manifest.tensors.contains_key(name) {
2781 let byte_length = bytes.len() as u64;
2782 manifest.tensors.insert(
2783 name.clone(),
2784 TensorEntry {
2785 data_type: decl.data_type.clone(),
2786 shape: decl.shape.clone(),
2787 byte_offset: current_offset,
2788 byte_length,
2789 layout: None,
2790 },
2791 );
2792 weights_data.extend_from_slice(bytes);
2793 current_offset += byte_length;
2794 }
2795 }
2796 }
2797 for (name, decl) in graph.consts.iter_mut() {
2799 if let crate::ast::ConstInit::InlineBytes { bytes } = &decl.init {
2800 if bytes.len() > INLINE_THRESHOLD {
2801 decl.init = crate::ast::ConstInit::Weights {
2802 r#ref: name.clone(),
2803 };
2804 }
2805 }
2806 }
2807
2808 fs::write(weights_path, &weights_data)?;
2810
2811 let manifest_json = serde_json::to_string_pretty(&manifest)
2813 .map_err(|e| OnnxError::ProtobufError(e.to_string()))?;
2814 fs::write(manifest_path, manifest_json)?;
2815
2816 Ok(())
2817}
2818
2819#[cfg(test)]
2820mod tests {
2821 use super::*;
2822
2823 #[test]
2824 fn test_convert_options_default() {
2825 let options = ConvertOptions::default();
2826 assert!(options.extract_weights);
2827 assert_eq!(options.output_path, "output.webnn");
2828 }
2829
2830 #[test]
2831 fn test_sanitize_identifier_replaces_colons() {
2832 assert_eq!(sanitize_identifier("foo::bar"), "foo__bar");
2833 assert_eq!(sanitize_identifier("foo:bar"), "foo_bar");
2834 }
2835
2836 #[test]
2837 fn test_sanitize_identifier_replaces_dots() {
2838 assert_eq!(sanitize_identifier("encoder.block.0"), "encoder_block_0");
2839 assert_eq!(
2840 sanitize_identifier("model.layer.weight"),
2841 "model_layer_weight"
2842 );
2843 assert_eq!(sanitize_identifier("a.b.c"), "a_b_c");
2844 }
2845
2846 #[test]
2847 fn test_sanitize_identifier_replaces_combined() {
2848 assert_eq!(
2850 sanitize_identifier("module::class:method.field"),
2851 "module__class_method_field"
2852 );
2853 assert_eq!(
2854 sanitize_identifier("encoder.attention::output:dense"),
2855 "encoder_attention__output_dense"
2856 );
2857 }
2858
2859 #[test]
2860 fn test_sanitize_identifier_no_change() {
2861 assert_eq!(sanitize_identifier("simple_name"), "simple_name");
2863 assert_eq!(sanitize_identifier("CamelCase"), "CamelCase");
2864 assert_eq!(sanitize_identifier("name123"), "name123");
2865 }
2866
2867 #[test]
2868 fn test_inline_bytes_encoding_for_i64_values() {
2869 let values: Vec<i64> = vec![0, 1, 2, 3, 4];
2872 let mut bytes = Vec::with_capacity(values.len() * 8);
2873 for v in values {
2874 bytes.extend_from_slice(&v.to_le_bytes());
2875 }
2876
2877 assert_eq!(bytes.len(), 40); let first_bytes: [u8; 8] = bytes[0..8].try_into().unwrap();
2882 assert_eq!(i64::from_le_bytes(first_bytes), 0);
2883
2884 let last_bytes: [u8; 8] = bytes[32..40].try_into().unwrap();
2886 assert_eq!(i64::from_le_bytes(last_bytes), 4);
2887 }
2888
2889 #[test]
2890 fn test_inline_bytes_encoding_single_value() {
2891 let values: Vec<i64> = vec![42];
2893 let mut bytes = Vec::with_capacity(values.len() * 8);
2894 for v in values {
2895 bytes.extend_from_slice(&v.to_le_bytes());
2896 }
2897
2898 assert_eq!(bytes.len(), 8);
2899 let decoded: [u8; 8] = bytes.try_into().unwrap();
2900 assert_eq!(i64::from_le_bytes(decoded), 42);
2901 }
2902
2903 #[test]
2904 fn test_inline_bytes_encoding_negative_values() {
2905 let values: Vec<i64> = vec![5, 4, 3, 2, 1, 0, -1, -2];
2907 let mut bytes = Vec::with_capacity(values.len() * 8);
2908 for v in values {
2909 bytes.extend_from_slice(&v.to_le_bytes());
2910 }
2911
2912 assert_eq!(bytes.len(), 64); let neg_bytes: [u8; 8] = bytes[56..64].try_into().unwrap();
2916 assert_eq!(i64::from_le_bytes(neg_bytes), -2);
2917 }
2918
2919 #[test]
2920 fn test_inline_bytes_encoding_large_values() {
2921 let values: Vec<i64> = vec![i64::MAX, i64::MIN, 0];
2923 let mut bytes = Vec::with_capacity(values.len() * 8);
2924 for v in values {
2925 bytes.extend_from_slice(&v.to_le_bytes());
2926 }
2927
2928 assert_eq!(bytes.len(), 24);
2929
2930 let max_bytes: [u8; 8] = bytes[0..8].try_into().unwrap();
2932 assert_eq!(i64::from_le_bytes(max_bytes), i64::MAX);
2933
2934 let min_bytes: [u8; 8] = bytes[8..16].try_into().unwrap();
2936 assert_eq!(i64::from_le_bytes(min_bytes), i64::MIN);
2937 }
2938
2939 #[test]
2940 fn test_convert_preserves_dynamic_input_dim_without_override() {
2941 use crate::protos::onnx::{tensor_shape_proto, type_proto};
2942 use crate::protos::onnx::{GraphProto, ModelProto, TensorShapeProto, ValueInfoProto};
2943
2944 let dim_batch = tensor_shape_proto::Dimension {
2945 value: Some(tensor_shape_proto::dimension::Value::DimParam(
2946 "batch_size".to_string(),
2947 )),
2948 denotation: String::new(),
2949 };
2950 let dim_seq = tensor_shape_proto::Dimension {
2951 value: Some(tensor_shape_proto::dimension::Value::DimValue(1)),
2952 denotation: String::new(),
2953 };
2954 let shape = TensorShapeProto {
2955 dim: vec![dim_batch, dim_seq],
2956 };
2957
2958 let tensor_type = type_proto::Tensor {
2959 elem_type: TensorProto_DataType::Int64.into(),
2960 shape: Some(shape),
2961 };
2962 let type_proto = crate::protos::onnx::TypeProto {
2963 value: Some(type_proto::Value::TensorType(tensor_type)),
2964 denotation: String::new(),
2965 };
2966
2967 let input_vi = ValueInfoProto {
2968 name: "input_ids".to_string(),
2969 r#type: Some(type_proto.clone()),
2970 ..Default::default()
2971 };
2972 let output_vi = ValueInfoProto {
2973 name: "input_ids".to_string(),
2974 r#type: Some(type_proto),
2975 ..Default::default()
2976 };
2977
2978 let model = ModelProto {
2979 graph: Some(GraphProto {
2980 input: vec![input_vi],
2981 output: vec![output_vi],
2982 ..Default::default()
2983 }),
2984 ..Default::default()
2985 };
2986
2987 let converter = OnnxConverter::new(model).expect("converter");
2988 let graph = converter
2989 .convert(&ConvertOptions {
2990 experimental_dynamic_inputs: true,
2991 ..ConvertOptions::default()
2992 })
2993 .expect("convert");
2994
2995 let input = graph.inputs.get("input_ids").expect("input_ids input");
2996 assert_eq!(input.shape.len(), 2);
2997 assert!(matches!(
2998 &input.shape[0],
2999 Dimension::Dynamic(d) if d.name == "batch_size"
3000 ));
3001 assert!(matches!(&input.shape[1], Dimension::Static(1)));
3002 assert_eq!(graph.version, 2);
3003 }
3004
3005 #[test]
3006 fn test_convert_rejects_dynamic_input_dim_without_flag() {
3007 use crate::protos::onnx::{tensor_shape_proto, type_proto};
3008 use crate::protos::onnx::{GraphProto, ModelProto, TensorShapeProto, ValueInfoProto};
3009
3010 let dim_batch = tensor_shape_proto::Dimension {
3011 value: Some(tensor_shape_proto::dimension::Value::DimParam(
3012 "unknown_dim".to_string(),
3013 )),
3014 denotation: String::new(),
3015 };
3016 let dim_seq = tensor_shape_proto::Dimension {
3017 value: Some(tensor_shape_proto::dimension::Value::DimValue(1)),
3018 denotation: String::new(),
3019 };
3020 let shape = TensorShapeProto {
3021 dim: vec![dim_batch, dim_seq],
3022 };
3023
3024 let tensor_type = type_proto::Tensor {
3025 elem_type: TensorProto_DataType::Int64.into(),
3026 shape: Some(shape),
3027 };
3028 let type_proto = crate::protos::onnx::TypeProto {
3029 value: Some(type_proto::Value::TensorType(tensor_type)),
3030 denotation: String::new(),
3031 };
3032
3033 let input_vi = ValueInfoProto {
3034 name: "input_ids".to_string(),
3035 r#type: Some(type_proto.clone()),
3036 ..Default::default()
3037 };
3038 let output_vi = ValueInfoProto {
3039 name: "input_ids".to_string(),
3040 r#type: Some(type_proto),
3041 ..Default::default()
3042 };
3043
3044 let model = ModelProto {
3045 graph: Some(GraphProto {
3046 input: vec![input_vi],
3047 output: vec![output_vi],
3048 ..Default::default()
3049 }),
3050 ..Default::default()
3051 };
3052
3053 let converter = OnnxConverter::new(model).expect("converter");
3054 let err = converter
3055 .convert(&ConvertOptions::default())
3056 .expect_err("should require overrides or flag");
3057 let msg = err.to_string();
3058 assert!(msg.contains("override-dim"));
3059 assert!(msg.contains("experimental-dynamic-inputs"));
3060 }
3061
3062 #[test]
3063 fn test_convert_dynamic_shape_concat_reshape_path_with_experimental_flag() {
3064 use crate::protos::onnx::{tensor_shape_proto, type_proto};
3065 use crate::protos::onnx::{
3066 AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, TensorShapeProto,
3067 ValueInfoProto,
3068 };
3069
3070 let batch_dim = tensor_shape_proto::Dimension {
3071 value: Some(tensor_shape_proto::dimension::Value::DimValue(1)),
3072 denotation: String::new(),
3073 };
3074 let seq_dim = tensor_shape_proto::Dimension {
3075 value: Some(tensor_shape_proto::dimension::Value::DimParam(
3076 "sequence_length".to_string(),
3077 )),
3078 denotation: String::new(),
3079 };
3080 let hidden_dim = tensor_shape_proto::Dimension {
3081 value: Some(tensor_shape_proto::dimension::Value::DimValue(4)),
3082 denotation: String::new(),
3083 };
3084 let data_shape = TensorShapeProto {
3085 dim: vec![batch_dim, seq_dim, hidden_dim],
3086 };
3087
3088 let data_tensor_type = type_proto::Tensor {
3089 elem_type: TensorProto_DataType::Float.into(),
3090 shape: Some(data_shape),
3091 };
3092 let data_type_proto = crate::protos::onnx::TypeProto {
3093 value: Some(type_proto::Value::TensorType(data_tensor_type)),
3094 denotation: String::new(),
3095 };
3096
3097 let data_input = ValueInfoProto {
3098 name: "data".to_string(),
3099 r#type: Some(data_type_proto.clone()),
3100 ..Default::default()
3101 };
3102 let data_output = ValueInfoProto {
3103 name: "out".to_string(),
3104 r#type: Some(data_type_proto),
3105 ..Default::default()
3106 };
3107
3108 let idx0 = TensorProto {
3109 name: "idx0".to_string(),
3110 data_type: TensorProto_DataType::Int64 as i32,
3111 dims: vec![1],
3112 int64_data: vec![0],
3113 ..Default::default()
3114 };
3115 let idx1 = TensorProto {
3116 name: "idx1".to_string(),
3117 data_type: TensorProto_DataType::Int64 as i32,
3118 dims: vec![1],
3119 int64_data: vec![1],
3120 ..Default::default()
3121 };
3122 let last_dim = TensorProto {
3123 name: "last_dim".to_string(),
3124 data_type: TensorProto_DataType::Int64 as i32,
3125 dims: vec![1],
3126 int64_data: vec![4],
3127 ..Default::default()
3128 };
3129
3130 let shape_node = NodeProto {
3131 op_type: "Shape".to_string(),
3132 input: vec!["data".to_string()],
3133 output: vec!["shape_out".to_string()],
3134 ..Default::default()
3135 };
3136 let gather0 = NodeProto {
3137 op_type: "Gather".to_string(),
3138 input: vec!["shape_out".to_string(), "idx0".to_string()],
3139 output: vec!["dim0".to_string()],
3140 attribute: vec![AttributeProto {
3141 name: "axis".to_string(),
3142 i: 0,
3143 ..Default::default()
3144 }],
3145 ..Default::default()
3146 };
3147 let gather1 = NodeProto {
3148 op_type: "Gather".to_string(),
3149 input: vec!["shape_out".to_string(), "idx1".to_string()],
3150 output: vec!["dim1".to_string()],
3151 attribute: vec![AttributeProto {
3152 name: "axis".to_string(),
3153 i: 0,
3154 ..Default::default()
3155 }],
3156 ..Default::default()
3157 };
3158 let concat_shape = NodeProto {
3159 op_type: "Concat".to_string(),
3160 input: vec![
3161 "dim0".to_string(),
3162 "dim1".to_string(),
3163 "last_dim".to_string(),
3164 ],
3165 output: vec!["shape_for_reshape".to_string()],
3166 attribute: vec![AttributeProto {
3167 name: "axis".to_string(),
3168 i: 0,
3169 ..Default::default()
3170 }],
3171 ..Default::default()
3172 };
3173 let reshape = NodeProto {
3174 op_type: "Reshape".to_string(),
3175 input: vec!["data".to_string(), "shape_for_reshape".to_string()],
3176 output: vec!["out".to_string()],
3177 ..Default::default()
3178 };
3179
3180 let model = ModelProto {
3181 graph: Some(GraphProto {
3182 input: vec![data_input],
3183 output: vec![data_output],
3184 initializer: vec![idx0, idx1, last_dim],
3185 node: vec![shape_node, gather0, gather1, concat_shape, reshape],
3186 ..Default::default()
3187 }),
3188 ..Default::default()
3189 };
3190
3191 let converter = OnnxConverter::new(model).expect("converter");
3192 let graph = converter
3193 .convert(&ConvertOptions {
3194 optimize: true,
3195 experimental_dynamic_inputs: true,
3196 extract_weights: false,
3197 ..ConvertOptions::default()
3198 })
3199 .expect("dynamic reshape path should convert");
3200
3201 let reshape_node = graph
3202 .nodes
3203 .iter()
3204 .find(|n| n.op == "reshape")
3205 .expect("reshape node should exist");
3206 let shape = reshape_node
3207 .options
3208 .get("newShape")
3209 .and_then(|v| v.as_array())
3210 .expect("newShape should be an array");
3211 assert_eq!(shape.len(), 3);
3212 assert_eq!(shape[0].as_u64(), Some(1));
3213 assert_eq!(shape[2].as_u64(), Some(4));
3214 let dim1_ok = shape[1].as_u64().is_some_and(|v| v > 0)
3218 || shape[1].as_object().is_some_and(|o| {
3219 o.contains_key("name")
3220 && o.get("maxSize")
3221 .and_then(|v| v.as_u64())
3222 .is_some_and(|v| v > 0)
3223 });
3224 assert!(
3225 dim1_ok,
3226 "sequence dimension should be concretized or dynamic for lowering, got: {:?}",
3227 shape[1]
3228 );
3229 }
3230
3231 #[test]
3232 fn test_convert_reshape_shape_path_survives_add_broadcast() {
3233 use crate::protos::onnx::{tensor_shape_proto, type_proto};
3234 use crate::protos::onnx::{
3235 AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, TensorShapeProto,
3236 ValueInfoProto,
3237 };
3238
3239 let batch_dim = tensor_shape_proto::Dimension {
3240 value: Some(tensor_shape_proto::dimension::Value::DimValue(1)),
3241 denotation: String::new(),
3242 };
3243 let seq_dim = tensor_shape_proto::Dimension {
3244 value: Some(tensor_shape_proto::dimension::Value::DimValue(128)),
3245 denotation: String::new(),
3246 };
3247 let hidden_dim = tensor_shape_proto::Dimension {
3248 value: Some(tensor_shape_proto::dimension::Value::DimValue(4)),
3249 denotation: String::new(),
3250 };
3251 let data_shape = TensorShapeProto {
3252 dim: vec![batch_dim, seq_dim, hidden_dim],
3253 };
3254
3255 let data_tensor_type = type_proto::Tensor {
3256 elem_type: TensorProto_DataType::Float.into(),
3257 shape: Some(data_shape),
3258 };
3259 let data_type_proto = crate::protos::onnx::TypeProto {
3260 value: Some(type_proto::Value::TensorType(data_tensor_type)),
3261 denotation: String::new(),
3262 };
3263
3264 let data_input = ValueInfoProto {
3265 name: "data".to_string(),
3266 r#type: Some(data_type_proto.clone()),
3267 ..Default::default()
3268 };
3269 let data_output = ValueInfoProto {
3270 name: "out".to_string(),
3271 r#type: Some(data_type_proto),
3272 ..Default::default()
3273 };
3274
3275 let bias = TensorProto {
3276 name: "bias".to_string(),
3277 data_type: TensorProto_DataType::Float as i32,
3278 dims: vec![4],
3279 float_data: vec![0.0, 0.0, 0.0, 0.0],
3280 ..Default::default()
3281 };
3282 let idx0 = TensorProto {
3283 name: "idx0".to_string(),
3284 data_type: TensorProto_DataType::Int64 as i32,
3285 dims: vec![1],
3286 int64_data: vec![0],
3287 ..Default::default()
3288 };
3289 let idx1 = TensorProto {
3290 name: "idx1".to_string(),
3291 data_type: TensorProto_DataType::Int64 as i32,
3292 dims: vec![1],
3293 int64_data: vec![1],
3294 ..Default::default()
3295 };
3296 let last_dim = TensorProto {
3297 name: "last_dim".to_string(),
3298 data_type: TensorProto_DataType::Int64 as i32,
3299 dims: vec![1],
3300 int64_data: vec![4],
3301 ..Default::default()
3302 };
3303
3304 let add_node = NodeProto {
3305 op_type: "Add".to_string(),
3306 input: vec!["data".to_string(), "bias".to_string()],
3307 output: vec!["add_out".to_string()],
3308 ..Default::default()
3309 };
3310 let shape_node = NodeProto {
3311 op_type: "Shape".to_string(),
3312 input: vec!["add_out".to_string()],
3313 output: vec!["shape_out".to_string()],
3314 ..Default::default()
3315 };
3316 let gather0 = NodeProto {
3317 op_type: "Gather".to_string(),
3318 input: vec!["shape_out".to_string(), "idx0".to_string()],
3319 output: vec!["dim0".to_string()],
3320 attribute: vec![AttributeProto {
3321 name: "axis".to_string(),
3322 i: 0,
3323 ..Default::default()
3324 }],
3325 ..Default::default()
3326 };
3327 let gather1 = NodeProto {
3328 op_type: "Gather".to_string(),
3329 input: vec!["shape_out".to_string(), "idx1".to_string()],
3330 output: vec!["dim1".to_string()],
3331 attribute: vec![AttributeProto {
3332 name: "axis".to_string(),
3333 i: 0,
3334 ..Default::default()
3335 }],
3336 ..Default::default()
3337 };
3338 let concat_shape = NodeProto {
3339 op_type: "Concat".to_string(),
3340 input: vec![
3341 "dim0".to_string(),
3342 "dim1".to_string(),
3343 "last_dim".to_string(),
3344 ],
3345 output: vec!["shape_for_reshape".to_string()],
3346 attribute: vec![AttributeProto {
3347 name: "axis".to_string(),
3348 i: 0,
3349 ..Default::default()
3350 }],
3351 ..Default::default()
3352 };
3353 let reshape = NodeProto {
3354 op_type: "Reshape".to_string(),
3355 input: vec!["add_out".to_string(), "shape_for_reshape".to_string()],
3356 output: vec!["out".to_string()],
3357 ..Default::default()
3358 };
3359
3360 let model = ModelProto {
3361 graph: Some(GraphProto {
3362 input: vec![data_input],
3363 output: vec![data_output],
3364 initializer: vec![bias, idx0, idx1, last_dim],
3365 node: vec![
3366 add_node,
3367 shape_node,
3368 gather0,
3369 gather1,
3370 concat_shape,
3371 reshape,
3372 ],
3373 ..Default::default()
3374 }),
3375 ..Default::default()
3376 };
3377
3378 let converter = OnnxConverter::new(model).expect("converter");
3379 let graph = converter
3380 .convert(&ConvertOptions {
3381 optimize: true,
3382 extract_weights: false,
3383 ..ConvertOptions::default()
3384 })
3385 .expect("broadcasted shape path should convert");
3386
3387 let reshape_node = graph
3388 .nodes
3389 .iter()
3390 .find(|n| n.op == "reshape")
3391 .expect("reshape node should exist");
3392 assert_eq!(
3393 reshape_node.options.get("newShape"),
3394 Some(&serde_json::json!([1, 128, 4]))
3395 );
3396 }
3397
3398 #[test]
3399 fn test_convert_dynamic_range_lowers_to_slice_and_preserves_dynamic_reshape() {
3400 use crate::protos::onnx::{tensor_shape_proto, type_proto};
3401 use crate::protos::onnx::{
3402 AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, TensorShapeProto,
3403 ValueInfoProto,
3404 };
3405
3406 let seq_dim = tensor_shape_proto::Dimension {
3407 value: Some(tensor_shape_proto::dimension::Value::DimParam(
3408 "sequence_length".to_string(),
3409 )),
3410 denotation: String::new(),
3411 };
3412 let data_shape = TensorShapeProto { dim: vec![seq_dim] };
3413
3414 let data_tensor_type = type_proto::Tensor {
3415 elem_type: TensorProto_DataType::Float.into(),
3416 shape: Some(data_shape),
3417 };
3418 let data_type_proto = crate::protos::onnx::TypeProto {
3419 value: Some(type_proto::Value::TensorType(data_tensor_type)),
3420 denotation: String::new(),
3421 };
3422
3423 let data_input = ValueInfoProto {
3424 name: "data".to_string(),
3425 r#type: Some(data_type_proto),
3426 ..Default::default()
3427 };
3428 let output_vi = ValueInfoProto {
3429 name: "out".to_string(),
3430 ..Default::default()
3431 };
3432
3433 let idx0 = TensorProto {
3434 name: "idx0".to_string(),
3435 data_type: TensorProto_DataType::Int64 as i32,
3436 dims: vec![1],
3437 int64_data: vec![0],
3438 ..Default::default()
3439 };
3440 let zero = TensorProto {
3441 name: "zero".to_string(),
3442 data_type: TensorProto_DataType::Int64 as i32,
3443 dims: vec![],
3444 int64_data: vec![0],
3445 ..Default::default()
3446 };
3447 let one = TensorProto {
3448 name: "one".to_string(),
3449 data_type: TensorProto_DataType::Int64 as i32,
3450 dims: vec![],
3451 int64_data: vec![1],
3452 ..Default::default()
3453 };
3454
3455 let shape_node = NodeProto {
3456 op_type: "Shape".to_string(),
3457 input: vec!["data".to_string()],
3458 output: vec!["shape_out".to_string()],
3459 ..Default::default()
3460 };
3461 let gather = NodeProto {
3462 op_type: "Gather".to_string(),
3463 input: vec!["shape_out".to_string(), "idx0".to_string()],
3464 output: vec!["seq_len".to_string()],
3465 attribute: vec![AttributeProto {
3466 name: "axis".to_string(),
3467 i: 0,
3468 ..Default::default()
3469 }],
3470 ..Default::default()
3471 };
3472 let add_limit = NodeProto {
3473 op_type: "Add".to_string(),
3474 input: vec!["seq_len".to_string(), "one".to_string()],
3475 output: vec!["range_limit".to_string()],
3476 ..Default::default()
3477 };
3478 let range = NodeProto {
3479 op_type: "Range".to_string(),
3480 input: vec![
3481 "zero".to_string(),
3482 "range_limit".to_string(),
3483 "one".to_string(),
3484 ],
3485 output: vec!["range_out".to_string()],
3486 ..Default::default()
3487 };
3488 let concat_shape = NodeProto {
3489 op_type: "Concat".to_string(),
3490 input: vec!["range_limit".to_string(), "one".to_string()],
3491 output: vec!["shape_for_reshape".to_string()],
3492 attribute: vec![AttributeProto {
3493 name: "axis".to_string(),
3494 i: 0,
3495 ..Default::default()
3496 }],
3497 ..Default::default()
3498 };
3499 let reshape = NodeProto {
3500 op_type: "Reshape".to_string(),
3501 input: vec!["range_out".to_string(), "shape_for_reshape".to_string()],
3502 output: vec!["out".to_string()],
3503 ..Default::default()
3504 };
3505
3506 let model = ModelProto {
3507 graph: Some(GraphProto {
3508 input: vec![data_input],
3509 output: vec![output_vi],
3510 initializer: vec![idx0, zero, one],
3511 node: vec![shape_node, gather, add_limit, range, concat_shape, reshape],
3512 ..Default::default()
3513 }),
3514 ..Default::default()
3515 };
3516
3517 let converter = OnnxConverter::new(model).expect("converter");
3518 let graph = converter
3519 .convert(&ConvertOptions {
3520 optimize: true,
3521 experimental_dynamic_inputs: true,
3522 extract_weights: false,
3523 ..ConvertOptions::default()
3524 })
3525 .expect("dynamic range path should convert");
3526
3527 let slice_node = graph
3528 .nodes
3529 .iter()
3530 .find(|n| n.op == "slice")
3531 .expect("range should lower to slice");
3532 let slice_sizes = slice_node
3533 .options
3534 .get("sizes")
3535 .and_then(|v| v.as_array())
3536 .expect("slice sizes should exist");
3537 assert_eq!(slice_sizes.len(), 1);
3538 let dynamic_size = slice_sizes[0]
3539 .as_object()
3540 .expect("dynamic range size should be a dimension object");
3541 assert_eq!(
3542 dynamic_size.get("name").and_then(|v| v.as_str()),
3543 Some("sequence_length + 1")
3544 );
3545 assert_eq!(
3546 dynamic_size.get("maxSize").and_then(|v| v.as_u64()),
3547 Some(4097)
3548 );
3549
3550 let reshape_node = graph
3551 .nodes
3552 .iter()
3553 .find(|n| n.op == "reshape")
3554 .expect("reshape node should exist");
3555 let new_shape = reshape_node
3556 .options
3557 .get("newShape")
3558 .and_then(|v| v.as_array())
3559 .expect("reshape newShape should exist");
3560 assert_eq!(new_shape.len(), 2);
3561 assert_eq!(new_shape[1].as_u64(), Some(1));
3562 let reshape_dim0 = new_shape[0]
3563 .as_object()
3564 .expect("reshape dim 0 should stay dynamic");
3565 assert_eq!(
3566 reshape_dim0.get("name").and_then(|v| v.as_str()),
3567 Some("sequence_length + 1")
3568 );
3569 assert_eq!(
3570 reshape_dim0.get("maxSize").and_then(|v| v.as_u64()),
3571 Some(4097)
3572 );
3573 }
3574
3575 #[test]
3576 fn test_convert_dynamic_range_with_dynamic_start_lowers_to_slice_and_add() {
3577 use crate::protos::onnx::{tensor_shape_proto, type_proto};
3578 use crate::protos::onnx::{
3579 AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, TensorShapeProto,
3580 ValueInfoProto,
3581 };
3582
3583 let batch_dim = tensor_shape_proto::Dimension {
3584 value: Some(tensor_shape_proto::dimension::Value::DimValue(1)),
3585 denotation: String::new(),
3586 };
3587 let seq_dim = tensor_shape_proto::Dimension {
3588 value: Some(tensor_shape_proto::dimension::Value::DimParam(
3589 "sequence_length".to_string(),
3590 )),
3591 denotation: String::new(),
3592 };
3593 let past_dim = tensor_shape_proto::Dimension {
3594 value: Some(tensor_shape_proto::dimension::Value::DimParam(
3595 "past_sequence_length".to_string(),
3596 )),
3597 denotation: String::new(),
3598 };
3599 let heads_dim = tensor_shape_proto::Dimension {
3600 value: Some(tensor_shape_proto::dimension::Value::DimValue(3)),
3601 denotation: String::new(),
3602 };
3603 let head_dim = tensor_shape_proto::Dimension {
3604 value: Some(tensor_shape_proto::dimension::Value::DimValue(4)),
3605 denotation: String::new(),
3606 };
3607
3608 let ids_shape = TensorShapeProto {
3609 dim: vec![batch_dim.clone(), seq_dim.clone()],
3610 };
3611 let past_shape = TensorShapeProto {
3612 dim: vec![batch_dim, heads_dim, past_dim, head_dim],
3613 };
3614 let range_shape = TensorShapeProto {
3615 dim: vec![seq_dim.clone()],
3616 };
3617 let out_shape = TensorShapeProto {
3618 dim: vec![
3619 seq_dim,
3620 tensor_shape_proto::Dimension {
3621 value: Some(tensor_shape_proto::dimension::Value::DimValue(1)),
3622 denotation: String::new(),
3623 },
3624 ],
3625 };
3626
3627 let ids_tensor_type = type_proto::Tensor {
3628 elem_type: TensorProto_DataType::Int64.into(),
3629 shape: Some(ids_shape),
3630 };
3631 let past_tensor_type = type_proto::Tensor {
3632 elem_type: TensorProto_DataType::Float.into(),
3633 shape: Some(past_shape),
3634 };
3635 let range_tensor_type = type_proto::Tensor {
3636 elem_type: TensorProto_DataType::Int64.into(),
3637 shape: Some(range_shape),
3638 };
3639 let out_tensor_type = type_proto::Tensor {
3640 elem_type: TensorProto_DataType::Int64.into(),
3641 shape: Some(out_shape),
3642 };
3643
3644 let ids_input = ValueInfoProto {
3645 name: "ids".to_string(),
3646 r#type: Some(crate::protos::onnx::TypeProto {
3647 value: Some(type_proto::Value::TensorType(ids_tensor_type)),
3648 denotation: String::new(),
3649 }),
3650 ..Default::default()
3651 };
3652 let past_input = ValueInfoProto {
3653 name: "past".to_string(),
3654 r#type: Some(crate::protos::onnx::TypeProto {
3655 value: Some(type_proto::Value::TensorType(past_tensor_type)),
3656 denotation: String::new(),
3657 }),
3658 ..Default::default()
3659 };
3660 let range_vi = ValueInfoProto {
3661 name: "range_out".to_string(),
3662 r#type: Some(crate::protos::onnx::TypeProto {
3663 value: Some(type_proto::Value::TensorType(range_tensor_type)),
3664 denotation: String::new(),
3665 }),
3666 ..Default::default()
3667 };
3668 let out_vi = ValueInfoProto {
3669 name: "out".to_string(),
3670 r#type: Some(crate::protos::onnx::TypeProto {
3671 value: Some(type_proto::Value::TensorType(out_tensor_type)),
3672 denotation: String::new(),
3673 }),
3674 ..Default::default()
3675 };
3676
3677 let idx1 = TensorProto {
3678 name: "idx1".to_string(),
3679 data_type: TensorProto_DataType::Int64 as i32,
3680 dims: vec![1],
3681 int64_data: vec![1],
3682 ..Default::default()
3683 };
3684 let idx2 = TensorProto {
3685 name: "idx2".to_string(),
3686 data_type: TensorProto_DataType::Int64 as i32,
3687 dims: vec![1],
3688 int64_data: vec![2],
3689 ..Default::default()
3690 };
3691 let one = TensorProto {
3692 name: "one".to_string(),
3693 data_type: TensorProto_DataType::Int64 as i32,
3694 dims: vec![],
3695 int64_data: vec![1],
3696 ..Default::default()
3697 };
3698 let reshape_shape = TensorProto {
3699 name: "reshape_shape".to_string(),
3700 data_type: TensorProto_DataType::Int64 as i32,
3701 dims: vec![2],
3702 int64_data: vec![4096, 1],
3703 ..Default::default()
3704 };
3705
3706 let shape_past = NodeProto {
3707 op_type: "Shape".to_string(),
3708 input: vec!["past".to_string()],
3709 output: vec!["past_shape".to_string()],
3710 ..Default::default()
3711 };
3712 let gather_start = NodeProto {
3713 op_type: "Gather".to_string(),
3714 input: vec!["past_shape".to_string(), "idx2".to_string()],
3715 output: vec!["range_start".to_string()],
3716 attribute: vec![AttributeProto {
3717 name: "axis".to_string(),
3718 i: 0,
3719 ..Default::default()
3720 }],
3721 ..Default::default()
3722 };
3723 let shape_ids = NodeProto {
3724 op_type: "Shape".to_string(),
3725 input: vec!["ids".to_string()],
3726 output: vec!["ids_shape".to_string()],
3727 ..Default::default()
3728 };
3729 let gather_seq = NodeProto {
3730 op_type: "Gather".to_string(),
3731 input: vec!["ids_shape".to_string(), "idx1".to_string()],
3732 output: vec!["seq_len".to_string()],
3733 attribute: vec![AttributeProto {
3734 name: "axis".to_string(),
3735 i: 0,
3736 ..Default::default()
3737 }],
3738 ..Default::default()
3739 };
3740 let add_limit = NodeProto {
3741 op_type: "Add".to_string(),
3742 input: vec!["range_start".to_string(), "seq_len".to_string()],
3743 output: vec!["range_limit".to_string()],
3744 ..Default::default()
3745 };
3746 let range = NodeProto {
3747 op_type: "Range".to_string(),
3748 input: vec![
3749 "range_start".to_string(),
3750 "range_limit".to_string(),
3751 "one".to_string(),
3752 ],
3753 output: vec!["range_out".to_string()],
3754 ..Default::default()
3755 };
3756 let reshape = NodeProto {
3757 op_type: "Reshape".to_string(),
3758 input: vec!["range_out".to_string(), "reshape_shape".to_string()],
3759 output: vec!["out".to_string()],
3760 ..Default::default()
3761 };
3762
3763 let model = ModelProto {
3764 graph: Some(GraphProto {
3765 input: vec![ids_input, past_input],
3766 output: vec![out_vi.clone()],
3767 value_info: vec![range_vi, out_vi],
3768 initializer: vec![idx1, idx2, one, reshape_shape],
3769 node: vec![
3770 shape_past,
3771 gather_start,
3772 shape_ids,
3773 gather_seq,
3774 add_limit,
3775 range,
3776 reshape,
3777 ],
3778 ..Default::default()
3779 }),
3780 ..Default::default()
3781 };
3782
3783 let converter = OnnxConverter::new(model).expect("converter");
3784 let graph = converter
3785 .convert(&ConvertOptions {
3786 optimize: true,
3787 experimental_dynamic_inputs: true,
3788 extract_weights: false,
3789 ..ConvertOptions::default()
3790 })
3791 .expect("dynamic range with dynamic start should convert");
3792
3793 assert!(
3794 !graph.consts.contains_key("range_out"),
3795 "range output should stay runtime-computed"
3796 );
3797
3798 let slice_node = graph
3799 .nodes
3800 .iter()
3801 .find(|n| n.id == "range_out_slice" && n.op == "slice")
3802 .expect("range should lower to a slice");
3803 let slice_sizes = slice_node
3804 .options
3805 .get("sizes")
3806 .and_then(|v| v.as_array())
3807 .expect("slice sizes should exist");
3808 let dynamic_size = slice_sizes[0]
3809 .as_object()
3810 .expect("slice size should be dynamic");
3811 assert_eq!(
3812 dynamic_size.get("name").and_then(|v| v.as_str()),
3813 Some("sequence_length")
3814 );
3815 assert_eq!(
3816 dynamic_size.get("maxSize").and_then(|v| v.as_u64()),
3817 Some(4096)
3818 );
3819
3820 let add_node = graph
3821 .nodes
3822 .iter()
3823 .find(|n| n.id == "range_out" && n.op == "add")
3824 .expect("dynamic-start range should add the runtime start offset");
3825 assert_eq!(add_node.inputs.len(), 2);
3826 assert_eq!(add_node.inputs[0], "range_out_slice");
3827
3828 let reshape_node = graph
3829 .nodes
3830 .iter()
3831 .find(|n| n.op == "reshape")
3832 .expect("reshape node should exist");
3833 let new_shape = reshape_node
3834 .options
3835 .get("newShape")
3836 .and_then(|v| v.as_array())
3837 .expect("reshape newShape should exist");
3838 assert_eq!(new_shape.len(), 2);
3839 assert_eq!(new_shape[1].as_u64(), Some(1));
3840 let reshape_dim0 = new_shape[0]
3841 .as_object()
3842 .expect("reshape dim 0 should stay dynamic");
3843 assert_eq!(
3844 reshape_dim0.get("name").and_then(|v| v.as_str()),
3845 Some("sequence_length")
3846 );
3847 assert_eq!(
3848 reshape_dim0.get("maxSize").and_then(|v| v.as_u64()),
3849 Some(4096)
3850 );
3851 }
3852
3853 #[test]
3854 fn test_binary_const_folding_preserves_broadcast_shape() {
3855 let a = vec![-1];
3856 let b = [1, 2, 3, 4].repeat(128);
3857 let a_shape = Vec::<i64>::new();
3858 let b_shape = vec![1, 128, 4];
3859 let (out, out_shape) =
3860 fold_binary_const_i64("Mul", &a, &b, &a_shape, &b_shape).expect("broadcast fold");
3861 assert_eq!(out_shape, vec![1, 128, 4]);
3862 assert_eq!(out.len(), 512);
3863 assert_eq!(out[0], -1);
3864 assert_eq!(out[1], -2);
3865 assert_eq!(out[2], -3);
3866 assert_eq!(out[3], -4);
3867 }
3868
3869 #[test]
3870 fn test_convert_equal_broadcast_path_does_not_flatten_const_shape() {
3871 use crate::protos::onnx::{
3872 type_proto, AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto,
3873 };
3874
3875 let a = TensorProto {
3876 name: "shape_vec".to_string(),
3877 data_type: TensorProto_DataType::Int64 as i32,
3878 dims: vec![4],
3879 int64_data: vec![1, 128, 4, 8],
3880 ..Default::default()
3881 };
3882 let shape3 = TensorProto {
3883 name: "shape3".to_string(),
3884 data_type: TensorProto_DataType::Int64 as i32,
3885 dims: vec![3],
3886 int64_data: vec![1, 128, 4],
3887 ..Default::default()
3888 };
3889 let neg1 = TensorProto {
3890 name: "neg1".to_string(),
3891 data_type: TensorProto_DataType::Int64 as i32,
3892 dims: vec![],
3893 int64_data: vec![-1],
3894 ..Default::default()
3895 };
3896 let cos_fill = TensorProto {
3897 data_type: TensorProto_DataType::Int64 as i32,
3898 dims: vec![],
3899 int64_data: vec![1],
3900 ..Default::default()
3901 };
3902
3903 let cos = NodeProto {
3904 op_type: "ConstantOfShape".to_string(),
3905 input: vec!["shape3".to_string()],
3906 output: vec!["cos_out".to_string()],
3907 attribute: vec![AttributeProto {
3908 name: "value".to_string(),
3909 t: Some(cos_fill),
3910 ..Default::default()
3911 }],
3912 ..Default::default()
3913 };
3914 let mul = NodeProto {
3915 op_type: "Mul".to_string(),
3916 input: vec!["cos_out".to_string(), "neg1".to_string()],
3917 output: vec!["mul_out".to_string()],
3918 ..Default::default()
3919 };
3920 let eq = NodeProto {
3921 op_type: "Equal".to_string(),
3922 input: vec!["shape_vec".to_string(), "mul_out".to_string()],
3923 output: vec!["eq_out".to_string()],
3924 ..Default::default()
3925 };
3926
3927 let output_type = crate::protos::onnx::TypeProto {
3928 value: Some(type_proto::Value::TensorType(type_proto::Tensor {
3929 elem_type: TensorProto_DataType::Bool.into(),
3930 shape: None,
3931 })),
3932 denotation: String::new(),
3933 };
3934
3935 let model = ModelProto {
3936 graph: Some(GraphProto {
3937 initializer: vec![a, shape3, neg1],
3938 node: vec![cos, mul, eq],
3939 output: vec![crate::protos::onnx::ValueInfoProto {
3940 name: "eq_out".to_string(),
3941 r#type: Some(output_type),
3942 ..Default::default()
3943 }],
3944 ..Default::default()
3945 }),
3946 ..Default::default()
3947 };
3948
3949 let converter = OnnxConverter::new(model).expect("converter");
3950 let graph = converter
3951 .convert(&ConvertOptions {
3952 optimize: true,
3953 extract_weights: false,
3954 ..ConvertOptions::default()
3955 })
3956 .expect("convert");
3957
3958 let mul_const = graph.consts.get("mul_out").expect("mul_out const");
3959 assert_eq!(mul_const.shape, vec![1, 128, 4]);
3960 assert!(
3961 !graph.consts.contains_key("eq_out")
3962 || graph
3963 .consts
3964 .get("eq_out")
3965 .is_some_and(|decl| decl.shape == vec![1, 128, 4]),
3966 "eq_out constant must not be flattened"
3967 );
3968 }
3969}