1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::ops::Deref;
4use std::path::Path;
5
6use crate::graph::Graph;
7use crate::onnx::{InputShaper, OnnxDimValue};
8use crate::onnx::external_data::{ExternalDataLoader, NoExternalData, PathExternalData};
9use crate::onnx::load::graph_from_onnx_bytes;
10use crate::onnx::result::{OnnxError, OnnxResult, ToOnnxLoadResult};
11use crate::shape::{Shape, Size};
12
13#[allow(missing_debug_implementations)]
41pub struct GraphLoader<'a> {
42 bytes: Cow<'a, [u8]>,
43 external: Box<dyn ExternalDataLoader>,
44
45 input_shaper_custom: Option<Box<InputShaper>>,
47 input_shape_overrides: Option<Vec<Option<Shape>>>,
48 named_axes: HashMap<String, Size>,
49}
50
51impl<'a> GraphLoader<'a> {
52 pub fn from_path(path: impl AsRef<Path>, allow_external: bool) -> OnnxResult<Self> {
53 let path = path.as_ref();
54 let bytes = std::fs::read(path).to_onnx_result(path)?;
55
56 let external: Box<dyn ExternalDataLoader> = if allow_external {
57 let parent = path
58 .parent()
59 .ok_or_else(|| OnnxError::MustHaveParentPath(path.to_owned()))?;
60 Box::new(PathExternalData(parent.to_owned()))
61 } else {
62 Box::new(NoExternalData)
63 };
64
65 Ok(GraphLoader {
66 bytes: Cow::Owned(bytes),
67 external,
68
69 input_shaper_custom: None,
70 input_shape_overrides: None,
71 named_axes: HashMap::new(),
72 })
73 }
74
75 pub fn from_bytes(bytes: &'a [u8]) -> Self {
76 GraphLoader {
77 bytes: Cow::Borrowed(bytes),
78 external: Box::new(NoExternalData),
79
80 input_shaper_custom: None,
81 input_shape_overrides: None,
82 named_axes: HashMap::new(),
83 }
84 }
85
86 pub fn set_external_data(&mut self, external: Box<dyn ExternalDataLoader>) {
87 self.external = external;
88 }
89
90 pub fn set_input_shaper_custom(&mut self, shaper: Box<InputShaper>) {
91 self.input_shaper_custom = Some(shaper);
92 }
93
94 pub fn force_input_shapes(&mut self, shapes: Vec<Option<Shape>>) {
95 self.input_shape_overrides = Some(shapes)
96 }
97
98 pub fn add_named_axis(&mut self, name: &str, value: Size) {
99 self.named_axes.insert(name.to_owned(), value);
100 }
101
102 pub fn load(self) -> OnnxResult<Graph> {
103 let mut external = self.external;
104
105 let input_shaper = move |dims: &[OnnxDimValue], name: &str, index| {
106 if let Some(input_shaper_custom) = &self.input_shaper_custom {
108 return input_shaper_custom(dims, name, index);
109 }
110 if let Some(input_shape_overrides) = &self.input_shape_overrides {
112 if index < input_shape_overrides.len() {
113 if let Some(shape) = &input_shape_overrides[index] {
114 return Some(shape.clone());
115 }
116 } else {
117 return None;
118 }
119 }
120 let mut new_dims = vec![];
122 for d in dims {
123 let d_new = match *d {
124 OnnxDimValue::Value(value) => Size::fixed(value as usize),
125 OnnxDimValue::Param(ref param) => self.named_axes.get(param)?.clone(),
126 };
127 new_dims.push(d_new);
128 }
129 Some(Shape::new(new_dims))
130 };
131
132 graph_from_onnx_bytes(self.bytes.deref(), external.as_mut(), &input_shaper)
133 }
134}