use super::blob_traits::{BlobHash, BlobMetadata, BlobProgress, BlobToken};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BlobReference {
pub hash: String,
pub size_bytes: u64,
pub metadata: BlobReferenceMetadata,
}
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct BlobReferenceMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content_type: Option<String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub custom: HashMap<String, String>,
}
impl From<&BlobToken> for BlobReference {
fn from(token: &BlobToken) -> Self {
Self {
hash: token.hash.as_hex().to_string(),
size_bytes: token.size_bytes,
metadata: BlobReferenceMetadata {
name: token.metadata.name.clone(),
content_type: token.metadata.content_type.clone(),
custom: token.metadata.custom.clone(),
},
}
}
}
impl From<BlobReference> for BlobToken {
fn from(reference: BlobReference) -> Self {
Self {
hash: BlobHash::from_hex(&reference.hash),
size_bytes: reference.size_bytes,
metadata: BlobMetadata {
name: reference.metadata.name,
content_type: reference.metadata.content_type,
custom: reference.metadata.custom,
},
}
}
}
#[async_trait::async_trait]
pub trait BlobDocumentIntegration: Send + Sync {
async fn store_blob_reference(
&self,
collection: &str,
doc_id: &str,
field: &str,
token: &BlobToken,
) -> Result<()>;
async fn get_blob_reference(
&self,
collection: &str,
doc_id: &str,
field: &str,
) -> Result<Option<BlobToken>>;
async fn remove_blob_reference(
&self,
collection: &str,
doc_id: &str,
field: &str,
) -> Result<()>;
async fn list_blob_references(
&self,
collection: &str,
doc_id: &str,
) -> Result<HashMap<String, BlobToken>>;
async fn store_and_fetch<F>(
&self,
collection: &str,
doc_id: &str,
field: &str,
token: &BlobToken,
progress: F,
) -> Result<std::path::PathBuf>
where
F: FnMut(BlobProgress) + Send + 'static;
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ModelVariantBlob {
pub blob: BlobReference,
pub precision: String,
pub execution_providers: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_gpu_memory_gb: Option<f64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ModelRegistryDocument {
pub model_id: String,
pub version: String,
pub variants: HashMap<String, ModelVariantBlob>,
#[serde(skip_serializing_if = "Option::is_none")]
pub provenance: Option<ModelProvenance>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ModelProvenance {
pub signed_by: String,
pub signature: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub signed_at: Option<String>,
}
impl ModelRegistryDocument {
pub fn new(model_id: &str, version: &str) -> Self {
Self {
model_id: model_id.to_string(),
version: version.to_string(),
variants: HashMap::new(),
provenance: None,
description: None,
}
}
pub fn add_variant(
&mut self,
variant_id: &str,
token: &BlobToken,
precision: &str,
execution_providers: Vec<String>,
min_gpu_memory_gb: Option<f64>,
) {
self.variants.insert(
variant_id.to_string(),
ModelVariantBlob {
blob: BlobReference::from(token),
precision: precision.to_string(),
execution_providers,
min_gpu_memory_gb,
},
);
}
pub fn doc_id(&self) -> String {
format!("{}:{}", self.model_id, self.version)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_blob_reference_serialization() {
let token = BlobToken {
hash: BlobHash::from_hex("abc123def456"),
size_bytes: 1024 * 1024,
metadata: BlobMetadata {
name: Some("test.onnx".to_string()),
content_type: Some("application/onnx".to_string()),
custom: HashMap::new(),
},
};
let reference = BlobReference::from(&token);
let json = serde_json::to_string_pretty(&reference).unwrap();
assert!(json.contains("abc123def456"));
assert!(json.contains("1048576"));
assert!(json.contains("test.onnx"));
let deserialized: BlobReference = serde_json::from_str(&json).unwrap();
let token_back = BlobToken::from(deserialized);
assert_eq!(token_back.hash.as_hex(), token.hash.as_hex());
assert_eq!(token_back.size_bytes, token.size_bytes);
assert_eq!(token_back.metadata.name, token.metadata.name);
}
#[test]
fn test_model_registry_document() {
let token = BlobToken {
hash: BlobHash::from_hex("sha256:abc123"),
size_bytes: 500_000_000,
metadata: BlobMetadata::with_name("target_recognition_fp32.onnx"),
};
let mut doc = ModelRegistryDocument::new("target_recognition", "4.2.1");
doc.add_variant(
"fp32_cuda",
&token,
"float32",
vec!["CUDAExecutionProvider".to_string()],
Some(4.0),
);
assert_eq!(doc.doc_id(), "target_recognition:4.2.1");
assert!(doc.variants.contains_key("fp32_cuda"));
let json = serde_json::to_string_pretty(&doc).unwrap();
assert!(json.contains("target_recognition"));
assert!(json.contains("CUDAExecutionProvider"));
}
#[test]
fn test_blob_reference_with_custom_metadata() {
let mut custom = HashMap::new();
custom.insert("training_date".to_string(), "2025-01-15".to_string());
custom.insert("accuracy".to_string(), "0.95".to_string());
let token = BlobToken {
hash: BlobHash::from_hex("deadbeef"),
size_bytes: 100,
metadata: BlobMetadata {
name: Some("model.onnx".to_string()),
content_type: None,
custom,
},
};
let reference = BlobReference::from(&token);
let json = serde_json::to_string(&reference).unwrap();
assert!(json.contains("training_date"));
assert!(json.contains("accuracy"));
let deserialized: BlobReference = serde_json::from_str(&json).unwrap();
assert_eq!(
deserialized.metadata.custom.get("training_date"),
Some(&"2025-01-15".to_string())
);
}
}