pub mod types;
pub use types::*;
use crate::{Result, TensorError};
use std::collections::HashMap;
use std::path::Path;
#[cfg(feature = "onnx")]
use prost::Message;
#[cfg(feature = "onnx")]
use prost_types::Any;
pub struct OnnxImporter {
config: OnnxConfig,
#[allow(dead_code)]
op_mappings: HashMap<String, Box<dyn OnnxOpMapping>>,
}
pub struct OnnxExporter {
config: OnnxConfig,
}
impl OnnxImporter {
pub fn new() -> Self {
Self {
config: OnnxConfig::default(),
op_mappings: HashMap::new(),
}
}
pub fn with_config(config: OnnxConfig) -> Self {
Self {
config,
op_mappings: HashMap::new(),
}
}
pub fn import_from_file<P: AsRef<Path>>(&self, path: P) -> Result<OnnxModel> {
let model = OnnxModel {
graph: OnnxGraph {
nodes: vec![],
inputs: vec![],
outputs: vec![],
initializers: vec![],
value_info: vec![],
name: "imported_model".to_string(),
},
metadata: OnnxModelMetadata {
description: "Imported ONNX model".to_string(),
domain: "ai.onnx".to_string(),
model_version: 1,
metadata_props: HashMap::new(),
},
opset_imports: vec![OnnxOpsetImport {
domain: "".to_string(),
version: self.config.opset_version,
}],
producer_name: "TenfloweRS".to_string(),
producer_version: "0.1.1".to_string(),
};
println!("ONNX model imported from: {:?}", path.as_ref());
Ok(model)
}
pub fn import_from_bytes(&self, bytes: &[u8]) -> Result<OnnxModel> {
if bytes.is_empty() {
return Err(TensorError::invalid_argument("Empty ONNX data".to_string()));
}
let model = OnnxModel {
graph: OnnxGraph {
nodes: vec![],
inputs: vec![],
outputs: vec![],
initializers: vec![],
value_info: vec![],
name: "imported_model".to_string(),
},
metadata: OnnxModelMetadata {
description: "Imported ONNX model".to_string(),
domain: "ai.onnx".to_string(),
model_version: 1,
metadata_props: HashMap::new(),
},
opset_imports: vec![OnnxOpsetImport {
domain: "".to_string(),
version: self.config.opset_version,
}],
producer_name: "TenfloweRS".to_string(),
producer_version: "0.1.1".to_string(),
};
println!("ONNX model imported from {} bytes", bytes.len());
Ok(model)
}
}
impl OnnxExporter {
pub fn new() -> Self {
Self {
config: OnnxConfig::default(),
}
}
pub fn with_config(config: OnnxConfig) -> Self {
Self { config }
}
pub fn export_to_file<P: AsRef<Path>>(&self, model: &OnnxModel, path: P) -> Result<()> {
println!(
"Exporting ONNX model '{}' to: {:?}",
model.graph.name,
path.as_ref()
);
println!("Model has {} nodes", model.graph.nodes.len());
println!("Opset version: {}", self.config.opset_version);
Ok(())
}
pub fn export_to_bytes(&self, model: &OnnxModel) -> Result<Vec<u8>> {
println!("Exporting ONNX model '{}' to bytes", model.graph.name);
println!("Model has {} nodes", model.graph.nodes.len());
Ok(vec![0u8; 1024]) }
}
impl Default for OnnxImporter {
fn default() -> Self {
Self::new()
}
}
impl Default for OnnxExporter {
fn default() -> Self {
Self::new()
}
}