stonnx_api/utils/
mod.rs

1use anyhow::anyhow;
2use num::traits::AsPrimitive;
3use std::io::Read;
4use std::os::raw::c_uchar;
5use std::path::PathBuf;
6use std::sync::Arc;
7use std::{collections::HashMap, io, path::Path};
8
9use ndarray::{ArrayD, IxDyn};
10use protobuf::Enum;
11
12use crate::common::FileInputs;
13use crate::onnx::tensor_proto::DataType;
14use crate::onnx::{NodeProto, TensorProto, ValueInfoProto};
15use crate::onnxparser::onnx;
16use crate::{common::*, print_at_level};
17use half::{bf16, f16};
18
19/// Calculates the product of the elements of an iterator, returning 1 if the iterator is empty.
20pub fn shape_safe_product<
21    'a,
22    B: 'a + std::iter::Product<&'a B> + std::default::Default + Copy + 'static,
23    A: IntoIterator<Item = &'a B>,
24>(
25    shape: A,
26) -> B
27where
28    usize: AsPrimitive<B>,
29{
30    let mut piter = shape.into_iter().peekable();
31    if piter.peek().is_none() {
32        1_usize.as_()
33    } else {
34        piter.product()
35    }
36}
37
38/// Writes an ndarray to a file in the npy format, only if the verbosity level is set to Intermediate or above.
39pub fn log_array_to_file<A: ndarray_npy::WritableElement, D: ndarray::Dimension>(
40    operation: &str,
41    name: &str,
42    a: &ndarray::ArrayBase<ndarray::ViewRepr<&A>, D>,
43) -> BoxResult<()> {
44    let verbose_flag = VERBOSE.load(std::sync::atomic::Ordering::Relaxed);
45    if verbose_flag == VerbosityLevel::Intermediate as usize {
46        static mut COUNTER: usize = 0;
47        unsafe {
48            ndarray_npy::write_npy(
49                format!(
50                    "{}_intermediate_outputs/{}_{}.npy",
51                    operation, COUNTER, name
52                ),
53                a,
54            )?;
55            COUNTER += 1;
56        }
57    }
58    Ok(())
59}
60
61#[macro_export]
62/// Logs an ndarray to a file in the npy format, only if the verbosity level is set to Intermediate or above.
63macro_rules! named_array_to_file {
64    ($op:ident, $name:ident) => {{
65        let $name = $name.view();
66        $crate::utils::log_array_to_file(stringify!($op), stringify!($name), &$name).unwrap();
67    }};
68    ($op:ident, $var:ident, $name:expr) => {{
69        let $var = $var.view();
70        $crate::utils::log_array_to_file(stringify!($op), &$name, &$var).unwrap();
71    }};
72}
73
74#[macro_export]
75/// Creates a directory for intermediate outputs, only if the verbosity level is set to Intermediate or above.
76macro_rules! create_intermediate_output_dir_for {
77    ($name:ident) => {{
78        use $crate::common::VerbosityLevel;
79        let verbose_flag = VERBOSE.load(std::sync::atomic::Ordering::Relaxed);
80        if verbose_flag == VerbosityLevel::Intermediate {
81            match std::fs::create_dir(concat!(stringify!($name), "_intermediate_outputs")) {
82                Ok(_) => {}
83                Err(e) => {
84                    if e.kind() != std::io::ErrorKind::AlreadyExists {
85                        return Err(anyhow!("Error creating rust_conv_outputs directory: {}", e));
86                    }
87                }
88            }
89        }
90    }};
91}
92
93#[derive(Debug, Clone)]
94/// Represents an ONNX ValueInfo stripped down to the bare minimum.
95pub struct ValueInfo {
96    pub name: String,
97    pub type_: (ValueType, Vec<i64>),
98    pub doc_string: String,
99}
100
101#[derive(Debug, Clone)]
102/// Represents an output to the graph, with the ValueInfo and the data.
103pub struct OutputInfo {
104    pub valueinfo: ValueInfo,
105    pub data: Option<TensorType>,
106}
107
108impl OutputInfo {
109    fn new(valueinfo: ValueInfo) -> Self {
110        Self {
111            valueinfo,
112            data: None,
113        }
114    }
115}
116
117impl ValueInfo {
118    /// Creates a new ValueInfo from an ONNX ValueInfoProto.
119    fn from_proto(proto: &ValueInfoProto) -> BoxResult<Self> {
120        if let Some(onnx::type_proto::Value::TensorType(tensor)) = &proto.type_.value {
121            let dt = onnx::tensor_proto::DataType::from_i32(tensor.elem_type.unwrap_or_default())
122                .unwrap_or_default();
123            Ok(Self {
124                name: proto
125                    .name
126                    .as_ref()
127                    .map_or_else(|| UNKNOWN.to_owned(), |v| v.clone()),
128                type_: (
129                    ValueType::new(dt)?,
130                    tensor.shape.dim.iter().map(|v| v.dim_value()).collect(),
131                ),
132                doc_string: proto
133                    .doc_string
134                    .as_ref()
135                    .map_or_else(|| UNKNOWN.to_owned(), |v| v.clone()),
136            })
137        } else {
138            todo!("ValueInfoProto type not supported: {:?}", proto.type_)
139        }
140    }
141}
142
143// FIXME: data in tensor may be external. Need to handle that.
144/// Creates a tensor from ONNX's TensorProto.
145pub fn make_tensor_from_proto(proto: &TensorProto) -> BoxResult<TensorType> {
146    let shape = &proto.dims;
147    if proto.data_location() != onnx::tensor_proto::DataLocation::DEFAULT {
148        return Err(anyhow!("External data location not supported"));
149    }
150    make_tensor(shape, proto, proto.data_type())
151}
152
153/// Gets the raw data from an ONNX TensorProto, returning a slice of bytes and the size of each element.
154fn get_raw_data(proto: &TensorProto) -> BoxResult<(&[u8], usize)> {
155    if let Some(ref raw_data) = proto.raw_data {
156        Ok((raw_data.as_slice(), 1))
157    } else if !proto.int32_data.is_empty() {
158        Ok((
159            bytemuck::try_cast_slice(proto.int32_data.as_slice()).map_err(|e| anyhow!(e))?,
160            4,
161        ))
162    } else if !proto.int64_data.is_empty() {
163        Ok((
164            bytemuck::try_cast_slice(proto.int64_data.as_slice()).map_err(|e| anyhow!(e))?,
165            8,
166        ))
167    } else if !proto.float_data.is_empty() {
168        Ok((
169            bytemuck::try_cast_slice(proto.float_data.as_slice()).map_err(|e| anyhow!(e))?,
170            4,
171        ))
172    } else if !proto.double_data.is_empty() {
173        Ok((
174            bytemuck::try_cast_slice(proto.double_data.as_slice()).map_err(|e| anyhow!(e))?,
175            8,
176        ))
177    } else if !proto.uint64_data.is_empty() {
178        Ok((
179            bytemuck::try_cast_slice(proto.uint64_data.as_slice()).map_err(|e| anyhow!(e))?,
180            8,
181        ))
182    } else {
183        Ok((&[], 0))
184    }
185}
186
187/// Creates a tensor from ONNX's TensorProto, given the shape and the data type.
188pub fn make_tensor(shape: &[i64], proto: &TensorProto, data_type: i32) -> BoxResult<TensorType> {
189    let enum_dt = DataType::from_i32(data_type).unwrap_or_default();
190    let shape = shape.iter().map(|v| *v as usize).collect::<Vec<usize>>();
191    let (bytedata, origin_elem_size) = get_raw_data(proto)?;
192    match enum_dt {
193        DataType::UNDEFINED => Err(anyhow!("Undefined data type")),
194        DataType::INT8 => match bytemuck::try_cast_slice::<u8, i8>(bytedata) {
195            Ok(data) => {
196                assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
197                let a = if origin_elem_size == std::mem::size_of::<i8>() {
198                    ArrayD::<i8>::from_shape_vec(IxDyn(&shape), data.to_vec())?
199                } else {
200                    ArrayD::<i8>::from_shape_vec(
201                        IxDyn(&shape),
202                        data.iter().step_by(origin_elem_size).copied().collect(),
203                    )?
204                };
205                Ok(TensorType::I8(a))
206            }
207            Err(e) => Err(anyhow!(e)),
208        },
209        DataType::INT16 => match bytemuck::try_cast_slice::<u8, i16>(bytedata) {
210            Ok(data) => {
211                assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
212                let a = ArrayD::<i16>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
213                Ok(TensorType::I16(a))
214            }
215            Err(e) => Err(anyhow!(e)),
216        },
217        DataType::INT32 => {
218            let data = if let Some(data) = &proto.raw_data {
219                if data.is_empty() {
220                    &[]
221                } else {
222                    match bytemuck::try_cast_slice::<u8, i32>(data) {
223                        Ok(data) => data,
224                        Err(e) => return Err(anyhow!(e)),
225                    }
226                }
227            } else {
228                proto.int32_data.as_slice()
229            };
230            let dlen = data.len();
231            let slen = if !shape.is_empty() {
232                shape_safe_product(&shape)
233            } else {
234                0
235            };
236            // if dlen != slen, check if data is 1 long and shape is [], then it is a scalar and it's fine
237            // panic otherwise
238            if dlen != slen && (slen == 0 && dlen != 1) {
239                return Err(anyhow!(
240                    "Data length {} does not match shape length {}",
241                    dlen,
242                    slen
243                ));
244            }
245            let a = if data.is_empty() {
246                ArrayD::<i32>::zeros(IxDyn(&shape))
247            } else {
248                ArrayD::<i32>::from_shape_vec(IxDyn(&shape), data.to_vec())?
249            };
250            Ok(TensorType::I32(a))
251        }
252        DataType::INT64 => {
253            let data = if let Some(data) = &proto.raw_data {
254                if data.is_empty() {
255                    &[]
256                } else {
257                    match bytemuck::try_cast_slice::<u8, i64>(data) {
258                        Ok(data) => data,
259                        Err(e) => return Err(anyhow!(e)),
260                    }
261                }
262            } else {
263                proto.int64_data.as_slice()
264            };
265            let dlen = data.len();
266            let slen = if !shape.is_empty() {
267                shape_safe_product(&shape)
268            } else {
269                0
270            };
271            // if dlen != slen, check if data is 1 long and shape is [], then it is a scalar and it's fine
272            // panic otherwise
273            if dlen != slen && (slen == 0 && dlen != 1) {
274                return Err(anyhow!(
275                    "Data length {} does not match shape length {}",
276                    dlen,
277                    slen
278                ));
279            }
280            let a = if data.is_empty() {
281                ArrayD::<i64>::zeros(IxDyn(&shape))
282            } else {
283                ArrayD::<i64>::from_shape_vec(IxDyn(&shape), data.to_vec())?
284            };
285            Ok(TensorType::I64(a))
286        }
287        DataType::UINT8 => match bytemuck::try_cast_slice::<u8, u8>(bytedata) {
288            Ok(data) => {
289                assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
290                let a = if origin_elem_size == std::mem::size_of::<u8>() {
291                    ArrayD::<u8>::from_shape_vec(IxDyn(&shape), data.to_vec())?
292                } else {
293                    ArrayD::<u8>::from_shape_vec(
294                        IxDyn(&shape),
295                        data.iter().step_by(origin_elem_size).copied().collect(),
296                    )?
297                };
298                Ok(TensorType::U8(a))
299            }
300            Err(e) => Err(anyhow!(e)),
301        },
302        DataType::UINT16 => match bytemuck::try_cast_slice::<u8, u16>(bytedata) {
303            Ok(data) => {
304                assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
305                let a = ArrayD::<u16>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
306                Ok(TensorType::U16(a))
307            }
308            Err(e) => Err(anyhow!(e)),
309        },
310        DataType::UINT32 => match bytemuck::try_cast_slice::<u8, u32>(bytedata) {
311            Ok(data) => {
312                assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
313                let a = ArrayD::<u32>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
314                Ok(TensorType::U32(a))
315            }
316            Err(e) => Err(anyhow!(e)),
317        },
318        DataType::UINT64 => {
319            let data = if let Some(data) = &proto.raw_data {
320                if data.is_empty() {
321                    &[]
322                } else {
323                    match bytemuck::try_cast_slice::<u8, u64>(data) {
324                        Ok(data) => data,
325                        Err(e) => return Err(anyhow!(e)),
326                    }
327                }
328            } else {
329                proto.uint64_data.as_slice()
330            };
331            let dlen = data.len();
332            let slen = if !shape.is_empty() {
333                shape_safe_product(&shape)
334            } else {
335                0
336            };
337            // if dlen != slen, check if data is 1 long and shape is [], then it is a scalar and it's fine
338            // panic otherwise
339            if dlen != slen && (slen == 0 && dlen != 1) {
340                return Err(anyhow!(
341                    "Data length {} does not match shape length {}",
342                    dlen,
343                    slen
344                ));
345            }
346            let a = if data.is_empty() {
347                ArrayD::<u64>::zeros(IxDyn(&shape))
348            } else {
349                ArrayD::<u64>::from_shape_vec(IxDyn(&shape), data.to_vec())?
350            };
351            Ok(TensorType::U64(a))
352        }
353        DataType::FLOAT16 => match bytemuck::try_cast_slice::<u8, u16>(bytedata) {
354            Ok(data) => {
355                assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
356                let a = ArrayD::<f16>::from_shape_vec(
357                    IxDyn(&shape),
358                    data.iter().map(|x| f16::from_bits(*x)).collect(),
359                )?;
360                Ok(TensorType::F16(a))
361            }
362            Err(e) => Err(anyhow!(e)),
363        },
364        DataType::BFLOAT16 => match bytemuck::try_cast_slice::<u8, f32>(bytedata) {
365            Ok(data) => {
366                assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
367                let a = ArrayD::<bf16>::from_shape_vec(
368                    IxDyn(&shape),
369                    data.iter().map(|x| bf16::from_f32(*x)).collect(),
370                )?;
371                Ok(TensorType::BF16(a))
372            }
373            Err(e) => Err(anyhow!(e)),
374        },
375        DataType::DOUBLE => {
376            let data = if let Some(data) = &proto.raw_data {
377                if data.is_empty() {
378                    &[]
379                } else {
380                    match bytemuck::try_cast_slice::<u8, f64>(data) {
381                        Ok(data) => data,
382                        Err(e) => return Err(anyhow!(e)),
383                    }
384                }
385            } else {
386                proto.double_data.as_slice()
387            };
388            let dlen = data.len();
389            let slen = if !shape.is_empty() {
390                shape_safe_product(&shape)
391            } else {
392                0
393            };
394            // if dlen != slen, check if data is 1 long and shape is [], then it is a scalar and it's fine
395            // panic otherwise
396            if dlen != slen && (slen == 0 && dlen != 1) {
397                return Err(anyhow!(
398                    "Data length {} does not match shape length {}",
399                    dlen,
400                    slen
401                ));
402            }
403            let a = if data.is_empty() {
404                ArrayD::<f64>::zeros(IxDyn(&shape))
405            } else {
406                ArrayD::<f64>::from_shape_vec(IxDyn(&shape), data.to_vec())?
407            };
408            Ok(TensorType::F64(a))
409        }
410        DataType::STRING => {
411            let bytedata = &proto.string_data;
412            let a = ArrayD::<String>::from_shape_vec(
413                IxDyn(&shape),
414                bytedata
415                    .iter()
416                    .map(|v| String::from_utf8_lossy(v.as_ref()).to_string())
417                    .collect(),
418            )?;
419            Ok(TensorType::Str(a))
420        }
421        DataType::BOOL => match bytemuck::try_cast_slice::<u8, c_uchar>(bytedata) {
422            Ok(data) => {
423                assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
424                let a = ArrayD::<bool>::from_shape_vec(
425                    IxDyn(&shape),
426                    data.iter().map(|x| *x != 0).collect(),
427                )?;
428                Ok(TensorType::Bool(a))
429            }
430            Err(e) => Err(anyhow!(e)),
431        },
432        DataType::FLOAT8E4M3FN
433        | DataType::FLOAT8E4M3FNUZ
434        | DataType::FLOAT8E5M2FNUZ
435        | DataType::FLOAT8E5M2 => {
436            todo!("Data type {:?} not supported", enum_dt);
437        }
438        DataType::FLOAT => {
439            let data = if let Some(data) = &proto.raw_data {
440                if data.is_empty() {
441                    &[]
442                } else {
443                    match bytemuck::try_cast_slice::<u8, f32>(data) {
444                        Ok(data) => data,
445                        Err(e) => return Err(anyhow!(e)),
446                    }
447                }
448            } else {
449                proto.float_data.as_slice()
450            };
451            let dlen = data.len();
452            let slen = if !shape.is_empty() {
453                shape_safe_product(&shape)
454            } else {
455                0
456            };
457            // if dlen != slen, check if data is 1 long and shape is [], then it is a scalar and it's fine
458            // panic otherwise
459            if dlen != slen && (slen == 0 && dlen != 1) {
460                return Err(anyhow!(
461                    "Data length {} does not match shape length {}",
462                    dlen,
463                    slen
464                ));
465            }
466            let a = if data.is_empty() {
467                ArrayD::<f32>::zeros(IxDyn(&shape))
468            } else {
469                ArrayD::<f32>::from_shape_vec(IxDyn(&shape), data.to_vec())?
470            };
471            Ok(TensorType::F32(a))
472        }
473        DataType::COMPLEX64 => match bytemuck::try_cast_slice::<u8, Complex64Repr>(bytedata) {
474            Ok(data) => {
475                assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
476                let a = ArrayD::<Complex64>::from_shape_vec(
477                    IxDyn(&shape),
478                    data.iter()
479                        .map(|v| Complex64::new(v._val[0], v._val[1]))
480                        .collect(),
481                )?;
482                Ok(TensorType::C64(a))
483            }
484            Err(e) => Err(anyhow!(e)),
485        },
486        DataType::COMPLEX128 => match bytemuck::try_cast_slice::<u8, Complex128Repr>(bytedata) {
487            Ok(data) => {
488                assert_eq!(data.len() / origin_elem_size, shape_safe_product(&shape));
489                let a = ArrayD::<Complex128>::from_shape_vec(
490                    IxDyn(&shape),
491                    data.iter()
492                        .map(|v| Complex128::new(v._val[0], v._val[1]))
493                        .collect(),
494                )?;
495                Ok(TensorType::C128(a))
496            }
497            Err(e) => Err(anyhow!(e)),
498        },
499    }
500}
501
502/// Creates a tensor from the given the shape, byte slice and the data type.
503pub fn make_tensor_from_raw(
504    shape: &[i64],
505    bytedata: &[u8],
506    data_type: i32,
507) -> BoxResult<TensorType> {
508    let enum_dt = DataType::from_i32(data_type).unwrap_or_default();
509    let shape = shape.iter().map(|v| *v as usize).collect::<Vec<usize>>();
510    match enum_dt {
511        DataType::UNDEFINED => Err(anyhow!("Undefined data type")),
512        DataType::INT8 => match bytemuck::try_cast_slice::<u8, i8>(bytedata) {
513            Ok(data) => {
514                assert_eq!(data.len(), shape_safe_product(&shape));
515                let a = ArrayD::<i8>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
516                Ok(TensorType::I8(a))
517            }
518            Err(e) => Err(anyhow!(e)),
519        },
520        DataType::INT16 => match bytemuck::try_cast_slice::<u8, i16>(bytedata) {
521            Ok(data) => {
522                assert_eq!(data.len(), shape_safe_product(&shape));
523                let a = ArrayD::<i16>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
524                Ok(TensorType::I16(a))
525            }
526            Err(e) => Err(anyhow!(e)),
527        },
528        DataType::INT32 => match bytemuck::try_cast_slice::<u8, i32>(bytedata) {
529            Ok(data) => {
530                assert_eq!(data.len(), shape_safe_product(&shape));
531                let a = ArrayD::<i32>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
532                Ok(TensorType::I32(a))
533            }
534            Err(e) => Err(anyhow!(e)),
535        },
536        DataType::INT64 => match bytemuck::try_cast_slice::<u8, i64>(bytedata) {
537            Ok(data) => {
538                let dlen = data.len();
539                let slen = if !shape.is_empty() {
540                    shape_safe_product(&shape)
541                } else {
542                    0
543                };
544                // if dlen != slen, check if data is 1 long and shape is [], then it is a scalar and it's fine
545                // panic otherwise
546                if dlen != slen && (slen == 0 && dlen != 1) {
547                    return Err(anyhow!(
548                        "Data length {} does not match shape length {}",
549                        dlen,
550                        slen
551                    ));
552                }
553                let a = if data.is_empty() {
554                    ArrayD::<i64>::zeros(IxDyn(&shape))
555                } else {
556                    ArrayD::<i64>::from_shape_vec(IxDyn(&shape), data.to_vec())?
557                };
558                Ok(TensorType::I64(a))
559            }
560            Err(e) => Err(anyhow!(e)),
561        },
562        DataType::UINT8 => match bytemuck::try_cast_slice::<u8, u8>(bytedata) {
563            Ok(data) => {
564                assert_eq!(data.len(), shape_safe_product(&shape));
565                let a = ArrayD::<u8>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
566                Ok(TensorType::U8(a))
567            }
568            Err(e) => Err(anyhow!(e)),
569        },
570        DataType::UINT16 => match bytemuck::try_cast_slice::<u8, u16>(bytedata) {
571            Ok(data) => {
572                assert_eq!(data.len(), shape_safe_product(&shape));
573                let a = ArrayD::<u16>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
574                Ok(TensorType::U16(a))
575            }
576            Err(e) => Err(anyhow!(e)),
577        },
578        DataType::UINT32 => match bytemuck::try_cast_slice::<u8, u32>(bytedata) {
579            Ok(data) => {
580                assert_eq!(data.len(), shape_safe_product(&shape));
581                let a = ArrayD::<u32>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
582                Ok(TensorType::U32(a))
583            }
584            Err(e) => Err(anyhow!(e)),
585        },
586        DataType::UINT64 => match bytemuck::try_cast_slice::<u8, u64>(bytedata) {
587            Ok(data) => {
588                assert_eq!(data.len(), shape_safe_product(&shape));
589                let a = ArrayD::<u64>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
590                Ok(TensorType::U64(a))
591            }
592            Err(e) => Err(anyhow!(e)),
593        },
594        DataType::FLOAT16 => match bytemuck::try_cast_slice::<u8, u16>(bytedata) {
595            Ok(data) => {
596                assert_eq!(data.len(), shape_safe_product(&shape));
597                let a = ArrayD::<f16>::from_shape_vec(
598                    IxDyn(&shape),
599                    data.iter().map(|x| f16::from_bits(*x)).collect(),
600                )?;
601                Ok(TensorType::F16(a))
602            }
603            Err(e) => Err(anyhow!(e)),
604        },
605        DataType::BFLOAT16 => match bytemuck::try_cast_slice::<u8, f32>(bytedata) {
606            Ok(data) => {
607                assert_eq!(data.len(), shape_safe_product(&shape));
608                let a = ArrayD::<bf16>::from_shape_vec(
609                    IxDyn(&shape),
610                    data.iter().map(|x| bf16::from_f32(*x)).collect(),
611                )?;
612                Ok(TensorType::BF16(a))
613            }
614            Err(e) => Err(anyhow!(e)),
615        },
616        DataType::DOUBLE => match bytemuck::try_cast_slice::<u8, f64>(bytedata) {
617            Ok(data) => {
618                assert_eq!(data.len(), shape_safe_product(&shape));
619                let a = ArrayD::<f64>::from_shape_vec(IxDyn(&shape), data.to_vec())?;
620                Ok(TensorType::F64(a))
621            }
622            Err(e) => Err(anyhow!(e)),
623        },
624        DataType::STRING => Err(anyhow!(
625            "String data type not supported, use make_string_tensor()"
626        )),
627        DataType::BOOL => match bytemuck::try_cast_slice::<u8, c_uchar>(bytedata) {
628            Ok(data) => {
629                assert_eq!(data.len(), shape_safe_product(&shape));
630                let a = ArrayD::<bool>::from_shape_vec(
631                    IxDyn(&shape),
632                    data.iter().map(|x| *x != 0).collect(),
633                )?;
634                Ok(TensorType::Bool(a))
635            }
636            Err(e) => Err(anyhow!(e)),
637        },
638        DataType::FLOAT
639        | DataType::FLOAT8E4M3FN
640        | DataType::FLOAT8E4M3FNUZ
641        | DataType::FLOAT8E5M2FNUZ
642        | DataType::FLOAT8E5M2 => match bytemuck::try_cast_slice::<u8, f32>(bytedata) {
643            Ok(data) => {
644                let dlen = data.len();
645                let slen = if !shape.is_empty() {
646                    shape_safe_product(&shape)
647                } else {
648                    0
649                };
650                // if dlen != slen, check if data is 1 long and shape is [], then it is a scalar and it's fine
651                // panic otherwise
652                if dlen != slen && (slen == 0 && dlen != 1) {
653                    return Err(anyhow!(
654                        "Data length {} does not match shape length {}",
655                        dlen,
656                        slen
657                    ));
658                }
659                let a = if data.is_empty() {
660                    ArrayD::<f32>::zeros(IxDyn(&shape))
661                } else {
662                    ArrayD::<f32>::from_shape_vec(IxDyn(&shape), data.to_vec())?
663                };
664                Ok(TensorType::F32(a))
665            }
666            Err(e) => {
667                eprintln!("Copying data of tensor as f32 because {}", e);
668                let mut copied_data = vec![];
669                for float_slice in bytedata.chunks_exact(std::mem::size_of::<f32>()) {
670                    copied_data.push(f32::from_le_bytes(float_slice.try_into()?));
671                }
672                let a = ArrayD::<f32>::from_shape_vec(IxDyn(&shape), copied_data)?;
673                Ok(TensorType::F32(a))
674            }
675        },
676        DataType::COMPLEX64 => match bytemuck::try_cast_slice::<u8, Complex64Repr>(bytedata) {
677            Ok(data) => {
678                assert_eq!(data.len(), shape_safe_product(&shape));
679                let a = ArrayD::<Complex64>::from_shape_vec(
680                    IxDyn(&shape),
681                    data.iter()
682                        .map(|v| Complex64::new(v._val[0], v._val[1]))
683                        .collect(),
684                )?;
685                Ok(TensorType::C64(a))
686            }
687            Err(e) => Err(anyhow!(e)),
688        },
689        DataType::COMPLEX128 => match bytemuck::try_cast_slice::<u8, Complex128Repr>(bytedata) {
690            Ok(data) => {
691                assert_eq!(data.len(), shape_safe_product(&shape));
692                let a = ArrayD::<Complex128>::from_shape_vec(
693                    IxDyn(&shape),
694                    data.iter()
695                        .map(|v| Complex128::new(v._val[0], v._val[1]))
696                        .collect(),
697                )?;
698                Ok(TensorType::C128(a))
699            }
700            Err(e) => Err(anyhow!(e.to_string())),
701        },
702    }
703}
704
705/// Creates the graph initializers from the ONNX graph.
706pub fn make_initializers(graph: &onnx::GraphProto) -> BoxResult<HashMap<String, TensorType>> {
707    let mut initializers: HashMap<String, TensorType> = HashMap::new();
708    for tensor in graph.initializer.iter() {
709        let tensor_name = tensor.name.as_ref().map_or(UNKNOWN, |v| v.as_str());
710        if !tensor.has_data_type() {
711            eprintln!("  Tensor: {} has no data type", tensor_name);
712        } else {
713            initializers.insert(tensor_name.to_string(), make_tensor_from_proto(tensor)?);
714        }
715    }
716    Ok(initializers)
717}
718
719/// Creates the graph inputs from the ONNX graph reading from external files.
720fn make_input_tensors_from_files(
721    graph: &onnx::GraphProto,
722    files: &[PathBuf],
723    mut initializers: HashMap<String, TensorType>,
724) -> BoxResult<HashMap<String, Arc<TensorType>>> {
725    let mut map = HashMap::new();
726    let mut external_inputs_map = HashMap::new();
727    for input in files.iter() {
728        let input_tensor = read_tensor(input)?;
729        external_inputs_map.insert(
730            input_tensor
731                .name
732                .as_ref()
733                .map_or_else(|| UNKNOWN.to_owned(), |v| v.clone()),
734            input_tensor,
735        );
736    }
737    for input in graph.input.iter() {
738        let input_name = input.name.as_ref().map_or(UNKNOWN, |v| v.as_str());
739        if let Some(input_from_file) = external_inputs_map.get(input_name) {
740            let tensor = make_tensor_from_proto(input_from_file)?;
741            print_at_level!(
742                VerbosityLevel::Informational,
743                "  Input {} from file has shape {:?} and type {:?}",
744                input_name,
745                tensor.shape(),
746                tensor.value_type()
747            );
748            map.insert(input_name.to_string(), Arc::new(tensor));
749        } else if let Some((_, init)) = initializers.remove_entry(input_name) {
750            print_at_level!(
751                VerbosityLevel::Informational,
752                "  Input {} from initializer has shape {:?} and type {:?}",
753                input_name,
754                init.shape(),
755                init.value_type()
756            );
757            map.insert(input_name.to_string(), Arc::new(init));
758        } else {
759            return Err(anyhow!(
760                "Input {} not found in inputs file or graph initializers",
761                input_name
762            ));
763        }
764    }
765    for (k, v) in initializers {
766        map.insert(k, Arc::new(v));
767    }
768    Ok(map)
769}
770
771/// Reads the expected outputs from external files.
772fn make_output_tensors_from_files(
773    graph: &onnx::GraphProto,
774    files: &[PathBuf],
775) -> BoxResult<HashMap<String, TensorType>> {
776    let mut map = HashMap::new();
777    let mut external_outputs_map = HashMap::new();
778    for output in files.iter() {
779        let ouput_tensor = read_tensor(output)?;
780        external_outputs_map.insert(
781            ouput_tensor
782                .name
783                .as_ref()
784                .map_or_else(|| UNKNOWN.to_owned(), |v| v.clone()),
785            ouput_tensor,
786        );
787    }
788    for output in graph.output.iter() {
789        let output_name = output.name.as_ref().map_or(UNKNOWN, |v| v.as_str());
790        if let Some(output_from_file) = external_outputs_map.get(output_name) {
791            map.insert(
792                output_name.to_string(),
793                make_tensor_from_proto(output_from_file)?,
794            );
795        } else {
796            return Err(anyhow!("Output {} not found in inputs file", output_name));
797        }
798    }
799    Ok(map)
800}
801
802/// Initializes the graph inputs from the ONNX graph with the input tensors from external files.
803pub fn initialize_nodes(
804    graph: &onnx::GraphProto,
805    fileinputs: &FileInputs,
806    initializers: HashMap<String, TensorType>,
807) -> BoxResult<HashMap<String, Arc<TensorType>>> {
808    if fileinputs.inputs.is_empty() {
809        return Ok(HashMap::new());
810    }
811    make_input_tensors_from_files(graph, &fileinputs.inputs, initializers)
812}
813
814/// Creates the *expected* graph outputs from the ONNX graph reading from external files.
815pub fn make_external_outputs(
816    graph: &onnx::GraphProto,
817    fileinputs: &FileInputs,
818) -> BoxResult<HashMap<String, TensorType>> {
819    if fileinputs.outputs.is_empty() {
820        return Ok(HashMap::new());
821    }
822    make_output_tensors_from_files(graph, &fileinputs.outputs)
823}
824
825/// Creates the graph outputs from the ONNX graph, without the data.
826pub fn make_graph_outputs(graph: &onnx::GraphProto) -> BoxResult<HashMap<String, OutputInfo>> {
827    let mut map = HashMap::new();
828    for output in graph.output.iter() {
829        let output_name = output.name.as_ref().map_or(UNKNOWN, |v| v.as_str());
830        map.insert(
831            output_name.to_string(),
832            OutputInfo::new(ValueInfo::from_proto(output)?),
833        );
834    }
835    Ok(map)
836}
837
838/// Reads an ONNX model in text format
839fn read_model_text(p: &Path) -> BoxResult<onnx::ModelProto> {
840    let file = std::fs::File::open(p)?;
841    let mut reader = io::BufReader::new(file);
842    let mut buf = String::new();
843    reader.read_to_string(&mut buf)?;
844    let model = protobuf::text_format::parse_from_str(&buf)?;
845    Ok(model)
846}
847
848/// Reads an ONNX model in binary format
849fn read_model_binary(p: &Path) -> BoxResult<onnx::ModelProto> {
850    let file = std::fs::File::open(p)?;
851    let mut reader = io::BufReader::new(file);
852    let model: onnx::ModelProto = protobuf::Message::parse_from_reader(&mut reader)?;
853    Ok(model)
854}
855
856/// Attempts to read an ONNX model in binary format, and if it fails, tries to read it in text format.
857pub fn read_model(p: &Path) -> BoxResult<onnx::ModelProto> {
858    print_at_level!(VerbosityLevel::Minimal, "Reading model from {}", p.display());
859    let merr = read_model_binary(p);
860    match merr {
861        Ok(m) => Ok(m),
862        Err(e) => {
863            eprintln!("Error reading binary model: {}", e);
864            read_model_text(p)
865        }
866    }
867}
868
869/// Reads an ONNX tensor in binary format from a file
870pub fn read_tensor(p: &Path) -> BoxResult<onnx::TensorProto> {
871    let file = std::fs::File::open(p)?;
872    let mut reader = io::BufReader::new(file);
873    let model: onnx::TensorProto = protobuf::Message::parse_from_reader(&mut reader)?;
874    Ok(model)
875}
876
877/// Selects the opset version to use for the given target version and the opset versions that the operator supports.
878pub fn pick_opset_version(target_ver: i64, opset_versions: &[i64]) -> i64 {
879    let mut opset_version = 0;
880    for v in opset_versions.iter() {
881        if *v <= target_ver && *v > opset_version {
882            opset_version = *v;
883        }
884    }
885    opset_version
886}
887
888/// Stub for operators that are not implemented.
889pub fn operator_not_implemented(
890    _inputs: &[&TensorType],
891    _node: &NodeProto,
892    _opset_version: i64,
893    _output_len: usize,
894) -> BoxResult<OperatorResult> {
895    todo!("operator not implemented");
896}