use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct BlobHash(pub String);
impl BlobHash {
pub fn from_hex(hex: &str) -> Self {
Self(hex.to_string())
}
pub fn as_hex(&self) -> &str {
&self.0
}
}
impl fmt::Display for BlobHash {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.0.len() > 16 {
write!(f, "{}...", &self.0[..16])
} else {
write!(f, "{}", self.0)
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BlobToken {
pub hash: BlobHash,
pub size_bytes: u64,
pub metadata: BlobMetadata,
}
impl BlobToken {
pub fn new(hash: BlobHash, size_bytes: u64, metadata: BlobMetadata) -> Self {
Self {
hash,
size_bytes,
metadata,
}
}
pub fn is_small(&self) -> bool {
self.size_bytes < 1024 * 1024
}
pub fn is_large(&self) -> bool {
self.size_bytes > 100 * 1024 * 1024
}
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct BlobMetadata {
pub name: Option<String>,
pub content_type: Option<String>,
pub custom: HashMap<String, String>,
}
impl BlobMetadata {
pub fn with_name(name: impl Into<String>) -> Self {
Self {
name: Some(name.into()),
..Default::default()
}
}
pub fn with_name_and_type(name: impl Into<String>, content_type: impl Into<String>) -> Self {
Self {
name: Some(name.into()),
content_type: Some(content_type.into()),
custom: HashMap::new(),
}
}
pub fn with_custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.custom.insert(key.into(), value.into());
self
}
}
#[derive(Clone, Debug)]
pub enum BlobProgress {
Started {
total_bytes: u64,
},
Downloading {
downloaded_bytes: u64,
total_bytes: u64,
},
Completed {
local_path: PathBuf,
},
Failed {
error: String,
},
}
impl BlobProgress {
pub fn percentage(&self) -> Option<f64> {
match self {
BlobProgress::Started { .. } => Some(0.0),
BlobProgress::Downloading {
downloaded_bytes,
total_bytes,
} => {
if *total_bytes == 0 {
Some(100.0)
} else {
Some(*downloaded_bytes as f64 / *total_bytes as f64 * 100.0)
}
}
BlobProgress::Completed { .. } => Some(100.0),
BlobProgress::Failed { .. } => None,
}
}
pub fn is_complete(&self) -> bool {
matches!(self, BlobProgress::Completed { .. })
}
pub fn is_failed(&self) -> bool {
matches!(self, BlobProgress::Failed { .. })
}
}
#[derive(Debug)]
pub struct BlobHandle {
pub token: BlobToken,
pub path: PathBuf,
}
impl BlobHandle {
pub fn new(token: BlobToken, path: PathBuf) -> Self {
Self { token, path }
}
pub async fn read_to_vec(&self) -> Result<Vec<u8>> {
tokio::fs::read(&self.path)
.await
.map_err(|e| anyhow::anyhow!("Failed to read blob at {:?}: {}", self.path, e))
}
pub async fn open_read_stream(&self) -> Result<tokio::fs::File> {
tokio::fs::File::open(&self.path)
.await
.map_err(|e| anyhow::anyhow!("Failed to open blob stream at {:?}: {}", self.path, e))
}
pub fn size(&self) -> u64 {
self.token.size_bytes
}
}
#[async_trait::async_trait]
pub trait BlobStore: Send + Sync {
async fn create_blob(&self, path: &Path, metadata: BlobMetadata) -> Result<BlobToken>;
async fn create_blob_from_bytes(
&self,
data: &[u8],
metadata: BlobMetadata,
) -> Result<BlobToken>;
async fn fetch_blob<F>(&self, token: &BlobToken, progress: F) -> Result<BlobHandle>
where
F: FnMut(BlobProgress) + Send + 'static;
fn blob_exists_locally(&self, hash: &BlobHash) -> bool;
fn blob_info(&self, hash: &BlobHash) -> Option<BlobToken>;
async fn delete_blob(&self, hash: &BlobHash) -> Result<()>;
fn list_local_blobs(&self) -> Vec<BlobToken>;
async fn create_blob_from_stream(
&self,
stream: &mut (dyn tokio::io::AsyncRead + Send + Unpin),
expected_size: Option<u64>,
metadata: BlobMetadata,
) -> Result<BlobToken> {
use tokio::io::AsyncReadExt;
let mut buf = match expected_size {
Some(size) => Vec::with_capacity(size as usize),
None => Vec::new(),
};
stream
.read_to_end(&mut buf)
.await
.map_err(|e| anyhow::anyhow!("Failed to read stream: {}", e))?;
self.create_blob_from_bytes(&buf, metadata).await
}
fn local_storage_bytes(&self) -> u64;
}
#[async_trait::async_trait]
pub trait BlobStoreExt: BlobStore {
async fn fetch_blob_simple(&self, token: &BlobToken) -> Result<BlobHandle> {
self.fetch_blob(token, |_| {}).await
}
async fn ensure_local(&self, token: &BlobToken) -> Result<PathBuf> {
if self.blob_exists_locally(&token.hash) {
if let Some(info) = self.blob_info(&token.hash) {
let handle = self
.fetch_blob_simple(&BlobToken {
hash: info.hash,
size_bytes: info.size_bytes,
metadata: token.metadata.clone(),
})
.await?;
return Ok(handle.path);
}
}
let handle = self.fetch_blob_simple(token).await?;
Ok(handle.path)
}
fn storage_summary(&self) -> BlobStorageSummary {
let blobs = self.list_local_blobs();
BlobStorageSummary {
blob_count: blobs.len(),
total_bytes: self.local_storage_bytes(),
largest_blob: blobs.iter().map(|t| t.size_bytes).max(),
}
}
}
#[derive(Debug, Clone)]
pub struct BlobStorageSummary {
pub blob_count: usize,
pub total_bytes: u64,
pub largest_blob: Option<u64>,
}
impl<T: BlobStore + ?Sized> BlobStoreExt for T {}
pub type SharedBlobStore = Arc<dyn BlobStore>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_blob_hash_display() {
let hash = BlobHash::from_hex("a7f8b3c4d5e6f7a8b9c0d1e2f3a4b5c6d7e8f9a0");
assert_eq!(format!("{}", hash), "a7f8b3c4d5e6f7a8...");
let short_hash = BlobHash::from_hex("abc123");
assert_eq!(format!("{}", short_hash), "abc123");
}
#[test]
fn test_blob_token_size_classification() {
let small = BlobToken::new(
BlobHash::from_hex("abc"),
500 * 1024, BlobMetadata::default(),
);
assert!(small.is_small());
assert!(!small.is_large());
let large = BlobToken::new(
BlobHash::from_hex("def"),
200 * 1024 * 1024, BlobMetadata::default(),
);
assert!(!large.is_small());
assert!(large.is_large());
}
#[test]
fn test_blob_metadata_builder() {
let meta = BlobMetadata::with_name("model.onnx")
.with_custom("version", "1.0")
.with_custom("precision", "fp16");
assert_eq!(meta.name, Some("model.onnx".to_string()));
assert_eq!(meta.custom.get("version"), Some(&"1.0".to_string()));
assert_eq!(meta.custom.get("precision"), Some(&"fp16".to_string()));
}
#[test]
fn test_blob_progress_percentage() {
let started = BlobProgress::Started { total_bytes: 1000 };
assert_eq!(started.percentage(), Some(0.0));
let downloading = BlobProgress::Downloading {
downloaded_bytes: 500,
total_bytes: 1000,
};
assert_eq!(downloading.percentage(), Some(50.0));
let completed = BlobProgress::Completed {
local_path: PathBuf::from("/tmp/blob"),
};
assert_eq!(completed.percentage(), Some(100.0));
let failed = BlobProgress::Failed {
error: "oops".to_string(),
};
assert_eq!(failed.percentage(), None);
}
#[test]
fn test_blob_token_serialization() {
let token = BlobToken::new(
BlobHash::from_hex("a7f8b3c4d5e6f7a8b9c0d1e2f3a4b5c6d7e8f9a0"),
1024 * 1024,
BlobMetadata::with_name_and_type("model.onnx", "application/onnx"),
);
let json = serde_json::to_string(&token).unwrap();
let parsed: BlobToken = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.hash, token.hash);
assert_eq!(parsed.size_bytes, token.size_bytes);
assert_eq!(parsed.metadata.name, token.metadata.name);
}
#[test]
fn test_blob_hash_as_hex() {
let hash = BlobHash::from_hex("deadbeef");
assert_eq!(hash.as_hex(), "deadbeef");
}
#[test]
fn test_blob_hash_display_short() {
let hash = BlobHash::from_hex("1234567890abcdef");
assert_eq!(format!("{}", hash), "1234567890abcdef");
}
#[test]
fn test_blob_hash_display_long() {
let hash = BlobHash::from_hex("1234567890abcdef0");
assert_eq!(format!("{}", hash), "1234567890abcdef...");
}
#[test]
fn test_blob_hash_equality() {
let h1 = BlobHash::from_hex("abc123");
let h2 = BlobHash::from_hex("abc123");
let h3 = BlobHash::from_hex("def456");
assert_eq!(h1, h2);
assert_ne!(h1, h3);
}
#[test]
fn test_blob_token_medium_size() {
let medium = BlobToken::new(
BlobHash::from_hex("abc"),
50 * 1024 * 1024, BlobMetadata::default(),
);
assert!(!medium.is_small());
assert!(!medium.is_large());
}
#[test]
fn test_blob_token_exact_boundary() {
let exactly_1mb = BlobToken::new(
BlobHash::from_hex("abc"),
1024 * 1024,
BlobMetadata::default(),
);
assert!(!exactly_1mb.is_small());
assert!(!exactly_1mb.is_large());
let exactly_100mb = BlobToken::new(
BlobHash::from_hex("abc"),
100 * 1024 * 1024,
BlobMetadata::default(),
);
assert!(!exactly_100mb.is_small());
assert!(!exactly_100mb.is_large());
}
#[test]
fn test_blob_metadata_default() {
let meta = BlobMetadata::default();
assert!(meta.name.is_none());
assert!(meta.content_type.is_none());
assert!(meta.custom.is_empty());
}
#[test]
fn test_blob_metadata_with_name() {
let meta = BlobMetadata::with_name("test.bin");
assert_eq!(meta.name, Some("test.bin".to_string()));
assert!(meta.content_type.is_none());
}
#[test]
fn test_blob_metadata_with_name_and_type() {
let meta = BlobMetadata::with_name_and_type("test.jpg", "image/jpeg");
assert_eq!(meta.name, Some("test.jpg".to_string()));
assert_eq!(meta.content_type, Some("image/jpeg".to_string()));
assert!(meta.custom.is_empty());
}
#[test]
fn test_blob_metadata_chained_custom_fields() {
let meta = BlobMetadata::with_name("model.onnx")
.with_custom("version", "1.0")
.with_custom("precision", "fp16")
.with_custom("framework", "pytorch");
assert_eq!(meta.custom.len(), 3);
assert_eq!(meta.custom.get("version"), Some(&"1.0".to_string()));
assert_eq!(meta.custom.get("framework"), Some(&"pytorch".to_string()));
}
#[test]
fn test_blob_progress_started_percentage() {
let p = BlobProgress::Started { total_bytes: 5000 };
assert_eq!(p.percentage(), Some(0.0));
assert!(!p.is_complete());
assert!(!p.is_failed());
}
#[test]
fn test_blob_progress_downloading_zero_total() {
let p = BlobProgress::Downloading {
downloaded_bytes: 0,
total_bytes: 0,
};
assert_eq!(p.percentage(), Some(100.0));
}
#[test]
fn test_blob_progress_downloading_partial() {
let p = BlobProgress::Downloading {
downloaded_bytes: 250,
total_bytes: 1000,
};
assert_eq!(p.percentage(), Some(25.0));
assert!(!p.is_complete());
assert!(!p.is_failed());
}
#[test]
fn test_blob_progress_completed() {
let p = BlobProgress::Completed {
local_path: PathBuf::from("/tmp/blob"),
};
assert_eq!(p.percentage(), Some(100.0));
assert!(p.is_complete());
assert!(!p.is_failed());
}
#[test]
fn test_blob_progress_failed() {
let p = BlobProgress::Failed {
error: "network error".to_string(),
};
assert_eq!(p.percentage(), None);
assert!(!p.is_complete());
assert!(p.is_failed());
}
#[test]
fn test_blob_handle_size() {
let token = BlobToken::new(BlobHash::from_hex("abc"), 42000, BlobMetadata::default());
let handle = BlobHandle::new(token, PathBuf::from("/tmp/blob"));
assert_eq!(handle.size(), 42000);
}
#[test]
fn test_blob_storage_summary_debug() {
let summary = BlobStorageSummary {
blob_count: 5,
total_bytes: 1024 * 1024,
largest_blob: Some(500_000),
};
let debug_str = format!("{:?}", summary);
assert!(debug_str.contains("blob_count"));
assert!(debug_str.contains("5"));
}
}