kn_graph/onnx/
loader.rs

1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::ops::Deref;
4use std::path::Path;
5
6use crate::graph::Graph;
7use crate::onnx::{InputShaper, OnnxDimValue};
8use crate::onnx::external_data::{ExternalDataLoader, NoExternalData, PathExternalData};
9use crate::onnx::load::graph_from_onnx_bytes;
10use crate::onnx::result::{OnnxError, OnnxResult, ToOnnxLoadResult};
11use crate::shape::{Shape, Size};
12
13/// Load an [ONNX](https://github.com/onnx/onnx/blob/main/docs/IR.md) graph.
14///
15/// Many loading settings are customizable:
16/// * the source, either from a path through [Self::from_path] or from bytes through [Self::from_bytes].
17/// * whether [external data](https://github.com/onnx/onnx/blob/main/docs/ExternalData.md) is allowed,
18///     through [Self::from_path] `allow_external` or [Self::set_external_data].
19/// * input shape overrides (in order of priority):
20///   * fully custom through [Self::set_input_shaper_custom]
21///   * specific input overrides through [Self::force_input_shapes]
22///   * named axes through [Self::add_named_axis]
23///
24/// A simple example:
25/// ```no_run
26/// # use kn_graph::graph::Graph;
27/// # use kn_graph::onnx::GraphLoader;
28/// # use kn_graph::shape;
29/// # use kn_graph::shape::Size;
30/// // load from a path, disallowing external data
31/// let mut loader = GraphLoader::from_path("model.onnx", false).unwrap();
32/// // set some named axes
33/// loader.add_named_axis("batch_size", Size::BATCH);
34/// loader.add_named_axis("sequence_length", Size::fixed(128));
35/// // override the third input shape
36/// loader.force_input_shapes(vec![None, None, Some(shape![1, Size::BATCH, 3])]);
37/// // load the graph
38/// let graph = loader.load().unwrap();
39/// ```
40#[allow(missing_debug_implementations)]
41pub struct GraphLoader<'a> {
42    bytes: Cow<'a, [u8]>,
43    external: Box<dyn ExternalDataLoader>,
44
45    // input shape overrides
46    input_shaper_custom: Option<Box<InputShaper>>,
47    input_shape_overrides: Option<Vec<Option<Shape>>>,
48    named_axes: HashMap<String, Size>,
49}
50
51impl<'a> GraphLoader<'a> {
52    pub fn from_path(path: impl AsRef<Path>, allow_external: bool) -> OnnxResult<Self> {
53        let path = path.as_ref();
54        let bytes = std::fs::read(path).to_onnx_result(path)?;
55
56        let external: Box<dyn ExternalDataLoader> = if allow_external {
57            let parent = path
58                .parent()
59                .ok_or_else(|| OnnxError::MustHaveParentPath(path.to_owned()))?;
60            Box::new(PathExternalData(parent.to_owned()))
61        } else {
62            Box::new(NoExternalData)
63        };
64
65        Ok(GraphLoader {
66            bytes: Cow::Owned(bytes),
67            external,
68
69            input_shaper_custom: None,
70            input_shape_overrides: None,
71            named_axes: HashMap::new(),
72        })
73    }
74
75    pub fn from_bytes(bytes: &'a [u8]) -> Self {
76        GraphLoader {
77            bytes: Cow::Borrowed(bytes),
78            external: Box::new(NoExternalData),
79
80            input_shaper_custom: None,
81            input_shape_overrides: None,
82            named_axes: HashMap::new(),
83        }
84    }
85
86    pub fn set_external_data(&mut self, external: Box<dyn ExternalDataLoader>) {
87        self.external = external;
88    }
89
90    pub fn set_input_shaper_custom(&mut self, shaper: Box<InputShaper>) {
91        self.input_shaper_custom = Some(shaper);
92    }
93
94    pub fn force_input_shapes(&mut self, shapes: Vec<Option<Shape>>) {
95        self.input_shape_overrides = Some(shapes)
96    }
97
98    pub fn add_named_axis(&mut self, name: &str, value: Size) {
99        self.named_axes.insert(name.to_owned(), value);
100    }
101
102    pub fn load(self) -> OnnxResult<Graph> {
103        let mut external = self.external;
104
105        let input_shaper = move |dims: &[OnnxDimValue], name: &str, index| {
106            // first try custom shaper
107            if let Some(input_shaper_custom) = &self.input_shaper_custom {
108                return input_shaper_custom(dims, name, index);
109            }
110            // then shape overrides
111            if let Some(input_shape_overrides) = &self.input_shape_overrides {
112                if index < input_shape_overrides.len() {
113                    if let Some(shape) = &input_shape_overrides[index] {
114                        return Some(shape.clone());
115                    }
116                } else {
117                    return None;
118                }
119            }
120            // finally try basic resolution using named axes
121            let mut new_dims = vec![];
122            for d in dims {
123                let d_new = match *d {
124                    OnnxDimValue::Value(value) => Size::fixed(value as usize),
125                    OnnxDimValue::Param(ref param) => self.named_axes.get(param)?.clone(),
126                };
127                new_dims.push(d_new);
128            }
129            Some(Shape::new(new_dims))
130        };
131
132        graph_from_onnx_bytes(self.bytes.deref(), external.as_mut(), &input_shaper)
133    }
134}