use crate::fx::{FxGraph, Node};
use petgraph::visit::EdgeRef;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use torsh_core::error::{Result, TorshError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelZooEntry {
pub metadata: ModelMetadata,
pub graph: SerializedGraph,
pub weights: ModelWeights,
pub training_config: Option<TrainingConfig>,
pub metrics: ModelMetrics,
pub provenance: ModelProvenance,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub id: String,
pub name: String,
pub version: String,
pub author: String,
pub description: String,
pub license: String,
pub tags: Vec<String>,
pub task: String,
pub input_shapes: Vec<Vec<usize>>,
pub output_shapes: Vec<Vec<usize>>,
pub framework_version: String,
pub created_at: String,
pub updated_at: String,
pub size_bytes: u64,
pub checksum: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedGraph {
pub nodes: Vec<SerializedNode>,
pub edges: Vec<(usize, usize)>,
pub inputs: Vec<usize>,
pub outputs: Vec<usize>,
pub graph_metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedNode {
pub id: usize,
pub node_type: String,
pub name: String,
pub params: HashMap<String, String>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelWeights {
pub format: WeightFormat,
pub data: WeightData,
pub shapes: HashMap<String, Vec<usize>>,
pub dtypes: HashMap<String, String>,
pub total_params: u64,
pub trainable_params: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WeightFormat {
SafeTensors,
Numpy,
PyTorch,
Onnx,
Custom { format_name: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WeightData {
Embedded { data: String },
External { path: String },
Remote { url: String, checksum: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub optimizer: String,
pub learning_rate: f64,
pub batch_size: usize,
pub epochs: usize,
pub loss_function: String,
pub dataset: DatasetInfo,
pub augmentations: Vec<String>,
pub hyperparameters: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetInfo {
pub name: String,
pub split: String,
pub num_samples: usize,
pub num_classes: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetrics {
pub accuracy: Option<f64>,
pub top_k_accuracy: HashMap<usize, f64>,
pub loss: Option<f64>,
pub f1_score: Option<f64>,
pub precision: Option<f64>,
pub recall: Option<f64>,
pub latency_ms: Option<f64>,
pub throughput: Option<f64>,
pub custom_metrics: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelProvenance {
pub base_model: Option<String>,
pub training_dataset: String,
pub training_framework: String,
pub training_hardware: String,
pub training_duration: Option<String>,
pub random_seed: Option<u64>,
pub code_repository: Option<String>,
pub paper_citation: Option<String>,
}
pub struct ModelZooRegistry {
base_path: PathBuf,
model_index: HashMap<String, ModelMetadata>,
remote_repos: Vec<RemoteRepository>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RemoteRepository {
pub name: String,
pub url: String,
pub auth_token: Option<String>,
pub mirrors: Vec<String>,
}
impl ModelZooEntry {
pub fn new(metadata: ModelMetadata, graph: FxGraph, weights: ModelWeights) -> Result<Self> {
let serialized_graph = Self::serialize_graph(&graph)?;
Ok(Self {
metadata,
graph: serialized_graph,
weights,
training_config: None,
metrics: ModelMetrics::default(),
provenance: ModelProvenance::default(),
})
}
fn serialize_graph(graph: &FxGraph) -> Result<SerializedGraph> {
let mut nodes = Vec::new();
let mut edges = Vec::new();
for (idx, node) in graph.nodes() {
let serialized_node = SerializedNode {
id: idx.index(),
node_type: Self::node_type_string(node),
name: format!("node_{}", idx.index()),
params: HashMap::new(),
metadata: HashMap::new(),
};
nodes.push(serialized_node);
}
for edge in graph.edges() {
edges.push((edge.source().index(), edge.target().index()));
}
Ok(SerializedGraph {
nodes,
edges,
inputs: graph.inputs().iter().map(|idx| idx.index()).collect(),
outputs: graph.outputs().iter().map(|idx| idx.index()).collect(),
graph_metadata: HashMap::new(),
})
}
fn node_type_string(node: &Node) -> String {
match node {
Node::Input(name) => format!("input:{}", name),
Node::Output => "output".to_string(),
Node::Call(name, _) => format!("call:{}", name),
Node::GetAttr { target, attr } => format!("getattr:{}:{}", target, attr),
Node::Conditional { .. } => "conditional".to_string(),
Node::Loop { .. } => "loop".to_string(),
Node::Merge { .. } => "merge".to_string(),
}
}
pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let json = serde_json::to_string_pretty(self)
.map_err(|e| TorshError::SerializationError(e.to_string()))?;
fs::write(path, json).map_err(|e| TorshError::IoError(e.to_string()))?;
Ok(())
}
pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let json = fs::read_to_string(path).map_err(|e| TorshError::IoError(e.to_string()))?;
let entry: Self = serde_json::from_str(&json)
.map_err(|e| TorshError::SerializationError(e.to_string()))?;
Ok(entry)
}
pub fn verify_integrity(&self) -> Result<bool> {
let computed_checksum = self.compute_checksum()?;
if computed_checksum != self.metadata.checksum {
return Ok(false);
}
if !self.verify_graph_structure()? {
return Ok(false);
}
Ok(true)
}
fn compute_checksum(&self) -> Result<String> {
let json = serde_json::to_string(self)
.map_err(|e| TorshError::SerializationError(e.to_string()))?;
Ok(format!("{:x}", md5::compute(json.as_bytes())))
}
fn verify_graph_structure(&self) -> Result<bool> {
for (src, dst) in &self.graph.edges {
if *src >= self.graph.nodes.len() || *dst >= self.graph.nodes.len() {
return Ok(false);
}
}
for idx in &self.graph.inputs {
if *idx >= self.graph.nodes.len() {
return Ok(false);
}
}
for idx in &self.graph.outputs {
if *idx >= self.graph.nodes.len() {
return Ok(false);
}
}
Ok(true)
}
}
impl ModelZooRegistry {
pub fn new<P: AsRef<Path>>(base_path: P) -> Result<Self> {
let base_path = base_path.as_ref().to_path_buf();
fs::create_dir_all(&base_path).map_err(|e| TorshError::IoError(e.to_string()))?;
let mut registry = Self {
base_path,
model_index: HashMap::new(),
remote_repos: Vec::new(),
};
registry.refresh_index()?;
Ok(registry)
}
pub fn refresh_index(&mut self) -> Result<()> {
self.model_index.clear();
for entry in
fs::read_dir(&self.base_path).map_err(|e| TorshError::IoError(e.to_string()))?
{
let entry = entry.map_err(|e| TorshError::IoError(e.to_string()))?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
if let Ok(model) = ModelZooEntry::load_from_file(&path) {
self.model_index
.insert(model.metadata.id.clone(), model.metadata);
}
}
}
Ok(())
}
pub fn register_model(&mut self, entry: ModelZooEntry) -> Result<()> {
let filename = format!("{}.json", entry.metadata.id);
let path = self.base_path.join(filename);
entry.save_to_file(&path)?;
self.model_index
.insert(entry.metadata.id.clone(), entry.metadata);
Ok(())
}
pub fn load_model(&self, model_id: &str) -> Result<ModelZooEntry> {
let filename = format!("{}.json", model_id);
let path = self.base_path.join(filename);
ModelZooEntry::load_from_file(path)
}
pub fn search_by_tags(&self, tags: &[String]) -> Vec<&ModelMetadata> {
self.model_index
.values()
.filter(|metadata| tags.iter().any(|tag| metadata.tags.contains(tag)))
.collect()
}
pub fn search_by_task(&self, task: &str) -> Vec<&ModelMetadata> {
self.model_index
.values()
.filter(|metadata| metadata.task == task)
.collect()
}
pub fn list_models(&self) -> Vec<&ModelMetadata> {
self.model_index.values().collect()
}
pub fn add_remote_repository(&mut self, repo: RemoteRepository) {
self.remote_repos.push(repo);
}
pub fn download_from_remote(&mut self, model_id: &str, repo_name: &str) -> Result<()> {
let repository = self
.remote_repos
.iter()
.find(|r| r.name == repo_name)
.ok_or_else(|| {
TorshError::RuntimeError(format!("Repository {} not found", repo_name))
})?;
let download_url = if repository.url.ends_with('/') {
format!("{}{}", repository.url, model_id)
} else {
format!("{}/{}", repository.url, model_id)
};
let temp_dir = std::env::temp_dir();
let model_file = temp_dir.join(format!("{}_{}.torsh", repo_name, model_id));
Err(TorshError::RuntimeError(format!(
"Model download requires manual implementation. \
Download URL: {} \
Save to: {} \
Then load using ModelZooRegistry::load_model()",
download_url,
model_file.display()
)))
}
}
impl Default for ModelMetrics {
fn default() -> Self {
Self {
accuracy: None,
top_k_accuracy: HashMap::new(),
loss: None,
f1_score: None,
precision: None,
recall: None,
latency_ms: None,
throughput: None,
custom_metrics: HashMap::new(),
}
}
}
impl Default for ModelProvenance {
fn default() -> Self {
Self {
base_model: None,
training_dataset: "unknown".to_string(),
training_framework: "torsh-fx".to_string(),
training_hardware: "unknown".to_string(),
training_duration: None,
random_seed: None,
code_repository: None,
paper_citation: None,
}
}
}
pub struct ModelMetadataBuilder {
id: String,
name: String,
version: String,
author: String,
description: String,
license: String,
tags: Vec<String>,
task: String,
input_shapes: Vec<Vec<usize>>,
output_shapes: Vec<Vec<usize>>,
}
impl ModelMetadataBuilder {
pub fn new(id: String, name: String) -> Self {
Self {
id,
name,
version: "1.0.0".to_string(),
author: "unknown".to_string(),
description: String::new(),
license: "MIT".to_string(),
tags: Vec::new(),
task: "general".to_string(),
input_shapes: Vec::new(),
output_shapes: Vec::new(),
}
}
pub fn version(mut self, version: String) -> Self {
self.version = version;
self
}
pub fn author(mut self, author: String) -> Self {
self.author = author;
self
}
pub fn description(mut self, description: String) -> Self {
self.description = description;
self
}
pub fn license(mut self, license: String) -> Self {
self.license = license;
self
}
pub fn add_tag(mut self, tag: String) -> Self {
self.tags.push(tag);
self
}
pub fn task(mut self, task: String) -> Self {
self.task = task;
self
}
pub fn add_input_shape(mut self, shape: Vec<usize>) -> Self {
self.input_shapes.push(shape);
self
}
pub fn add_output_shape(mut self, shape: Vec<usize>) -> Self {
self.output_shapes.push(shape);
self
}
pub fn build(self) -> ModelMetadata {
let now = chrono::Utc::now().to_rfc3339();
ModelMetadata {
id: self.id,
name: self.name,
version: self.version,
author: self.author,
description: self.description,
license: self.license,
tags: self.tags,
task: self.task,
input_shapes: self.input_shapes,
output_shapes: self.output_shapes,
framework_version: env!("CARGO_PKG_VERSION").to_string(),
created_at: now.clone(),
updated_at: now,
size_bytes: 0,
checksum: String::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_metadata_builder() {
let metadata =
ModelMetadataBuilder::new("test-model-001".to_string(), "Test Model".to_string())
.version("1.0.0".to_string())
.author("Test Author".to_string())
.description("A test model".to_string())
.add_tag("test".to_string())
.add_tag("demo".to_string())
.task("classification".to_string())
.add_input_shape(vec![1, 3, 224, 224])
.add_output_shape(vec![1, 1000])
.build();
assert_eq!(metadata.id, "test-model-001");
assert_eq!(metadata.name, "Test Model");
assert_eq!(metadata.version, "1.0.0");
assert_eq!(metadata.tags.len(), 2);
}
#[test]
fn test_model_zoo_entry_creation() {
let metadata =
ModelMetadataBuilder::new("test-001".to_string(), "Test".to_string()).build();
let graph = FxGraph::new();
let weights = ModelWeights {
format: WeightFormat::SafeTensors,
data: WeightData::Embedded {
data: String::new(),
},
shapes: HashMap::new(),
dtypes: HashMap::new(),
total_params: 0,
trainable_params: 0,
};
let result = ModelZooEntry::new(metadata, graph, weights);
assert!(result.is_ok());
}
}