use std::collections::HashMap;
use std::sync::Arc;
use std::sync::RwLock;
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use tokio::sync::Mutex;
use crate::object_lock::LockMode;
use crate::tagging::TagSet;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum MultipartSseMode {
None,
SseS4,
SseC {
key: [u8; 32],
key_md5: [u8; 16],
},
SseKms {
key_id: String,
},
}
#[derive(Clone, Debug)]
pub struct MultipartUploadContext {
pub bucket: String,
pub key: String,
pub sse: MultipartSseMode,
pub tags: Option<TagSet>,
pub object_lock_mode: Option<LockMode>,
pub object_lock_retain_until: Option<DateTime<Utc>>,
pub object_lock_legal_hold: bool,
}
pub struct MultipartStateStore {
by_upload_id: RwLock<HashMap<String, MultipartUploadContext>>,
completion_locks: DashMap<(String, String), Arc<Mutex<()>>>,
}
impl MultipartStateStore {
#[must_use]
pub fn new() -> Self {
Self {
by_upload_id: RwLock::new(HashMap::new()),
completion_locks: DashMap::new(),
}
}
pub fn put(&self, upload_id: &str, ctx: MultipartUploadContext) {
self.by_upload_id
.write()
.expect("multipart-state by_upload_id RwLock poisoned")
.insert(upload_id.to_owned(), ctx);
}
#[must_use]
pub fn get(&self, upload_id: &str) -> Option<MultipartUploadContext> {
self.by_upload_id
.read()
.expect("multipart-state by_upload_id RwLock poisoned")
.get(upload_id)
.cloned()
}
pub fn remove(&self, upload_id: &str) {
self.by_upload_id
.write()
.expect("multipart-state by_upload_id RwLock poisoned")
.remove(upload_id);
}
pub fn completion_lock(&self, bucket: &str, key: &str) -> Arc<Mutex<()>> {
let k = (bucket.to_owned(), key.to_owned());
self.completion_locks
.entry(k)
.or_insert_with(|| Arc::new(Mutex::new(())))
.value()
.clone()
}
pub fn prune_completion_locks(&self) {
self.completion_locks
.retain(|_, lock| Arc::strong_count(lock) > 1);
}
#[cfg(test)]
fn completion_locks_len(&self) -> usize {
self.completion_locks.len()
}
#[cfg(test)]
fn len(&self) -> usize {
self.by_upload_id
.read()
.expect("multipart-state by_upload_id RwLock poisoned")
.len()
}
}
impl Default for MultipartStateStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
fn sample_ctx(bucket: &str, key: &str) -> MultipartUploadContext {
MultipartUploadContext {
bucket: bucket.to_owned(),
key: key.to_owned(),
sse: MultipartSseMode::None,
tags: None,
object_lock_mode: None,
object_lock_retain_until: None,
object_lock_legal_hold: false,
}
}
#[test]
fn put_get_remove_round_trip() {
let store = MultipartStateStore::new();
let ctx = sample_ctx("b", "k");
store.put("upload-001", ctx.clone());
let got = store.get("upload-001").expect("entry must be present");
assert_eq!(got.bucket, "b");
assert_eq!(got.key, "k");
assert_eq!(got.sse, MultipartSseMode::None);
store.remove("upload-001");
assert!(store.get("upload-001").is_none(), "entry must be gone");
}
#[test]
fn sse_c_key_bytes_round_trip() {
let store = MultipartStateStore::new();
let key = [0xa5u8; 32];
let key_md5 = [0xb6u8; 16];
let mut ctx = sample_ctx("b", "k");
ctx.sse = MultipartSseMode::SseC { key, key_md5 };
store.put("u-sse-c", ctx);
let got = store.get("u-sse-c").expect("entry must be present");
match got.sse {
MultipartSseMode::SseC { key: k, key_md5: m } => {
assert_eq!(k, key, "SSE-C key bytes must round-trip");
assert_eq!(m, key_md5, "SSE-C MD5 must round-trip");
}
other => panic!("expected SseC variant, got {other:?}"),
}
}
#[test]
fn completion_lock_returns_same_arc_for_same_key() {
let store = MultipartStateStore::new();
let a = store.completion_lock("bucket-a", "key/x");
let b = store.completion_lock("bucket-a", "key/x");
assert!(
Arc::ptr_eq(&a, &b),
"completion_lock(same bucket, same key) must return identical Arc"
);
}
#[tokio::test]
async fn completion_lock_distinct_keys_independent() {
let store = MultipartStateStore::new();
let a = store.completion_lock("bucket-a", "shared/key");
let b = store.completion_lock("bucket-b", "shared/key");
assert!(
!Arc::ptr_eq(&a, &b),
"completion_lock with different bucket must yield different Arc"
);
let guard_a = a.try_lock().expect("lock on bucket-a/shared/key must be free");
let guard_b = b.try_lock().expect("lock on bucket-b/shared/key must be free");
let a2 = store.completion_lock("bucket-a", "shared/key");
assert!(
Arc::ptr_eq(&a, &a2),
"completion_lock for the same (bucket, key) must alias"
);
assert!(
a2.try_lock().is_err(),
"completion_lock alias must observe the held guard as contended"
);
drop(guard_a);
drop(guard_b);
}
#[test]
fn prune_completion_locks_removes_unreferenced() {
let store = MultipartStateStore::new();
{
let _lock = store.completion_lock("b", "ephemeral");
}
assert_eq!(
store.completion_locks_len(),
1,
"lock entry must be present immediately after acquire-drop"
);
store.prune_completion_locks();
assert_eq!(
store.completion_locks_len(),
0,
"prune must retire entries with strong_count == 1"
);
let held = store.completion_lock("b", "in-flight");
store.prune_completion_locks();
assert_eq!(
store.completion_locks_len(),
1,
"prune must keep entries with outstanding Arc borrowers"
);
drop(held);
store.prune_completion_locks();
assert_eq!(
store.completion_locks_len(),
0,
"prune must retire the entry once the borrower drops"
);
}
#[test]
fn concurrent_put_lookup_race_free() {
let store = Arc::new(MultipartStateStore::new());
let mut handles = Vec::new();
for tid in 0..8u32 {
let st = Arc::clone(&store);
handles.push(thread::spawn(move || {
for i in 0..250u32 {
let id = format!("u-{tid}-{i}");
let ctx = sample_ctx("b", &id);
st.put(&id, ctx);
let got = st.get(&id).expect("self-put must be visible");
assert_eq!(got.key, id);
}
}));
}
for h in handles {
h.join().expect("worker thread panicked");
}
assert_eq!(store.len(), 8 * 250, "all puts must persist");
}
}