use std::fs::File;
use std::path::{Path, PathBuf};
use memmap2::Mmap;
use oxionnx_proto::parser::parse_model;
use oxionnx_proto::types::{ModelProto, TensorProto};
use super::error::OnnxImportError;
mod dtype_code {
pub const FLOAT32: i32 = 1;
pub const FLOAT16: i32 = 10;
pub const BFLOAT16: i32 = 16;
}
pub struct OnnxReader {
pub onnx_path: PathBuf,
pub base_dir: PathBuf,
pub model: ModelProto,
sidecars: Vec<(PathBuf, Mmap)>,
}
impl OnnxReader {
pub fn open(onnx_path: &Path) -> Result<Self, OnnxImportError> {
let bytes = std::fs::read(onnx_path).map_err(|e| OnnxImportError::Io {
path: onnx_path.to_path_buf(),
source: e,
})?;
let model = parse_model(&bytes).map_err(|msg| OnnxImportError::Parse {
path: onnx_path.to_path_buf(),
msg,
})?;
let base_dir = onnx_path
.parent()
.map(Path::to_path_buf)
.unwrap_or_else(|| PathBuf::from("."));
Ok(Self {
onnx_path: onnx_path.to_path_buf(),
base_dir,
model,
sidecars: Vec::new(),
})
}
pub fn find_initializer(&self, name: &str) -> Option<&TensorProto> {
self.model
.graph
.initializers
.iter()
.find(|t| t.name == name)
}
pub fn initializer_bytes<'a>(
&'a mut self,
tensor: &'a TensorProto,
) -> Result<&'a [u8], OnnxImportError> {
if !is_external(tensor) {
return Ok(tensor.raw_data.as_slice());
}
let location = external_entry(tensor, "location").ok_or_else(|| {
OnnxImportError::MissingExternalEntry {
tensor: tensor.name.clone(),
key: "location",
}
})?;
let offset: usize = external_entry(tensor, "offset")
.and_then(|s| s.parse::<u64>().ok())
.map(|u| u as usize)
.unwrap_or(0);
let length: usize = external_entry(tensor, "length")
.and_then(|s| s.parse::<u64>().ok())
.map(|u| u as usize)
.unwrap_or(0);
if length == 0 {
return Err(OnnxImportError::MissingExternalEntry {
tensor: tensor.name.clone(),
key: "length",
});
}
let sidecar_path = self.base_dir.join(location);
self.ensure_sidecar_mapped(&sidecar_path)?;
let mmap = self
.sidecars
.iter()
.find(|(p, _)| p == &sidecar_path)
.map(|(_, m)| m)
.ok_or_else(|| {
OnnxImportError::Other(format!(
"internal: sidecar {} was not mapped after ensure_sidecar_mapped",
sidecar_path.display()
))
})?;
let end = offset.checked_add(length).ok_or_else(|| {
OnnxImportError::Other(format!(
"offset {offset} + length {length} overflows usize for tensor '{}'",
tensor.name
))
})?;
if end > mmap.len() {
return Err(OnnxImportError::Other(format!(
"external-data range {offset}..{end} exceeds sidecar size {} for tensor '{}'",
mmap.len(),
tensor.name
)));
}
Ok(&mmap[offset..end])
}
fn ensure_sidecar_mapped(&mut self, path: &Path) -> Result<(), OnnxImportError> {
if self.sidecars.iter().any(|(p, _)| p == path) {
return Ok(());
}
let file = File::open(path).map_err(|e| OnnxImportError::Io {
path: path.to_path_buf(),
source: e,
})?;
let mmap = unsafe { Mmap::map(&file) }.map_err(|e| OnnxImportError::Io {
path: path.to_path_buf(),
source: e,
})?;
self.sidecars.push((path.to_path_buf(), mmap));
Ok(())
}
}
pub fn is_external(tensor: &TensorProto) -> bool {
tensor.data_location == 1 || !tensor.external_data.is_empty()
}
pub fn external_entry<'a>(tensor: &'a TensorProto, key: &str) -> Option<&'a str> {
tensor
.external_data
.iter()
.find(|(k, _)| k == key)
.map(|(_, v)| v.as_str())
}
pub fn bytes_to_f32(
bytes: &[u8],
data_type: i32,
tensor_name: &str,
) -> Result<Vec<f32>, OnnxImportError> {
match data_type {
dtype_code::FLOAT32 => Ok(bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect()),
dtype_code::FLOAT16 => Ok(bytes
.chunks_exact(2)
.map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
.collect()),
dtype_code::BFLOAT16 => Ok(bytes
.chunks_exact(2)
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
.collect()),
other => Err(OnnxImportError::UnsupportedDtype {
tensor: tensor_name.to_string(),
dtype: other,
}),
}
}
pub fn attr_int(attrs: &[oxionnx_proto::types::AttributeProto], name: &'static str) -> Option<i64> {
attrs
.iter()
.find(|a| a.name == name)
.and_then(|a| match a.value.attr_type {
2 => Some(a.value.i),
_ => None,
})
}
pub fn locate_config_json(onnx_path: &Path) -> Result<PathBuf, OnnxImportError> {
let mut dir = onnx_path.parent().map(Path::to_path_buf);
for _ in 0..3 {
if let Some(d) = dir.as_ref() {
let candidate = d.join("config.json");
if candidate.exists() {
return Ok(candidate);
}
dir = d.parent().map(Path::to_path_buf);
} else {
break;
}
}
Err(OnnxImportError::ConfigJsonMissing {
onnx_path: onnx_path.to_path_buf(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bytes_to_f32_f32_roundtrip() {
let values: [f32; 3] = [1.0, -2.5, 0.125];
let raw: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
let out = bytes_to_f32(&raw, dtype_code::FLOAT32, "t").expect("ok");
assert_eq!(out, values);
}
#[test]
fn bytes_to_f32_f16_roundtrip() {
let values = [half::f16::from_f32(1.0), half::f16::from_f32(-0.5)];
let raw: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
let out = bytes_to_f32(&raw, dtype_code::FLOAT16, "t").expect("ok");
assert_eq!(out.len(), values.len());
assert!((out[0] - 1.0).abs() < 1e-4);
assert!((out[1] - -0.5).abs() < 1e-4);
}
#[test]
fn bytes_to_f32_unsupported_dtype_errors() {
let err = bytes_to_f32(&[0, 0, 0, 0], 7 , "t").unwrap_err();
match err {
OnnxImportError::UnsupportedDtype { tensor, dtype } => {
assert_eq!(tensor, "t");
assert_eq!(dtype, 7);
}
_ => panic!("expected UnsupportedDtype, got {:?}", err),
}
}
}