kn_graph/onnx/
external_data.rs1use crate::onnx::result::{OnnxError, OnnxResult, ToOnnxLoadResult};
2use itertools::Itertools;
3use rand::Rng;
4use std::fs::File;
5use std::io::{Read, Seek, SeekFrom};
6use std::path::{Component, Path, PathBuf};
7
8pub trait ExternalDataLoader {
9 fn load_external_data(
10 &mut self,
11 location: &Path,
12 offset: usize,
13 length: Option<usize>,
14 length_guess: usize,
15 ) -> OnnxResult<Vec<u8>>;
16}
17
18#[derive(Debug)]
19pub struct NoExternalData;
20
21#[derive(Debug)]
22pub struct DummyExternalData<R: Rng>(pub R);
23
24#[derive(Debug)]
25pub struct PathExternalData(pub PathBuf);
26
27impl ExternalDataLoader for NoExternalData {
28 fn load_external_data(&mut self, location: &Path, _: usize, _: Option<usize>, _: usize) -> OnnxResult<Vec<u8>> {
29 panic!(
30 "External data not allowed, trying to read from '{}'",
31 location.display()
32 );
33 }
34}
35
36impl<R: Rng> ExternalDataLoader for DummyExternalData<R> {
37 fn load_external_data(&mut self, _: &Path, _: usize, _: Option<usize>, length_guess: usize) -> OnnxResult<Vec<u8>> {
38 Ok((0..length_guess).map(|_| self.0.gen()).collect_vec())
39 }
40}
41
42impl ExternalDataLoader for PathExternalData {
43 fn load_external_data(
44 &mut self,
45 location: &Path,
46 offset: usize,
47 length: Option<usize>,
48 _: usize,
49 ) -> OnnxResult<Vec<u8>> {
50 if !path_is_normal(location) {
51 return Err(OnnxError::NonNormalExternalDataPath(location.to_owned()));
52 }
53
54 let path = self.0.join(location);
55
56 let mut file = File::open(&path).unwrap_or_else(|_| panic!("Failed to open file {:?}", path));
57 file.seek(SeekFrom::Start(offset as u64)).to_onnx_result(&path)?;
58
59 let mut buffer = vec![];
60
61 if let Some(length) = length {
62 buffer.resize(length, 0);
63 file.read_exact(&mut buffer).to_onnx_result(&path)?;
64 } else {
65 file.read_to_end(&mut buffer).to_onnx_result(&path)?;
66 }
67
68 Ok(buffer)
69 }
70}
71
72fn path_is_normal(path: &Path) -> bool {
73 path.components().all(|c| matches!(c, Component::Normal(_)))
74}