tract_nnef/
resource.rs

1use std::path::Path;
2
3use crate::ast::QuantFormat;
4use crate::internal::*;
5use safetensors::SafeTensors;
6use tract_core::downcast_rs::{impl_downcast, DowncastSync};
7use tract_core::tract_data::itertools::Itertools;
8
9pub const GRAPH_NNEF_FILENAME: &str = "graph.nnef";
10pub const GRAPH_QUANT_FILENAME: &str = "graph.quant";
11
12pub fn resource_path_to_id(path: impl AsRef<Path>) -> TractResult<String> {
13    let mut path = path.as_ref().to_path_buf();
14    path.set_extension("");
15    path.to_str()
16        .ok_or_else(|| format_err!("Badly encoded filename for path: {:?}", path))
17        .map(|s| s.to_string())
18}
19
20pub trait Resource: DowncastSync + std::fmt::Debug + Send + Sync {
21    /// Get value for a given key.
22    fn get(&self, _key: &str) -> TractResult<Value> {
23        bail!("No key access supported by this resource");
24    }
25
26    fn to_liquid_value(&self) -> Option<liquid::model::Value> {
27        None
28    }
29}
30
31impl_downcast!(sync Resource);
32
33pub trait ResourceLoader: Send + Sync {
34    /// Name of the resource loader.
35    fn name(&self) -> StaticName;
36    /// Try to load a resource give a path and its corresponding reader.
37    /// None is returned if the path is not accepted by this loader.
38    fn try_load(
39        &self,
40        path: &Path,
41        reader: &mut dyn std::io::Read,
42        framework: &Nnef,
43    ) -> TractResult<Option<(String, Arc<dyn Resource>)>>;
44
45    fn into_boxed(self) -> Box<dyn ResourceLoader>
46    where
47        Self: Sized + 'static,
48    {
49        Box::new(self)
50    }
51}
52
53#[derive(Debug)]
54pub struct GraphNnef(pub String);
55impl Resource for GraphNnef {}
56
57#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
58pub struct GraphNnefLoader;
59
60impl ResourceLoader for GraphNnefLoader {
61    fn name(&self) -> StaticName {
62        "GraphNnefLoader".into()
63    }
64
65    fn try_load(
66        &self,
67        path: &Path,
68        reader: &mut dyn std::io::Read,
69        _framework: &Nnef,
70    ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
71        if path.ends_with(GRAPH_NNEF_FILENAME) {
72            let mut text = String::new();
73            reader.read_to_string(&mut text)?;
74            Ok(Some((path.to_string_lossy().to_string(), Arc::new(GraphNnef(text)))))
75        } else {
76            Ok(None)
77        }
78    }
79}
80
81impl Resource for Tensor {}
82
83#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
84pub struct DatLoader;
85
86impl ResourceLoader for DatLoader {
87    fn name(&self) -> StaticName {
88        "DatLoader".into()
89    }
90
91    fn try_load(
92        &self,
93        path: &Path,
94        reader: &mut dyn std::io::Read,
95        _framework: &Nnef,
96    ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
97        if path.extension().map(|e| e == "dat").unwrap_or(false) {
98            let tensor = crate::tensors::read_tensor(reader)
99                .with_context(|| format!("Error while reading tensor {path:?}"))?;
100            Ok(Some((resource_path_to_id(path)?, Arc::new(tensor))))
101        } else {
102            Ok(None)
103        }
104    }
105}
106
107impl Resource for HashMap<String, QuantFormat> {}
108
109#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
110pub struct GraphQuantLoader;
111
112impl ResourceLoader for GraphQuantLoader {
113    fn name(&self) -> StaticName {
114        "GraphQuantLoader".into()
115    }
116
117    fn try_load(
118        &self,
119        path: &Path,
120        reader: &mut dyn std::io::Read,
121        _framework: &Nnef,
122    ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
123        if path.ends_with(GRAPH_QUANT_FILENAME) {
124            let mut text = String::new();
125            reader.read_to_string(&mut text)?;
126            let quant = crate::ast::quant::parse_quantization(&text)?;
127            let quant: HashMap<String, QuantFormat> =
128                quant.into_iter().map(|(k, v)| (k.0, v)).collect();
129            Ok(Some((path.to_str().unwrap().to_string(), Arc::new(quant))))
130        } else {
131            Ok(None)
132        }
133    }
134}
135
136pub struct TypedModelLoader {
137    pub optimized_model: bool,
138}
139
140impl TypedModelLoader {
141    pub fn new(optimized_model: bool) -> Self {
142        Self { optimized_model }
143    }
144}
145
146impl ResourceLoader for TypedModelLoader {
147    fn name(&self) -> StaticName {
148        "TypedModelLoader".into()
149    }
150
151    fn try_load(
152        &self,
153        path: &Path,
154        reader: &mut dyn std::io::Read,
155        framework: &Nnef,
156    ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
157        const NNEF_TGZ: &str = ".nnef.tgz";
158        const NNEF_TAR: &str = ".nnef.tar";
159        let path_str = path.to_str().unwrap_or("");
160        if path_str.ends_with(NNEF_TGZ) || path_str.ends_with(NNEF_TAR) {
161            let model = if self.optimized_model {
162                framework.model_for_read(reader)?.into_optimized()?
163            } else {
164                framework.model_for_read(reader)?
165            };
166
167            let label = if path_str.ends_with(NNEF_TGZ) {
168                path.to_str()
169                    .ok_or_else(|| anyhow!("invalid model resource path"))?
170                    .trim_end_matches(NNEF_TGZ)
171            } else {
172                path.to_str()
173                    .ok_or_else(|| anyhow!("invalid model resource path"))?
174                    .trim_end_matches(NNEF_TAR)
175            };
176            Ok(Some((resource_path_to_id(label)?, Arc::new(TypedModelResource(model)))))
177        } else {
178            Ok(None)
179        }
180    }
181}
182
183#[derive(Debug, Clone)]
184pub struct TypedModelResource(pub TypedModel);
185
186impl Resource for TypedModelResource {}
187
188pub struct SafeTensorsLoader;
189
190impl ResourceLoader for SafeTensorsLoader {
191    fn name(&self) -> StaticName {
192        "SafeTensorsLoader".into()
193    }
194
195    fn try_load(
196        &self,
197        path: &Path,
198        reader: &mut dyn std::io::Read,
199        _framework: &Nnef,
200    ) -> TractResult<Option<(String, Arc<dyn Resource>)>> {
201        if path.extension().is_some_and(|e| e == "safetensors") {
202            let mut buffer = vec![];
203            reader.read_to_end(&mut buffer)?;
204            let tensors: Vec<(String, Arc<Tensor>)> = SafeTensors::deserialize(&buffer)?
205                .tensors()
206                .into_iter()
207                .map(|(name, t)| {
208                    let dt = match t.dtype() {
209                        safetensors::Dtype::F32 => DatumType::F32,
210                        safetensors::Dtype::F16 => DatumType::F16,
211                        _ => panic!(),
212                    };
213                    let tensor = unsafe { Tensor::from_raw_dt(dt, t.shape(), t.data()).unwrap() };
214                    (name, tensor.into_arc_tensor())
215                })
216                .collect_vec();
217            return Ok(Some((path.to_string_lossy().to_string(), Arc::new(tensors))));
218        }
219        Ok(None)
220    }
221}
222
223impl Resource for Vec<(String, Arc<Tensor>)> {}