use std::collections::HashMap;
use std::sync::Arc;
use std::sync::RwLock;
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use tokio::sync::Mutex;
use zeroize::Zeroizing;
use crate::object_lock::LockMode;
use crate::tagging::TagSet;
#[derive(Clone, Debug)]
pub enum MultipartSseMode {
None,
SseS4,
SseC {
key: Zeroizing<[u8; 32]>,
key_md5: [u8; 16],
},
SseKms { key_id: String },
}
impl PartialEq for MultipartSseMode {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(MultipartSseMode::None, MultipartSseMode::None) => true,
(MultipartSseMode::SseS4, MultipartSseMode::SseS4) => true,
(
MultipartSseMode::SseC {
key: a,
key_md5: am,
},
MultipartSseMode::SseC {
key: b,
key_md5: bm,
},
) => a.as_slice() == b.as_slice() && am == bm,
(MultipartSseMode::SseKms { key_id: a }, MultipartSseMode::SseKms { key_id: b }) => {
a == b
}
_ => false,
}
}
}
impl Eq for MultipartSseMode {}
#[derive(Clone, Debug)]
pub struct MultipartUploadContext {
pub bucket: String,
pub key: String,
pub sse: MultipartSseMode,
pub tags: Option<TagSet>,
pub object_lock_mode: Option<LockMode>,
pub object_lock_retain_until: Option<DateTime<Utc>>,
pub object_lock_legal_hold: bool,
}
pub struct MultipartStateStore {
by_upload_id: RwLock<HashMap<String, (MultipartUploadContext, DateTime<Utc>)>>,
completion_locks: DashMap<(String, String), Arc<Mutex<()>>>,
}
impl MultipartStateStore {
#[must_use]
pub fn new() -> Self {
Self {
by_upload_id: RwLock::new(HashMap::new()),
completion_locks: DashMap::new(),
}
}
pub fn put(&self, upload_id: &str, ctx: MultipartUploadContext) {
crate::lock_recovery::recover_write(&self.by_upload_id, "multipart_state.by_upload_id")
.insert(upload_id.to_owned(), (ctx, Utc::now()));
}
#[must_use]
pub fn get(&self, upload_id: &str) -> Option<MultipartUploadContext> {
crate::lock_recovery::recover_read(&self.by_upload_id, "multipart_state.by_upload_id")
.get(upload_id)
.map(|(c, _)| c.clone())
}
pub fn remove(&self, upload_id: &str) {
crate::lock_recovery::recover_write(&self.by_upload_id, "multipart_state.by_upload_id")
.remove(upload_id);
}
pub fn sweep_stale(&self, now: DateTime<Utc>, max_age: chrono::Duration) -> usize {
let cutoff = now - max_age;
let mut map =
crate::lock_recovery::recover_write(&self.by_upload_id, "multipart_state.by_upload_id");
let stale: Vec<String> = map
.iter()
.filter(|(_, (_, ts))| *ts < cutoff)
.map(|(k, _)| k.clone())
.collect();
let count = stale.len();
for k in stale {
map.remove(&k);
}
count
}
pub fn completion_lock(&self, bucket: &str, key: &str) -> Arc<Mutex<()>> {
let k = (bucket.to_owned(), key.to_owned());
self.completion_locks
.entry(k)
.or_insert_with(|| Arc::new(Mutex::new(())))
.value()
.clone()
}
pub fn prune_completion_locks(&self) {
self.completion_locks
.retain(|_, lock| Arc::strong_count(lock) > 1);
}
#[cfg(test)]
fn completion_locks_len(&self) -> usize {
self.completion_locks.len()
}
#[cfg(test)]
fn len(&self) -> usize {
crate::lock_recovery::recover_read(&self.by_upload_id, "multipart_state.by_upload_id").len()
}
}
impl Default for MultipartStateStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
fn sample_ctx(bucket: &str, key: &str) -> MultipartUploadContext {
MultipartUploadContext {
bucket: bucket.to_owned(),
key: key.to_owned(),
sse: MultipartSseMode::None,
tags: None,
object_lock_mode: None,
object_lock_retain_until: None,
object_lock_legal_hold: false,
}
}
#[test]
fn put_get_remove_round_trip() {
let store = MultipartStateStore::new();
let ctx = sample_ctx("b", "k");
store.put("upload-001", ctx.clone());
let got = store.get("upload-001").expect("entry must be present");
assert_eq!(got.bucket, "b");
assert_eq!(got.key, "k");
assert_eq!(got.sse, MultipartSseMode::None);
store.remove("upload-001");
assert!(store.get("upload-001").is_none(), "entry must be gone");
}
#[test]
fn sse_c_key_bytes_round_trip() {
let store = MultipartStateStore::new();
let key = [0xa5u8; 32];
let key_md5 = [0xb6u8; 16];
let mut ctx = sample_ctx("b", "k");
ctx.sse = MultipartSseMode::SseC {
key: Zeroizing::new(key),
key_md5,
};
store.put("u-sse-c", ctx);
let got = store.get("u-sse-c").expect("entry must be present");
match got.sse {
MultipartSseMode::SseC { key: k, key_md5: m } => {
assert_eq!(*k, key, "SSE-C key bytes must round-trip");
assert_eq!(m, key_md5, "SSE-C MD5 must round-trip");
}
other => panic!("expected SseC variant, got {other:?}"),
}
}
#[test]
fn sse_c_key_zeroized_on_remove() {
let store = MultipartStateStore::new();
let key = [0x77u8; 32];
let key_md5 = [0x33u8; 16];
let mut ctx = sample_ctx("b", "k");
ctx.sse = MultipartSseMode::SseC {
key: Zeroizing::new(key),
key_md5,
};
store.put("u-zero", ctx);
let got = store.get("u-zero").expect("entry present");
match &got.sse {
MultipartSseMode::SseC { key: k, .. } => {
let _deref: &[u8; 32] = k; assert_eq!(**k, key);
}
other => panic!("expected SseC, got {other:?}"),
}
drop(got);
store.remove("u-zero");
assert!(
store.get("u-zero").is_none(),
"removed entry must be gone (its Zeroizing<[u8;32]> ran Drop and wiped the key)"
);
}
#[test]
fn sweep_stale_drops_old_contexts() {
let store = MultipartStateStore::new();
store.put("u-1", sample_ctx("b", "k1"));
store.put("u-2", sample_ctx("b", "k2"));
store.put("u-3", sample_ctx("b", "k3"));
assert_eq!(store.len(), 3, "all three entries inserted");
let future = Utc::now() + chrono::Duration::hours(25);
let swept = store.sweep_stale(future, chrono::Duration::hours(24));
assert_eq!(swept, 3, "all three entries are older than 24 h cutoff");
assert_eq!(store.len(), 0, "store must be empty after sweep");
}
#[test]
fn sweep_stale_keeps_recent_contexts() {
let store = MultipartStateStore::new();
store.put("u-fresh", sample_ctx("b", "k"));
let near_future = Utc::now() + chrono::Duration::hours(1);
let swept = store.sweep_stale(near_future, chrono::Duration::hours(24));
assert_eq!(swept, 0, "1 h-old entry must NOT be swept under 24 h TTL");
assert!(store.get("u-fresh").is_some(), "fresh entry must remain");
assert_eq!(store.len(), 1);
}
#[test]
fn sweep_stale_count_returns_correct() {
let store = MultipartStateStore::new();
store.put("old-1", sample_ctx("b", "k1"));
store.put("old-2", sample_ctx("b", "k2"));
let sweep_now = Utc::now() + chrono::Duration::hours(25);
std::thread::sleep(std::time::Duration::from_millis(10));
let fresh_marker = Utc::now();
std::thread::sleep(std::time::Duration::from_millis(10));
store.put("fresh", sample_ctx("b", "k3"));
let sweep_at = fresh_marker + chrono::Duration::hours(24);
let swept = store.sweep_stale(sweep_at, chrono::Duration::hours(24));
assert_eq!(swept, 2, "exactly the two pre-marker entries must sweep");
assert!(store.get("fresh").is_some(), "post-marker entry survives");
assert!(store.get("old-1").is_none(), "old-1 must be gone");
assert!(store.get("old-2").is_none(), "old-2 must be gone");
let _ = sweep_now; }
#[test]
fn completion_lock_returns_same_arc_for_same_key() {
let store = MultipartStateStore::new();
let a = store.completion_lock("bucket-a", "key/x");
let b = store.completion_lock("bucket-a", "key/x");
assert!(
Arc::ptr_eq(&a, &b),
"completion_lock(same bucket, same key) must return identical Arc"
);
}
#[tokio::test]
async fn completion_lock_distinct_keys_independent() {
let store = MultipartStateStore::new();
let a = store.completion_lock("bucket-a", "shared/key");
let b = store.completion_lock("bucket-b", "shared/key");
assert!(
!Arc::ptr_eq(&a, &b),
"completion_lock with different bucket must yield different Arc"
);
let guard_a = a
.try_lock()
.expect("lock on bucket-a/shared/key must be free");
let guard_b = b
.try_lock()
.expect("lock on bucket-b/shared/key must be free");
let a2 = store.completion_lock("bucket-a", "shared/key");
assert!(
Arc::ptr_eq(&a, &a2),
"completion_lock for the same (bucket, key) must alias"
);
assert!(
a2.try_lock().is_err(),
"completion_lock alias must observe the held guard as contended"
);
drop(guard_a);
drop(guard_b);
}
#[test]
fn prune_completion_locks_removes_unreferenced() {
let store = MultipartStateStore::new();
{
let _lock = store.completion_lock("b", "ephemeral");
}
assert_eq!(
store.completion_locks_len(),
1,
"lock entry must be present immediately after acquire-drop"
);
store.prune_completion_locks();
assert_eq!(
store.completion_locks_len(),
0,
"prune must retire entries with strong_count == 1"
);
let held = store.completion_lock("b", "in-flight");
store.prune_completion_locks();
assert_eq!(
store.completion_locks_len(),
1,
"prune must keep entries with outstanding Arc borrowers"
);
drop(held);
store.prune_completion_locks();
assert_eq!(
store.completion_locks_len(),
0,
"prune must retire the entry once the borrower drops"
);
}
#[test]
fn concurrent_put_lookup_race_free() {
let store = Arc::new(MultipartStateStore::new());
let mut handles = Vec::new();
for tid in 0..8u32 {
let st = Arc::clone(&store);
handles.push(thread::spawn(move || {
for i in 0..250u32 {
let id = format!("u-{tid}-{i}");
let ctx = sample_ctx("b", &id);
st.put(&id, ctx);
let got = st.get(&id).expect("self-put must be visible");
assert_eq!(got.key, id);
}
}));
}
for h in handles {
h.join().expect("worker thread panicked");
}
assert_eq!(store.len(), 8 * 250, "all puts must persist");
}
#[test]
fn multipart_state_get_after_panic_recovers_via_poison() {
let store = Arc::new(MultipartStateStore::new());
store.put("u1", sample_ctx("b", "k"));
let store_cl = Arc::clone(&store);
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut g = store_cl.by_upload_id.write().expect("clean lock");
g.insert("u2".to_owned(), (sample_ctx("b", "k2"), Utc::now()));
panic!("force-poison");
}));
assert!(
store.by_upload_id.is_poisoned(),
"write panic must poison by_upload_id lock"
);
let got = store.get("u1").expect("get after poison must succeed");
assert_eq!(got.bucket, "b");
assert_eq!(got.key, "k");
let n = store.sweep_stale(
Utc::now() + chrono::Duration::hours(48),
chrono::Duration::hours(1),
);
assert!(n >= 1, "stale sweep must run + recover via poison");
}
}