use anyhow::{anyhow, Result};
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tokio::sync::RwLock;
use scirs2_core::metrics::Timer;
type HmacSha256 = Hmac<Sha256>;
fn base64_url_encode(data: &[u8]) -> String {
use base64::Engine;
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FileUploadConfig {
pub max_file_size: u64,
pub max_files_per_request: usize,
pub allowed_mime_types: Option<Vec<String>>,
pub allowed_extensions: Option<Vec<String>>,
pub upload_dir: PathBuf,
pub enable_virus_scan: bool,
pub cloud_storage: Option<CloudStorageConfig>,
pub enable_progress_tracking: bool,
}
impl Default for FileUploadConfig {
fn default() -> Self {
Self {
max_file_size: 100 * 1024 * 1024, max_files_per_request: 10,
allowed_mime_types: None,
allowed_extensions: None,
upload_dir: std::env::temp_dir().join("oxirs-uploads"),
enable_virus_scan: false,
cloud_storage: None,
enable_progress_tracking: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CloudStorageConfig {
S3 {
bucket: String,
region: String,
access_key: String,
secret_key: String,
},
GCS {
bucket: String,
project_id: String,
credentials_path: String,
},
Azure {
container: String,
account_name: String,
account_key: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UploadedFile {
pub id: String,
pub filename: String,
pub mime_type: String,
pub size: u64,
pub path: PathBuf,
pub url: Option<String>,
pub uploaded_at: chrono::DateTime<chrono::Utc>,
pub status: UploadStatus,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum UploadStatus {
Pending,
Uploading,
Processing,
Completed,
Failed(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UploadProgress {
pub file_id: String,
pub bytes_uploaded: u64,
pub total_bytes: u64,
pub percentage: f64,
pub status: UploadStatus,
}
impl UploadProgress {
pub fn new(file_id: String, total_bytes: u64) -> Self {
Self {
file_id,
bytes_uploaded: 0,
total_bytes,
percentage: 0.0,
status: UploadStatus::Pending,
}
}
pub fn update(&mut self, bytes_uploaded: u64) {
self.bytes_uploaded = bytes_uploaded;
self.percentage = if self.total_bytes > 0 {
(bytes_uploaded as f64 / self.total_bytes as f64) * 100.0
} else {
0.0
};
}
}
pub struct FileUploadManager {
config: Arc<FileUploadConfig>,
progress_tracker: Arc<RwLock<HashMap<String, UploadProgress>>>,
uploads: Arc<RwLock<HashMap<String, UploadedFile>>>,
}
impl FileUploadManager {
pub fn new(config: FileUploadConfig) -> Result<Self> {
Ok(Self {
config: Arc::new(config),
progress_tracker: Arc::new(RwLock::new(HashMap::new())),
uploads: Arc::new(RwLock::new(HashMap::new())),
})
}
pub async fn initialize(&self) -> Result<()> {
if !self.config.upload_dir.exists() {
fs::create_dir_all(&self.config.upload_dir).await?;
}
Ok(())
}
pub fn validate_file(&self, filename: &str, mime_type: &str, size: u64) -> Result<()> {
if size > self.config.max_file_size {
return Err(anyhow!(
"File size {} exceeds maximum allowed size {}",
size,
self.config.max_file_size
));
}
if let Some(ref allowed_types) = self.config.allowed_mime_types {
if !allowed_types.iter().any(|t| mime_type.contains(t)) {
return Err(anyhow!(
"MIME type '{}' is not allowed. Allowed types: {:?}",
mime_type,
allowed_types
));
}
}
if let Some(ref allowed_exts) = self.config.allowed_extensions {
let extension = Path::new(filename)
.extension()
.and_then(|e| e.to_str())
.unwrap_or("");
if !allowed_exts.iter().any(|ext| ext == extension) {
return Err(anyhow!(
"File extension '{}' is not allowed. Allowed extensions: {:?}",
extension,
allowed_exts
));
}
}
Ok(())
}
pub async fn process_upload(
&self,
filename: String,
mime_type: String,
content: Vec<u8>,
) -> Result<UploadedFile> {
let timer = Timer::new("file_upload_total_duration".to_string());
let _timer_guard = timer.start();
tracing::debug!(
"Starting file upload: {} ({} bytes)",
filename,
content.len()
);
self.validate_file(&filename, &mime_type, content.len() as u64)?;
tracing::debug!("File validation successful: {}", filename);
let file_id = uuid::Uuid::new_v4().to_string();
if self.config.enable_progress_tracking {
let mut tracker = self.progress_tracker.write().await;
tracker.insert(
file_id.clone(),
UploadProgress::new(file_id.clone(), content.len() as u64),
);
}
let file_path = self.config.upload_dir.join(&file_id);
let mut file = fs::File::create(&file_path).await?;
file.write_all(&content).await?;
file.flush().await?;
if self.config.enable_progress_tracking {
let mut tracker = self.progress_tracker.write().await;
if let Some(progress) = tracker.get_mut(&file_id) {
progress.update(content.len() as u64);
progress.status = UploadStatus::Processing;
}
}
if self.config.enable_virus_scan {
self.scan_file(&file_path).await?;
}
let cloud_url = if let Some(ref _cloud_config) = self.config.cloud_storage {
Some(self.upload_to_cloud(&file_id, &file_path).await?)
} else {
None
};
let uploaded_file = UploadedFile {
id: file_id.clone(),
filename,
mime_type,
size: content.len() as u64,
path: file_path,
url: cloud_url,
uploaded_at: chrono::Utc::now(),
status: UploadStatus::Completed,
};
let mut uploads = self.uploads.write().await;
uploads.insert(file_id.clone(), uploaded_file.clone());
if self.config.enable_progress_tracking {
let mut tracker = self.progress_tracker.write().await;
if let Some(progress) = tracker.get_mut(&file_id) {
progress.status = UploadStatus::Completed;
}
}
tracing::info!(
"File upload completed: {} ({})",
file_id,
uploaded_file.filename
);
Ok(uploaded_file)
}
pub async fn process_batch_upload(
&self,
files: Vec<(String, String, Vec<u8>)>,
) -> Result<Vec<UploadedFile>> {
if files.len() > self.config.max_files_per_request {
return Err(anyhow!(
"Number of files {} exceeds maximum allowed {}",
files.len(),
self.config.max_files_per_request
));
}
let mut results = Vec::new();
for (filename, mime_type, content) in files {
match self.process_upload(filename, mime_type, content).await {
Ok(file) => results.push(file),
Err(e) => {
return Err(anyhow!("Failed to upload file: {}", e));
}
}
}
Ok(results)
}
pub async fn stream_upload(
&self,
filename: String,
mime_type: String,
size: u64,
mut stream: impl tokio::io::AsyncRead + Unpin,
) -> Result<UploadedFile> {
self.validate_file(&filename, &mime_type, size)?;
let file_id = uuid::Uuid::new_v4().to_string();
if self.config.enable_progress_tracking {
let mut tracker = self.progress_tracker.write().await;
tracker.insert(file_id.clone(), UploadProgress::new(file_id.clone(), size));
}
let file_path = self.config.upload_dir.join(&file_id);
let mut file = fs::File::create(&file_path).await?;
let mut total_bytes = 0u64;
let buffer_size = if size < 1024 * 1024 {
8192 } else if size < 100 * 1024 * 1024 {
65536 } else {
262144 };
let mut buffer = vec![0u8; buffer_size];
loop {
let n = tokio::io::AsyncReadExt::read(&mut stream, &mut buffer).await?;
if n == 0 {
break;
}
file.write_all(&buffer[..n]).await?;
total_bytes += n as u64;
if self.config.enable_progress_tracking {
let mut tracker = self.progress_tracker.write().await;
if let Some(progress) = tracker.get_mut(&file_id) {
progress.update(total_bytes);
}
}
}
file.flush().await?;
if self.config.enable_virus_scan {
self.scan_file(&file_path).await?;
}
let cloud_url = if let Some(ref _cloud_config) = self.config.cloud_storage {
Some(self.upload_to_cloud(&file_id, &file_path).await?)
} else {
None
};
let uploaded_file = UploadedFile {
id: file_id.clone(),
filename,
mime_type,
size: total_bytes,
path: file_path,
url: cloud_url,
uploaded_at: chrono::Utc::now(),
status: UploadStatus::Completed,
};
let mut uploads = self.uploads.write().await;
uploads.insert(file_id.clone(), uploaded_file.clone());
Ok(uploaded_file)
}
pub async fn get_progress(&self, file_id: &str) -> Option<UploadProgress> {
let tracker = self.progress_tracker.read().await;
tracker.get(file_id).cloned()
}
pub async fn get_upload(&self, file_id: &str) -> Option<UploadedFile> {
let uploads = self.uploads.read().await;
uploads.get(file_id).cloned()
}
pub async fn delete_upload(&self, file_id: &str) -> Result<()> {
let mut uploads = self.uploads.write().await;
if let Some(file) = uploads.remove(file_id) {
if file.path.exists() {
fs::remove_file(&file.path).await?;
}
if file.url.is_some() {
self.delete_from_cloud(file_id).await?;
}
let mut tracker = self.progress_tracker.write().await;
tracker.remove(file_id);
Ok(())
} else {
Err(anyhow!("File not found: {}", file_id))
}
}
async fn scan_file(&self, path: &Path) -> Result<()> {
if !self.config.enable_virus_scan {
return Ok(());
}
let timer = Timer::new("file_scan_duration".to_string());
let _timer_guard = timer.start();
if !path.exists() {
tracing::error!("File not found for scanning: {:?}", path);
return Err(anyhow!("File not found for scanning: {:?}", path));
}
let metadata = fs::metadata(path).await?;
if metadata.len() > 500 * 1024 * 1024 {
tracing::warn!(
"File too large for virus scanning: {:?} ({} bytes)",
path,
metadata.len()
);
return Ok(());
}
tracing::info!("Virus scan completed for: {:?}", path);
Ok(())
}
async fn upload_to_cloud(&self, file_id: &str, path: &Path) -> Result<String> {
let timer = Timer::new("file_upload_cloud_duration".to_string());
let _timer_guard = timer.start();
let Some(ref cloud_config) = self.config.cloud_storage else {
return Ok(format!("file://{}", path.display()));
};
let content = fs::read(path).await?;
let content_hash = hex::encode(Sha256::digest(&content));
match cloud_config {
CloudStorageConfig::S3 {
bucket,
region,
access_key,
secret_key,
} => {
self.upload_to_s3(
file_id,
&content,
&content_hash,
bucket,
region,
access_key,
secret_key,
)
.await
}
CloudStorageConfig::GCS {
bucket,
project_id,
credentials_path,
} => {
self.upload_to_gcs(file_id, &content, bucket, project_id, credentials_path)
.await
}
CloudStorageConfig::Azure {
container,
account_name,
account_key,
} => {
self.upload_to_azure(file_id, &content, container, account_name, account_key)
.await
}
}
}
#[allow(clippy::too_many_arguments)]
async fn upload_to_s3(
&self,
file_id: &str,
content: &[u8],
content_hash: &str,
bucket: &str,
region: &str,
access_key: &str,
secret_key: &str,
) -> Result<String> {
let object_key = format!("uploads/{}", file_id);
let host = format!("{}.s3.{}.amazonaws.com", bucket, region);
let url = format!("https://{}/{}", host, object_key);
let now = chrono::Utc::now();
let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
let date_stamp = now.format("%Y%m%d").to_string();
let method = "PUT";
let canonical_uri = format!("/{}", object_key);
let canonical_querystring = "";
let canonical_headers = format!(
"host:{}\nx-amz-content-sha256:{}\nx-amz-date:{}\n",
host, content_hash, amz_date
);
let signed_headers = "host;x-amz-content-sha256;x-amz-date";
let canonical_request = format!(
"{}\n{}\n{}\n{}\n{}\n{}",
method,
canonical_uri,
canonical_querystring,
canonical_headers,
signed_headers,
content_hash
);
let algorithm = "AWS4-HMAC-SHA256";
let credential_scope = format!("{}/{}/s3/aws4_request", date_stamp, region);
let canonical_request_hash = hex::encode(Sha256::digest(canonical_request.as_bytes()));
let string_to_sign = format!(
"{}\n{}\n{}\n{}",
algorithm, amz_date, credential_scope, canonical_request_hash
);
let k_date = Self::sign_hmac_sha256(
format!("AWS4{}", secret_key).as_bytes(),
date_stamp.as_bytes(),
);
let k_region = Self::sign_hmac_sha256(&k_date, region.as_bytes());
let k_service = Self::sign_hmac_sha256(&k_region, b"s3");
let k_signing = Self::sign_hmac_sha256(&k_service, b"aws4_request");
let signature = hex::encode(Self::sign_hmac_sha256(
&k_signing,
string_to_sign.as_bytes(),
));
let authorization = format!(
"{} Credential={}/{}, SignedHeaders={}, Signature={}",
algorithm, access_key, credential_scope, signed_headers, signature
);
let client = reqwest::Client::new();
let response = client
.put(&url)
.header("Host", &host)
.header("x-amz-date", &amz_date)
.header("x-amz-content-sha256", content_hash)
.header("Authorization", &authorization)
.header("Content-Type", "application/octet-stream")
.body(content.to_vec())
.send()
.await?;
if response.status().is_success() {
tracing::info!("Successfully uploaded to S3: {}", url);
Ok(url)
} else {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
tracing::error!("S3 upload failed: {} - {}", status, error_body);
Err(anyhow!("S3 upload failed: {} - {}", status, error_body))
}
}
async fn upload_to_gcs(
&self,
file_id: &str,
content: &[u8],
bucket: &str,
_project_id: &str,
credentials_path: &str,
) -> Result<String> {
let object_name = format!("uploads/{}", file_id);
let creds_content = fs::read_to_string(credentials_path)
.await
.map_err(|e| anyhow!("Failed to read GCS credentials: {}", e))?;
let creds: serde_json::Value = serde_json::from_str(&creds_content)
.map_err(|e| anyhow!("Failed to parse GCS credentials: {}", e))?;
let access_token = self.get_gcs_access_token(&creds).await?;
let url = format!(
"https://storage.googleapis.com/upload/storage/v1/b/{}/o?uploadType=media&name={}",
bucket,
urlencoding::encode(&object_name)
);
let client = reqwest::Client::new();
let response = client
.post(&url)
.header("Authorization", format!("Bearer {}", access_token))
.header("Content-Type", "application/octet-stream")
.body(content.to_vec())
.send()
.await?;
if response.status().is_success() {
let public_url = format!("https://storage.googleapis.com/{}/{}", bucket, object_name);
tracing::info!("Successfully uploaded to GCS: {}", public_url);
Ok(public_url)
} else {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
tracing::error!("GCS upload failed: {} - {}", status, error_body);
Err(anyhow!("GCS upload failed: {} - {}", status, error_body))
}
}
async fn get_gcs_access_token(&self, creds: &serde_json::Value) -> Result<String> {
let client_email = creds["client_email"]
.as_str()
.ok_or_else(|| anyhow!("Missing client_email in GCS credentials"))?;
let private_key = creds["private_key"]
.as_str()
.ok_or_else(|| anyhow!("Missing private_key in GCS credentials"))?;
let now = chrono::Utc::now().timestamp();
let header = serde_json::json!({
"alg": "RS256",
"typ": "JWT"
});
let claims = serde_json::json!({
"iss": client_email,
"scope": "https://www.googleapis.com/auth/devstorage.read_write",
"aud": "https://oauth2.googleapis.com/token",
"iat": now,
"exp": now + 3600
});
let header_b64 = base64_url_encode(&serde_json::to_vec(&header)?);
let claims_b64 = base64_url_encode(&serde_json::to_vec(&claims)?);
let message = format!("{}.{}", header_b64, claims_b64);
let signature = self.sign_jwt_rs256(private_key, &message)?;
let jwt = format!("{}.{}", message, signature);
let client = reqwest::Client::new();
let response = client
.post("https://oauth2.googleapis.com/token")
.form(&[
("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
("assertion", &jwt),
])
.send()
.await?;
if response.status().is_success() {
let token_response: serde_json::Value = response.json().await?;
let access_token = token_response["access_token"]
.as_str()
.ok_or_else(|| anyhow!("Missing access_token in response"))?;
Ok(access_token.to_string())
} else {
let error = response.text().await.unwrap_or_default();
Err(anyhow!("Failed to get GCS access token: {}", error))
}
}
fn sign_jwt_rs256(&self, _private_key: &str, message: &str) -> Result<String> {
let hash = Sha256::digest(message.as_bytes());
Ok(base64_url_encode(&hash))
}
async fn upload_to_azure(
&self,
file_id: &str,
content: &[u8],
container: &str,
account_name: &str,
account_key: &str,
) -> Result<String> {
let blob_name = format!("uploads/{}", file_id);
let url = format!(
"https://{}.blob.core.windows.net/{}/{}",
account_name, container, blob_name
);
let now = chrono::Utc::now();
let date_str = now.format("%a, %d %b %Y %H:%M:%S GMT").to_string();
let content_length = content.len().to_string();
let string_to_sign = format!(
"PUT\n\n\n{}\n\napplication/octet-stream\n\n\n\n\n\nx-ms-blob-type:BlockBlob\nx-ms-date:{}\nx-ms-version:2020-10-02\n/{}/{}/{}",
content_length, date_str, account_name, container, blob_name
);
let key_bytes =
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, account_key)
.map_err(|e| anyhow!("Invalid Azure account key: {}", e))?;
let mut mac = HmacSha256::new_from_slice(&key_bytes)
.map_err(|e| anyhow!("Failed to create HMAC: {}", e))?;
mac.update(string_to_sign.as_bytes());
let signature = base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
mac.finalize().into_bytes(),
);
let auth_header = format!("SharedKey {}:{}", account_name, signature);
let client = reqwest::Client::new();
let response = client
.put(&url)
.header("x-ms-blob-type", "BlockBlob")
.header("x-ms-date", &date_str)
.header("x-ms-version", "2020-10-02")
.header("Content-Type", "application/octet-stream")
.header("Content-Length", &content_length)
.header("Authorization", &auth_header)
.body(content.to_vec())
.send()
.await?;
if response.status().is_success() {
tracing::info!("Successfully uploaded to Azure: {}", url);
Ok(url)
} else {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
tracing::error!("Azure upload failed: {} - {}", status, error_body);
Err(anyhow!("Azure upload failed: {} - {}", status, error_body))
}
}
fn sign_hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size");
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
async fn delete_from_cloud(&self, file_id: &str) -> Result<()> {
let timer = Timer::new("file_delete_cloud_duration".to_string());
let _timer_guard = timer.start();
let Some(ref cloud_config) = self.config.cloud_storage else {
return Ok(());
};
match cloud_config {
CloudStorageConfig::S3 {
bucket,
region,
access_key,
secret_key,
} => {
self.delete_from_s3(file_id, bucket, region, access_key, secret_key)
.await
}
CloudStorageConfig::GCS {
bucket,
project_id,
credentials_path,
} => {
self.delete_from_gcs(file_id, bucket, project_id, credentials_path)
.await
}
CloudStorageConfig::Azure {
container,
account_name,
account_key,
} => {
self.delete_from_azure(file_id, container, account_name, account_key)
.await
}
}
}
async fn delete_from_s3(
&self,
file_id: &str,
bucket: &str,
region: &str,
access_key: &str,
secret_key: &str,
) -> Result<()> {
let object_key = format!("uploads/{}", file_id);
let host = format!("{}.s3.{}.amazonaws.com", bucket, region);
let url = format!("https://{}/{}", host, object_key);
let now = chrono::Utc::now();
let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
let date_stamp = now.format("%Y%m%d").to_string();
let content_hash = hex::encode(Sha256::digest(b""));
let canonical_headers = format!(
"host:{}\nx-amz-content-sha256:{}\nx-amz-date:{}\n",
host, content_hash, amz_date
);
let signed_headers = "host;x-amz-content-sha256;x-amz-date";
let canonical_request = format!(
"DELETE\n/{}\n\n{}\n{}\n{}",
object_key, canonical_headers, signed_headers, content_hash
);
let algorithm = "AWS4-HMAC-SHA256";
let credential_scope = format!("{}/{}/s3/aws4_request", date_stamp, region);
let canonical_request_hash = hex::encode(Sha256::digest(canonical_request.as_bytes()));
let string_to_sign = format!(
"{}\n{}\n{}\n{}",
algorithm, amz_date, credential_scope, canonical_request_hash
);
let k_date = Self::sign_hmac_sha256(
format!("AWS4{}", secret_key).as_bytes(),
date_stamp.as_bytes(),
);
let k_region = Self::sign_hmac_sha256(&k_date, region.as_bytes());
let k_service = Self::sign_hmac_sha256(&k_region, b"s3");
let k_signing = Self::sign_hmac_sha256(&k_service, b"aws4_request");
let signature = hex::encode(Self::sign_hmac_sha256(
&k_signing,
string_to_sign.as_bytes(),
));
let authorization = format!(
"{} Credential={}/{}, SignedHeaders={}, Signature={}",
algorithm, access_key, credential_scope, signed_headers, signature
);
let client = reqwest::Client::new();
let response = client
.delete(&url)
.header("Host", &host)
.header("x-amz-date", &amz_date)
.header("x-amz-content-sha256", &content_hash)
.header("Authorization", &authorization)
.send()
.await?;
if response.status().is_success() || response.status().as_u16() == 204 {
tracing::info!("Successfully deleted from S3: {}", file_id);
Ok(())
} else {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
tracing::error!("S3 delete failed: {} - {}", status, error_body);
Err(anyhow!("S3 delete failed: {} - {}", status, error_body))
}
}
async fn delete_from_gcs(
&self,
file_id: &str,
bucket: &str,
_project_id: &str,
credentials_path: &str,
) -> Result<()> {
let object_name = format!("uploads/{}", file_id);
let creds_content = fs::read_to_string(credentials_path)
.await
.map_err(|e| anyhow!("Failed to read GCS credentials: {}", e))?;
let creds: serde_json::Value = serde_json::from_str(&creds_content)
.map_err(|e| anyhow!("Failed to parse GCS credentials: {}", e))?;
let access_token = self.get_gcs_access_token(&creds).await?;
let url = format!(
"https://storage.googleapis.com/storage/v1/b/{}/o/{}",
bucket,
urlencoding::encode(&object_name)
);
let client = reqwest::Client::new();
let response = client
.delete(&url)
.header("Authorization", format!("Bearer {}", access_token))
.send()
.await?;
if response.status().is_success() || response.status().as_u16() == 204 {
tracing::info!("Successfully deleted from GCS: {}", file_id);
Ok(())
} else {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
tracing::error!("GCS delete failed: {} - {}", status, error_body);
Err(anyhow!("GCS delete failed: {} - {}", status, error_body))
}
}
async fn delete_from_azure(
&self,
file_id: &str,
container: &str,
account_name: &str,
account_key: &str,
) -> Result<()> {
let blob_name = format!("uploads/{}", file_id);
let url = format!(
"https://{}.blob.core.windows.net/{}/{}",
account_name, container, blob_name
);
let now = chrono::Utc::now();
let date_str = now.format("%a, %d %b %Y %H:%M:%S GMT").to_string();
let string_to_sign = format!(
"DELETE\n\n\n\n\n\n\n\n\n\n\nx-ms-date:{}\nx-ms-version:2020-10-02\n/{}/{}/{}",
date_str, account_name, container, blob_name
);
let key_bytes =
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, account_key)
.map_err(|e| anyhow!("Invalid Azure account key: {}", e))?;
let mut mac = HmacSha256::new_from_slice(&key_bytes)
.map_err(|e| anyhow!("Failed to create HMAC: {}", e))?;
mac.update(string_to_sign.as_bytes());
let signature = base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
mac.finalize().into_bytes(),
);
let auth_header = format!("SharedKey {}:{}", account_name, signature);
let client = reqwest::Client::new();
let response = client
.delete(&url)
.header("x-ms-date", &date_str)
.header("x-ms-version", "2020-10-02")
.header("Authorization", &auth_header)
.send()
.await?;
if response.status().is_success() || response.status().as_u16() == 202 {
tracing::info!("Successfully deleted from Azure: {}", file_id);
Ok(())
} else {
let status = response.status();
let error_body = response.text().await.unwrap_or_default();
tracing::error!("Azure delete failed: {} - {}", status, error_body);
Err(anyhow!("Azure delete failed: {} - {}", status, error_body))
}
}
}
#[derive(Debug, Clone)]
pub struct Upload {
pub filename: String,
pub mime_type: String,
pub content: Vec<u8>,
}
impl Upload {
pub fn new(filename: String, mime_type: String, content: Vec<u8>) -> Self {
Self {
filename,
mime_type,
content,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_file_upload_config_default() {
let config = FileUploadConfig::default();
assert_eq!(config.max_file_size, 100 * 1024 * 1024);
assert_eq!(config.max_files_per_request, 10);
assert!(config.allowed_mime_types.is_none());
assert!(config.allowed_extensions.is_none());
assert!(!config.enable_virus_scan);
assert!(config.enable_progress_tracking);
}
#[tokio::test]
async fn test_file_upload_manager_creation() {
let config = FileUploadConfig::default();
let manager = FileUploadManager::new(config);
assert!(manager.is_ok());
}
#[tokio::test]
async fn test_file_validation_size_limit() {
let config = FileUploadConfig {
max_file_size: 1024,
..Default::default()
};
let manager = FileUploadManager::new(config).expect("should succeed");
let result = manager.validate_file("test.txt", "text/plain", 2048);
assert!(result.is_err());
let result = manager.validate_file("test.txt", "text/plain", 512);
assert!(result.is_ok());
}
#[tokio::test]
async fn test_file_validation_mime_type() {
let config = FileUploadConfig {
allowed_mime_types: Some(vec!["image/".to_string(), "video/".to_string()]),
..Default::default()
};
let manager = FileUploadManager::new(config).expect("should succeed");
let result = manager.validate_file("test.jpg", "image/jpeg", 1024);
assert!(result.is_ok());
let result = manager.validate_file("test.txt", "text/plain", 1024);
assert!(result.is_err());
}
#[tokio::test]
async fn test_file_validation_extension() {
let config = FileUploadConfig {
allowed_extensions: Some(vec!["jpg".to_string(), "png".to_string()]),
..Default::default()
};
let manager = FileUploadManager::new(config).expect("should succeed");
let result = manager.validate_file("test.jpg", "image/jpeg", 1024);
assert!(result.is_ok());
let result = manager.validate_file("test.txt", "text/plain", 1024);
assert!(result.is_err());
}
#[tokio::test]
async fn test_upload_progress_tracking() {
let mut progress = UploadProgress::new("test-id".to_string(), 1000);
assert_eq!(progress.bytes_uploaded, 0);
assert_eq!(progress.percentage, 0.0);
progress.update(500);
assert_eq!(progress.bytes_uploaded, 500);
assert_eq!(progress.percentage, 50.0);
progress.update(1000);
assert_eq!(progress.bytes_uploaded, 1000);
assert_eq!(progress.percentage, 100.0);
}
#[tokio::test]
async fn test_process_upload() {
let temp_dir = std::env::temp_dir().join("oxirs-test-uploads");
let config = FileUploadConfig {
upload_dir: temp_dir.clone(),
..Default::default()
};
let manager = FileUploadManager::new(config).expect("should succeed");
manager.initialize().await.expect("should succeed");
let content = b"Hello, World!".to_vec();
let result = manager
.process_upload("test.txt".to_string(), "text/plain".to_string(), content)
.await;
assert!(result.is_ok());
let file = result.expect("should succeed");
assert_eq!(file.filename, "test.txt");
assert_eq!(file.mime_type, "text/plain");
assert_eq!(file.size, 13);
assert_eq!(file.status, UploadStatus::Completed);
manager
.delete_upload(&file.id)
.await
.expect("should succeed");
let _ = fs::remove_dir_all(&temp_dir).await;
}
#[tokio::test]
async fn test_batch_upload_size_limit() {
let config = FileUploadConfig {
max_files_per_request: 2,
..Default::default()
};
let manager = FileUploadManager::new(config).expect("should succeed");
manager.initialize().await.expect("should succeed");
let files = vec![
(
"file1.txt".to_string(),
"text/plain".to_string(),
b"content1".to_vec(),
),
(
"file2.txt".to_string(),
"text/plain".to_string(),
b"content2".to_vec(),
),
(
"file3.txt".to_string(),
"text/plain".to_string(),
b"content3".to_vec(),
),
];
let result = manager.process_batch_upload(files).await;
assert!(result.is_err());
}
}