use crate::storage::{StorageBackend, StorageObject};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::SystemTime;
use torsh_core::error::{Result, TorshError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct S3Config {
pub bucket: String,
pub region: String,
pub access_key_id: Option<String>,
pub secret_access_key: Option<String>,
pub endpoint: Option<String>,
pub server_side_encryption: bool,
pub storage_class: String,
pub multipart_threshold: usize,
pub multipart_chunk_size: usize,
}
impl Default for S3Config {
fn default() -> Self {
Self {
bucket: String::new(),
region: "us-east-1".to_string(),
access_key_id: None,
secret_access_key: None,
endpoint: None,
server_side_encryption: true,
storage_class: "STANDARD".to_string(),
multipart_threshold: 50 * 1024 * 1024, multipart_chunk_size: 10 * 1024 * 1024, }
}
}
pub struct MockS3Storage {
config: S3Config,
storage: Arc<Mutex<HashMap<String, Vec<u8>>>>,
metadata: Arc<Mutex<HashMap<String, StorageObject>>>,
}
impl MockS3Storage {
pub fn new(bucket: String) -> Self {
Self {
config: S3Config {
bucket,
..Default::default()
},
storage: Arc::new(Mutex::new(HashMap::new())),
metadata: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn with_config(config: S3Config) -> Self {
Self {
config,
storage: Arc::new(Mutex::new(HashMap::new())),
metadata: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn bucket(&self) -> &str {
&self.config.bucket
}
pub fn region(&self) -> &str {
&self.config.region
}
fn multipart_upload(&mut self, key: &str, data: &[u8]) -> Result<()> {
let chunk_size = self.config.multipart_chunk_size;
let num_parts = (data.len() + chunk_size - 1) / chunk_size;
for i in 0..num_parts {
let start = i * chunk_size;
let end = std::cmp::min(start + chunk_size, data.len());
let _part_data = &data[start..end];
}
self.storage
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.insert(key.to_string(), data.to_vec());
Ok(())
}
}
impl StorageBackend for MockS3Storage {
fn put(&mut self, key: &str, data: &[u8]) -> Result<()> {
if data.len() > self.config.multipart_threshold {
return self.multipart_upload(key, data);
}
self.storage
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.insert(key.to_string(), data.to_vec());
let metadata = StorageObject {
key: key.to_string(),
size: data.len() as u64,
last_modified: SystemTime::now(),
content_type: Some("application/octet-stream".to_string()),
etag: Some(format!("{:x}", md5::compute(data))),
metadata: HashMap::new(),
};
self.metadata
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.insert(key.to_string(), metadata);
Ok(())
}
fn get(&self, key: &str) -> Result<Vec<u8>> {
self.storage
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.get(key)
.cloned()
.ok_or_else(|| TorshError::InvalidArgument(format!("Key not found: {}", key)))
}
fn delete(&mut self, key: &str) -> Result<()> {
self.storage
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.remove(key);
self.metadata
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.remove(key);
Ok(())
}
fn exists(&self, key: &str) -> Result<bool> {
Ok(self
.storage
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.contains_key(key))
}
fn list(&self, prefix: &str) -> Result<Vec<StorageObject>> {
let metadata = self
.metadata
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?;
Ok(metadata
.values()
.filter(|obj| obj.key.starts_with(prefix))
.cloned()
.collect())
}
fn get_metadata(&self, key: &str) -> Result<StorageObject> {
self.metadata
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.get(key)
.cloned()
.ok_or_else(|| TorshError::InvalidArgument(format!("Key not found: {}", key)))
}
fn backend_type(&self) -> &str {
"s3"
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GcsConfig {
pub bucket: String,
pub project_id: String,
pub service_account_key: Option<String>,
pub storage_class: String,
}
impl Default for GcsConfig {
fn default() -> Self {
Self {
bucket: String::new(),
project_id: String::new(),
service_account_key: None,
storage_class: "STANDARD".to_string(),
}
}
}
pub struct MockGcsStorage {
config: GcsConfig,
storage: Arc<Mutex<HashMap<String, Vec<u8>>>>,
metadata: Arc<Mutex<HashMap<String, StorageObject>>>,
}
impl MockGcsStorage {
pub fn new(bucket: String, project_id: String) -> Self {
Self {
config: GcsConfig {
bucket,
project_id,
..Default::default()
},
storage: Arc::new(Mutex::new(HashMap::new())),
metadata: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn with_config(config: GcsConfig) -> Self {
Self {
config,
storage: Arc::new(Mutex::new(HashMap::new())),
metadata: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn bucket(&self) -> &str {
&self.config.bucket
}
pub fn project_id(&self) -> &str {
&self.config.project_id
}
}
impl StorageBackend for MockGcsStorage {
fn put(&mut self, key: &str, data: &[u8]) -> Result<()> {
self.storage
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.insert(key.to_string(), data.to_vec());
let metadata = StorageObject {
key: key.to_string(),
size: data.len() as u64,
last_modified: SystemTime::now(),
content_type: Some("application/octet-stream".to_string()),
etag: Some(format!("{:x}", md5::compute(data))),
metadata: HashMap::new(),
};
self.metadata
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.insert(key.to_string(), metadata);
Ok(())
}
fn get(&self, key: &str) -> Result<Vec<u8>> {
self.storage
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.get(key)
.cloned()
.ok_or_else(|| TorshError::InvalidArgument(format!("Key not found: {}", key)))
}
fn delete(&mut self, key: &str) -> Result<()> {
self.storage
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.remove(key);
self.metadata
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.remove(key);
Ok(())
}
fn exists(&self, key: &str) -> Result<bool> {
Ok(self
.storage
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.contains_key(key))
}
fn list(&self, prefix: &str) -> Result<Vec<StorageObject>> {
let metadata = self
.metadata
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?;
Ok(metadata
.values()
.filter(|obj| obj.key.starts_with(prefix))
.cloned()
.collect())
}
fn get_metadata(&self, key: &str) -> Result<StorageObject> {
self.metadata
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.get(key)
.cloned()
.ok_or_else(|| TorshError::InvalidArgument(format!("Key not found: {}", key)))
}
fn backend_type(&self) -> &str {
"gcs"
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AzureConfig {
pub account_name: String,
pub container: String,
pub access_key: Option<String>,
pub sas_token: Option<String>,
pub access_tier: String,
}
impl Default for AzureConfig {
fn default() -> Self {
Self {
account_name: String::new(),
container: String::new(),
access_key: None,
sas_token: None,
access_tier: "Hot".to_string(),
}
}
}
pub struct MockAzureStorage {
config: AzureConfig,
storage: Arc<Mutex<HashMap<String, Vec<u8>>>>,
metadata: Arc<Mutex<HashMap<String, StorageObject>>>,
}
impl MockAzureStorage {
pub fn new(account_name: String, container: String) -> Self {
Self {
config: AzureConfig {
account_name,
container,
..Default::default()
},
storage: Arc::new(Mutex::new(HashMap::new())),
metadata: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn with_config(config: AzureConfig) -> Self {
Self {
config,
storage: Arc::new(Mutex::new(HashMap::new())),
metadata: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn account_name(&self) -> &str {
&self.config.account_name
}
pub fn container(&self) -> &str {
&self.config.container
}
}
impl StorageBackend for MockAzureStorage {
fn put(&mut self, key: &str, data: &[u8]) -> Result<()> {
self.storage
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.insert(key.to_string(), data.to_vec());
let metadata = StorageObject {
key: key.to_string(),
size: data.len() as u64,
last_modified: SystemTime::now(),
content_type: Some("application/octet-stream".to_string()),
etag: Some(format!("{:x}", md5::compute(data))),
metadata: HashMap::new(),
};
self.metadata
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.insert(key.to_string(), metadata);
Ok(())
}
fn get(&self, key: &str) -> Result<Vec<u8>> {
self.storage
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.get(key)
.cloned()
.ok_or_else(|| TorshError::InvalidArgument(format!("Key not found: {}", key)))
}
fn delete(&mut self, key: &str) -> Result<()> {
self.storage
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.remove(key);
self.metadata
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.remove(key);
Ok(())
}
fn exists(&self, key: &str) -> Result<bool> {
Ok(self
.storage
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.contains_key(key))
}
fn list(&self, prefix: &str) -> Result<Vec<StorageObject>> {
let metadata = self
.metadata
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?;
Ok(metadata
.values()
.filter(|obj| obj.key.starts_with(prefix))
.cloned()
.collect())
}
fn get_metadata(&self, key: &str) -> Result<StorageObject> {
self.metadata
.lock()
.map_err(|e| TorshError::IoError(format!("Lock error: {}", e)))?
.get(key)
.cloned()
.ok_or_else(|| TorshError::InvalidArgument(format!("Key not found: {}", key)))
}
fn backend_type(&self) -> &str {
"azure"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_s3_storage() {
let mut storage = MockS3Storage::new("test-bucket".to_string());
assert_eq!(storage.backend_type(), "s3");
assert_eq!(storage.bucket(), "test-bucket");
let data = b"test data";
storage.put("test/key", data).unwrap();
let retrieved = storage.get("test/key").unwrap();
assert_eq!(retrieved, data);
assert!(storage.exists("test/key").unwrap());
assert!(!storage.exists("nonexistent").unwrap());
let metadata = storage.get_metadata("test/key").unwrap();
assert_eq!(metadata.size, data.len() as u64);
assert!(metadata.etag.is_some());
storage.delete("test/key").unwrap();
assert!(!storage.exists("test/key").unwrap());
}
#[test]
fn test_mock_s3_multipart_upload() {
let mut storage = MockS3Storage::new("test-bucket".to_string());
let large_data = vec![0u8; 60 * 1024 * 1024];
storage.put("test/large", &large_data).unwrap();
let retrieved = storage.get("test/large").unwrap();
assert_eq!(retrieved.len(), large_data.len());
}
#[test]
fn test_mock_s3_list() {
let mut storage = MockS3Storage::new("test-bucket".to_string());
storage.put("models/bert/v1.bin", b"data1").unwrap();
storage.put("models/bert/v2.bin", b"data2").unwrap();
storage.put("models/gpt/v1.bin", b"data3").unwrap();
let bert_models = storage.list("models/bert/").unwrap();
assert_eq!(bert_models.len(), 2);
let all_models = storage.list("models/").unwrap();
assert_eq!(all_models.len(), 3);
}
#[test]
fn test_mock_gcs_storage() {
let mut storage =
MockGcsStorage::new("test-bucket".to_string(), "test-project".to_string());
assert_eq!(storage.backend_type(), "gcs");
assert_eq!(storage.bucket(), "test-bucket");
assert_eq!(storage.project_id(), "test-project");
let data = b"test data";
storage.put("test/key", data).unwrap();
let retrieved = storage.get("test/key").unwrap();
assert_eq!(retrieved, data);
}
#[test]
fn test_mock_azure_storage() {
let mut storage =
MockAzureStorage::new("testaccount".to_string(), "testcontainer".to_string());
assert_eq!(storage.backend_type(), "azure");
assert_eq!(storage.account_name(), "testaccount");
assert_eq!(storage.container(), "testcontainer");
let data = b"test data";
storage.put("test/key", data).unwrap();
let retrieved = storage.get("test/key").unwrap();
assert_eq!(retrieved, data);
}
#[test]
fn test_s3_config_defaults() {
let config = S3Config::default();
assert_eq!(config.region, "us-east-1");
assert_eq!(config.storage_class, "STANDARD");
assert!(config.server_side_encryption);
assert_eq!(config.multipart_threshold, 50 * 1024 * 1024);
}
}