use std::collections::HashMap;
use std::path::{Path, PathBuf};
use prost::Message;
use crate::tensor::DType;
use super::error::{OnnxError, OnnxResult};
use super::proto;
pub struct OnnxFile {
pub model: proto::ModelProto,
pub base_dir: PathBuf,
}
#[derive(Debug, Clone)]
pub struct OnnxMetadata {
pub ir_version: i64,
pub producer_name: String,
pub producer_version: String,
pub domain: String,
pub model_version: i64,
pub doc_string: String,
pub opset_imports: Vec<(String, i64)>,
}
#[derive(Debug, Clone)]
pub struct OnnxTensorInfo {
pub name: String,
pub dims: Vec<i64>,
pub data_type: i32,
pub n_elements: usize,
}
impl OnnxFile {
pub fn open<P: AsRef<Path>>(path: P) -> OnnxResult<Self> {
let path = path.as_ref();
let base_dir = path.parent().unwrap_or(Path::new(".")).to_path_buf();
let data = std::fs::read(path)?;
let model = proto::ModelProto::decode(data.as_slice())?;
Ok(Self { model, base_dir })
}
pub fn metadata(&self) -> OnnxMetadata {
let opset_imports = self
.model
.opset_import
.iter()
.map(|op| (op.domain.clone(), op.version))
.collect();
OnnxMetadata {
ir_version: self.model.ir_version,
producer_name: self.model.producer_name.clone(),
producer_version: self.model.producer_version.clone(),
domain: self.model.domain.clone(),
model_version: self.model.model_version,
doc_string: self.model.doc_string.clone(),
opset_imports,
}
}
pub fn graph(&self) -> OnnxResult<&proto::GraphProto> {
self.model.graph.as_ref().ok_or(OnnxError::MissingGraph)
}
pub fn initializers(&self) -> OnnxResult<HashMap<&str, &proto::TensorProto>> {
let graph = self.graph()?;
let mut map = HashMap::new();
for init in &graph.initializer {
map.insert(init.name.as_str(), init);
}
Ok(map)
}
pub fn tensor_infos(&self) -> OnnxResult<Vec<OnnxTensorInfo>> {
let graph = self.graph()?;
Ok(graph
.initializer
.iter()
.map(|t| {
let n_elements = t.dims.iter().map(|&d| d as usize).product::<usize>();
OnnxTensorInfo {
name: t.name.clone(),
dims: t.dims.clone(),
data_type: t.data_type,
n_elements,
}
})
.collect())
}
pub fn get_initializer(&self, name: &str) -> OnnxResult<&proto::TensorProto> {
let graph = self.graph()?;
graph
.initializer
.iter()
.find(|t| t.name == name)
.ok_or_else(|| OnnxError::MissingTensor(name.to_string()))
}
pub fn nodes(&self) -> OnnxResult<&[proto::NodeProto]> {
Ok(&self.graph()?.node)
}
pub fn inputs(&self) -> OnnxResult<Vec<(&str, Vec<i64>)>> {
let graph = self.graph()?;
Ok(graph
.input
.iter()
.map(|inp| {
let dims = inp
.r#type
.as_ref()
.and_then(|t| t.value.as_ref())
.and_then(|v| {
if let proto::type_proto::Value::TensorType(tt) = v {
tt.shape.as_ref().map(|s| {
s.dim
.iter()
.map(|d| match &d.value {
Some(
proto::tensor_shape_proto::dimension::Value::DimValue(
v,
),
) => *v,
_ => -1,
})
.collect()
})
} else {
None
}
})
.unwrap_or_default();
(inp.name.as_str(), dims)
})
.collect())
}
pub fn outputs(&self) -> OnnxResult<Vec<&str>> {
let graph = self.graph()?;
Ok(graph.output.iter().map(|o| o.name.as_str()).collect())
}
}
pub fn extract_tensor_bytes(tensor: &proto::TensorProto, base_dir: &Path) -> OnnxResult<Vec<u8>> {
if tensor.data_location == 1 && !tensor.external_data.is_empty() {
return extract_external_tensor_bytes(tensor, base_dir);
}
let data_type = tensor.data_type;
if !tensor.raw_data.is_empty() {
return Ok(tensor.raw_data.clone());
}
match data_type {
1 => {
let mut bytes = Vec::with_capacity(tensor.float_data.len() * 4);
for &val in &tensor.float_data {
bytes.extend_from_slice(&val.to_le_bytes());
}
Ok(bytes)
}
10 => {
let mut bytes = Vec::with_capacity(tensor.int32_data.len() * 2);
for &val in &tensor.int32_data {
bytes.extend_from_slice(&(val as u16).to_le_bytes());
}
Ok(bytes)
}
16 => {
let mut bytes = Vec::with_capacity(tensor.int32_data.len() * 2);
for &val in &tensor.int32_data {
bytes.extend_from_slice(&(val as u16).to_le_bytes());
}
Ok(bytes)
}
6 => {
let mut bytes = Vec::with_capacity(tensor.int32_data.len() * 4);
for &val in &tensor.int32_data {
bytes.extend_from_slice(&val.to_le_bytes());
}
Ok(bytes)
}
7 => {
let mut bytes = Vec::with_capacity(tensor.int64_data.len() * 8);
for &val in &tensor.int64_data {
bytes.extend_from_slice(&val.to_le_bytes());
}
Ok(bytes)
}
11 => {
let mut bytes = Vec::with_capacity(tensor.double_data.len() * 8);
for &val in &tensor.double_data {
bytes.extend_from_slice(&val.to_le_bytes());
}
Ok(bytes)
}
2 | 3 => {
let bytes: Vec<u8> = tensor.int32_data.iter().map(|&v| v as u8).collect();
Ok(bytes)
}
_ => Err(OnnxError::UnsupportedDataType(data_type)),
}
}
fn extract_external_tensor_bytes(
tensor: &proto::TensorProto,
base_dir: &Path,
) -> OnnxResult<Vec<u8>> {
let mut location = None;
let mut offset: u64 = 0;
let mut length: Option<u64> = None;
for kv in &tensor.external_data {
match kv.key.as_str() {
"location" => location = Some(kv.value.as_str()),
"offset" => offset = kv.value.parse().unwrap_or(0),
"length" => length = Some(kv.value.parse().unwrap_or(0)),
_ => {}
}
}
let filename = location.ok_or_else(|| {
OnnxError::Other(format!(
"External tensor '{}' missing 'location' field",
tensor.name
))
})?;
let file_path = base_dir.join(filename);
use std::io::{Read, Seek, SeekFrom};
let mut file = std::fs::File::open(&file_path).map_err(|e| {
OnnxError::Other(format!(
"Failed to open external data file '{}': {}",
file_path.display(),
e
))
})?;
file.seek(SeekFrom::Start(offset)).map_err(|e| {
OnnxError::Other(format!(
"Failed to seek in external data file '{}': {}",
file_path.display(),
e
))
})?;
let data = if let Some(len) = length {
let mut buf = vec![0u8; len as usize];
file.read_exact(&mut buf).map_err(|e| {
OnnxError::Other(format!(
"Failed to read {} bytes from external data file '{}': {}",
len,
file_path.display(),
e
))
})?;
buf
} else {
let mut buf = Vec::new();
file.read_to_end(&mut buf).map_err(|e| {
OnnxError::Other(format!(
"Failed to read external data file '{}': {}",
file_path.display(),
e
))
})?;
buf
};
Ok(data)
}
pub fn onnx_dtype_to_dtype(onnx_type: i32) -> OnnxResult<DType> {
match onnx_type {
1 => Ok(DType::F32), 2 => Ok(DType::U8), 3 => Ok(DType::I8), 5 => Ok(DType::I16), 6 => Ok(DType::I32), 7 => Ok(DType::I64), 10 => Ok(DType::F16), 11 => Ok(DType::F64), 16 => Ok(DType::BF16), other => Err(OnnxError::UnsupportedDataType(other)),
}
}
pub fn trace_graph_tensor_names(onnx: &OnnxFile) -> OnnxResult<HashMap<String, String>> {
let graph = onnx.graph()?;
let init_names: std::collections::HashSet<&str> =
graph.initializer.iter().map(|i| i.name.as_str()).collect();
let mut name_map: HashMap<String, String> = HashMap::new();
for node in &graph.node {
if node.op_type == "Identity" && !node.input.is_empty() && !node.output.is_empty() {
let src = &node.input[0];
let dst = &node.output[0];
if init_names.contains(src.as_str()) {
name_map.insert(dst.clone(), src.clone());
}
}
}
let graph_output_names: std::collections::HashSet<&str> =
graph.output.iter().map(|o| o.name.as_str()).collect();
let mut output_aliases: HashMap<String, String> = HashMap::new();
for node in &graph.node {
if (node.op_type == "Cast" || node.op_type == "Identity")
&& !node.input.is_empty()
&& !node.output.is_empty()
{
let out = &node.output[0];
if graph_output_names.contains(out.as_str()) {
output_aliases.insert(node.input[0].clone(), out.clone());
}
}
}
for node in &graph.node {
if node.op_type != "MatMul" || node.output.is_empty() {
continue;
}
let weight_input = node
.input
.iter()
.find(|inp| init_names.contains(inp.as_str()));
let weight_name = match weight_input {
Some(w) => w,
None => continue, };
if !weight_name.starts_with("onnx::") {
continue;
}
let output_name = &node.output[0];
let effective_output = output_aliases
.get(output_name.as_str())
.map(|s| s.as_str())
.unwrap_or(output_name.as_str());
if let Some(hf_name) = matmul_output_to_hf_name(effective_output) {
name_map.insert(weight_name.clone(), hf_name);
}
}
Ok(name_map)
}
fn matmul_output_to_hf_name(output_name: &str) -> Option<String> {
if output_name == "logits" {
return Some("lm_head.weight".to_string());
}
let path = output_name.strip_prefix('/')?;
let path = if let Some(idx) = path.rfind("/MatMul") {
&path[..idx]
} else {
return None;
};
let hf_name = path.replace('/', ".");
Some(format!("{}.weight", hf_name))
}
pub fn onnx_dtype_name(onnx_type: i32) -> &'static str {
match onnx_type {
0 => "UNDEFINED",
1 => "FLOAT",
2 => "UINT8",
3 => "INT8",
4 => "UINT16",
5 => "INT16",
6 => "INT32",
7 => "INT64",
8 => "STRING",
9 => "BOOL",
10 => "FLOAT16",
11 => "DOUBLE",
12 => "UINT32",
13 => "UINT64",
16 => "BFLOAT16",
_ => "UNKNOWN",
}
}