use std::cell::RefCell;
use std::collections::HashMap;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom};
use std::ops::Range;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
#[cfg(feature = "mmap")]
use memmap2::Mmap;
use super::load_error::{LoadError, LoadErrorImpl};
use crate::constant_storage::ConstantStorage;
#[derive(Clone, Debug)]
pub struct DataLocation {
pub path: String,
pub offset: u64,
pub length: u64,
}
#[derive(Debug)]
pub struct DataSlice {
pub storage: Arc<ConstantStorage>,
pub bytes: Range<usize>,
}
impl DataSlice {
pub fn data(&self) -> &[u8] {
&self.storage.data()[self.bytes.clone()]
}
}
#[derive(Debug)]
pub enum ExternalDataError {
IoError(std::io::Error),
InvalidLength,
InvalidPath(PathBuf),
NotSupported,
TooShort {
required_len: usize,
actual_len: usize,
},
DisallowedPath(PathBuf),
}
impl std::fmt::Display for ExternalDataError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::IoError(err) => write!(f, "io error: {}", err),
Self::InvalidLength => write!(f, "invalid data length"),
Self::InvalidPath(path) => write!(f, "invalid path \"{}\"", path.display()),
Self::NotSupported => write!(f, "external data not supported"),
Self::TooShort {
required_len,
actual_len,
} => write!(
f,
"file too short. required {} actual {}",
required_len, actual_len
),
Self::DisallowedPath(path) => {
write!(f, "disallowed path \"{}\"", path.display(),)
}
}
}
}
impl std::error::Error for ExternalDataError {}
impl From<std::io::Error> for ExternalDataError {
fn from(val: std::io::Error) -> Self {
Self::IoError(val)
}
}
impl From<ExternalDataError> for LoadError {
fn from(err: ExternalDataError) -> LoadError {
LoadErrorImpl::ExternalDataError(Box::new(err)).into()
}
}
pub trait DataLoader {
fn load(&self, location: &DataLocation) -> Result<DataSlice, ExternalDataError>;
}
fn is_allowed_external_data_path(path: &Path) -> bool {
let mut components = path.components();
let Some(Component::Normal(_)) = components.next() else {
return false;
};
if components.next().is_some() {
return false;
}
match path.extension().and_then(|ext| ext.to_str()) {
Some(ext) if ext.starts_with("data") => true,
Some(ext) if ext.starts_with("onnx_data") => true,
_ => false,
}
}
pub struct FileLoader {
dir_path: PathBuf,
files: RefCell<HashMap<PathBuf, File>>,
}
impl FileLoader {
pub fn new(model_path: &Path) -> Result<Self, ExternalDataError> {
let dir_path = dir_path_from_model_path(model_path)?;
Ok(Self {
dir_path,
files: HashMap::new().into(),
})
}
fn read(&self, location: &DataLocation) -> Result<Vec<u8>, ExternalDataError> {
if cfg!(target_endian = "big") {
return Err(ExternalDataError::NotSupported);
}
if location.length > isize::MAX as u64 {
return Err(ExternalDataError::InvalidLength);
}
let vec_len = location.length as usize;
let mut files = self.files.borrow_mut();
let mut file = get_or_open_file(&mut files, &self.dir_path, Path::new(&location.path))?;
file.seek(SeekFrom::Start(location.offset))
.map_err(ExternalDataError::IoError)?;
let mut remaining = vec_len;
let mut buf = Vec::with_capacity(remaining);
const TMP_SIZE: usize = 8192;
let mut tmp = [0u8; TMP_SIZE];
loop {
let tmp_size = remaining.min(TMP_SIZE);
let n_read =
read_fill(&mut file, &mut tmp[..tmp_size]).map_err(ExternalDataError::IoError)?;
let chunk = &tmp[..n_read];
remaining -= chunk.len();
buf.extend_from_slice(chunk);
if n_read < tmp.len() || remaining == 0 {
break;
}
}
if buf.len() != vec_len {
return Err(ExternalDataError::TooShort {
required_len: vec_len,
actual_len: buf.len(),
});
}
Ok(buf)
}
}
fn read_fill<R: Read>(mut src: R, buf: &mut [u8]) -> std::io::Result<usize> {
let mut total_read = 0;
loop {
let n = src.read(&mut buf[total_read..])?;
total_read += n;
if n == 0 || total_read == buf.len() {
break;
}
}
Ok(total_read)
}
impl DataLoader for FileLoader {
fn load(&self, location: &DataLocation) -> Result<DataSlice, ExternalDataError> {
let bytes = self.read(location)?;
Ok(DataSlice {
bytes: 0..bytes.len(),
storage: Arc::new(ConstantStorage::Buffer(bytes)),
})
}
}
fn get_or_open_file<'a>(
files: &'a mut HashMap<PathBuf, File>,
dir_path: &Path,
data_path: &Path,
) -> Result<&'a mut File, ExternalDataError> {
let data_path = Path::new(data_path);
if !is_allowed_external_data_path(data_path) {
return Err(ExternalDataError::DisallowedPath(data_path.into()));
}
if files.get(data_path).is_none() {
let mut file_path = dir_path.to_path_buf();
file_path.push(data_path);
let file = File::open(file_path).map_err(ExternalDataError::IoError)?;
files.insert(data_path.into(), file);
}
Ok(files.get_mut(data_path).unwrap())
}
fn dir_path_from_model_path(model_path: &Path) -> Result<PathBuf, ExternalDataError> {
let model_path = if !cfg!(target_arch = "wasm32") {
model_path.canonicalize()?
} else {
model_path.to_path_buf()
};
if !model_path.is_file() {
return Err(ExternalDataError::InvalidPath(model_path));
}
let dir_path = model_path
.parent()
.expect("should have parent dir")
.to_path_buf();
Ok(dir_path)
}
#[cfg(feature = "mmap")]
pub struct MmapLoader {
dir_path: PathBuf,
mmaps: RefCell<HashMap<PathBuf, Arc<ConstantStorage>>>,
}
#[cfg(feature = "mmap")]
impl MmapLoader {
pub unsafe fn new(model_path: &Path) -> Result<Self, ExternalDataError> {
let dir_path = dir_path_from_model_path(model_path)?;
Ok(Self {
dir_path,
mmaps: HashMap::new().into(),
})
}
fn get_or_open_mmap(
&self,
data_path: &Path,
) -> Result<Arc<ConstantStorage>, ExternalDataError> {
let mut mmaps = self.mmaps.borrow_mut();
let data_path = Path::new(data_path);
if !is_allowed_external_data_path(data_path) {
return Err(ExternalDataError::DisallowedPath(data_path.into()));
}
if mmaps.get(data_path).is_none() {
let mut file_path = self.dir_path.to_path_buf();
file_path.push(data_path);
let file = File::open(file_path).map_err(ExternalDataError::IoError)?;
let mmap = unsafe { Mmap::map(&file) }?;
let storage = Arc::new(ConstantStorage::Mmap(mmap));
mmaps.insert(data_path.into(), storage);
}
Ok(mmaps.get(data_path).unwrap().clone())
}
}
#[cfg(feature = "mmap")]
impl DataLoader for MmapLoader {
fn load(&self, location: &DataLocation) -> Result<DataSlice, ExternalDataError> {
let storage = self.get_or_open_mmap(Path::new(&location.path))?;
let end_offset = location.offset.saturating_add(location.length);
if end_offset > storage.data().len() as u64 {
return Err(ExternalDataError::TooShort {
required_len: end_offset as usize,
actual_len: storage.data().len(),
});
}
Ok(DataSlice {
storage,
bytes: location.offset as usize..location.offset as usize + location.length as usize,
})
}
}
pub struct MemLoader(HashMap<String, Arc<ConstantStorage>>);
impl MemLoader {
pub fn new(map: HashMap<String, Arc<ConstantStorage>>) -> Self {
Self(map)
}
#[cfg(test)]
pub fn from_entries(entries: impl IntoIterator<Item = (String, Vec<u8>)>) -> Self {
let map = entries
.into_iter()
.map(|(path, buf)| {
let storage = Arc::new(ConstantStorage::Buffer(buf));
(path, storage)
})
.collect();
Self(map)
}
}
impl DataLoader for MemLoader {
fn load(&self, location: &DataLocation) -> Result<DataSlice, ExternalDataError> {
if !is_allowed_external_data_path(Path::new(&location.path)) {
return Err(ExternalDataError::DisallowedPath(
location.path.clone().into(),
));
}
let Some(storage) = self.0.get(&location.path) else {
return Err(ExternalDataError::IoError(std::io::Error::new(
std::io::ErrorKind::NotFound,
"No such file or directory".to_string(),
)));
};
let end_offset = location.offset + location.length;
if end_offset > storage.data().len() as u64 {
return Err(ExternalDataError::TooShort {
required_len: end_offset as usize,
actual_len: storage.data().len(),
});
}
let bytes = (location.offset as usize)..end_offset as usize;
Ok(DataSlice {
storage: storage.clone(),
bytes,
})
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::panic::RefUnwindSafe;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use super::{DataLoader, DataLocation, ExternalDataError, FileLoader, MemLoader};
use crate::constant_storage::ConstantStorage;
use rten_testing::TestCases;
fn temp_dir() -> PathBuf {
if cfg!(target_arch = "wasm32") {
PathBuf::new()
} else {
std::env::temp_dir()
}
}
struct TempFile {
path: PathBuf,
}
impl TempFile {
fn new(name: impl AsRef<Path>, content: &[u8]) -> std::io::Result<Self> {
let mut path = temp_dir();
path.push(name);
std::fs::write(&path, content)?;
Ok(Self { path })
}
fn path(&self) -> &Path {
&self.path
}
}
impl Drop for TempFile {
fn drop(&mut self) {
std::fs::remove_file(&self.path).expect("should remove file");
}
}
fn test_loader<L: DataLoader>(
base_name: &str,
make_loader: impl Fn(&Path) -> Result<L, ExternalDataError> + RefUnwindSafe,
) {
let bytes: Vec<u8> = (0..32).collect();
let model_file = TempFile::new(format!("{base_name}.onnx"), &[]).unwrap();
let data_file = TempFile::new(format!("{base_name}.onnx.data"), &bytes).unwrap();
let data_filename = data_file
.path()
.file_name()
.unwrap()
.to_string_lossy()
.to_string();
#[derive(Debug)]
struct Case {
location: DataLocation,
expected: Result<Vec<u8>, String>,
}
let cases = [
Case {
location: DataLocation {
path: data_filename.clone(),
offset: 8,
length: 8,
},
expected: Ok(bytes[8..16].into()),
},
Case {
location: DataLocation {
path: data_filename.clone(),
offset: 0,
length: bytes.len() as u64,
},
expected: Ok(bytes.clone()),
},
Case {
location: DataLocation {
path: String::new(),
offset: 0,
length: 0,
},
expected: Err("disallowed path".into()),
},
Case {
location: DataLocation {
path: "../foo.data".into(),
offset: 0,
length: 0,
},
expected: Err("disallowed path".into()),
},
Case {
location: DataLocation {
path: "not_a_data_file.md".into(),
offset: 0,
length: 0,
},
expected: Err("disallowed path".into()),
},
Case {
location: DataLocation {
path: "file_does_not_exist.data".into(),
offset: 0,
length: 0,
},
expected: Err("No such file or directory".into()),
},
Case {
location: DataLocation {
path: data_filename,
offset: 0,
length: 36,
},
expected: Err("file too short".into()),
},
];
cases.test_each(|case| {
let loader = make_loader(model_file.path()).unwrap();
let data = loader.load(&case.location).map_err(|e| e.to_string());
match (&data, &case.expected) {
(Ok(actual), Ok(expected)) => assert_eq!(actual.data(), expected),
(Err(actual), Err(expected)) => assert!(
actual.contains(expected),
"{} does not contain {}",
actual,
expected
),
(actual, expected) => assert_eq!(actual.is_ok(), expected.is_ok()),
}
});
}
#[test]
fn test_file_loader() {
test_loader("test_file_loader", FileLoader::new)
}
#[cfg(feature = "mmap")]
#[test]
fn test_mmap_loader() {
use super::MmapLoader;
test_loader("test_mmap_loader", |model_path| unsafe {
MmapLoader::new(model_path)
})
}
#[test]
fn test_mem_loader() {
test_loader("test_mem_loader", |_model_path| {
let mut map = HashMap::new();
let buf = (0..32).collect();
let storage = Arc::new(ConstantStorage::Buffer(buf));
map.insert("test_mem_loader.onnx.data".to_string(), storage);
Ok(MemLoader::new(map))
})
}
}