kn_graph/onnx/
external_data.rs

1use crate::onnx::result::{OnnxError, OnnxResult, ToOnnxLoadResult};
2use itertools::Itertools;
3use rand::Rng;
4use std::fs::File;
5use std::io::{Read, Seek, SeekFrom};
6use std::path::{Component, Path, PathBuf};
7
8pub trait ExternalDataLoader {
9    fn load_external_data(
10        &mut self,
11        location: &Path,
12        offset: usize,
13        length: Option<usize>,
14        length_guess: usize,
15    ) -> OnnxResult<Vec<u8>>;
16}
17
18#[derive(Debug)]
19pub struct NoExternalData;
20
21#[derive(Debug)]
22pub struct DummyExternalData<R: Rng>(pub R);
23
24#[derive(Debug)]
25pub struct PathExternalData(pub PathBuf);
26
27impl ExternalDataLoader for NoExternalData {
28    fn load_external_data(&mut self, location: &Path, _: usize, _: Option<usize>, _: usize) -> OnnxResult<Vec<u8>> {
29        panic!(
30            "External data not allowed, trying to read from '{}'",
31            location.display()
32        );
33    }
34}
35
36impl<R: Rng> ExternalDataLoader for DummyExternalData<R> {
37    fn load_external_data(&mut self, _: &Path, _: usize, _: Option<usize>, length_guess: usize) -> OnnxResult<Vec<u8>> {
38        Ok((0..length_guess).map(|_| self.0.gen()).collect_vec())
39    }
40}
41
42impl ExternalDataLoader for PathExternalData {
43    fn load_external_data(
44        &mut self,
45        location: &Path,
46        offset: usize,
47        length: Option<usize>,
48        _: usize,
49    ) -> OnnxResult<Vec<u8>> {
50        if !path_is_normal(location) {
51            return Err(OnnxError::NonNormalExternalDataPath(location.to_owned()));
52        }
53
54        let path = self.0.join(location);
55
56        let mut file = File::open(&path).unwrap_or_else(|_| panic!("Failed to open file {:?}", path));
57        file.seek(SeekFrom::Start(offset as u64)).to_onnx_result(&path)?;
58
59        let mut buffer = vec![];
60
61        if let Some(length) = length {
62            buffer.resize(length, 0);
63            file.read_exact(&mut buffer).to_onnx_result(&path)?;
64        } else {
65            file.read_to_end(&mut buffer).to_onnx_result(&path)?;
66        }
67
68        Ok(buffer)
69    }
70}
71
72fn path_is_normal(path: &Path) -> bool {
73    path.components().all(|c| matches!(c, Component::Normal(_)))
74}