use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ModelFormat {
Onnx,
PyTorch,
TensorFlow,
Keras,
TensorFlowLite,
CoreML,
Caffe,
MXNet,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CustomFormat {
pub name: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ExtendedModelFormat {
Standard(ModelFormat),
Custom(String),
}
impl ModelFormat {
pub fn from_extension(ext: &str) -> Option<Self> {
match ext.to_lowercase().as_str() {
"onnx" => Some(ModelFormat::Onnx),
"pth" | "pt" => Some(ModelFormat::PyTorch),
"pb" => Some(ModelFormat::TensorFlow),
"h5" => Some(ModelFormat::Keras),
"tflite" => Some(ModelFormat::TensorFlowLite),
"mlmodel" => Some(ModelFormat::CoreML),
"caffemodel" => Some(ModelFormat::Caffe),
"params" => Some(ModelFormat::MXNet),
_ => None,
}
}
pub fn extensions(&self) -> &[&str] {
match self {
ModelFormat::Onnx => &["onnx"],
ModelFormat::PyTorch => &["pth", "pt"],
ModelFormat::TensorFlow => &["pb"],
ModelFormat::Keras => &["h5"],
ModelFormat::TensorFlowLite => &["tflite"],
ModelFormat::CoreML => &["mlmodel"],
ModelFormat::Caffe => &["caffemodel"],
ModelFormat::MXNet => &["params"],
}
}
pub fn supports_feature(&self, feature: FormatFeature) -> bool {
matches!(
(self, feature),
(
ModelFormat::Onnx,
FormatFeature::GraphStructure
| FormatFeature::Metadata
| FormatFeature::Quantization
) | (
ModelFormat::PyTorch,
FormatFeature::DynamicShapes | FormatFeature::StateDict
) | (
ModelFormat::TensorFlowLite,
FormatFeature::Quantization | FormatFeature::MobileOptimized
)
)
}
pub fn description(&self) -> &str {
match self {
ModelFormat::Onnx => "Open Neural Network Exchange format",
ModelFormat::PyTorch => "PyTorch native format",
ModelFormat::TensorFlow => "TensorFlow SavedModel format",
ModelFormat::Keras => "Keras HDF5 format",
ModelFormat::TensorFlowLite => "TensorFlow Lite format",
ModelFormat::CoreML => "Apple CoreML format",
ModelFormat::Caffe => "Caffe model format",
ModelFormat::MXNet => "Apache MXNet format",
}
}
}
impl std::fmt::Display for ModelFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
ModelFormat::Onnx => "ONNX",
ModelFormat::PyTorch => "PyTorch",
ModelFormat::TensorFlow => "TensorFlow",
ModelFormat::Keras => "Keras",
ModelFormat::TensorFlowLite => "TensorFlow Lite",
ModelFormat::CoreML => "CoreML",
ModelFormat::Caffe => "Caffe",
ModelFormat::MXNet => "MXNet",
};
write!(f, "{}", name)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FormatFeature {
GraphStructure,
Metadata,
Quantization,
DynamicShapes,
StateDict,
MobileOptimized,
CustomOperators,
TrainingMode,
InferenceOnly,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionType {
None,
Int8,
Int16,
Float16,
Pruned,
Distilled,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DeploymentTarget {
ServerCpu,
ServerGpu,
MobileCpu,
MobileGpu,
Edge,
WebAssembly,
Embedded,
}
pub struct FormatCompatibility {
compatibility_map: HashMap<(ModelFormat, ModelFormat), f32>,
}
impl FormatCompatibility {
pub fn new() -> Self {
let mut compatibility_map = HashMap::new();
for format in [
ModelFormat::Onnx,
ModelFormat::PyTorch,
ModelFormat::TensorFlow,
ModelFormat::Keras,
] {
compatibility_map.insert((format, format), 1.0);
}
compatibility_map.insert((ModelFormat::PyTorch, ModelFormat::Onnx), 0.9);
compatibility_map.insert((ModelFormat::Onnx, ModelFormat::PyTorch), 0.8);
compatibility_map.insert((ModelFormat::TensorFlow, ModelFormat::Onnx), 0.8);
compatibility_map.insert((ModelFormat::Keras, ModelFormat::TensorFlow), 0.95);
compatibility_map.insert((ModelFormat::PyTorch, ModelFormat::TensorFlow), 0.6);
compatibility_map.insert((ModelFormat::Onnx, ModelFormat::TensorFlowLite), 0.7);
compatibility_map.insert((ModelFormat::Caffe, ModelFormat::Onnx), 0.5);
compatibility_map.insert((ModelFormat::MXNet, ModelFormat::Onnx), 0.5);
Self { compatibility_map }
}
pub fn get_compatibility(&self, from: ModelFormat, to: ModelFormat) -> f32 {
self.compatibility_map
.get(&(from, to))
.copied()
.unwrap_or(0.0)
}
pub fn get_conversion_path(&self, from: ModelFormat, to: ModelFormat) -> Vec<ModelFormat> {
if from == to {
return vec![from];
}
if self.get_compatibility(from, to) > 0.5 {
return vec![from, to];
}
if from != ModelFormat::Onnx
&& to != ModelFormat::Onnx
&& self.get_compatibility(from, ModelFormat::Onnx) > 0.5
&& self.get_compatibility(ModelFormat::Onnx, to) > 0.5
{
return vec![from, ModelFormat::Onnx, to];
}
vec![from, to]
}
}
impl Default for FormatCompatibility {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct OptimizationProfile {
pub name: String,
pub target: DeploymentTarget,
pub compression: CompressionType,
pub max_model_size: Option<usize>,
pub target_latency_ms: Option<f32>,
pub memory_limit_mb: Option<usize>,
pub preserve_accuracy: f32,
}
impl OptimizationProfile {
pub fn server() -> Self {
Self {
name: "Server".to_string(),
target: DeploymentTarget::ServerGpu,
compression: CompressionType::None,
max_model_size: None,
target_latency_ms: Some(100.0),
memory_limit_mb: None,
preserve_accuracy: 0.99,
}
}
pub fn mobile() -> Self {
Self {
name: "Mobile".to_string(),
target: DeploymentTarget::MobileCpu,
compression: CompressionType::Int8,
max_model_size: Some(50 * 1024 * 1024), target_latency_ms: Some(50.0),
memory_limit_mb: Some(100),
preserve_accuracy: 0.95,
}
}
pub fn edge() -> Self {
Self {
name: "Edge".to_string(),
target: DeploymentTarget::Edge,
compression: CompressionType::Int8,
max_model_size: Some(10 * 1024 * 1024), target_latency_ms: Some(20.0),
memory_limit_mb: Some(50),
preserve_accuracy: 0.90,
}
}
pub fn web() -> Self {
Self {
name: "Web".to_string(),
target: DeploymentTarget::WebAssembly,
compression: CompressionType::Float16,
max_model_size: Some(5 * 1024 * 1024), target_latency_ms: Some(100.0),
memory_limit_mb: Some(100),
preserve_accuracy: 0.95,
}
}
}
pub struct FormatValidator;
impl FormatValidator {
pub fn validate_format(data: &[u8], expected_format: ModelFormat) -> bool {
match expected_format {
ModelFormat::Onnx => Self::validate_onnx(data),
ModelFormat::PyTorch => Self::validate_pytorch(data),
ModelFormat::TensorFlow => Self::validate_tensorflow(data),
_ => false, }
}
fn validate_onnx(data: &[u8]) -> bool {
data.len() > 8 && data.starts_with(&[0x08]) }
fn validate_pytorch(data: &[u8]) -> bool {
data.len() > 2 && (data[0] == 0x80 || data.starts_with(b"PK")) }
fn validate_tensorflow(data: &[u8]) -> bool {
data.len() > 10
&& (
data.starts_with(b"\x08\x01") || data.starts_with(b"TensorFlow")
)
}
pub fn get_format_confidence(data: &[u8]) -> HashMap<ModelFormat, f32> {
let mut scores = HashMap::new();
if data.len() > 8 && data.starts_with(&[0x08]) {
scores.insert(ModelFormat::Onnx, 0.8);
}
if data.len() > 2 && data[0] == 0x80 {
scores.insert(ModelFormat::PyTorch, 0.9);
}
if data.len() > 10 && data.starts_with(b"\x08\x01") {
scores.insert(ModelFormat::TensorFlow, 0.7);
}
scores
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_from_extension() {
assert_eq!(ModelFormat::from_extension("onnx"), Some(ModelFormat::Onnx));
assert_eq!(
ModelFormat::from_extension("pth"),
Some(ModelFormat::PyTorch)
);
assert_eq!(
ModelFormat::from_extension("pt"),
Some(ModelFormat::PyTorch)
);
assert_eq!(ModelFormat::from_extension("h5"), Some(ModelFormat::Keras));
assert_eq!(ModelFormat::from_extension("unknown"), None);
}
#[test]
fn test_format_features() {
assert!(ModelFormat::Onnx.supports_feature(FormatFeature::GraphStructure));
assert!(ModelFormat::PyTorch.supports_feature(FormatFeature::DynamicShapes));
assert!(!ModelFormat::Onnx.supports_feature(FormatFeature::StateDict));
}
#[test]
fn test_compatibility_matrix() {
let compat = FormatCompatibility::new();
assert_eq!(
compat.get_compatibility(ModelFormat::Onnx, ModelFormat::Onnx),
1.0
);
assert!(compat.get_compatibility(ModelFormat::PyTorch, ModelFormat::Onnx) > 0.8);
assert_eq!(
compat.get_compatibility(ModelFormat::Onnx, ModelFormat::Caffe),
0.0
);
}
#[test]
fn test_conversion_path() {
let compat = FormatCompatibility::new();
let path = compat.get_conversion_path(ModelFormat::PyTorch, ModelFormat::Onnx);
assert_eq!(path, vec![ModelFormat::PyTorch, ModelFormat::Onnx]);
let path = compat.get_conversion_path(ModelFormat::PyTorch, ModelFormat::TensorFlowLite);
assert!(path.contains(&ModelFormat::Onnx));
}
#[test]
fn test_optimization_profiles() {
let mobile = OptimizationProfile::mobile();
assert_eq!(mobile.target, DeploymentTarget::MobileCpu);
assert_eq!(mobile.compression, CompressionType::Int8);
assert!(mobile.max_model_size.is_some());
let server = OptimizationProfile::server();
assert_eq!(server.target, DeploymentTarget::ServerGpu);
assert_eq!(server.compression, CompressionType::None);
assert!(server.max_model_size.is_none());
}
#[test]
fn test_format_validation() {
let onnx_data = vec![0x08, 0x01, 0x12, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00];
assert!(FormatValidator::validate_format(
&onnx_data,
ModelFormat::Onnx
));
let pytorch_data = vec![0x80, 0x02, 0x00];
assert!(FormatValidator::validate_format(
&pytorch_data,
ModelFormat::PyTorch
));
let invalid_data = vec![0x00, 0x00];
assert!(!FormatValidator::validate_format(
&invalid_data,
ModelFormat::Onnx
));
}
}