kn_graph/onnx/
load.rs

1use std::path::PathBuf;
2
3use byteorder::{ByteOrder, LittleEndian};
4use itertools::{Itertools, zip_eq};
5use ndarray::{Axis, azip};
6use prost::Message;
7
8use crate::cpu::{cpu_flip, cpu_gather, cpu_slice};
9use crate::dtype::{DBool, dispatch_dtensor, DScalar, DTensor, DType, IntoDScalar, map_dtensor_pair, Tensor};
10use crate::graph::{
11    BinaryOp, broadcast_shape_symmetric, broadcast_tensors_symmetric, ReduceOp, SliceRange, UnaryOp, Value,
12};
13pub use crate::graph::Graph;
14use crate::onnx::external_data::ExternalDataLoader;
15use crate::onnx::inputs::{Attributes, Inputs};
16use crate::onnx::proto::{ModelProto, TensorProto, TypeProto};
17use crate::onnx::proto::tensor_proto::DataLocation;
18use crate::onnx::proto::tensor_proto::DataType;
19use crate::onnx::proto::tensor_shape_proto::dimension;
20use crate::onnx::proto::type_proto::Value as ProtoTypeValue;
21use crate::onnx::result::{Node, OnnxError, OnnxResult, UnwrapProto};
22use crate::onnx::store::Store;
23use crate::onnx::typed_value::{OnnxValue, SignedSize};
24use crate::shape;
25use crate::shape::{DivResult, Shape, Size};
26
27// TODO we should switch to taking an extra `HashMap<String, Size>` parameter,
28//   so the user can decide which named axes match to what size or even the batch size
29
30// TODO convert every possible panic to an error (even in the shape classes if possible)
31//    things to grep for: unwrap|expect|assert|panic
32//    introduce two main error kinds: "bug in file" and "unsupported"
33
34pub type InputShaper = dyn Fn(&[OnnxDimValue], &str, usize) -> Option<Shape>;
35
36#[derive(Debug, Clone)]
37pub enum OnnxDimValue {
38    Value(i64),
39    Param(String),
40}
41
42// we use &dyn to avoid duplicate codegen of this large and non-critical function
43pub fn graph_from_onnx_bytes(buf: &[u8], external: &mut dyn ExternalDataLoader, input_shaper: &InputShaper) -> OnnxResult<Graph> {
44    let model = load_model_proto(buf);
45    let model_graph = model.graph.as_ref().unwrap_proto("model.graph")?;
46
47    let mut graph = Graph::new();
48    let mut nodes: Store<OnnxValue> = Store::default();
49
50    // load initializer values (similar to constants but defined separately)
51    for tensor in &model_graph.initializer {
52        let value = define_tensor_data(&mut graph, &tensor.name, tensor, external)?;
53        nodes.define(&tensor.name, OnnxValue::Value(value))
54    }
55
56    // load inputs
57    let mut real_input_index = 0;
58    for input in &model_graph.input {
59        // initializers are allowed to re-appear in the inputs, so we skip them the second time
60        if nodes.contains(&input.name) {
61            continue;
62        }
63
64        let input_proto = input.r#type.as_ref().unwrap_proto("input.type")?;
65        let (shape, dtype) = resolve_tensor_type(input_proto, &input.name, real_input_index, input_shaper)?;
66        let value = graph.input(shape, dtype);
67        nodes.define(&input.name, OnnxValue::Value(value));
68
69        real_input_index += 1;
70    }
71
72    // clear newly defined values so we don't attribute them to the first node
73    let _ = graph.take_new_values();
74
75    // load nodes
76    for node_proto in &model_graph.node {
77        let node = Node {
78            name: node_proto.name.as_str(),
79            op_type: node_proto.op_type.as_str(),
80        };
81
82        let mut attrs = Attributes::from(node, &node_proto.attribute);
83        let mut inputs = Inputs::from(node, &node_proto.input, &nodes)?;
84
85        let values: Vec<OnnxValue> = visit_node(&mut graph, external, node, &mut inputs, &mut attrs)?;
86
87        // set debug id for all newly created nodes to the current node name
88        for value in graph.take_new_values() {
89            graph.set_debug_id(value, node.name.to_owned())
90        }
91
92        // check that the value if only a size if necessary
93        for value in &values {
94            value.assert_valid();
95        }
96
97        // check that we used all attributes and inputs
98        let leftover_attributes = attrs.leftover();
99        if !leftover_attributes.is_empty() {
100            return Err(OnnxError::LeftoverAttributes(node.to_owned(), leftover_attributes));
101        }
102        let leftover_inputs = inputs.leftover();
103        if !leftover_inputs.is_empty() {
104            return Err(OnnxError::LeftoverInputs(node.to_owned(), leftover_inputs));
105        }
106
107        // actually define the result values
108        let output_names = &node_proto.output;
109        assert_eq!(output_names.len(), values.len(), "Expected {:?} outputs, got {}", output_names, values.len());
110        for (name, value) in zip_eq(output_names, values) {
111            nodes.define(name, value);
112        }
113    }
114
115    for output in &model_graph.output {
116        let value_or_size = &nodes[output.name.as_str()];
117        let value = value_or_size
118            .unwrap_value()
119            .ok_or(OnnxError::ExpectedNonBatchValue(output.name.clone()))?;
120        graph.output(value);
121    }
122
123    Ok(graph)
124}
125
126fn visit_node(
127    graph: &mut Graph,
128    external: &mut dyn ExternalDataLoader,
129    node: Node<&str>,
130    inputs: &mut Inputs,
131    attrs: &mut Attributes,
132) -> OnnxResult<Vec<OnnxValue>> {
133    let result_single = match node.op_type {
134        "Conv" => {
135            let input = inputs.required(0)?.unwrap_value().unwrap();
136            let filter = inputs.required(1)?.unwrap_value().unwrap();
137            let bias_raw = inputs.optional(2).map(|v| v.unwrap_value().unwrap());
138
139            let groups = attrs.maybe_take_int("group")?.unwrap_or(1);
140            let kernel_shape = attrs.take_ints("kernel_shape")?;
141            let conv_rank = kernel_shape.len();
142            let strides = attrs.maybe_take_ints("strides")?
143                .map_or(vec![1; conv_rank], |strides| strides.to_vec());
144            let dilations = attrs.maybe_take_ints("dilations")?
145                .map_or(vec![1; conv_rank], |strides| strides.to_vec());
146
147            let auto_pad = attrs.maybe_take_string("auto_pad")?;
148
149            let padding = match auto_pad {
150                None | Some("NOTSET") => {
151                    // custom padding
152                    attrs.take_ints("pads")?.to_vec()
153                }
154                Some("SAME_UPPER") => {
155                    // input and output same size, excess on upper side of dim
156                    calculate_auto_padding(graph, conv_rank, input, filter, &strides, &dilations, true)?
157                }
158                Some("SAME_LOWER") => {
159                    // input and output same size, excess on lower side of dim
160                    calculate_auto_padding(graph, conv_rank, input, filter, &strides, &dilations, false)?
161                }
162                Some("VALID") => {
163                    // no padding
164                    vec![0; strides.len()]
165                }
166                Some(auto_pad) => return Err(OnnxError::InvalidAutoPadValue(node.to_owned(), auto_pad.to_owned()))
167            };
168
169            let filter_shape = graph[filter]
170                .shape
171                .unwrap_fixed("Convolution kernel shape must be fixed");
172
173            // always add bias in the 2D conv view domain, so it's easier to fuse later on
174            let bias = bias_raw.map(|bias| {
175                let bias_size = graph[bias].shape.unwrap_1();
176                let bias_view_shape = shape![1, bias_size, 1, 1];
177
178                graph.view(bias, bias_view_shape)
179            });
180
181            assert_eq!(1, groups);
182
183            let result = match conv_rank {
184                1 => {
185                    let kernel_size0 = unwrap_1(kernel_shape);
186                    let [padding_0, padding_1] = unwrap_2(&padding);
187                    let stride = unwrap_1(&strides);
188                    let dilation = unwrap_1(&dilations);
189
190                    let [_, _, kernel_size1] = filter_shape.unwrap_3();
191
192                    assert_eq!(padding_0, padding_1);
193                    assert!(dilation == 1 && stride == 1);
194                    assert_eq!(kernel_size0, kernel_size1);
195
196                    let input_extra = graph.view(input, graph[input].shape.clone().concat(&shape![1]));
197                    let filter_extra = graph.view(filter, graph[filter].shape.clone().concat(&shape![1]));
198
199                    let result_conv = graph.conv(input_extra, filter_extra, 1, 1, padding_0, 0);
200                    let result_biased = bias.map_or(result_conv, |bias| graph.add(result_conv, bias));
201
202                    let result_shape = graph[result_biased].shape.replace(3, shape![]);
203                    let result = graph.view(result_biased, result_shape);
204
205                    result
206                }
207                2 => {
208                    let [kernel_h0, kernel_w0] = unwrap_2(kernel_shape);
209                    let [padding_y0, padding_x0, padding_y1, padding_x1] = unwrap_4(&padding);
210                    let [stride_y, stride_x] = unwrap_2(&strides);
211                    let [dilation_y, dilation_x] = unwrap_2(&dilations);
212
213                    let [_, _, kernel_h1, kernel_w1] = filter_shape.unwrap_4();
214
215                    assert!(padding_y0 == padding_y1 && padding_x0 == padding_x1);
216                    assert!(dilation_y == 1 && dilation_x == 1);
217                    assert!(kernel_h1 == kernel_h0 && kernel_w1 == kernel_w0);
218
219                    let result_conv = graph.conv(input, filter, stride_y, stride_x, padding_y0, padding_x0);
220                    let result_biased = bias.map_or(result_conv, |bias| graph.add(result_conv, bias));
221
222                    result_biased
223                }
224                rank => return Err(OnnxError::UnsupportedNdConvolution(node.to_owned(), rank)),
225            };
226
227            OnnxValue::Value(result)
228        }
229        "Clip" => {
230            let input = inputs.required(0)?.unwrap_value().unwrap();
231            // these are optional since the older version of the operator used attributes instead
232            let input_min = inputs.optional(1);
233            let input_max = inputs.optional(2);
234
235            let result = match (input_min, input_max) {
236                (None, None) => {
237                    let min = attrs.take_float("min")?;
238                    let max = attrs.take_float("max")?;
239                    graph.clamp::<f32>(input, min, max)
240                }
241                (Some(min), Some(max)) => {
242                    let min = min.unwrap_value().unwrap();
243                    let max = max.unwrap_value().unwrap();
244                    assert_eq!(graph[min].shape, Shape::SCALAR);
245                    assert_eq!(graph[max].shape, Shape::SCALAR);
246
247                    let mid = graph.binary(BinaryOp::Min, input, max);
248                    let result = graph.binary(BinaryOp::Max, mid, min);
249                    result
250                }
251                _ => {
252                    let message = "Clip must have either 1 or 3 inputs, got 2".to_owned();
253                    return Err(OnnxError::InvalidOperationArgs(node.to_owned(), message));
254                }
255            };
256
257            OnnxValue::Value(result)
258        }
259        "Abs" | "Neg" | "Sin" | "Cos" | "Exp" | "Log" | "Sqrt" | "Sigmoid" | "Relu" | "Tanh" | "Erf" | "Mish" | "Softplus" => {
260            let input = inputs.required(0)?.unwrap_value().unwrap();
261
262            let result = match node.op_type {
263                "Abs" => graph.unary(UnaryOp::Abs, input),
264                "Neg" => graph.unary(UnaryOp::Neg, input),
265                "Sin" => graph.unary(UnaryOp::Sin, input),
266                "Cos" => graph.unary(UnaryOp::Cos, input),
267                "Exp" => graph.unary(UnaryOp::Exp, input),
268                "Log" => graph.unary(UnaryOp::Log, input),
269                "Sqrt" => graph.unary(UnaryOp::Sqrt, input),
270                "Sigmoid" => graph.unary(UnaryOp::Sigmoid, input),
271                "Relu" => graph.relu(input),
272                "Tanh" => graph.unary(UnaryOp::Tanh, input),
273                "Erf" => graph.unary(UnaryOp::Erf, input),
274                "Softplus" => graph.unary(UnaryOp::Softplus, input),
275                _ => unreachable!("missing {:?}", node.op_type),
276            };
277
278            OnnxValue::Value(result)
279        }
280        "Add" | "Sub" | "Mul" | "Div" | "Min" | "Max" | "Pow" => {
281            let op = match node.op_type {
282                "Add" => BinaryOp::Add,
283                "Sub" => BinaryOp::Sub,
284                "Mul" => BinaryOp::Mul,
285                "Div" => BinaryOp::Div,
286                "Min" => BinaryOp::Min,
287                "Max" => BinaryOp::Max,
288                "Pow" => BinaryOp::Pow,
289                _ => unreachable!("missing {:?}", node.op_type),
290            };
291
292            let left = inputs.required(0)?;
293            let right = inputs.required(1)?;
294
295            if let (&OnnxValue::Value(left), &OnnxValue::Value(right)) = (left, right) {
296                // keep values as values
297                OnnxValue::Value(graph.binary(op, left, right))
298            } else {
299                // decay to shape
300                let left = left.as_size(graph)?;
301                let right = right.as_size(graph)?;
302
303                let (left, right) = broadcast_tensors_symmetric(&left, &right);
304
305                let result = azip!(&left, &right).map_collect(|&l, &r| {
306                    eval_binary_op(op, l, r)
307                        .unwrap_or_else(|| panic!("Operation {:?} failed between {:?} and {:?}", op, left, right))
308                });
309
310                // the batch size might have cancelled out!
311                OnnxValue::new_size(result.into_shared(), graph)
312            }
313        }
314        "Equal" => {
315            let left = inputs.required(0)?;
316            let right = inputs.required(1)?;
317
318            let result = match (left, right) {
319                (&OnnxValue::Value(left), &OnnxValue::Value(right)) => {
320                    // subtract and cast to bool
321                    // this automatically broadcasts correctly
322                    let diff = graph.sub(left, right);
323                    graph.unary(UnaryOp::ValueCast(DType::Bool), diff)
324                }
325                (OnnxValue::Size(left), OnnxValue::Size(right)) => {
326                    // broadcast and compare
327                    // TODO we consider batch and ints always not-equal, even though they theoretically could be
328                    let (left, right) = broadcast_tensors_symmetric(&left, &right);
329
330                    let result = azip!(left, right).map_collect(|l, r| DBool(l == r)).into_shared();
331                    graph.constant_tensor(DTensor::Bool(result))
332                }
333                _ => {
334                    // one contains batch, the other doesn't => they can't be equal
335                    // return false of the right shape
336                    let broadcast_shape = broadcast_shape_symmetric(&left.shape(graph), &right.shape(graph));
337                    let scalar = graph.scalar(DBool(false));
338                    graph.broadcast(scalar, broadcast_shape)
339                }
340            };
341
342            OnnxValue::Value(result)
343        }
344        "Where" => {
345            // TODO extend to non-consts and shapes
346            let cond = inputs.required(0)?.unwrap_value().unwrap();
347            let x = inputs.required(1)?.unwrap_value().unwrap();
348            let y = inputs.required(2)?.unwrap_value().unwrap();
349
350            let cond = graph.as_const(cond).unwrap();
351            let cond = cond.unwrap_bool().unwrap();
352            let x = graph.as_const(x).unwrap();
353            let y = graph.as_const(y).unwrap();
354
355            // TODO proper broadcasting
356            assert_eq!(cond.shape(), x.shape(), "Where broadcasting not yet implemented");
357            assert_eq!(cond.shape(), y.shape(), "Where broadcasting not yet implemented");
358
359            let result = map_dtensor_pair!(x, y, |x, y| {
360                azip!(cond, &x, &y)
361                    .map_collect(|&DBool(c), &x, &y| if c { x } else { y })
362                    .into_shared()
363            });
364
365            OnnxValue::Value(graph.constant_tensor(result))
366        }
367        "Flatten" => {
368            let input = inputs.required(0)?;
369
370            // figure out the axis
371            let rel_axis = attrs.maybe_take_int("axis")?.unwrap_or(1);
372            let axis = abs_axis(rel_axis, input.shape(graph).rank());
373
374            // figure out new shape
375            let kept_shape = &input.shape(graph).dims[..axis];
376            let flat_shape = input.shape(graph).dims[axis..].iter().copied().product::<Size>();
377
378            let mut new_shape = kept_shape.to_vec();
379            new_shape.push(flat_shape);
380
381            // strange special case in onnx spec, insert additional 1 axis
382            if axis == 0 {
383                new_shape.insert(0, Size::ONE);
384            }
385            let new_shape = Shape::new(new_shape);
386
387            // apply view operation
388            match input {
389                &OnnxValue::Value(input) => OnnxValue::Value(graph.view(input, new_shape)),
390                OnnxValue::Size(input) => {
391                    OnnxValue::new_size(input.reshape(new_shape.unwrap_fixed("size shape").dims), graph)
392                }
393            }
394        }
395        "Gemm" => {
396            let input = inputs.required(0)?.unwrap_value().unwrap();
397            let weight = inputs.required(1)?.unwrap_value().unwrap();
398            let bias = inputs.optional(2).map(|v| v.unwrap_value().unwrap());
399
400            let alpha = attrs.take_float("alpha")?;
401            let beta = attrs.take_float("beta")?;
402            let trans_b = attrs.take_int("transB")? != 0;
403
404            assert_eq!(1.0, alpha);
405            assert_eq!(1.0, beta);
406            assert!(trans_b);
407
408            let linear = graph.linear(input, weight);
409
410            let result = if let Some(bias) = bias {
411                let bias_len = graph[bias].shape.unwrap_1();
412                let bias_view_shape = shape![1, bias_len];
413                let bias_view = graph.view(bias, bias_view_shape);
414
415                graph.add(linear, bias_view)
416            } else {
417                linear
418            };
419
420            OnnxValue::Value(result)
421        }
422        "MatMul" => {
423            let left = inputs.required(0)?.unwrap_value().unwrap();
424            let right = inputs.required(1)?.unwrap_value().unwrap();
425
426            // TODO we're still missing support for 1D operand broadcasting, but that should be pretty rare
427            let result = graph.mat_mul(left, right);
428            OnnxValue::Value(result)
429        }
430        "Einsum" => {
431            let inputs = inputs.take_all_variadic();
432            let equation = attrs.take_string("equation")?;
433
434            let equation_compact = equation.replace(' ', "");
435
436            // TODO for now we hardcode some typical einsum operations, replace this with a general implementation
437            //   look into "tensor primitives" and optimal "contractions"?
438            match equation_compact.as_ref() {
439                "bid,bjd->bij" => {
440                    assert_eq!(inputs.len(), 2);
441                    let left = inputs[0].unwrap_value().unwrap();
442                    let right = inputs[1].unwrap_value().unwrap();
443
444                    let right_transpose = graph.permute(right, vec![0, 2, 1]);
445                    let result = graph.batched_mat_mul(left, right_transpose);
446                    OnnxValue::Value(result)
447                }
448                "bij,bjd->bid" => {
449                    assert_eq!(inputs.len(), 2);
450                    let left = inputs[0].unwrap_value().unwrap();
451                    let right = inputs[1].unwrap_value().unwrap();
452
453                    let result = graph.batched_mat_mul(left, right);
454                    OnnxValue::Value(result)
455                }
456                _ => panic!(
457                    "Einsum with inputs equation {:?} and inputs {:?} not yet supported",
458                    equation, inputs
459                ),
460            }
461        }
462        // TODO ensure the optimizer can fuse the scale/eps/var and mean/bias operations
463        "BatchNormalization" => {
464            let input = inputs.required(0)?.unwrap_value().unwrap();
465
466            let input_scale = inputs.required(1)?.unwrap_value().unwrap();
467            let input_bias = inputs.required(2)?.unwrap_value().unwrap();
468            let input_mean = inputs.required(3)?.unwrap_value().unwrap();
469            let input_variance = inputs.required(4)?.unwrap_value().unwrap();
470
471            let epsilon = attrs.take_float("epsilon")?;
472            let _ = attrs.take_float("momentum")?;
473            let spatial = attrs.maybe_take_int("spatial")?;
474            assert!(
475                spatial == None || spatial == Some(1),
476                "non-spatial cases are not supported and have been deprecated since ONNX version 9"
477            );
478
479            // figure out the shapes
480            let input_shape = &graph[input].shape;
481            assert!(input_shape.rank() >= 2, "BN input must have at least rank 2");
482
483            let channels = input_shape[1];
484            let shape_vec = shape![channels];
485            let shape_exp = input_shape.keep(1, Size::ONE);
486
487            for param in [input_scale, input_bias, input_mean, input_variance] {
488                assert_eq!(graph[param].shape, shape_vec);
489            }
490
491            // put everything into the graph
492            let result = {
493                let value_eps = graph.scalar(epsilon);
494
495                let exp_scale = graph.view(input_scale, shape_exp.clone());
496                let exp_bias = graph.view(input_bias, shape_exp.clone());
497                let exp_mean = graph.view(input_mean, shape_exp.clone());
498                let exp_variance = graph.view(input_variance, shape_exp);
499
500                let div_squared = graph.add(exp_variance, value_eps);
501                let div = graph.unary(UnaryOp::Sqrt, div_squared);
502
503                let x = input;
504                let x_mean = graph.sub(x, exp_mean);
505                let x_div = graph.binary(BinaryOp::Div, x_mean, div);
506                let x_scale = graph.mul(x_div, exp_scale);
507                let x_bias = graph.add(x_scale, exp_bias);
508                x_bias
509            };
510
511            OnnxValue::Value(result)
512        }
513        "InstanceNormalization" | "LayerNormalization" => {
514            let (input, start_axis, epsilon, scale_bias) = match node.op_type {
515                "InstanceNormalization" => {
516                    let input = inputs.required(0)?.unwrap_value().unwrap();
517                    let scale = inputs.required(1)?.unwrap_value().unwrap();
518                    let bias = inputs.required(2)?.unwrap_value().unwrap();
519                    let epsilon = attrs.take_float("epsilon")?;
520
521                    let broadcast_shape = graph[input].shape.keep(1, Size::ONE);
522                    let scale_broadcast = graph.view(scale, broadcast_shape.clone());
523                    let bias_broadcast = graph.view(bias, broadcast_shape);
524
525                    (input, 2, epsilon, Some((scale_broadcast, bias_broadcast)))
526                },
527                "LayerNormalization" => {
528                    let input = inputs.required(0)?.unwrap_value().unwrap();
529                    let scale = inputs.required(1)?.unwrap_value().unwrap();
530                    let bias = inputs.required(2)?.unwrap_value().unwrap();
531                    let axis = attrs.maybe_take_int("axis")?.unwrap_or(-1);
532                    let epsilon = attrs.maybe_take_float("epsilon")?.unwrap_or(1e-05);
533
534                    let input_shape = graph[input].shape.clone();
535                    let scale_broadcast = graph.broadcast(scale, input_shape.clone());
536                    let bias_broadcast = graph.broadcast(bias, input_shape);
537
538                    let axis = abs_axis(axis, graph[input].shape.rank());
539
540                    (input, axis, epsilon, Some((scale_broadcast, bias_broadcast)))
541                },
542                _ => unreachable!("missing {:?}", node.op_type)
543            };
544
545            let shape = graph[input].shape.clone();
546            assert!(
547                shape.rank() >= start_axis,
548                "Input rank must be >= {} (the start axis), for the the batch and channel axes, got {}",
549                start_axis,
550                shape
551            );
552
553            let (shape_keep, shape_reduced) = shape.split(start_axis);
554            let shape_flat = shape_keep.concat(&shape![shape_reduced.size()]);
555
556            let input_flat = graph.view(input, shape_flat);
557            let norm_flat = graph.layernorm(input_flat, start_axis, epsilon);
558            let norm = graph.view(norm_flat, shape);
559
560            let result = if let Some((scale, bias)) = scale_bias {
561                let scaled = graph.mul(norm, scale);
562                let result = graph.add(scaled, bias);
563                result
564            } else {
565                norm
566            };
567
568            OnnxValue::Value(result)
569        }
570        "Constant" => {
571            let tensor = attrs.take_tensor("value")?;
572            let value = define_tensor_data(graph, node.name, tensor, external)?;
573            OnnxValue::Value(value)
574        }
575        "ConstantOfShape" => {
576            let shape = inputs.optional(0);
577
578            let shape = match shape {
579                None => Shape::SCALAR,
580                Some(shape) => shape.as_shape(graph)?,
581            };
582
583            let value = match attrs.maybe_take_tensor("value")? {
584                None => graph.scalar(0f32),
585                Some(tensor) => define_tensor_data(graph, node.name, tensor, external)?,
586            };
587
588            // TODO force scalar value? spec is unclear
589            assert_eq!(
590                graph[value].shape.size(),
591                Size::ONE,
592                "value must be a one-element tensor"
593            );
594
595            let scalar = graph.view(value, Shape::SCALAR);
596            let result = graph.broadcast(scalar, shape);
597            OnnxValue::Value(result)
598        }
599        "Cast" => {
600            let input = inputs.required(0)?;
601            let data_type = DataType::try_from(attrs.take_int("to")? as i32).expect("Invalid data type");
602            let dtype = resolve_dtype(data_type, node.name)?;
603
604            match input {
605                &OnnxValue::Value(value) => OnnxValue::Value(graph.unary(UnaryOp::ValueCast(dtype), value)),
606                OnnxValue::Size(value) => {
607                    // only allow no-op casts for now
608                    assert_eq!(dtype, DType::I64);
609                    OnnxValue::new_size(value.clone(), graph)
610                }
611            }
612        }
613        "Reshape" => {
614            let input = inputs.required(0)?;
615            let new_shape = inputs.required(1)?.as_signed_shape(graph)?;
616            let allow_zero = attrs.maybe_take_bool("allowzero")?.unwrap_or(false);
617
618            let old_shape = input.shape(graph);
619            let output_shape = calculate_reshape_output_shape(&old_shape, &new_shape, allow_zero);
620
621            match input {
622                &OnnxValue::Value(input) => OnnxValue::Value(graph.view(input, output_shape)),
623                OnnxValue::Size(input) => {
624                    let result = input.reshape(output_shape.unwrap_fixed("reshape shape").dims.clone());
625                    OnnxValue::new_size(result, graph)
626                }
627            }
628        }
629        "Expand" => {
630            let input = inputs.required(0)?;
631            let shape = inputs.required(1)?.as_shape(graph)?;
632
633            // "Expand" is a symmetric broadcast, not just a directional one
634            let result_shape = broadcast_shape_symmetric(&input.shape(&graph), &shape);
635
636            match input {
637                &OnnxValue::Value(input) => OnnxValue::Value(graph.broadcast(input, result_shape)),
638                OnnxValue::Size(input) => {
639                    let result_shape = result_shape.unwrap_fixed("expand shape").dims.clone();
640                    let result = input.broadcast(result_shape).unwrap().to_shared();
641                    OnnxValue::new_size(result, graph)
642                }
643            }
644        }
645        "Unsqueeze" => {
646            let input = inputs.required(0)?;
647
648            let rel_axes = match inputs.optional(1) {
649                Some(rel_axes) => {
650                    let shape = rel_axes.as_signed_shape(graph)?;
651                    shape.iter().map(|d| d.unwrap_fixed().unwrap()).collect_vec()
652                }
653                None => attrs.take_ints("axes")?.to_vec(),
654            };
655
656            // calculate output shape
657            let input_shape = input.shape(graph);
658
659            let output_rank = input_shape.rank() + rel_axes.len();
660            let axes = rel_axes.iter().map(|&a| abs_axis(a, output_rank)).collect_vec();
661
662            assert!(
663                axes.iter().all_unique() && axes.iter().all(|&a| a < output_rank),
664                "Invalid axis {:?} for input rank {} in Unsqueeze",
665                axes,
666                input_shape.rank(),
667            );
668
669            let mut input_shape_left = input_shape.dims.iter().copied();
670            let output_dims = (0..output_rank)
671                .map(|i| {
672                    if axes.contains(&i) {
673                        Size::ONE
674                    } else {
675                        input_shape_left.next().unwrap()
676                    }
677                })
678                .collect_vec();
679            assert_eq!(input_shape_left.len(), 0);
680
681            let output_shape = Shape::new(output_dims);
682
683            // map value
684            match input {
685                &OnnxValue::Value(input) => OnnxValue::Value(graph.view(input, output_shape)),
686                OnnxValue::Size(input) => {
687                    let result_shape = output_shape.unwrap_fixed("unsqueeze shape").dims;
688                    let result = input.reshape(result_shape);
689                    OnnxValue::new_size(result, graph)
690                }
691            }
692        }
693        "Transpose" => {
694            let input = inputs.required(0)?;
695
696            let permutation = attrs.take_ints("perm")?;
697            let permutation = permutation.iter().map(|&x| x as usize).collect_vec();
698
699            match input {
700                &OnnxValue::Value(input) => OnnxValue::Value(graph.permute(input, permutation)),
701                OnnxValue::Size(input) => {
702                    let result = input.to_shared().permuted_axes(permutation);
703                    OnnxValue::new_size(result, graph)
704                }
705            }
706        }
707        "Gather" => {
708            let input = inputs.required(0)?;
709            let indices_raw = inputs.required(1)?;
710            let rel_axis = attrs.maybe_take_int("axis")?.unwrap_or(0);
711
712            let input_shape = input.shape(graph);
713            let axis = abs_axis(rel_axis, input_shape.rank());
714            let axis_size = input.shape(graph).dims[axis];
715
716            let indices = match indices_raw {
717                &OnnxValue::Value(indices) => {
718                    match graph.as_const(indices) {
719                        Some(indices) => {
720                            let dim = axis_size.unwrap_fixed("gather dim size");
721                            let indices = dispatch_dtensor!(indices, |T, ft, indices| {
722                                // this is super cursed but it seems to work
723                                let zero = T::from_dscalar(T::DTYPE.specials().zero).unwrap();
724                                let dim = T::from_dscalar(DScalar::U64(dim as u64).value_cast(T::DTYPE)).unwrap();
725
726                                ft(indices.mapv(|x| if x < zero { x + dim } else { x }).into_shared())
727                            });
728                            OnnxValue::Value(graph.constant_tensor(indices))
729                        }
730                        // TODO support dynamic negative indices, by properly remapping in the graph
731                        //   for now just hope for the best
732                        None => OnnxValue::Value(indices),
733                    }
734                }
735                OnnxValue::Size(indices) => {
736                    let indices = indices.mapv(|x| {
737                        if x.is_neg() {
738                            (x + axis_size).expect("gather negative index overflow")
739                        } else {
740                            x
741                        }
742                    });
743                    OnnxValue::new_size(indices.into_shared(), graph)
744                }
745            };
746
747            match input {
748                &OnnxValue::Value(input) => {
749                    let result = graph.gather(input, axis, indices.unwrap_value().unwrap());
750                    OnnxValue::Value(result)
751                }
752                OnnxValue::Size(input) => {
753                    let indices = graph.as_const(indices.unwrap_value().unwrap()).unwrap();
754
755                    // shape trickery to support multi-dim gathers
756                    let indices_flat = indices.reshape(vec![indices.len()]);
757                    let result_flat = cpu_gather(input, axis, indices_flat);
758                    let mut result_shape = input.shape().to_owned();
759                    result_shape.splice(axis..axis + 1, indices.shape().iter().copied());
760                    let result = result_flat.reshape(result_shape);
761
762                    OnnxValue::new_size(result, graph)
763                }
764            }
765        }
766        "Slice" => {
767            let get = |inputs: &mut Inputs, attrs: &mut Attributes, index: usize, name: &str| -> OnnxResult<_> {
768                match inputs.optional(index) {
769                    Some(value) => {
770                        let value = graph.as_const(value.unwrap_value().unwrap()).unwrap();
771
772                        assert_eq!(
773                            value.shape().len(),
774                            1,
775                            "Slice operand {} must be 1D const, got shape {:?}",
776                            name,
777                            value.shape()
778                        );
779
780                        let vec = match value {
781                            DTensor::I64(value) => value.iter().copied().collect_vec(),
782                            DTensor::I32(value) => value.iter().map(|&x| x as i64).collect_vec(),
783                            _ => panic!("Invalid slice operand type {:?}", value.dtype()),
784                        };
785
786                        Ok(Some(vec))
787                    }
788                    None => Ok(attrs.maybe_take_ints(name)?.map(|v| v.to_vec())),
789                }
790            };
791
792            let input = inputs.required(0)?;
793            let starts = get(inputs, attrs, 1, "starts")?.expect("Missing starts input and attribute");
794            let ends = get(inputs, attrs, 2, "ends")?.expect("Missing ends input and attribute");
795
796            let slice_rank = starts.len();
797            let axes = get(inputs, attrs, 3, "axes")?.unwrap_or_else(|| (0..slice_rank as i64).collect_vec());
798            let steps = get(inputs, attrs, 4, "steps")?.unwrap_or_else(|| vec![1; slice_rank]);
799
800            assert!(
801                slice_rank == ends.len() && slice_rank == axes.len() && slice_rank == steps.len(),
802                "Inconsistent axes count"
803            );
804            assert!(
805                axes.iter().all_unique(),
806                "Slice axis cannot be repeated, got {:?}",
807                axes
808            );
809
810            // TODO properly clamp ends (and follow all of the other slicing rules, there are a lot of them)
811            let input_shape = input.shape(graph);
812
813            (0..slice_rank).fold(input.clone(), |curr, i| {
814                let axis = abs_axis(axes[i], input_shape.rank());
815                let axis_size = input_shape[axis].try_unwrap_fixed();
816
817                let step = steps[i];
818                assert_ne!(step, 0, "Step cannot be 0");
819
820                if step > 0 {
821                    // allow slicing the entire batch axis which is effectively just skipping
822                    if axis_size.is_none() && starts[i] == 0 && ends[i] == i32::MAX as i64 {
823                        curr
824                    } else {
825                        let axis_size = input_shape[axis].unwrap_fixed("Slice axis size");
826                        let start = abs_axis(starts[i], axis_size);
827                        let end = abs_axis(ends[i], axis_size);
828
829                        let range = SliceRange::new(start, end, step as usize);
830
831                        // slice
832                        match curr {
833                            OnnxValue::Value(curr) => OnnxValue::Value(graph.slice(curr, axis, range)),
834                            OnnxValue::Size(curr) => OnnxValue::Size(cpu_slice(&curr, axis, range)),
835                        }
836                    }
837                } else {
838                    // TODO support all negative strides?
839                    assert!(
840                        starts[i] == -1 && ends[i] == i64::MIN && step == -1,
841                        "Only simple flip negative stride supported for now"
842                    );
843
844                    // flip
845                    match curr {
846                        OnnxValue::Value(curr) => OnnxValue::Value(graph.flip(curr, axis)),
847                        OnnxValue::Size(curr) => OnnxValue::Size(cpu_flip(&curr, axis)),
848                    }
849                }
850            })
851        }
852        "Concat" => {
853            let inputs = inputs.take_all_variadic();
854            assert!(!inputs.is_empty(), "Must concatenate at least one value");
855
856            let rank = inputs[0].shape(graph).rank();
857
858            let rel_axis = attrs.take_int("axis")?;
859            let axis = abs_axis(rel_axis, rank);
860
861            let any_shape = inputs.iter().any(|x| matches!(x, OnnxValue::Size(_)));
862
863            if any_shape {
864                let tensors: Vec<_> = inputs.iter().map(|x| x.as_size(graph)).try_collect()?;
865                let views: Vec<_> = tensors.iter().map(|x| x.view()).collect();
866                let result = ndarray::concatenate(Axis(axis), &views).unwrap().into_shared();
867                OnnxValue::new_size(result, graph)
868            } else {
869                let inputs = inputs.iter().map(|v| v.unwrap_value().unwrap()).collect();
870                let result = graph.concat(inputs, axis, None, None);
871                OnnxValue::Value(result)
872            }
873        }
874        "Split" => {
875            // TODO support "num_outputs" and "split" attribute/input
876            let input = inputs.required(0)?;
877            let split = inputs.optional(1).map(|s| s.as_shape(graph).unwrap());
878
879            let shape = input.shape(graph);
880            let axis = attrs.take_int("axis")?;
881            let axis = abs_axis(axis, shape.rank());
882
883            let ranges = match split {
884                None => {
885                    // two equally sized outputs by default
886                    let len = shape[axis].unwrap_fixed("Split axis length");
887                    let num_outputs = 2;
888                    let len_first = (len + num_outputs - 1) / num_outputs;
889                    vec![SliceRange::simple(0, len_first), SliceRange::simple(len_first, len)]
890                }
891                Some(split) => {
892                    split.unwrap_fixed("split lengths").dims.iter()
893                        .scan(0, |a, s| {
894                            let start = *a;
895                            let end = start + s;
896                            *a = end;
897                            Some(SliceRange::simple(start, end))
898                        })
899                        .collect_vec()
900                }
901            };
902
903            let result = match input {
904                &OnnxValue::Value(input) => {
905                    ranges.iter().map(|&r| OnnxValue::Value(graph.slice(input, axis, r))).collect_vec()
906                }
907                OnnxValue::Size(input) => {
908                    ranges.iter().map(|&r| OnnxValue::new_size(cpu_slice(&input, axis, r), graph)).collect_vec()
909                }
910            };
911
912            return Ok(result);
913        }
914        "Pad" => {
915            // operands
916            let input = inputs.required(0)?.unwrap_value().unwrap();
917            let pads = inputs.required(1)?.unwrap_value().unwrap();
918            let constant_value = inputs.optional(2);
919            let axes = inputs.optional(3);
920            let mode = attrs.maybe_take_string("mode")?.unwrap_or("constant");
921
922            // map operands
923            let input_shape = &graph[input].shape.clone();
924
925            let constant_value = constant_value
926                .map(|v| {
927                    graph
928                        .as_single_const(v.unwrap_value().unwrap())
929                        .unwrap()
930                        .unwrap_f32()
931                        .unwrap()
932                })
933                .unwrap_or(0.0);
934
935            let axes = match axes {
936                Some(axes) => {
937                    let axes = axes.as_signed_shape(graph)?;
938                    axes.iter()
939                        .map(|&i| abs_axis(i.unwrap_fixed().unwrap(), input_shape.rank()))
940                        .collect_vec()
941                }
942                None => (0..input_shape.rank()).collect_vec(),
943            };
944
945            let pads = graph.as_const(pads).unwrap();
946            let pads = pads.unwrap_i64().unwrap();
947            assert_eq!(pads.shape(), &[axes.len() * 2], "Pads and axes shape mismatch");
948            let pads = pads.iter().copied().collect_vec();
949
950            assert_eq!(mode, "constant", "Only 'constant' pad mode supported");
951
952            let constant = graph.scalar(constant_value);
953
954            // TODO consider adding a real pad operation instead of this concat workaround
955            let output = axes.iter().fold(input, |acc, &axis| {
956                let acc_shape = graph[acc].shape.clone();
957
958                let pad_left = pads[axis];
959                let pad_right = pads[axes.len() + axis];
960                assert!(pad_left >= 0 && pad_right >= 0, "Pad values cannot be negative");
961
962                let blocks = vec![
963                    graph.broadcast(constant, acc_shape.replace(axis, shape![pad_left as usize])),
964                    acc,
965                    graph.broadcast(constant, acc_shape.replace(axis, shape![pad_right as usize])),
966                ];
967                graph.concat(blocks, axis, None, None)
968            });
969
970            OnnxValue::Value(output)
971        }
972        "Shape" => {
973            let input = inputs.required(0)?;
974            let shape = input.shape(graph);
975            let dims = shape
976                .dims
977                .iter()
978                .map(|&d| SignedSize::from_size(d).unwrap())
979                .collect_vec();
980            OnnxValue::new_size(Tensor::from_shape_vec(vec![dims.len()], dims).unwrap(), graph)
981        }
982        "Identity" => {
983            let input = inputs.required(0)?;
984            input.clone()
985        }
986        "Softmax" => {
987            let input = inputs.required(0)?.unwrap_value().unwrap();
988
989            let shape = graph[input].shape.clone();
990            let axis = attrs.maybe_take_int("axis")?.unwrap_or(-1);
991            let axis = abs_axis(axis, shape.rank());
992
993            OnnxValue::Value(graph.softmax(input, axis))
994        }
995        "ReduceSum" | "ReduceMean" | "ReduceProd" | "ReduceMin" | "ReduceMax" => {
996            let op = match node.op_type {
997                "ReduceSum" => ReduceOp::Sum,
998                "ReduceMean" => ReduceOp::Mean,
999                "ReduceProd" => ReduceOp::Prod,
1000                "ReduceMin" => ReduceOp::Min,
1001                "ReduceMax" => ReduceOp::Max,
1002                _ => unreachable!("missing {:?}", node.op_type),
1003            };
1004
1005            let input = inputs.required(0)?.unwrap_value().unwrap();
1006            let input_shape = graph[input].shape.clone();
1007
1008            let axes = attrs.maybe_take_ints("axes")?.map_or_else(
1009                || (0..input_shape.rank()).collect_vec(),
1010                |axes| axes.iter().map(|&a| abs_axis(a, input_shape.rank())).collect_vec(),
1011            );
1012            let keep_dims = attrs.maybe_take_int("keepdims")?.unwrap_or(1) != 0;
1013
1014            let result_shape = if keep_dims {
1015                input_shape.replace_all(&axes, shape![1])
1016            } else {
1017                input_shape.replace_all(&axes, shape![])
1018            };
1019
1020            let result = graph.reduce(input, axes, op);
1021            let result_shaped = graph.view(result, result_shape);
1022
1023            OnnxValue::Value(result_shaped)
1024        }
1025        "MaxPool" | "AveragePool" => {
1026            let op = match node.op_type {
1027                "MaxPool" => ReduceOp::Max,
1028                "AveragePool" => ReduceOp::Mean,
1029                _ => unreachable!("missing {:?}", node.op_type),
1030            };
1031
1032            let input = inputs.required(0)?.unwrap_value().unwrap();
1033
1034            let strides = attrs.take_ints("strides")?;
1035            let kernel_shape = attrs.take_ints("kernel_shape")?;
1036            let pads = attrs.take_ints("pads")?;
1037            let ceil_mode = attrs.maybe_take_int("ceil_mode")?.unwrap_or(0) != 0;
1038            let auto_pad = attrs.maybe_take_string("auto_pad")?;
1039
1040            assert_eq!(strides, kernel_shape, "Real strides not supported yet");
1041            assert_eq!(pads, &vec![0; pads.len()], "Padding not supported yet");
1042            assert!(matches!(auto_pad, None | Some("NOTSET")), "Auto padding not supported yet");
1043
1044            // max pool the last N dimensions:
1045            // split each pooled axis into (input_size/kernel_size, kernel_size), then max pool over all kernel sizes
1046            let raw_input_shape = &graph[input].shape;
1047            let input_rank = raw_input_shape.rank();
1048            let kernel_rank = kernel_shape.len();
1049
1050            let kept_rank = input_rank - kernel_rank;
1051            let (batch_shape, active_shape) = raw_input_shape.split(kept_rank);
1052
1053            // calculate padding and reshaping
1054            let mut pad_amounts = vec![(0, 0); kept_rank];
1055            let mut reshape = batch_shape.dims.clone();
1056            let mut pooled_dims = vec![];
1057            let mut slices = vec![None; kept_rank];
1058
1059            for i in 0..kernel_rank {
1060                let kernel_size = kernel_shape[i] as usize;
1061                let input_size = active_shape.dims[i];
1062
1063                let div_rem = input_size.div_rem(kernel_size);
1064                let (left, pad, slice) = match div_rem {
1065                    DivResult::Exact(left) => {
1066                        (left, (0, 0), None)
1067                    }
1068                    DivResult::Remainder(rem) => {
1069                        if ceil_mode {
1070                            let pad = kernel_size - rem;
1071                            let left = ((input_size + pad).unwrap() / kernel_size).unwrap();
1072                            (left, (0, pad), None)
1073                        } else {
1074                            let left = ((input_size - rem).unwrap() / kernel_size).unwrap();
1075                            let end = left.unwrap_fixed("pool dim size") * kernel_size;
1076                            let slice = SliceRange::new(0, end, 1);
1077                            (left, (0, 0), Some(slice))
1078                        }
1079                    },
1080                    DivResult::Impossible => {
1081                        return Err(OnnxError::NonDividingPooling(node.to_owned(), raw_input_shape.clone(), kernel_shape.to_vec()));
1082                    }
1083                };
1084
1085                pad_amounts.push(pad);
1086                reshape.push(left);
1087                pooled_dims.push(reshape.len());
1088                reshape.push(Size::fixed(kernel_size));
1089                slices.push(slice);
1090            }
1091            let reshape = Shape::new(reshape);
1092
1093            let pad_value = op.identity(graph[input].dtype);
1094
1095            // add to graph
1096            let pad_value = graph.scalar_dyn(pad_value);
1097            let padded = graph.pad(input, &pad_amounts, pad_value);
1098            let sliced = slices.iter().enumerate().fold(padded, |a, (i, &s)| {
1099                if let Some(s) = s {
1100                    graph.slice(a, i, s)
1101                } else {
1102                    a
1103                }
1104            });
1105            let reshaped = graph.view(sliced, reshape);
1106            let result = graph.reduce(reshaped, pooled_dims, op);
1107
1108            OnnxValue::Value(result)
1109        }
1110        "GlobalMaxPool" | "GlobalAveragePool" => {
1111            let op = match node.op_type {
1112                "GlobalMaxPool" => ReduceOp::Max,
1113                "GlobalAveragePool" => ReduceOp::Mean,
1114                _ => unreachable!("missing {:?}", node.op_type),
1115            };
1116
1117            let input = inputs.required(0)?.unwrap_value().unwrap();
1118
1119            // pool the channel dimension
1120            let shape = &graph[input].shape;
1121            if shape.rank() < 2 {
1122                return Err(OnnxError::UnsupportedShape(node.to_owned(), shape.to_string()));
1123            }
1124
1125            let axes = (2..shape.rank()).collect_vec();
1126            let result = graph.reduce(input, axes, op);
1127
1128            OnnxValue::Value(result)
1129        }
1130        "Resize" => {
1131            // operands
1132            let input = inputs.required(0)?;
1133            let roi = inputs.optional(1);
1134            let scales = inputs.optional(2).map(|v| v.unwrap_value().unwrap());
1135            let sizes = inputs.optional(3);
1136
1137            let _antialias = attrs.maybe_take_bool("antialias")?.unwrap_or(false);
1138            let axes = attrs.maybe_take_ints("axes")?;
1139            let _coordinate_transformation_mode = attrs
1140                .maybe_take_string("coordinate_transformation_mode")?
1141                .unwrap_or("half_pixel");
1142            let _cubic_coeff_a = attrs.take_float("cubic_coeff_a")?;
1143            let _exclude_outside = attrs.maybe_take_int("exclude_outside")?.unwrap_or(0);
1144            let _extrapolation_value = attrs.maybe_take_float("extrapolation_value")?.unwrap_or(0.0);
1145            let keep_aspect_ratio_policy = attrs
1146                .maybe_take_string("keep_aspect_ratio_policy")?
1147                .unwrap_or("stretch")
1148                .to_owned();
1149            let mode = attrs.maybe_take_string("mode")?.unwrap_or("nearest").to_owned();
1150            let nearest_mode = attrs
1151                .maybe_take_string("nearest_mode")?
1152                .unwrap_or("round_prefer_floor")
1153                .to_owned();
1154
1155            // require exactly matching operands for most
1156            assert!(
1157                mode == "nearest"
1158                    && nearest_mode == "floor"
1159                    && roi.is_none()
1160                    && sizes.is_none()
1161                    && axes.is_none()
1162                    && keep_aspect_ratio_policy == "stretch",
1163                "The given resize operation is not supported"
1164            );
1165
1166            let scales = graph
1167                .as_const(scales.expect("Resize requires scales for now"))
1168                .expect("Resize only supported with constant scales");
1169
1170            let input_tensor = input.unwrap_value().unwrap();
1171            let input_shape = &graph[input_tensor].shape;
1172            let rank = input_shape.rank();
1173
1174            assert_eq!(
1175                scales.shape(),
1176                &[rank],
1177                "Scales must be a vector with length the input rank"
1178            );
1179
1180            let result = scales
1181                .unwrap_f32()
1182                .unwrap()
1183                .iter()
1184                .enumerate()
1185                .fold(input_tensor, |acc, (axis, &scale_f)| {
1186                    let scale = scale_f as usize;
1187                    assert_eq!(scale as f32, scale_f, "Only integer scales supported, got {:?}", scales);
1188                    graph.repeat_interleave(acc, axis, Size::fixed(scale))
1189                });
1190
1191            OnnxValue::Value(result)
1192        }
1193        _ => {
1194            return Err(OnnxError::UnsupportedOperation(node.to_owned()));
1195        }
1196    };
1197
1198    Ok(vec![result_single])
1199}
1200
1201fn calculate_auto_padding(graph: &Graph, conv_rank: usize, input: Value, filter: Value, strides: &[i64], dilations: &[i64], up: bool) -> OnnxResult<Vec<i64>> {
1202    let (_, input_spatial_dims) = graph[input].shape.split(2);
1203    let input_spatial_dims = input_spatial_dims.unwrap_fixed("conv input spatial dims");
1204
1205    let (_, filter_spatial_dims) = graph[filter].shape.split(2);
1206    let filter_spatial_dims = filter_spatial_dims.unwrap_fixed("conv filter spatial dims");
1207
1208    let mut result = vec![];
1209    for i in 0..conv_rank {
1210        let (low, high) = split_padding(input_spatial_dims.dims[i] as i64, filter_spatial_dims.dims[i] as i64, strides[i], dilations[i], up);
1211        result.push(low);
1212        result.push(high);
1213    }
1214    Ok(result)
1215}
1216
1217fn split_padding(i: i64, f: i64, s: i64, d: i64, up: bool) -> (i64, i64) {
1218    let total = (i - 1) * s + 1 + d * (f - 1) - i;
1219
1220    let min = total / 2;
1221    let max = total - min;
1222
1223    if up {
1224        (min, max)
1225    } else {
1226        (max, min)
1227    }
1228}
1229
1230fn define_tensor_data(
1231    graph: &mut Graph,
1232    name: &str,
1233    tensor: &TensorProto,
1234    external: &mut dyn ExternalDataLoader,
1235) -> OnnxResult<Value> {
1236    let data_location = DataLocation::try_from(tensor.data_location).expect("Illegal data_location");
1237
1238    // figure out the shape and type
1239    let dims = tensor.dims.iter().map(|&d| Size::fixed(d as usize)).collect_vec();
1240    let shape = Shape::new(dims);
1241    let size = shape.size().unwrap_fixed("Data tensor shape must be fixed");
1242
1243    let data_type = DataType::try_from(tensor.data_type).expect("Illegal data type");
1244    let dtype = resolve_dtype(data_type, name)?;
1245
1246    let length_guess = size * dtype.size().bytes();
1247
1248    // load the data
1249    let raw_data_slot;
1250    let raw_data = match data_location {
1251        DataLocation::Default => {
1252            // just use the built-in external data
1253            &tensor.raw_data
1254        }
1255        DataLocation::External => {
1256            // collect external data properties
1257            let mut location: Option<&str> = None;
1258            let mut offset: usize = 0;
1259            let mut length: Option<usize> = None;
1260
1261            for entry in &tensor.external_data {
1262                let key: &str = &entry.key;
1263                let value: &str = &entry.value;
1264
1265                match key {
1266                    "location" => location = Some(value),
1267                    "offset" => offset = value.parse().unwrap(),
1268                    "length" => length = Some(value.parse().unwrap()),
1269                    "hash" => {}
1270                    _ => panic!("Invalid external_data key: {} (value {})", key, value),
1271                }
1272            }
1273
1274            if let Some(length) = length {
1275                assert_eq!(length, length_guess, "External data length mismatch");
1276            }
1277
1278            // try loading from external source
1279            let location = location.expect("External data must have a location");
1280            raw_data_slot = external.load_external_data(&PathBuf::from(location), offset, length, length_guess)?;
1281
1282            if let Some(length) = length {
1283                assert_eq!(raw_data_slot.len(), length, "Raw data length mismatch");
1284            }
1285
1286            &raw_data_slot
1287        }
1288    };
1289
1290    macro_rules! read_type {
1291        (graph, $T:ty, $data:ident, None) => {{
1292            let data: Vec<$T> = if tensor.$data.is_empty() {
1293                raw_data.iter().map(|&x| x as $T).collect()
1294            } else {
1295                tensor.$data.iter().map(|&x| x as $T).collect()
1296            };
1297            graph.constant::<$T>(shape, data)
1298        }};
1299        (graph, $T:ty, $data:ident, $read:ident) => {{
1300            let data: Vec<$T> = if tensor.$data.is_empty() {
1301                let mut data = vec![Default::default(); size];
1302                LittleEndian::$read(raw_data, &mut data);
1303                data
1304            } else {
1305                tensor.$data.iter().map(|&x| x as $T).collect()
1306            };
1307            graph.constant::<$T>(shape, data)
1308        }};
1309    }
1310
1311    // careful, this stuff is pretty weirdly mapped, see the TensorProto docs
1312    let value = match dtype {
1313        DType::F32 => read_type!(graph, f32, float_data, read_f32_into),
1314        DType::F64 => read_type!(graph, f64, double_data, read_f64_into),
1315        DType::I8 => read_type!(graph, i8, int32_data, None),
1316        DType::I16 => read_type!(graph, i16, int32_data, read_i16_into),
1317        DType::I32 => read_type!(graph, i32, int32_data, read_i32_into),
1318        DType::I64 => read_type!(graph, i64, int64_data, read_i64_into),
1319        DType::U8 => read_type!(graph, u8, int32_data, None),
1320        DType::U16 => read_type!(graph, u16, int32_data, read_u16_into),
1321        DType::U32 => read_type!(graph, u32, uint64_data, read_u32_into),
1322        DType::U64 => read_type!(graph, u64, uint64_data, read_u64_into),
1323        DType::Bool => {
1324            let data: Vec<DBool> = if tensor.int32_data.is_empty() {
1325                raw_data.iter().map(|&x| DBool(x != 0)).collect_vec()
1326            } else {
1327                tensor.int32_data.iter().map(|&x| DBool(x != 0)).collect()
1328            };
1329            graph.constant::<DBool>(shape, data)
1330        }
1331    };
1332
1333    Ok(value)
1334}
1335
1336fn resolve_tensor_type(ty: &TypeProto, name: &str, index: usize, input_shaper: &InputShaper) -> OnnxResult<(Shape, DType)> {
1337    let value = ty.value.as_ref().expect("Value doesn't have type set");
1338    let result = match value {
1339        ProtoTypeValue::TensorType(tensor) => {
1340            let data_type = DataType::try_from(tensor.elem_type).expect("Invalid data type");
1341            let dtype = resolve_dtype(data_type, name)?;
1342
1343            let dims = tensor
1344                .shape
1345                .as_ref()
1346                .expect("Tensor does not have shape set")
1347                .dim
1348                .iter()
1349                .map(|d| match *d.value.as_ref().expect("Missing value for dimension") {
1350                    dimension::Value::DimValue(value) => OnnxDimValue::Value(value),
1351                    dimension::Value::DimParam(ref param) => OnnxDimValue::Param(param.clone()),
1352                })
1353                .collect_vec();
1354
1355            let shape = input_shaper(&dims, name, index).ok_or_else(|| OnnxError::FailedToShapeInput(dims, name.to_owned(), index))?;
1356
1357            (shape, dtype)
1358        }
1359        _ => panic!("Unsupported value kind {:?}", value),
1360    };
1361    Ok(result)
1362}
1363
1364fn resolve_dtype(data_type: DataType, node: &str) -> OnnxResult<DType> {
1365    let dtype = match data_type {
1366        DataType::Float => DType::F32,
1367        DataType::Double => DType::F64,
1368        DataType::Uint8 => DType::U8,
1369        DataType::Int8 => DType::I8,
1370        DataType::Uint16 => DType::U16,
1371        DataType::Int16 => DType::I16,
1372        DataType::Int32 => DType::I32,
1373        DataType::Int64 => DType::I64,
1374        DataType::Uint32 => DType::U32,
1375        DataType::Uint64 => DType::U64,
1376        DataType::Bool => DType::Bool,
1377        DataType::Undefined
1378        | DataType::String
1379        | DataType::Complex64
1380        | DataType::Complex128
1381        | DataType::Float16
1382        | DataType::Bfloat16
1383        | DataType::Float8e4m3fn
1384        | DataType::Float8e4m3fnuz
1385        | DataType::Float8e5m2
1386        | DataType::Float8e5m2fnuz => return Err(OnnxError::UnsupportedType(node.to_owned(), data_type)),
1387    };
1388    Ok(dtype)
1389}
1390
1391fn abs_axis(axis: i64, rank: usize) -> usize {
1392    if axis == i64::MAX {
1393        rank
1394    } else if axis < 0 {
1395        rank - ((-axis) as usize)
1396    } else {
1397        axis as usize
1398    }
1399}
1400
1401#[track_caller]
1402fn unwrap_1(slice: &[i64]) -> usize {
1403    assert_eq!(slice.len(), 1, "Expected 1 element, got {:?}", slice);
1404    slice[0] as usize
1405}
1406
1407#[track_caller]
1408fn unwrap_2(slice: &[i64]) -> [usize; 2] {
1409    assert_eq!(slice.len(), 2, "Expected 2 elements, got {:?}", slice);
1410    [slice[0] as usize, slice[1] as usize]
1411}
1412
1413#[track_caller]
1414fn unwrap_4(slice: &[i64]) -> [usize; 4] {
1415    assert_eq!(slice.len(), 4, "Expected 4 elements, got {:?}", slice);
1416    [
1417        slice[0] as usize,
1418        slice[1] as usize,
1419        slice[2] as usize,
1420        slice[3] as usize,
1421    ]
1422}
1423
1424fn calculate_reshape_output_shape(old_shape: &Shape, new_shape_raw: &[SignedSize], allow_zero: bool) -> Shape {
1425    let old_size = old_shape.size();
1426
1427    let mut new_shape = vec![];
1428    let mut leftover_index = None;
1429    let mut leftover_size = old_size;
1430
1431    for (i, &signed_size) in new_shape_raw.iter().enumerate() {
1432        let size = match signed_size {
1433            SignedSize::ZERO => {
1434                if allow_zero {
1435                    Size::ZERO
1436                } else {
1437                    assert!(
1438                        i < old_shape.rank(),
1439                        "Cannot copy dim {} of output shape {:?}, not present in input {}",
1440                        i,
1441                        new_shape,
1442                        old_shape,
1443                    );
1444                    old_shape[i]
1445                }
1446            }
1447            SignedSize::NEG_ONE => {
1448                assert!(
1449                    leftover_index.is_none(),
1450                    "Reshape shape can only contain a single -1 value"
1451                );
1452                leftover_index = Some(i);
1453                new_shape.push(Size::ZERO);
1454                continue;
1455            }
1456            signed_size => signed_size
1457                .to_size()
1458                .unwrap_or_else(|_| panic!("Reshape size must be positive, 0 or -1, got {:?}", signed_size)),
1459        };
1460
1461        leftover_size = (leftover_size / size).unwrap_or_else(|| {
1462            panic!("Cannot reshape {} into {:?}", old_size, new_shape_raw);
1463        });
1464        new_shape.push(size);
1465    }
1466
1467    if let Some(leftover_index) = leftover_index {
1468        new_shape[leftover_index] = leftover_size;
1469    }
1470
1471    let shape = Shape::new(new_shape);
1472    assert_eq!(old_size, shape.size(), "Output and input sizes differ");
1473
1474    shape
1475}
1476
1477fn eval_binary_op(op: BinaryOp, a: SignedSize, b: SignedSize) -> Option<SignedSize> {
1478    // TODO min/max?
1479    match op {
1480        BinaryOp::Add => a + b,
1481        BinaryOp::Sub => a - b,
1482        BinaryOp::Mul => Some(a * b),
1483        BinaryOp::Div => a.floor_div(b),
1484        _ => None,
1485    }
1486}
1487
1488fn load_model_proto(buf: &[u8]) -> ModelProto {
1489    let mut buf: &[u8] = buf;
1490    ModelProto::decode(&mut buf).unwrap()
1491}