use anyhow::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use tokio::fs;
use uuid::Uuid;
#[async_trait]
pub trait ModelStorage: Send + Sync {
async fn store_artifacts(&self, artifacts: &[Artifact]) -> Result<Vec<Uuid>>;
async fn get_artifact(&self, artifact_id: Uuid) -> Result<Option<Artifact>>;
async fn delete_artifacts(&self, artifact_ids: &[Uuid]) -> Result<()>;
async fn archive_version(&self, version_id: Uuid) -> Result<()>;
async fn delete_version(&self, version_id: Uuid) -> Result<()>;
async fn list_artifacts(&self, version_id: Uuid) -> Result<Vec<Artifact>>;
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ArtifactType {
Model,
Config,
Tokenizer,
Vocabulary,
Checkpoint,
OptimizerState,
Architecture,
Preprocessing,
Metrics,
Documentation,
Custom(String),
}
impl ArtifactType {
pub fn default_extension(&self) -> &'static str {
match self {
ArtifactType::Model => "bin",
ArtifactType::Config => "json",
ArtifactType::Tokenizer => "json",
ArtifactType::Vocabulary => "txt",
ArtifactType::Checkpoint => "ckpt",
ArtifactType::OptimizerState => "bin",
ArtifactType::Architecture => "json",
ArtifactType::Preprocessing => "json",
ArtifactType::Metrics => "json",
ArtifactType::Documentation => "md",
ArtifactType::Custom(_) => "bin",
}
}
pub fn is_required_for_deployment(&self) -> bool {
matches!(self, ArtifactType::Model | ArtifactType::Config)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Artifact {
pub id: Uuid,
pub artifact_type: ArtifactType,
pub file_path: PathBuf,
pub size_bytes: u64,
pub content_hash: String,
pub mime_type: String,
pub content: Vec<u8>,
pub created_at: DateTime<Utc>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl Artifact {
pub fn new(artifact_type: ArtifactType, file_path: PathBuf, content: Vec<u8>) -> Self {
let content_hash = Self::compute_hash(&content);
let mime_type = Self::detect_mime_type(&file_path, &artifact_type);
Self {
id: Uuid::new_v4(),
artifact_type,
size_bytes: content.len() as u64,
content_hash,
mime_type,
content,
file_path,
created_at: Utc::now(),
metadata: HashMap::new(),
}
}
pub async fn from_file(artifact_type: ArtifactType, file_path: PathBuf) -> Result<Self> {
let content = fs::read(&file_path).await?;
Ok(Self::new(artifact_type, file_path, content))
}
pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
self.metadata.insert(key, value);
self
}
fn compute_hash(content: &[u8]) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(content);
format!("{:x}", hasher.finalize())
}
fn detect_mime_type(file_path: &Path, artifact_type: &ArtifactType) -> String {
if let Some(extension) = file_path.extension().and_then(|s| s.to_str()) {
match extension.to_lowercase().as_str() {
"json" => "application/json".to_string(),
"bin" | "pt" | "pth" => "application/octet-stream".to_string(),
"txt" => "text/plain".to_string(),
"md" => "text/markdown".to_string(),
"yaml" | "yml" => "application/x-yaml".to_string(),
_ => "application/octet-stream".to_string(),
}
} else {
match artifact_type {
ArtifactType::Config
| ArtifactType::Tokenizer
| ArtifactType::Architecture
| ArtifactType::Preprocessing
| ArtifactType::Metrics => "application/json".to_string(),
ArtifactType::Documentation => "text/markdown".to_string(),
ArtifactType::Vocabulary => "text/plain".to_string(),
_ => "application/octet-stream".to_string(),
}
}
}
pub fn verify_integrity(&self) -> bool {
Self::compute_hash(&self.content) == self.content_hash
}
pub fn file_extension(&self) -> Option<&str> {
self.file_path.extension()?.to_str()
}
}
pub struct FileSystemStorage {
base_path: PathBuf,
archive_path: PathBuf,
metadata_cache: tokio::sync::RwLock<HashMap<Uuid, Artifact>>,
}
impl FileSystemStorage {
pub fn new(base_path: PathBuf) -> Self {
let archive_path = base_path.join("archive");
Self {
base_path,
archive_path,
metadata_cache: tokio::sync::RwLock::new(HashMap::new()),
}
}
pub async fn initialize(&self) -> Result<()> {
fs::create_dir_all(&self.base_path).await?;
fs::create_dir_all(&self.archive_path).await?;
Ok(())
}
fn get_artifact_path(&self, artifact_id: Uuid) -> PathBuf {
let id_str = artifact_id.to_string();
let prefix = &id_str[0..2];
self.base_path.join("artifacts").join(prefix).join(&id_str)
}
fn get_archive_path(&self, artifact_id: Uuid) -> PathBuf {
let id_str = artifact_id.to_string();
let prefix = &id_str[0..2];
self.archive_path.join("artifacts").join(prefix).join(&id_str)
}
async fn store_metadata(&self, artifact: &Artifact) -> Result<()> {
let metadata_path = self.get_artifact_path(artifact.id).with_extension("meta");
if let Some(parent) = metadata_path.parent() {
fs::create_dir_all(parent).await?;
}
let metadata_json = serde_json::to_string_pretty(artifact)?;
fs::write(metadata_path, metadata_json).await?;
self.metadata_cache.write().await.insert(artifact.id, artifact.clone());
Ok(())
}
async fn load_metadata(&self, artifact_id: Uuid) -> Result<Option<Artifact>> {
if let Some(artifact) = self.metadata_cache.read().await.get(&artifact_id) {
return Ok(Some(artifact.clone()));
}
let metadata_path = self.get_artifact_path(artifact_id).with_extension("meta");
if !metadata_path.exists() {
return Ok(None);
}
let metadata_json = fs::read_to_string(metadata_path).await?;
let mut artifact: Artifact = serde_json::from_str(&metadata_json)?;
let content_path = self.get_artifact_path(artifact_id).with_extension("bin");
if content_path.exists() {
artifact.content = fs::read(content_path).await?;
}
self.metadata_cache.write().await.insert(artifact_id, artifact.clone());
Ok(Some(artifact))
}
}
#[async_trait]
impl ModelStorage for FileSystemStorage {
async fn store_artifacts(&self, artifacts: &[Artifact]) -> Result<Vec<Uuid>> {
let mut artifact_ids = Vec::new();
for artifact in artifacts {
let content_path = self.get_artifact_path(artifact.id).with_extension("bin");
if let Some(parent) = content_path.parent() {
fs::create_dir_all(parent).await?;
}
fs::write(&content_path, &artifact.content).await?;
self.store_metadata(artifact).await?;
artifact_ids.push(artifact.id);
tracing::debug!("Stored artifact {} at {:?}", artifact.id, content_path);
}
Ok(artifact_ids)
}
async fn get_artifact(&self, artifact_id: Uuid) -> Result<Option<Artifact>> {
self.load_metadata(artifact_id).await
}
async fn delete_artifacts(&self, artifact_ids: &[Uuid]) -> Result<()> {
for &artifact_id in artifact_ids {
let content_path = self.get_artifact_path(artifact_id).with_extension("bin");
let metadata_path = self.get_artifact_path(artifact_id).with_extension("meta");
if content_path.exists() {
fs::remove_file(content_path).await?;
}
if metadata_path.exists() {
fs::remove_file(metadata_path).await?;
}
self.metadata_cache.write().await.remove(&artifact_id);
tracing::debug!("Deleted artifact {}", artifact_id);
}
Ok(())
}
async fn archive_version(&self, version_id: Uuid) -> Result<()> {
let artifacts = self.list_artifacts(version_id).await?;
for artifact in artifacts {
let src_content = self.get_artifact_path(artifact.id).with_extension("bin");
let src_metadata = self.get_artifact_path(artifact.id).with_extension("meta");
let dst_content = self.get_archive_path(artifact.id).with_extension("bin");
let dst_metadata = self.get_archive_path(artifact.id).with_extension("meta");
if let Some(parent) = dst_content.parent() {
fs::create_dir_all(parent).await?;
}
if src_content.exists() {
fs::rename(src_content, dst_content).await?;
}
if src_metadata.exists() {
fs::rename(src_metadata, dst_metadata).await?;
}
self.metadata_cache.write().await.remove(&artifact.id);
}
tracing::info!("Archived version {}", version_id);
Ok(())
}
async fn delete_version(&self, version_id: Uuid) -> Result<()> {
let artifacts = self.list_artifacts(version_id).await?;
let artifact_ids: Vec<Uuid> = artifacts.iter().map(|a| a.id).collect();
self.delete_artifacts(&artifact_ids).await?;
tracing::info!("Deleted version {}", version_id);
Ok(())
}
async fn list_artifacts(&self, _version_id: Uuid) -> Result<Vec<Artifact>> {
let cache = self.metadata_cache.read().await;
Ok(cache.values().cloned().collect())
}
}
pub struct InMemoryStorage {
artifacts: tokio::sync::RwLock<HashMap<Uuid, Artifact>>,
archived: tokio::sync::RwLock<HashMap<Uuid, Artifact>>,
}
impl InMemoryStorage {
pub fn new() -> Self {
Self {
artifacts: tokio::sync::RwLock::new(HashMap::new()),
archived: tokio::sync::RwLock::new(HashMap::new()),
}
}
}
#[async_trait]
impl ModelStorage for InMemoryStorage {
async fn store_artifacts(&self, artifacts: &[Artifact]) -> Result<Vec<Uuid>> {
let mut artifact_ids = Vec::new();
let mut storage = self.artifacts.write().await;
for artifact in artifacts {
storage.insert(artifact.id, artifact.clone());
artifact_ids.push(artifact.id);
}
Ok(artifact_ids)
}
async fn get_artifact(&self, artifact_id: Uuid) -> Result<Option<Artifact>> {
let storage = self.artifacts.read().await;
Ok(storage.get(&artifact_id).cloned())
}
async fn delete_artifacts(&self, artifact_ids: &[Uuid]) -> Result<()> {
let mut storage = self.artifacts.write().await;
for &artifact_id in artifact_ids {
storage.remove(&artifact_id);
}
Ok(())
}
async fn archive_version(&self, version_id: Uuid) -> Result<()> {
let artifacts = self.list_artifacts(version_id).await?;
let mut storage = self.artifacts.write().await;
let mut archived = self.archived.write().await;
for artifact in artifacts {
if let Some(artifact) = storage.remove(&artifact.id) {
archived.insert(artifact.id, artifact);
}
}
Ok(())
}
async fn delete_version(&self, version_id: Uuid) -> Result<()> {
let artifacts = self.list_artifacts(version_id).await?;
let artifact_ids: Vec<Uuid> = artifacts.iter().map(|a| a.id).collect();
self.delete_artifacts(&artifact_ids).await
}
async fn list_artifacts(&self, _version_id: Uuid) -> Result<Vec<Artifact>> {
let storage = self.artifacts.read().await;
Ok(storage.values().cloned().collect())
}
}
impl Default for InMemoryStorage {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_artifact_creation() {
let content = b"test model data".to_vec();
let artifact = Artifact::new(
ArtifactType::Model,
PathBuf::from("model.bin"),
content.clone(),
);
assert_eq!(artifact.artifact_type, ArtifactType::Model);
assert_eq!(artifact.content, content);
assert_eq!(artifact.size_bytes, content.len() as u64);
assert!(!artifact.content_hash.is_empty());
assert!(artifact.verify_integrity());
}
#[tokio::test]
async fn test_filesystem_storage() {
let temp_dir = TempDir::new().expect("temp file creation failed");
let storage = FileSystemStorage::new(temp_dir.path().to_path_buf());
storage.initialize().await.expect("async operation failed");
let artifact = Artifact::new(
ArtifactType::Model,
PathBuf::from("test_model.bin"),
b"test content".to_vec(),
);
let ids = storage
.store_artifacts(std::slice::from_ref(&artifact))
.await
.expect("async operation failed");
assert_eq!(ids.len(), 1);
assert_eq!(ids[0], artifact.id);
let retrieved = storage.get_artifact(artifact.id).await.expect("async operation failed");
assert!(retrieved.is_some());
let retrieved = retrieved.expect("operation failed in test");
assert_eq!(retrieved.content, artifact.content);
assert_eq!(retrieved.content_hash, artifact.content_hash);
storage.delete_artifacts(&[artifact.id]).await.expect("async operation failed");
let deleted = storage.get_artifact(artifact.id).await.expect("async operation failed");
assert!(deleted.is_none());
}
#[tokio::test]
async fn test_inmemory_storage() {
let storage = InMemoryStorage::new();
let artifact = Artifact::new(
ArtifactType::Config,
PathBuf::from("config.json"),
b"{}".to_vec(),
);
let ids = storage
.store_artifacts(std::slice::from_ref(&artifact))
.await
.expect("async operation failed");
assert_eq!(ids[0], artifact.id);
let retrieved = storage.get_artifact(artifact.id).await.expect("async operation failed");
assert!(retrieved.is_some());
assert_eq!(
retrieved.expect("operation failed in test").content,
artifact.content
);
}
#[test]
fn test_artifact_types() {
assert_eq!(ArtifactType::Model.default_extension(), "bin");
assert_eq!(ArtifactType::Config.default_extension(), "json");
assert!(ArtifactType::Model.is_required_for_deployment());
assert!(ArtifactType::Config.is_required_for_deployment());
assert!(!ArtifactType::Documentation.is_required_for_deployment());
}
#[test]
fn test_mime_type_detection() {
let json_artifact = Artifact::new(
ArtifactType::Config,
PathBuf::from("config.json"),
b"{}".to_vec(),
);
assert_eq!(json_artifact.mime_type, "application/json");
let bin_artifact = Artifact::new(
ArtifactType::Model,
PathBuf::from("model.bin"),
b"binary data".to_vec(),
);
assert_eq!(bin_artifact.mime_type, "application/octet-stream");
}
}