use std::collections::HashMap;
use std::sync::RwLock;
use chrono::{DateTime, Utc};
use crate::object_lock::LockMode;
use crate::tagging::TagSet;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum MultipartSseMode {
None,
SseS4,
SseC {
key: [u8; 32],
key_md5: [u8; 16],
},
SseKms {
key_id: String,
},
}
#[derive(Clone, Debug)]
pub struct MultipartUploadContext {
pub bucket: String,
pub key: String,
pub sse: MultipartSseMode,
pub tags: Option<TagSet>,
pub object_lock_mode: Option<LockMode>,
pub object_lock_retain_until: Option<DateTime<Utc>>,
pub object_lock_legal_hold: bool,
}
pub struct MultipartStateStore {
by_upload_id: RwLock<HashMap<String, MultipartUploadContext>>,
}
impl MultipartStateStore {
#[must_use]
pub fn new() -> Self {
Self {
by_upload_id: RwLock::new(HashMap::new()),
}
}
pub fn put(&self, upload_id: &str, ctx: MultipartUploadContext) {
self.by_upload_id
.write()
.expect("multipart-state by_upload_id RwLock poisoned")
.insert(upload_id.to_owned(), ctx);
}
#[must_use]
pub fn get(&self, upload_id: &str) -> Option<MultipartUploadContext> {
self.by_upload_id
.read()
.expect("multipart-state by_upload_id RwLock poisoned")
.get(upload_id)
.cloned()
}
pub fn remove(&self, upload_id: &str) {
self.by_upload_id
.write()
.expect("multipart-state by_upload_id RwLock poisoned")
.remove(upload_id);
}
#[cfg(test)]
fn len(&self) -> usize {
self.by_upload_id
.read()
.expect("multipart-state by_upload_id RwLock poisoned")
.len()
}
}
impl Default for MultipartStateStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
fn sample_ctx(bucket: &str, key: &str) -> MultipartUploadContext {
MultipartUploadContext {
bucket: bucket.to_owned(),
key: key.to_owned(),
sse: MultipartSseMode::None,
tags: None,
object_lock_mode: None,
object_lock_retain_until: None,
object_lock_legal_hold: false,
}
}
#[test]
fn put_get_remove_round_trip() {
let store = MultipartStateStore::new();
let ctx = sample_ctx("b", "k");
store.put("upload-001", ctx.clone());
let got = store.get("upload-001").expect("entry must be present");
assert_eq!(got.bucket, "b");
assert_eq!(got.key, "k");
assert_eq!(got.sse, MultipartSseMode::None);
store.remove("upload-001");
assert!(store.get("upload-001").is_none(), "entry must be gone");
}
#[test]
fn sse_c_key_bytes_round_trip() {
let store = MultipartStateStore::new();
let key = [0xa5u8; 32];
let key_md5 = [0xb6u8; 16];
let mut ctx = sample_ctx("b", "k");
ctx.sse = MultipartSseMode::SseC { key, key_md5 };
store.put("u-sse-c", ctx);
let got = store.get("u-sse-c").expect("entry must be present");
match got.sse {
MultipartSseMode::SseC { key: k, key_md5: m } => {
assert_eq!(k, key, "SSE-C key bytes must round-trip");
assert_eq!(m, key_md5, "SSE-C MD5 must round-trip");
}
other => panic!("expected SseC variant, got {other:?}"),
}
}
#[test]
fn concurrent_put_lookup_race_free() {
let store = Arc::new(MultipartStateStore::new());
let mut handles = Vec::new();
for tid in 0..8u32 {
let st = Arc::clone(&store);
handles.push(thread::spawn(move || {
for i in 0..250u32 {
let id = format!("u-{tid}-{i}");
let ctx = sample_ctx("b", &id);
st.put(&id, ctx);
let got = st.get(&id).expect("self-put must be visible");
assert_eq!(got.key, id);
}
}));
}
for h in handles {
h.join().expect("worker thread panicked");
}
assert_eq!(store.len(), 8 * 250, "all puts must persist");
}
}