1use std::path::PathBuf;
2
3use byteorder::{ByteOrder, LittleEndian};
4use itertools::{Itertools, zip_eq};
5use ndarray::{Axis, azip};
6use prost::Message;
7
8use crate::cpu::{cpu_flip, cpu_gather, cpu_slice};
9use crate::dtype::{DBool, dispatch_dtensor, DScalar, DTensor, DType, IntoDScalar, map_dtensor_pair, Tensor};
10use crate::graph::{
11 BinaryOp, broadcast_shape_symmetric, broadcast_tensors_symmetric, ReduceOp, SliceRange, UnaryOp, Value,
12};
13pub use crate::graph::Graph;
14use crate::onnx::external_data::ExternalDataLoader;
15use crate::onnx::inputs::{Attributes, Inputs};
16use crate::onnx::proto::{ModelProto, TensorProto, TypeProto};
17use crate::onnx::proto::tensor_proto::DataLocation;
18use crate::onnx::proto::tensor_proto::DataType;
19use crate::onnx::proto::tensor_shape_proto::dimension;
20use crate::onnx::proto::type_proto::Value as ProtoTypeValue;
21use crate::onnx::result::{Node, OnnxError, OnnxResult, UnwrapProto};
22use crate::onnx::store::Store;
23use crate::onnx::typed_value::{OnnxValue, SignedSize};
24use crate::shape;
25use crate::shape::{DivResult, Shape, Size};
26
27pub type InputShaper = dyn Fn(&[OnnxDimValue], &str, usize) -> Option<Shape>;
35
36#[derive(Debug, Clone)]
37pub enum OnnxDimValue {
38 Value(i64),
39 Param(String),
40}
41
42pub fn graph_from_onnx_bytes(buf: &[u8], external: &mut dyn ExternalDataLoader, input_shaper: &InputShaper) -> OnnxResult<Graph> {
44 let model = load_model_proto(buf);
45 let model_graph = model.graph.as_ref().unwrap_proto("model.graph")?;
46
47 let mut graph = Graph::new();
48 let mut nodes: Store<OnnxValue> = Store::default();
49
50 for tensor in &model_graph.initializer {
52 let value = define_tensor_data(&mut graph, &tensor.name, tensor, external)?;
53 nodes.define(&tensor.name, OnnxValue::Value(value))
54 }
55
56 let mut real_input_index = 0;
58 for input in &model_graph.input {
59 if nodes.contains(&input.name) {
61 continue;
62 }
63
64 let input_proto = input.r#type.as_ref().unwrap_proto("input.type")?;
65 let (shape, dtype) = resolve_tensor_type(input_proto, &input.name, real_input_index, input_shaper)?;
66 let value = graph.input(shape, dtype);
67 nodes.define(&input.name, OnnxValue::Value(value));
68
69 real_input_index += 1;
70 }
71
72 let _ = graph.take_new_values();
74
75 for node_proto in &model_graph.node {
77 let node = Node {
78 name: node_proto.name.as_str(),
79 op_type: node_proto.op_type.as_str(),
80 };
81
82 let mut attrs = Attributes::from(node, &node_proto.attribute);
83 let mut inputs = Inputs::from(node, &node_proto.input, &nodes)?;
84
85 let values: Vec<OnnxValue> = visit_node(&mut graph, external, node, &mut inputs, &mut attrs)?;
86
87 for value in graph.take_new_values() {
89 graph.set_debug_id(value, node.name.to_owned())
90 }
91
92 for value in &values {
94 value.assert_valid();
95 }
96
97 let leftover_attributes = attrs.leftover();
99 if !leftover_attributes.is_empty() {
100 return Err(OnnxError::LeftoverAttributes(node.to_owned(), leftover_attributes));
101 }
102 let leftover_inputs = inputs.leftover();
103 if !leftover_inputs.is_empty() {
104 return Err(OnnxError::LeftoverInputs(node.to_owned(), leftover_inputs));
105 }
106
107 let output_names = &node_proto.output;
109 assert_eq!(output_names.len(), values.len(), "Expected {:?} outputs, got {}", output_names, values.len());
110 for (name, value) in zip_eq(output_names, values) {
111 nodes.define(name, value);
112 }
113 }
114
115 for output in &model_graph.output {
116 let value_or_size = &nodes[output.name.as_str()];
117 let value = value_or_size
118 .unwrap_value()
119 .ok_or(OnnxError::ExpectedNonBatchValue(output.name.clone()))?;
120 graph.output(value);
121 }
122
123 Ok(graph)
124}
125
126fn visit_node(
127 graph: &mut Graph,
128 external: &mut dyn ExternalDataLoader,
129 node: Node<&str>,
130 inputs: &mut Inputs,
131 attrs: &mut Attributes,
132) -> OnnxResult<Vec<OnnxValue>> {
133 let result_single = match node.op_type {
134 "Conv" => {
135 let input = inputs.required(0)?.unwrap_value().unwrap();
136 let filter = inputs.required(1)?.unwrap_value().unwrap();
137 let bias_raw = inputs.optional(2).map(|v| v.unwrap_value().unwrap());
138
139 let groups = attrs.maybe_take_int("group")?.unwrap_or(1);
140 let kernel_shape = attrs.take_ints("kernel_shape")?;
141 let conv_rank = kernel_shape.len();
142 let strides = attrs.maybe_take_ints("strides")?
143 .map_or(vec![1; conv_rank], |strides| strides.to_vec());
144 let dilations = attrs.maybe_take_ints("dilations")?
145 .map_or(vec![1; conv_rank], |strides| strides.to_vec());
146
147 let auto_pad = attrs.maybe_take_string("auto_pad")?;
148
149 let padding = match auto_pad {
150 None | Some("NOTSET") => {
151 attrs.take_ints("pads")?.to_vec()
153 }
154 Some("SAME_UPPER") => {
155 calculate_auto_padding(graph, conv_rank, input, filter, &strides, &dilations, true)?
157 }
158 Some("SAME_LOWER") => {
159 calculate_auto_padding(graph, conv_rank, input, filter, &strides, &dilations, false)?
161 }
162 Some("VALID") => {
163 vec![0; strides.len()]
165 }
166 Some(auto_pad) => return Err(OnnxError::InvalidAutoPadValue(node.to_owned(), auto_pad.to_owned()))
167 };
168
169 let filter_shape = graph[filter]
170 .shape
171 .unwrap_fixed("Convolution kernel shape must be fixed");
172
173 let bias = bias_raw.map(|bias| {
175 let bias_size = graph[bias].shape.unwrap_1();
176 let bias_view_shape = shape![1, bias_size, 1, 1];
177
178 graph.view(bias, bias_view_shape)
179 });
180
181 assert_eq!(1, groups);
182
183 let result = match conv_rank {
184 1 => {
185 let kernel_size0 = unwrap_1(kernel_shape);
186 let [padding_0, padding_1] = unwrap_2(&padding);
187 let stride = unwrap_1(&strides);
188 let dilation = unwrap_1(&dilations);
189
190 let [_, _, kernel_size1] = filter_shape.unwrap_3();
191
192 assert_eq!(padding_0, padding_1);
193 assert!(dilation == 1 && stride == 1);
194 assert_eq!(kernel_size0, kernel_size1);
195
196 let input_extra = graph.view(input, graph[input].shape.clone().concat(&shape![1]));
197 let filter_extra = graph.view(filter, graph[filter].shape.clone().concat(&shape![1]));
198
199 let result_conv = graph.conv(input_extra, filter_extra, 1, 1, padding_0, 0);
200 let result_biased = bias.map_or(result_conv, |bias| graph.add(result_conv, bias));
201
202 let result_shape = graph[result_biased].shape.replace(3, shape![]);
203 let result = graph.view(result_biased, result_shape);
204
205 result
206 }
207 2 => {
208 let [kernel_h0, kernel_w0] = unwrap_2(kernel_shape);
209 let [padding_y0, padding_x0, padding_y1, padding_x1] = unwrap_4(&padding);
210 let [stride_y, stride_x] = unwrap_2(&strides);
211 let [dilation_y, dilation_x] = unwrap_2(&dilations);
212
213 let [_, _, kernel_h1, kernel_w1] = filter_shape.unwrap_4();
214
215 assert!(padding_y0 == padding_y1 && padding_x0 == padding_x1);
216 assert!(dilation_y == 1 && dilation_x == 1);
217 assert!(kernel_h1 == kernel_h0 && kernel_w1 == kernel_w0);
218
219 let result_conv = graph.conv(input, filter, stride_y, stride_x, padding_y0, padding_x0);
220 let result_biased = bias.map_or(result_conv, |bias| graph.add(result_conv, bias));
221
222 result_biased
223 }
224 rank => return Err(OnnxError::UnsupportedNdConvolution(node.to_owned(), rank)),
225 };
226
227 OnnxValue::Value(result)
228 }
229 "Clip" => {
230 let input = inputs.required(0)?.unwrap_value().unwrap();
231 let input_min = inputs.optional(1);
233 let input_max = inputs.optional(2);
234
235 let result = match (input_min, input_max) {
236 (None, None) => {
237 let min = attrs.take_float("min")?;
238 let max = attrs.take_float("max")?;
239 graph.clamp::<f32>(input, min, max)
240 }
241 (Some(min), Some(max)) => {
242 let min = min.unwrap_value().unwrap();
243 let max = max.unwrap_value().unwrap();
244 assert_eq!(graph[min].shape, Shape::SCALAR);
245 assert_eq!(graph[max].shape, Shape::SCALAR);
246
247 let mid = graph.binary(BinaryOp::Min, input, max);
248 let result = graph.binary(BinaryOp::Max, mid, min);
249 result
250 }
251 _ => {
252 let message = "Clip must have either 1 or 3 inputs, got 2".to_owned();
253 return Err(OnnxError::InvalidOperationArgs(node.to_owned(), message));
254 }
255 };
256
257 OnnxValue::Value(result)
258 }
259 "Abs" | "Neg" | "Sin" | "Cos" | "Exp" | "Log" | "Sqrt" | "Sigmoid" | "Relu" | "Tanh" | "Erf" | "Mish" | "Softplus" => {
260 let input = inputs.required(0)?.unwrap_value().unwrap();
261
262 let result = match node.op_type {
263 "Abs" => graph.unary(UnaryOp::Abs, input),
264 "Neg" => graph.unary(UnaryOp::Neg, input),
265 "Sin" => graph.unary(UnaryOp::Sin, input),
266 "Cos" => graph.unary(UnaryOp::Cos, input),
267 "Exp" => graph.unary(UnaryOp::Exp, input),
268 "Log" => graph.unary(UnaryOp::Log, input),
269 "Sqrt" => graph.unary(UnaryOp::Sqrt, input),
270 "Sigmoid" => graph.unary(UnaryOp::Sigmoid, input),
271 "Relu" => graph.relu(input),
272 "Tanh" => graph.unary(UnaryOp::Tanh, input),
273 "Erf" => graph.unary(UnaryOp::Erf, input),
274 "Softplus" => graph.unary(UnaryOp::Softplus, input),
275 _ => unreachable!("missing {:?}", node.op_type),
276 };
277
278 OnnxValue::Value(result)
279 }
280 "Add" | "Sub" | "Mul" | "Div" | "Min" | "Max" | "Pow" => {
281 let op = match node.op_type {
282 "Add" => BinaryOp::Add,
283 "Sub" => BinaryOp::Sub,
284 "Mul" => BinaryOp::Mul,
285 "Div" => BinaryOp::Div,
286 "Min" => BinaryOp::Min,
287 "Max" => BinaryOp::Max,
288 "Pow" => BinaryOp::Pow,
289 _ => unreachable!("missing {:?}", node.op_type),
290 };
291
292 let left = inputs.required(0)?;
293 let right = inputs.required(1)?;
294
295 if let (&OnnxValue::Value(left), &OnnxValue::Value(right)) = (left, right) {
296 OnnxValue::Value(graph.binary(op, left, right))
298 } else {
299 let left = left.as_size(graph)?;
301 let right = right.as_size(graph)?;
302
303 let (left, right) = broadcast_tensors_symmetric(&left, &right);
304
305 let result = azip!(&left, &right).map_collect(|&l, &r| {
306 eval_binary_op(op, l, r)
307 .unwrap_or_else(|| panic!("Operation {:?} failed between {:?} and {:?}", op, left, right))
308 });
309
310 OnnxValue::new_size(result.into_shared(), graph)
312 }
313 }
314 "Equal" => {
315 let left = inputs.required(0)?;
316 let right = inputs.required(1)?;
317
318 let result = match (left, right) {
319 (&OnnxValue::Value(left), &OnnxValue::Value(right)) => {
320 let diff = graph.sub(left, right);
323 graph.unary(UnaryOp::ValueCast(DType::Bool), diff)
324 }
325 (OnnxValue::Size(left), OnnxValue::Size(right)) => {
326 let (left, right) = broadcast_tensors_symmetric(&left, &right);
329
330 let result = azip!(left, right).map_collect(|l, r| DBool(l == r)).into_shared();
331 graph.constant_tensor(DTensor::Bool(result))
332 }
333 _ => {
334 let broadcast_shape = broadcast_shape_symmetric(&left.shape(graph), &right.shape(graph));
337 let scalar = graph.scalar(DBool(false));
338 graph.broadcast(scalar, broadcast_shape)
339 }
340 };
341
342 OnnxValue::Value(result)
343 }
344 "Where" => {
345 let cond = inputs.required(0)?.unwrap_value().unwrap();
347 let x = inputs.required(1)?.unwrap_value().unwrap();
348 let y = inputs.required(2)?.unwrap_value().unwrap();
349
350 let cond = graph.as_const(cond).unwrap();
351 let cond = cond.unwrap_bool().unwrap();
352 let x = graph.as_const(x).unwrap();
353 let y = graph.as_const(y).unwrap();
354
355 assert_eq!(cond.shape(), x.shape(), "Where broadcasting not yet implemented");
357 assert_eq!(cond.shape(), y.shape(), "Where broadcasting not yet implemented");
358
359 let result = map_dtensor_pair!(x, y, |x, y| {
360 azip!(cond, &x, &y)
361 .map_collect(|&DBool(c), &x, &y| if c { x } else { y })
362 .into_shared()
363 });
364
365 OnnxValue::Value(graph.constant_tensor(result))
366 }
367 "Flatten" => {
368 let input = inputs.required(0)?;
369
370 let rel_axis = attrs.maybe_take_int("axis")?.unwrap_or(1);
372 let axis = abs_axis(rel_axis, input.shape(graph).rank());
373
374 let kept_shape = &input.shape(graph).dims[..axis];
376 let flat_shape = input.shape(graph).dims[axis..].iter().copied().product::<Size>();
377
378 let mut new_shape = kept_shape.to_vec();
379 new_shape.push(flat_shape);
380
381 if axis == 0 {
383 new_shape.insert(0, Size::ONE);
384 }
385 let new_shape = Shape::new(new_shape);
386
387 match input {
389 &OnnxValue::Value(input) => OnnxValue::Value(graph.view(input, new_shape)),
390 OnnxValue::Size(input) => {
391 OnnxValue::new_size(input.reshape(new_shape.unwrap_fixed("size shape").dims), graph)
392 }
393 }
394 }
395 "Gemm" => {
396 let input = inputs.required(0)?.unwrap_value().unwrap();
397 let weight = inputs.required(1)?.unwrap_value().unwrap();
398 let bias = inputs.optional(2).map(|v| v.unwrap_value().unwrap());
399
400 let alpha = attrs.take_float("alpha")?;
401 let beta = attrs.take_float("beta")?;
402 let trans_b = attrs.take_int("transB")? != 0;
403
404 assert_eq!(1.0, alpha);
405 assert_eq!(1.0, beta);
406 assert!(trans_b);
407
408 let linear = graph.linear(input, weight);
409
410 let result = if let Some(bias) = bias {
411 let bias_len = graph[bias].shape.unwrap_1();
412 let bias_view_shape = shape![1, bias_len];
413 let bias_view = graph.view(bias, bias_view_shape);
414
415 graph.add(linear, bias_view)
416 } else {
417 linear
418 };
419
420 OnnxValue::Value(result)
421 }
422 "MatMul" => {
423 let left = inputs.required(0)?.unwrap_value().unwrap();
424 let right = inputs.required(1)?.unwrap_value().unwrap();
425
426 let result = graph.mat_mul(left, right);
428 OnnxValue::Value(result)
429 }
430 "Einsum" => {
431 let inputs = inputs.take_all_variadic();
432 let equation = attrs.take_string("equation")?;
433
434 let equation_compact = equation.replace(' ', "");
435
436 match equation_compact.as_ref() {
439 "bid,bjd->bij" => {
440 assert_eq!(inputs.len(), 2);
441 let left = inputs[0].unwrap_value().unwrap();
442 let right = inputs[1].unwrap_value().unwrap();
443
444 let right_transpose = graph.permute(right, vec![0, 2, 1]);
445 let result = graph.batched_mat_mul(left, right_transpose);
446 OnnxValue::Value(result)
447 }
448 "bij,bjd->bid" => {
449 assert_eq!(inputs.len(), 2);
450 let left = inputs[0].unwrap_value().unwrap();
451 let right = inputs[1].unwrap_value().unwrap();
452
453 let result = graph.batched_mat_mul(left, right);
454 OnnxValue::Value(result)
455 }
456 _ => panic!(
457 "Einsum with inputs equation {:?} and inputs {:?} not yet supported",
458 equation, inputs
459 ),
460 }
461 }
462 "BatchNormalization" => {
464 let input = inputs.required(0)?.unwrap_value().unwrap();
465
466 let input_scale = inputs.required(1)?.unwrap_value().unwrap();
467 let input_bias = inputs.required(2)?.unwrap_value().unwrap();
468 let input_mean = inputs.required(3)?.unwrap_value().unwrap();
469 let input_variance = inputs.required(4)?.unwrap_value().unwrap();
470
471 let epsilon = attrs.take_float("epsilon")?;
472 let _ = attrs.take_float("momentum")?;
473 let spatial = attrs.maybe_take_int("spatial")?;
474 assert!(
475 spatial == None || spatial == Some(1),
476 "non-spatial cases are not supported and have been deprecated since ONNX version 9"
477 );
478
479 let input_shape = &graph[input].shape;
481 assert!(input_shape.rank() >= 2, "BN input must have at least rank 2");
482
483 let channels = input_shape[1];
484 let shape_vec = shape![channels];
485 let shape_exp = input_shape.keep(1, Size::ONE);
486
487 for param in [input_scale, input_bias, input_mean, input_variance] {
488 assert_eq!(graph[param].shape, shape_vec);
489 }
490
491 let result = {
493 let value_eps = graph.scalar(epsilon);
494
495 let exp_scale = graph.view(input_scale, shape_exp.clone());
496 let exp_bias = graph.view(input_bias, shape_exp.clone());
497 let exp_mean = graph.view(input_mean, shape_exp.clone());
498 let exp_variance = graph.view(input_variance, shape_exp);
499
500 let div_squared = graph.add(exp_variance, value_eps);
501 let div = graph.unary(UnaryOp::Sqrt, div_squared);
502
503 let x = input;
504 let x_mean = graph.sub(x, exp_mean);
505 let x_div = graph.binary(BinaryOp::Div, x_mean, div);
506 let x_scale = graph.mul(x_div, exp_scale);
507 let x_bias = graph.add(x_scale, exp_bias);
508 x_bias
509 };
510
511 OnnxValue::Value(result)
512 }
513 "InstanceNormalization" | "LayerNormalization" => {
514 let (input, start_axis, epsilon, scale_bias) = match node.op_type {
515 "InstanceNormalization" => {
516 let input = inputs.required(0)?.unwrap_value().unwrap();
517 let scale = inputs.required(1)?.unwrap_value().unwrap();
518 let bias = inputs.required(2)?.unwrap_value().unwrap();
519 let epsilon = attrs.take_float("epsilon")?;
520
521 let broadcast_shape = graph[input].shape.keep(1, Size::ONE);
522 let scale_broadcast = graph.view(scale, broadcast_shape.clone());
523 let bias_broadcast = graph.view(bias, broadcast_shape);
524
525 (input, 2, epsilon, Some((scale_broadcast, bias_broadcast)))
526 },
527 "LayerNormalization" => {
528 let input = inputs.required(0)?.unwrap_value().unwrap();
529 let scale = inputs.required(1)?.unwrap_value().unwrap();
530 let bias = inputs.required(2)?.unwrap_value().unwrap();
531 let axis = attrs.maybe_take_int("axis")?.unwrap_or(-1);
532 let epsilon = attrs.maybe_take_float("epsilon")?.unwrap_or(1e-05);
533
534 let input_shape = graph[input].shape.clone();
535 let scale_broadcast = graph.broadcast(scale, input_shape.clone());
536 let bias_broadcast = graph.broadcast(bias, input_shape);
537
538 let axis = abs_axis(axis, graph[input].shape.rank());
539
540 (input, axis, epsilon, Some((scale_broadcast, bias_broadcast)))
541 },
542 _ => unreachable!("missing {:?}", node.op_type)
543 };
544
545 let shape = graph[input].shape.clone();
546 assert!(
547 shape.rank() >= start_axis,
548 "Input rank must be >= {} (the start axis), for the the batch and channel axes, got {}",
549 start_axis,
550 shape
551 );
552
553 let (shape_keep, shape_reduced) = shape.split(start_axis);
554 let shape_flat = shape_keep.concat(&shape![shape_reduced.size()]);
555
556 let input_flat = graph.view(input, shape_flat);
557 let norm_flat = graph.layernorm(input_flat, start_axis, epsilon);
558 let norm = graph.view(norm_flat, shape);
559
560 let result = if let Some((scale, bias)) = scale_bias {
561 let scaled = graph.mul(norm, scale);
562 let result = graph.add(scaled, bias);
563 result
564 } else {
565 norm
566 };
567
568 OnnxValue::Value(result)
569 }
570 "Constant" => {
571 let tensor = attrs.take_tensor("value")?;
572 let value = define_tensor_data(graph, node.name, tensor, external)?;
573 OnnxValue::Value(value)
574 }
575 "ConstantOfShape" => {
576 let shape = inputs.optional(0);
577
578 let shape = match shape {
579 None => Shape::SCALAR,
580 Some(shape) => shape.as_shape(graph)?,
581 };
582
583 let value = match attrs.maybe_take_tensor("value")? {
584 None => graph.scalar(0f32),
585 Some(tensor) => define_tensor_data(graph, node.name, tensor, external)?,
586 };
587
588 assert_eq!(
590 graph[value].shape.size(),
591 Size::ONE,
592 "value must be a one-element tensor"
593 );
594
595 let scalar = graph.view(value, Shape::SCALAR);
596 let result = graph.broadcast(scalar, shape);
597 OnnxValue::Value(result)
598 }
599 "Cast" => {
600 let input = inputs.required(0)?;
601 let data_type = DataType::try_from(attrs.take_int("to")? as i32).expect("Invalid data type");
602 let dtype = resolve_dtype(data_type, node.name)?;
603
604 match input {
605 &OnnxValue::Value(value) => OnnxValue::Value(graph.unary(UnaryOp::ValueCast(dtype), value)),
606 OnnxValue::Size(value) => {
607 assert_eq!(dtype, DType::I64);
609 OnnxValue::new_size(value.clone(), graph)
610 }
611 }
612 }
613 "Reshape" => {
614 let input = inputs.required(0)?;
615 let new_shape = inputs.required(1)?.as_signed_shape(graph)?;
616 let allow_zero = attrs.maybe_take_bool("allowzero")?.unwrap_or(false);
617
618 let old_shape = input.shape(graph);
619 let output_shape = calculate_reshape_output_shape(&old_shape, &new_shape, allow_zero);
620
621 match input {
622 &OnnxValue::Value(input) => OnnxValue::Value(graph.view(input, output_shape)),
623 OnnxValue::Size(input) => {
624 let result = input.reshape(output_shape.unwrap_fixed("reshape shape").dims.clone());
625 OnnxValue::new_size(result, graph)
626 }
627 }
628 }
629 "Expand" => {
630 let input = inputs.required(0)?;
631 let shape = inputs.required(1)?.as_shape(graph)?;
632
633 let result_shape = broadcast_shape_symmetric(&input.shape(&graph), &shape);
635
636 match input {
637 &OnnxValue::Value(input) => OnnxValue::Value(graph.broadcast(input, result_shape)),
638 OnnxValue::Size(input) => {
639 let result_shape = result_shape.unwrap_fixed("expand shape").dims.clone();
640 let result = input.broadcast(result_shape).unwrap().to_shared();
641 OnnxValue::new_size(result, graph)
642 }
643 }
644 }
645 "Unsqueeze" => {
646 let input = inputs.required(0)?;
647
648 let rel_axes = match inputs.optional(1) {
649 Some(rel_axes) => {
650 let shape = rel_axes.as_signed_shape(graph)?;
651 shape.iter().map(|d| d.unwrap_fixed().unwrap()).collect_vec()
652 }
653 None => attrs.take_ints("axes")?.to_vec(),
654 };
655
656 let input_shape = input.shape(graph);
658
659 let output_rank = input_shape.rank() + rel_axes.len();
660 let axes = rel_axes.iter().map(|&a| abs_axis(a, output_rank)).collect_vec();
661
662 assert!(
663 axes.iter().all_unique() && axes.iter().all(|&a| a < output_rank),
664 "Invalid axis {:?} for input rank {} in Unsqueeze",
665 axes,
666 input_shape.rank(),
667 );
668
669 let mut input_shape_left = input_shape.dims.iter().copied();
670 let output_dims = (0..output_rank)
671 .map(|i| {
672 if axes.contains(&i) {
673 Size::ONE
674 } else {
675 input_shape_left.next().unwrap()
676 }
677 })
678 .collect_vec();
679 assert_eq!(input_shape_left.len(), 0);
680
681 let output_shape = Shape::new(output_dims);
682
683 match input {
685 &OnnxValue::Value(input) => OnnxValue::Value(graph.view(input, output_shape)),
686 OnnxValue::Size(input) => {
687 let result_shape = output_shape.unwrap_fixed("unsqueeze shape").dims;
688 let result = input.reshape(result_shape);
689 OnnxValue::new_size(result, graph)
690 }
691 }
692 }
693 "Transpose" => {
694 let input = inputs.required(0)?;
695
696 let permutation = attrs.take_ints("perm")?;
697 let permutation = permutation.iter().map(|&x| x as usize).collect_vec();
698
699 match input {
700 &OnnxValue::Value(input) => OnnxValue::Value(graph.permute(input, permutation)),
701 OnnxValue::Size(input) => {
702 let result = input.to_shared().permuted_axes(permutation);
703 OnnxValue::new_size(result, graph)
704 }
705 }
706 }
707 "Gather" => {
708 let input = inputs.required(0)?;
709 let indices_raw = inputs.required(1)?;
710 let rel_axis = attrs.maybe_take_int("axis")?.unwrap_or(0);
711
712 let input_shape = input.shape(graph);
713 let axis = abs_axis(rel_axis, input_shape.rank());
714 let axis_size = input.shape(graph).dims[axis];
715
716 let indices = match indices_raw {
717 &OnnxValue::Value(indices) => {
718 match graph.as_const(indices) {
719 Some(indices) => {
720 let dim = axis_size.unwrap_fixed("gather dim size");
721 let indices = dispatch_dtensor!(indices, |T, ft, indices| {
722 let zero = T::from_dscalar(T::DTYPE.specials().zero).unwrap();
724 let dim = T::from_dscalar(DScalar::U64(dim as u64).value_cast(T::DTYPE)).unwrap();
725
726 ft(indices.mapv(|x| if x < zero { x + dim } else { x }).into_shared())
727 });
728 OnnxValue::Value(graph.constant_tensor(indices))
729 }
730 None => OnnxValue::Value(indices),
733 }
734 }
735 OnnxValue::Size(indices) => {
736 let indices = indices.mapv(|x| {
737 if x.is_neg() {
738 (x + axis_size).expect("gather negative index overflow")
739 } else {
740 x
741 }
742 });
743 OnnxValue::new_size(indices.into_shared(), graph)
744 }
745 };
746
747 match input {
748 &OnnxValue::Value(input) => {
749 let result = graph.gather(input, axis, indices.unwrap_value().unwrap());
750 OnnxValue::Value(result)
751 }
752 OnnxValue::Size(input) => {
753 let indices = graph.as_const(indices.unwrap_value().unwrap()).unwrap();
754
755 let indices_flat = indices.reshape(vec![indices.len()]);
757 let result_flat = cpu_gather(input, axis, indices_flat);
758 let mut result_shape = input.shape().to_owned();
759 result_shape.splice(axis..axis + 1, indices.shape().iter().copied());
760 let result = result_flat.reshape(result_shape);
761
762 OnnxValue::new_size(result, graph)
763 }
764 }
765 }
766 "Slice" => {
767 let get = |inputs: &mut Inputs, attrs: &mut Attributes, index: usize, name: &str| -> OnnxResult<_> {
768 match inputs.optional(index) {
769 Some(value) => {
770 let value = graph.as_const(value.unwrap_value().unwrap()).unwrap();
771
772 assert_eq!(
773 value.shape().len(),
774 1,
775 "Slice operand {} must be 1D const, got shape {:?}",
776 name,
777 value.shape()
778 );
779
780 let vec = match value {
781 DTensor::I64(value) => value.iter().copied().collect_vec(),
782 DTensor::I32(value) => value.iter().map(|&x| x as i64).collect_vec(),
783 _ => panic!("Invalid slice operand type {:?}", value.dtype()),
784 };
785
786 Ok(Some(vec))
787 }
788 None => Ok(attrs.maybe_take_ints(name)?.map(|v| v.to_vec())),
789 }
790 };
791
792 let input = inputs.required(0)?;
793 let starts = get(inputs, attrs, 1, "starts")?.expect("Missing starts input and attribute");
794 let ends = get(inputs, attrs, 2, "ends")?.expect("Missing ends input and attribute");
795
796 let slice_rank = starts.len();
797 let axes = get(inputs, attrs, 3, "axes")?.unwrap_or_else(|| (0..slice_rank as i64).collect_vec());
798 let steps = get(inputs, attrs, 4, "steps")?.unwrap_or_else(|| vec![1; slice_rank]);
799
800 assert!(
801 slice_rank == ends.len() && slice_rank == axes.len() && slice_rank == steps.len(),
802 "Inconsistent axes count"
803 );
804 assert!(
805 axes.iter().all_unique(),
806 "Slice axis cannot be repeated, got {:?}",
807 axes
808 );
809
810 let input_shape = input.shape(graph);
812
813 (0..slice_rank).fold(input.clone(), |curr, i| {
814 let axis = abs_axis(axes[i], input_shape.rank());
815 let axis_size = input_shape[axis].try_unwrap_fixed();
816
817 let step = steps[i];
818 assert_ne!(step, 0, "Step cannot be 0");
819
820 if step > 0 {
821 if axis_size.is_none() && starts[i] == 0 && ends[i] == i32::MAX as i64 {
823 curr
824 } else {
825 let axis_size = input_shape[axis].unwrap_fixed("Slice axis size");
826 let start = abs_axis(starts[i], axis_size);
827 let end = abs_axis(ends[i], axis_size);
828
829 let range = SliceRange::new(start, end, step as usize);
830
831 match curr {
833 OnnxValue::Value(curr) => OnnxValue::Value(graph.slice(curr, axis, range)),
834 OnnxValue::Size(curr) => OnnxValue::Size(cpu_slice(&curr, axis, range)),
835 }
836 }
837 } else {
838 assert!(
840 starts[i] == -1 && ends[i] == i64::MIN && step == -1,
841 "Only simple flip negative stride supported for now"
842 );
843
844 match curr {
846 OnnxValue::Value(curr) => OnnxValue::Value(graph.flip(curr, axis)),
847 OnnxValue::Size(curr) => OnnxValue::Size(cpu_flip(&curr, axis)),
848 }
849 }
850 })
851 }
852 "Concat" => {
853 let inputs = inputs.take_all_variadic();
854 assert!(!inputs.is_empty(), "Must concatenate at least one value");
855
856 let rank = inputs[0].shape(graph).rank();
857
858 let rel_axis = attrs.take_int("axis")?;
859 let axis = abs_axis(rel_axis, rank);
860
861 let any_shape = inputs.iter().any(|x| matches!(x, OnnxValue::Size(_)));
862
863 if any_shape {
864 let tensors: Vec<_> = inputs.iter().map(|x| x.as_size(graph)).try_collect()?;
865 let views: Vec<_> = tensors.iter().map(|x| x.view()).collect();
866 let result = ndarray::concatenate(Axis(axis), &views).unwrap().into_shared();
867 OnnxValue::new_size(result, graph)
868 } else {
869 let inputs = inputs.iter().map(|v| v.unwrap_value().unwrap()).collect();
870 let result = graph.concat(inputs, axis, None, None);
871 OnnxValue::Value(result)
872 }
873 }
874 "Split" => {
875 let input = inputs.required(0)?;
877 let split = inputs.optional(1).map(|s| s.as_shape(graph).unwrap());
878
879 let shape = input.shape(graph);
880 let axis = attrs.take_int("axis")?;
881 let axis = abs_axis(axis, shape.rank());
882
883 let ranges = match split {
884 None => {
885 let len = shape[axis].unwrap_fixed("Split axis length");
887 let num_outputs = 2;
888 let len_first = (len + num_outputs - 1) / num_outputs;
889 vec![SliceRange::simple(0, len_first), SliceRange::simple(len_first, len)]
890 }
891 Some(split) => {
892 split.unwrap_fixed("split lengths").dims.iter()
893 .scan(0, |a, s| {
894 let start = *a;
895 let end = start + s;
896 *a = end;
897 Some(SliceRange::simple(start, end))
898 })
899 .collect_vec()
900 }
901 };
902
903 let result = match input {
904 &OnnxValue::Value(input) => {
905 ranges.iter().map(|&r| OnnxValue::Value(graph.slice(input, axis, r))).collect_vec()
906 }
907 OnnxValue::Size(input) => {
908 ranges.iter().map(|&r| OnnxValue::new_size(cpu_slice(&input, axis, r), graph)).collect_vec()
909 }
910 };
911
912 return Ok(result);
913 }
914 "Pad" => {
915 let input = inputs.required(0)?.unwrap_value().unwrap();
917 let pads = inputs.required(1)?.unwrap_value().unwrap();
918 let constant_value = inputs.optional(2);
919 let axes = inputs.optional(3);
920 let mode = attrs.maybe_take_string("mode")?.unwrap_or("constant");
921
922 let input_shape = &graph[input].shape.clone();
924
925 let constant_value = constant_value
926 .map(|v| {
927 graph
928 .as_single_const(v.unwrap_value().unwrap())
929 .unwrap()
930 .unwrap_f32()
931 .unwrap()
932 })
933 .unwrap_or(0.0);
934
935 let axes = match axes {
936 Some(axes) => {
937 let axes = axes.as_signed_shape(graph)?;
938 axes.iter()
939 .map(|&i| abs_axis(i.unwrap_fixed().unwrap(), input_shape.rank()))
940 .collect_vec()
941 }
942 None => (0..input_shape.rank()).collect_vec(),
943 };
944
945 let pads = graph.as_const(pads).unwrap();
946 let pads = pads.unwrap_i64().unwrap();
947 assert_eq!(pads.shape(), &[axes.len() * 2], "Pads and axes shape mismatch");
948 let pads = pads.iter().copied().collect_vec();
949
950 assert_eq!(mode, "constant", "Only 'constant' pad mode supported");
951
952 let constant = graph.scalar(constant_value);
953
954 let output = axes.iter().fold(input, |acc, &axis| {
956 let acc_shape = graph[acc].shape.clone();
957
958 let pad_left = pads[axis];
959 let pad_right = pads[axes.len() + axis];
960 assert!(pad_left >= 0 && pad_right >= 0, "Pad values cannot be negative");
961
962 let blocks = vec![
963 graph.broadcast(constant, acc_shape.replace(axis, shape![pad_left as usize])),
964 acc,
965 graph.broadcast(constant, acc_shape.replace(axis, shape![pad_right as usize])),
966 ];
967 graph.concat(blocks, axis, None, None)
968 });
969
970 OnnxValue::Value(output)
971 }
972 "Shape" => {
973 let input = inputs.required(0)?;
974 let shape = input.shape(graph);
975 let dims = shape
976 .dims
977 .iter()
978 .map(|&d| SignedSize::from_size(d).unwrap())
979 .collect_vec();
980 OnnxValue::new_size(Tensor::from_shape_vec(vec![dims.len()], dims).unwrap(), graph)
981 }
982 "Identity" => {
983 let input = inputs.required(0)?;
984 input.clone()
985 }
986 "Softmax" => {
987 let input = inputs.required(0)?.unwrap_value().unwrap();
988
989 let shape = graph[input].shape.clone();
990 let axis = attrs.maybe_take_int("axis")?.unwrap_or(-1);
991 let axis = abs_axis(axis, shape.rank());
992
993 OnnxValue::Value(graph.softmax(input, axis))
994 }
995 "ReduceSum" | "ReduceMean" | "ReduceProd" | "ReduceMin" | "ReduceMax" => {
996 let op = match node.op_type {
997 "ReduceSum" => ReduceOp::Sum,
998 "ReduceMean" => ReduceOp::Mean,
999 "ReduceProd" => ReduceOp::Prod,
1000 "ReduceMin" => ReduceOp::Min,
1001 "ReduceMax" => ReduceOp::Max,
1002 _ => unreachable!("missing {:?}", node.op_type),
1003 };
1004
1005 let input = inputs.required(0)?.unwrap_value().unwrap();
1006 let input_shape = graph[input].shape.clone();
1007
1008 let axes = attrs.maybe_take_ints("axes")?.map_or_else(
1009 || (0..input_shape.rank()).collect_vec(),
1010 |axes| axes.iter().map(|&a| abs_axis(a, input_shape.rank())).collect_vec(),
1011 );
1012 let keep_dims = attrs.maybe_take_int("keepdims")?.unwrap_or(1) != 0;
1013
1014 let result_shape = if keep_dims {
1015 input_shape.replace_all(&axes, shape![1])
1016 } else {
1017 input_shape.replace_all(&axes, shape![])
1018 };
1019
1020 let result = graph.reduce(input, axes, op);
1021 let result_shaped = graph.view(result, result_shape);
1022
1023 OnnxValue::Value(result_shaped)
1024 }
1025 "MaxPool" | "AveragePool" => {
1026 let op = match node.op_type {
1027 "MaxPool" => ReduceOp::Max,
1028 "AveragePool" => ReduceOp::Mean,
1029 _ => unreachable!("missing {:?}", node.op_type),
1030 };
1031
1032 let input = inputs.required(0)?.unwrap_value().unwrap();
1033
1034 let strides = attrs.take_ints("strides")?;
1035 let kernel_shape = attrs.take_ints("kernel_shape")?;
1036 let pads = attrs.take_ints("pads")?;
1037 let ceil_mode = attrs.maybe_take_int("ceil_mode")?.unwrap_or(0) != 0;
1038 let auto_pad = attrs.maybe_take_string("auto_pad")?;
1039
1040 assert_eq!(strides, kernel_shape, "Real strides not supported yet");
1041 assert_eq!(pads, &vec![0; pads.len()], "Padding not supported yet");
1042 assert!(matches!(auto_pad, None | Some("NOTSET")), "Auto padding not supported yet");
1043
1044 let raw_input_shape = &graph[input].shape;
1047 let input_rank = raw_input_shape.rank();
1048 let kernel_rank = kernel_shape.len();
1049
1050 let kept_rank = input_rank - kernel_rank;
1051 let (batch_shape, active_shape) = raw_input_shape.split(kept_rank);
1052
1053 let mut pad_amounts = vec![(0, 0); kept_rank];
1055 let mut reshape = batch_shape.dims.clone();
1056 let mut pooled_dims = vec![];
1057 let mut slices = vec![None; kept_rank];
1058
1059 for i in 0..kernel_rank {
1060 let kernel_size = kernel_shape[i] as usize;
1061 let input_size = active_shape.dims[i];
1062
1063 let div_rem = input_size.div_rem(kernel_size);
1064 let (left, pad, slice) = match div_rem {
1065 DivResult::Exact(left) => {
1066 (left, (0, 0), None)
1067 }
1068 DivResult::Remainder(rem) => {
1069 if ceil_mode {
1070 let pad = kernel_size - rem;
1071 let left = ((input_size + pad).unwrap() / kernel_size).unwrap();
1072 (left, (0, pad), None)
1073 } else {
1074 let left = ((input_size - rem).unwrap() / kernel_size).unwrap();
1075 let end = left.unwrap_fixed("pool dim size") * kernel_size;
1076 let slice = SliceRange::new(0, end, 1);
1077 (left, (0, 0), Some(slice))
1078 }
1079 },
1080 DivResult::Impossible => {
1081 return Err(OnnxError::NonDividingPooling(node.to_owned(), raw_input_shape.clone(), kernel_shape.to_vec()));
1082 }
1083 };
1084
1085 pad_amounts.push(pad);
1086 reshape.push(left);
1087 pooled_dims.push(reshape.len());
1088 reshape.push(Size::fixed(kernel_size));
1089 slices.push(slice);
1090 }
1091 let reshape = Shape::new(reshape);
1092
1093 let pad_value = op.identity(graph[input].dtype);
1094
1095 let pad_value = graph.scalar_dyn(pad_value);
1097 let padded = graph.pad(input, &pad_amounts, pad_value);
1098 let sliced = slices.iter().enumerate().fold(padded, |a, (i, &s)| {
1099 if let Some(s) = s {
1100 graph.slice(a, i, s)
1101 } else {
1102 a
1103 }
1104 });
1105 let reshaped = graph.view(sliced, reshape);
1106 let result = graph.reduce(reshaped, pooled_dims, op);
1107
1108 OnnxValue::Value(result)
1109 }
1110 "GlobalMaxPool" | "GlobalAveragePool" => {
1111 let op = match node.op_type {
1112 "GlobalMaxPool" => ReduceOp::Max,
1113 "GlobalAveragePool" => ReduceOp::Mean,
1114 _ => unreachable!("missing {:?}", node.op_type),
1115 };
1116
1117 let input = inputs.required(0)?.unwrap_value().unwrap();
1118
1119 let shape = &graph[input].shape;
1121 if shape.rank() < 2 {
1122 return Err(OnnxError::UnsupportedShape(node.to_owned(), shape.to_string()));
1123 }
1124
1125 let axes = (2..shape.rank()).collect_vec();
1126 let result = graph.reduce(input, axes, op);
1127
1128 OnnxValue::Value(result)
1129 }
1130 "Resize" => {
1131 let input = inputs.required(0)?;
1133 let roi = inputs.optional(1);
1134 let scales = inputs.optional(2).map(|v| v.unwrap_value().unwrap());
1135 let sizes = inputs.optional(3);
1136
1137 let _antialias = attrs.maybe_take_bool("antialias")?.unwrap_or(false);
1138 let axes = attrs.maybe_take_ints("axes")?;
1139 let _coordinate_transformation_mode = attrs
1140 .maybe_take_string("coordinate_transformation_mode")?
1141 .unwrap_or("half_pixel");
1142 let _cubic_coeff_a = attrs.take_float("cubic_coeff_a")?;
1143 let _exclude_outside = attrs.maybe_take_int("exclude_outside")?.unwrap_or(0);
1144 let _extrapolation_value = attrs.maybe_take_float("extrapolation_value")?.unwrap_or(0.0);
1145 let keep_aspect_ratio_policy = attrs
1146 .maybe_take_string("keep_aspect_ratio_policy")?
1147 .unwrap_or("stretch")
1148 .to_owned();
1149 let mode = attrs.maybe_take_string("mode")?.unwrap_or("nearest").to_owned();
1150 let nearest_mode = attrs
1151 .maybe_take_string("nearest_mode")?
1152 .unwrap_or("round_prefer_floor")
1153 .to_owned();
1154
1155 assert!(
1157 mode == "nearest"
1158 && nearest_mode == "floor"
1159 && roi.is_none()
1160 && sizes.is_none()
1161 && axes.is_none()
1162 && keep_aspect_ratio_policy == "stretch",
1163 "The given resize operation is not supported"
1164 );
1165
1166 let scales = graph
1167 .as_const(scales.expect("Resize requires scales for now"))
1168 .expect("Resize only supported with constant scales");
1169
1170 let input_tensor = input.unwrap_value().unwrap();
1171 let input_shape = &graph[input_tensor].shape;
1172 let rank = input_shape.rank();
1173
1174 assert_eq!(
1175 scales.shape(),
1176 &[rank],
1177 "Scales must be a vector with length the input rank"
1178 );
1179
1180 let result = scales
1181 .unwrap_f32()
1182 .unwrap()
1183 .iter()
1184 .enumerate()
1185 .fold(input_tensor, |acc, (axis, &scale_f)| {
1186 let scale = scale_f as usize;
1187 assert_eq!(scale as f32, scale_f, "Only integer scales supported, got {:?}", scales);
1188 graph.repeat_interleave(acc, axis, Size::fixed(scale))
1189 });
1190
1191 OnnxValue::Value(result)
1192 }
1193 _ => {
1194 return Err(OnnxError::UnsupportedOperation(node.to_owned()));
1195 }
1196 };
1197
1198 Ok(vec![result_single])
1199}
1200
1201fn calculate_auto_padding(graph: &Graph, conv_rank: usize, input: Value, filter: Value, strides: &[i64], dilations: &[i64], up: bool) -> OnnxResult<Vec<i64>> {
1202 let (_, input_spatial_dims) = graph[input].shape.split(2);
1203 let input_spatial_dims = input_spatial_dims.unwrap_fixed("conv input spatial dims");
1204
1205 let (_, filter_spatial_dims) = graph[filter].shape.split(2);
1206 let filter_spatial_dims = filter_spatial_dims.unwrap_fixed("conv filter spatial dims");
1207
1208 let mut result = vec![];
1209 for i in 0..conv_rank {
1210 let (low, high) = split_padding(input_spatial_dims.dims[i] as i64, filter_spatial_dims.dims[i] as i64, strides[i], dilations[i], up);
1211 result.push(low);
1212 result.push(high);
1213 }
1214 Ok(result)
1215}
1216
1217fn split_padding(i: i64, f: i64, s: i64, d: i64, up: bool) -> (i64, i64) {
1218 let total = (i - 1) * s + 1 + d * (f - 1) - i;
1219
1220 let min = total / 2;
1221 let max = total - min;
1222
1223 if up {
1224 (min, max)
1225 } else {
1226 (max, min)
1227 }
1228}
1229
1230fn define_tensor_data(
1231 graph: &mut Graph,
1232 name: &str,
1233 tensor: &TensorProto,
1234 external: &mut dyn ExternalDataLoader,
1235) -> OnnxResult<Value> {
1236 let data_location = DataLocation::try_from(tensor.data_location).expect("Illegal data_location");
1237
1238 let dims = tensor.dims.iter().map(|&d| Size::fixed(d as usize)).collect_vec();
1240 let shape = Shape::new(dims);
1241 let size = shape.size().unwrap_fixed("Data tensor shape must be fixed");
1242
1243 let data_type = DataType::try_from(tensor.data_type).expect("Illegal data type");
1244 let dtype = resolve_dtype(data_type, name)?;
1245
1246 let length_guess = size * dtype.size().bytes();
1247
1248 let raw_data_slot;
1250 let raw_data = match data_location {
1251 DataLocation::Default => {
1252 &tensor.raw_data
1254 }
1255 DataLocation::External => {
1256 let mut location: Option<&str> = None;
1258 let mut offset: usize = 0;
1259 let mut length: Option<usize> = None;
1260
1261 for entry in &tensor.external_data {
1262 let key: &str = &entry.key;
1263 let value: &str = &entry.value;
1264
1265 match key {
1266 "location" => location = Some(value),
1267 "offset" => offset = value.parse().unwrap(),
1268 "length" => length = Some(value.parse().unwrap()),
1269 "hash" => {}
1270 _ => panic!("Invalid external_data key: {} (value {})", key, value),
1271 }
1272 }
1273
1274 if let Some(length) = length {
1275 assert_eq!(length, length_guess, "External data length mismatch");
1276 }
1277
1278 let location = location.expect("External data must have a location");
1280 raw_data_slot = external.load_external_data(&PathBuf::from(location), offset, length, length_guess)?;
1281
1282 if let Some(length) = length {
1283 assert_eq!(raw_data_slot.len(), length, "Raw data length mismatch");
1284 }
1285
1286 &raw_data_slot
1287 }
1288 };
1289
1290 macro_rules! read_type {
1291 (graph, $T:ty, $data:ident, None) => {{
1292 let data: Vec<$T> = if tensor.$data.is_empty() {
1293 raw_data.iter().map(|&x| x as $T).collect()
1294 } else {
1295 tensor.$data.iter().map(|&x| x as $T).collect()
1296 };
1297 graph.constant::<$T>(shape, data)
1298 }};
1299 (graph, $T:ty, $data:ident, $read:ident) => {{
1300 let data: Vec<$T> = if tensor.$data.is_empty() {
1301 let mut data = vec![Default::default(); size];
1302 LittleEndian::$read(raw_data, &mut data);
1303 data
1304 } else {
1305 tensor.$data.iter().map(|&x| x as $T).collect()
1306 };
1307 graph.constant::<$T>(shape, data)
1308 }};
1309 }
1310
1311 let value = match dtype {
1313 DType::F32 => read_type!(graph, f32, float_data, read_f32_into),
1314 DType::F64 => read_type!(graph, f64, double_data, read_f64_into),
1315 DType::I8 => read_type!(graph, i8, int32_data, None),
1316 DType::I16 => read_type!(graph, i16, int32_data, read_i16_into),
1317 DType::I32 => read_type!(graph, i32, int32_data, read_i32_into),
1318 DType::I64 => read_type!(graph, i64, int64_data, read_i64_into),
1319 DType::U8 => read_type!(graph, u8, int32_data, None),
1320 DType::U16 => read_type!(graph, u16, int32_data, read_u16_into),
1321 DType::U32 => read_type!(graph, u32, uint64_data, read_u32_into),
1322 DType::U64 => read_type!(graph, u64, uint64_data, read_u64_into),
1323 DType::Bool => {
1324 let data: Vec<DBool> = if tensor.int32_data.is_empty() {
1325 raw_data.iter().map(|&x| DBool(x != 0)).collect_vec()
1326 } else {
1327 tensor.int32_data.iter().map(|&x| DBool(x != 0)).collect()
1328 };
1329 graph.constant::<DBool>(shape, data)
1330 }
1331 };
1332
1333 Ok(value)
1334}
1335
1336fn resolve_tensor_type(ty: &TypeProto, name: &str, index: usize, input_shaper: &InputShaper) -> OnnxResult<(Shape, DType)> {
1337 let value = ty.value.as_ref().expect("Value doesn't have type set");
1338 let result = match value {
1339 ProtoTypeValue::TensorType(tensor) => {
1340 let data_type = DataType::try_from(tensor.elem_type).expect("Invalid data type");
1341 let dtype = resolve_dtype(data_type, name)?;
1342
1343 let dims = tensor
1344 .shape
1345 .as_ref()
1346 .expect("Tensor does not have shape set")
1347 .dim
1348 .iter()
1349 .map(|d| match *d.value.as_ref().expect("Missing value for dimension") {
1350 dimension::Value::DimValue(value) => OnnxDimValue::Value(value),
1351 dimension::Value::DimParam(ref param) => OnnxDimValue::Param(param.clone()),
1352 })
1353 .collect_vec();
1354
1355 let shape = input_shaper(&dims, name, index).ok_or_else(|| OnnxError::FailedToShapeInput(dims, name.to_owned(), index))?;
1356
1357 (shape, dtype)
1358 }
1359 _ => panic!("Unsupported value kind {:?}", value),
1360 };
1361 Ok(result)
1362}
1363
1364fn resolve_dtype(data_type: DataType, node: &str) -> OnnxResult<DType> {
1365 let dtype = match data_type {
1366 DataType::Float => DType::F32,
1367 DataType::Double => DType::F64,
1368 DataType::Uint8 => DType::U8,
1369 DataType::Int8 => DType::I8,
1370 DataType::Uint16 => DType::U16,
1371 DataType::Int16 => DType::I16,
1372 DataType::Int32 => DType::I32,
1373 DataType::Int64 => DType::I64,
1374 DataType::Uint32 => DType::U32,
1375 DataType::Uint64 => DType::U64,
1376 DataType::Bool => DType::Bool,
1377 DataType::Undefined
1378 | DataType::String
1379 | DataType::Complex64
1380 | DataType::Complex128
1381 | DataType::Float16
1382 | DataType::Bfloat16
1383 | DataType::Float8e4m3fn
1384 | DataType::Float8e4m3fnuz
1385 | DataType::Float8e5m2
1386 | DataType::Float8e5m2fnuz => return Err(OnnxError::UnsupportedType(node.to_owned(), data_type)),
1387 };
1388 Ok(dtype)
1389}
1390
1391fn abs_axis(axis: i64, rank: usize) -> usize {
1392 if axis == i64::MAX {
1393 rank
1394 } else if axis < 0 {
1395 rank - ((-axis) as usize)
1396 } else {
1397 axis as usize
1398 }
1399}
1400
1401#[track_caller]
1402fn unwrap_1(slice: &[i64]) -> usize {
1403 assert_eq!(slice.len(), 1, "Expected 1 element, got {:?}", slice);
1404 slice[0] as usize
1405}
1406
1407#[track_caller]
1408fn unwrap_2(slice: &[i64]) -> [usize; 2] {
1409 assert_eq!(slice.len(), 2, "Expected 2 elements, got {:?}", slice);
1410 [slice[0] as usize, slice[1] as usize]
1411}
1412
1413#[track_caller]
1414fn unwrap_4(slice: &[i64]) -> [usize; 4] {
1415 assert_eq!(slice.len(), 4, "Expected 4 elements, got {:?}", slice);
1416 [
1417 slice[0] as usize,
1418 slice[1] as usize,
1419 slice[2] as usize,
1420 slice[3] as usize,
1421 ]
1422}
1423
1424fn calculate_reshape_output_shape(old_shape: &Shape, new_shape_raw: &[SignedSize], allow_zero: bool) -> Shape {
1425 let old_size = old_shape.size();
1426
1427 let mut new_shape = vec![];
1428 let mut leftover_index = None;
1429 let mut leftover_size = old_size;
1430
1431 for (i, &signed_size) in new_shape_raw.iter().enumerate() {
1432 let size = match signed_size {
1433 SignedSize::ZERO => {
1434 if allow_zero {
1435 Size::ZERO
1436 } else {
1437 assert!(
1438 i < old_shape.rank(),
1439 "Cannot copy dim {} of output shape {:?}, not present in input {}",
1440 i,
1441 new_shape,
1442 old_shape,
1443 );
1444 old_shape[i]
1445 }
1446 }
1447 SignedSize::NEG_ONE => {
1448 assert!(
1449 leftover_index.is_none(),
1450 "Reshape shape can only contain a single -1 value"
1451 );
1452 leftover_index = Some(i);
1453 new_shape.push(Size::ZERO);
1454 continue;
1455 }
1456 signed_size => signed_size
1457 .to_size()
1458 .unwrap_or_else(|_| panic!("Reshape size must be positive, 0 or -1, got {:?}", signed_size)),
1459 };
1460
1461 leftover_size = (leftover_size / size).unwrap_or_else(|| {
1462 panic!("Cannot reshape {} into {:?}", old_size, new_shape_raw);
1463 });
1464 new_shape.push(size);
1465 }
1466
1467 if let Some(leftover_index) = leftover_index {
1468 new_shape[leftover_index] = leftover_size;
1469 }
1470
1471 let shape = Shape::new(new_shape);
1472 assert_eq!(old_size, shape.size(), "Output and input sizes differ");
1473
1474 shape
1475}
1476
1477fn eval_binary_op(op: BinaryOp, a: SignedSize, b: SignedSize) -> Option<SignedSize> {
1478 match op {
1480 BinaryOp::Add => a + b,
1481 BinaryOp::Sub => a - b,
1482 BinaryOp::Mul => Some(a * b),
1483 BinaryOp::Div => a.floor_div(b),
1484 _ => None,
1485 }
1486}
1487
1488fn load_model_proto(buf: &[u8]) -> ModelProto {
1489 let mut buf: &[u8] = buf;
1490 ModelProto::decode(&mut buf).unwrap()
1491}