use std::{collections::BTreeMap, time::Duration};
use chrono::{DateTime, Utc};
use rustrails_support::runtime;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use url::Url;
use crate::{
Blob,
service::StorageService,
urls::{SignedUrlError, sign_payload, verify_payload},
};
#[derive(Debug, Error)]
pub enum DirectUploadError {
#[error(transparent)]
Storage(#[from] crate::service::StorageError),
#[error(transparent)]
SignedUrl(#[from] SignedUrlError),
#[error("invalid direct upload token")]
InvalidToken,
#[error("direct upload token has expired")]
Expired,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DirectUploadTokenClaims {
pub blob_id: uuid::Uuid,
pub key: String,
pub byte_size: u64,
pub checksum: String,
pub service_name: String,
pub expires_at: i64,
}
#[derive(Debug, Clone)]
pub struct DirectUploadRequest {
pub blob: Blob,
pub upload_url: Url,
pub headers: BTreeMap<String, String>,
pub token: String,
pub expires_at: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct DirectUploadManager {
secret: Vec<u8>,
}
impl DirectUploadManager {
#[must_use]
pub fn new(secret: impl Into<Vec<u8>>) -> Self {
Self {
secret: secret.into(),
}
}
pub async fn prepare<S: StorageService + ?Sized>(
&self,
blob: Blob,
service: &S,
expires_in: Duration,
) -> Result<DirectUploadRequest, DirectUploadError> {
let expires_at = Utc::now()
+ chrono::Duration::from_std(expires_in)
.map_err(|_| DirectUploadError::InvalidToken)?;
let claims = DirectUploadTokenClaims {
blob_id: blob.id(),
key: blob.key().to_owned(),
byte_size: blob.byte_size(),
checksum: blob.checksum().to_owned(),
service_name: blob.service_name().to_owned(),
expires_at: expires_at.timestamp(),
};
let payload = serde_json::to_vec(&claims).map_err(|_| DirectUploadError::InvalidToken)?;
let token = sign_payload(&self.secret, &payload)?;
let upload_url = service.url(blob.key(), expires_in).await?;
let mut headers = BTreeMap::new();
headers.insert(
"x-rustrails-checksum".to_owned(),
blob.checksum().to_owned(),
);
headers.insert(
"x-rustrails-byte-size".to_owned(),
blob.byte_size().to_string(),
);
if let Some(content_type) = blob.content_type() {
headers.insert("content-type".to_owned(), content_type.to_owned());
}
Ok(DirectUploadRequest {
blob,
upload_url,
headers,
token,
expires_at,
})
}
pub fn prepare_sync<S: StorageService + ?Sized>(
&self,
blob: Blob,
service: &S,
expires_in: Duration,
) -> Result<DirectUploadRequest, DirectUploadError> {
runtime::block_on(self.prepare(blob, service, expires_in))
}
pub fn verify(
&self,
token: &str,
now: DateTime<Utc>,
) -> Result<DirectUploadTokenClaims, DirectUploadError> {
let payload = verify_payload(token, &self.secret)?;
let claims: DirectUploadTokenClaims =
serde_json::from_slice(&payload).map_err(|_| DirectUploadError::InvalidToken)?;
if now.timestamp() > claims.expires_at {
return Err(DirectUploadError::Expired);
}
Ok(claims)
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use super::*;
use crate::{blob::Blob, service::memory::MemoryService, test_support::run_sync_test};
fn blob() -> Blob {
Blob::create(
Bytes::from_static(b"hello"),
"hello.txt",
None,
Default::default(),
"memory",
)
.expect("blob should build")
}
#[tokio::test]
async fn test_prepare_builds_upload_request() {
let manager = DirectUploadManager::new(b"secret".to_vec());
let service = MemoryService::new("memory").expect("service should build");
let request = manager
.prepare(blob(), &service, Duration::from_secs(60))
.await
.expect("request should build");
assert!(request.upload_url.as_str().contains("expires_in=60"));
assert!(request.headers.contains_key("x-rustrails-checksum"));
}
#[test]
fn test_prepare_sync_builds_upload_request() {
run_sync_test(|| {
let manager = DirectUploadManager::new(b"secret".to_vec());
let service = MemoryService::new("memory").expect("service should build");
let request = manager
.prepare_sync(blob(), &service, Duration::from_secs(60))
.expect("request should build");
assert!(request.upload_url.as_str().contains("expires_in=60"));
assert!(request.headers.contains_key("x-rustrails-checksum"));
});
}
#[tokio::test]
async fn test_prepare_includes_byte_size_header() {
let manager = DirectUploadManager::new(b"secret".to_vec());
let service = MemoryService::new("memory").expect("service should build");
let blob = blob();
let expected_byte_size = blob.byte_size().to_string();
let request = manager
.prepare(blob, &service, Duration::from_secs(60))
.await
.expect("request should build");
assert_eq!(
request
.headers
.get("x-rustrails-byte-size")
.map(String::as_str),
Some(expected_byte_size.as_str())
);
}
#[tokio::test]
async fn test_prepare_preserves_blob_reference() {
let manager = DirectUploadManager::new(b"secret".to_vec());
let service = MemoryService::new("memory").expect("service should build");
let blob = blob();
let request = manager
.prepare(blob.clone(), &service, Duration::from_secs(60))
.await
.expect("request should build");
assert_eq!(request.blob.id(), blob.id());
}
#[tokio::test]
async fn test_verify_round_trips_claims() {
let manager = DirectUploadManager::new(b"secret".to_vec());
let service = MemoryService::new("memory").expect("service should build");
let request = manager
.prepare(blob(), &service, Duration::from_secs(60))
.await
.expect("request should build");
let claims = manager
.verify(&request.token, Utc::now())
.expect("token should verify");
assert_eq!(claims.key, request.blob.key());
assert_eq!(claims.checksum, request.blob.checksum());
}
#[tokio::test]
async fn test_prepare_round_trips_blob_metadata_service_name_and_expiration() {
let manager = DirectUploadManager::new(b"secret".to_vec());
let service = MemoryService::new("public").expect("service should build");
let mut metadata = BTreeMap::new();
metadata.insert("custom".to_owned(), serde_json::json!("value"));
let blob = Blob::create_before_direct_upload(
"direct-key",
"hello.txt",
6,
"checksum",
Some("text/plain"),
metadata.clone(),
"mirror",
)
.expect("blob should build");
let blob_id = blob.id();
let request = manager
.prepare(blob, &service, Duration::from_secs(60))
.await
.expect("request should build");
let claims = manager
.verify(&request.token, Utc::now())
.expect("token should verify");
assert_eq!(request.blob.metadata(), &metadata);
assert_eq!(claims.blob_id, blob_id);
assert_eq!(claims.service_name, "mirror");
assert_eq!(claims.expires_at, request.expires_at.timestamp());
}
#[tokio::test]
async fn test_verify_rejects_expired_token() {
let manager = DirectUploadManager::new(b"secret".to_vec());
let service = MemoryService::new("memory").expect("service should build");
let request = manager
.prepare(blob(), &service, Duration::from_secs(1))
.await
.expect("request should build");
let future = Utc::now() + chrono::Duration::seconds(2);
let error = manager
.verify(&request.token, future)
.expect_err("token should fail");
assert!(matches!(error, DirectUploadError::Expired));
}
#[tokio::test]
async fn test_verify_rejects_tampered_token() {
let manager = DirectUploadManager::new(b"secret".to_vec());
let error = manager
.verify("tampered", Utc::now())
.expect_err("token should fail");
assert!(matches!(error, DirectUploadError::SignedUrl(_)));
}
#[tokio::test]
async fn test_verify_rejects_token_signed_with_different_secret() {
let manager = DirectUploadManager::new(b"secret".to_vec());
let other_manager = DirectUploadManager::new(b"other-secret".to_vec());
let service = MemoryService::new("memory").expect("service should build");
let request = manager
.prepare(blob(), &service, Duration::from_secs(60))
.await
.expect("request should build");
let error = other_manager
.verify(&request.token, Utc::now())
.expect_err("token should fail");
assert!(matches!(error, DirectUploadError::SignedUrl(_)));
}
#[tokio::test]
async fn test_verify_accepts_exact_expiration_timestamp() {
let manager = DirectUploadManager::new(b"secret".to_vec());
let service = MemoryService::new("memory").expect("service should build");
let request = manager
.prepare(blob(), &service, Duration::from_secs(1))
.await
.expect("request should build");
let claims = manager
.verify(&request.token, Utc::now())
.expect("token should verify");
let boundary = chrono::DateTime::<Utc>::from_timestamp(claims.expires_at, 0)
.expect("timestamp should be valid");
let boundary_claims = manager
.verify(&request.token, boundary)
.expect("boundary token should verify");
assert_eq!(boundary_claims, claims);
}
#[tokio::test]
async fn test_prepare_includes_content_type_header_when_known() {
let manager = DirectUploadManager::new(b"secret".to_vec());
let service = MemoryService::new("memory").expect("service should build");
let blob = Blob::create(
Bytes::from_static(b"hello"),
"hello.txt",
Some("text/plain"),
Default::default(),
"memory",
)
.expect("blob should build");
let request = manager
.prepare(blob, &service, Duration::from_secs(60))
.await
.expect("request should build");
assert_eq!(
request.headers.get("content-type").map(String::as_str),
Some("text/plain")
);
}
#[tokio::test]
async fn test_prepare_omits_content_type_header_when_unknown() {
let manager = DirectUploadManager::new(b"secret".to_vec());
let service = MemoryService::new("memory").expect("service should build");
let blob = Blob::create_before_direct_upload(
"direct-key",
"unknown_file",
100,
"checksum",
None,
BTreeMap::new(),
"memory",
)
.expect("blob should build");
let request = manager
.prepare(blob, &service, Duration::from_secs(60))
.await
.expect("request should build");
assert!(!request.headers.contains_key("content-type"));
}
}