use std::{
collections::HashMap,
sync::{
atomic::{AtomicU64, Ordering},
Arc, RwLock,
},
};
use arc_swap::ArcSwap;
use chrono::Utc;
use crate::{
config::StatusConfig,
request_id::RequestId,
status::{StatusStore, StatusUpdate, SubmissionStatusRecord},
};
pub struct InMemoryStatusStore {
records: RwLock<HashMap<String, SubmissionStatusRecord>>,
config: ArcSwap<StatusConfig>,
expired_total: AtomicU64,
}
impl InMemoryStatusStore {
pub fn new(config: &StatusConfig) -> Arc<Self> {
Arc::new(Self {
records: RwLock::new(HashMap::new()),
config: ArcSwap::from_pointee(config.clone()),
expired_total: AtomicU64::new(0),
})
}
pub fn expired_total(&self) -> u64 {
self.expired_total.load(Ordering::Relaxed)
}
}
impl StatusStore for InMemoryStatusStore {
fn put(&self, record: SubmissionStatusRecord) {
let key = record.request_id.as_str().to_string();
let mut map = self.records.write().unwrap();
let cfg = self.config.load();
if map.len() >= cfg.max_records {
let to_remove = map
.iter()
.min_by_key(|(_, r)| r.created_at)
.map(|(k, _)| k.clone());
if let Some(k) = to_remove {
map.remove(&k);
}
}
map.insert(key, record);
}
fn update_status(&self, request_id: &RequestId, key_id: &str, update: StatusUpdate) {
let mut map = self.records.write().unwrap();
if let Some(record) = map.get_mut(request_id.as_str()) {
if record.key_id != key_id || record.is_expired() || record.status.is_terminal() {
return;
}
let now = Utc::now();
record.status = update.status;
record.code = update.code;
if update.message.is_some() {
record.message = update.message;
}
record.updated_at = now;
}
}
fn get(&self, request_id: &RequestId, key_id: &str) -> Option<SubmissionStatusRecord> {
{
let map = self.records.read().unwrap();
let record = map.get(request_id.as_str())?;
if record.key_id != key_id {
return None;
}
if !record.is_expired() {
return Some(record.clone());
}
}
let mut map = self.records.write().unwrap();
if let Some(record) = map.get(request_id.as_str()) {
if record.is_expired() {
map.remove(request_id.as_str());
self.expired_total.fetch_add(1, Ordering::Relaxed);
}
}
None
}
fn expire_old_records(&self) {
let now = Utc::now();
let mut map = self.records.write().unwrap();
let cfg = self.config.load();
let before = map.len();
map.retain(|_, r| r.expires_at > now);
let removed = before - map.len();
if removed > 0 {
self.expired_total.fetch_add(removed as u64, Ordering::Relaxed);
}
if map.len() > cfg.max_records {
let excess = map.len() - cfg.max_records;
let mut keys_by_age: Vec<(String, chrono::DateTime<Utc>)> = map
.iter()
.map(|(k, r)| (k.clone(), r.created_at))
.collect();
keys_by_age.sort_by_key(|(_, t)| *t);
for (k, _) in keys_by_age.into_iter().take(excess) {
map.remove(&k);
}
}
}
fn record_count(&self) -> usize {
self.records.read().unwrap().len()
}
fn reload_config(&self, config: &StatusConfig) {
self.config.store(Arc::new(config.clone()));
}
}
pub struct NoopStatusStore;
impl StatusStore for NoopStatusStore {
fn put(&self, _record: SubmissionStatusRecord) {}
fn update_status(&self, _: &RequestId, _: &str, _: StatusUpdate) {}
fn get(&self, _: &RequestId, _: &str) -> Option<SubmissionStatusRecord> { None }
fn expire_old_records(&self) {}
fn record_count(&self) -> usize { 0 }
fn reload_config(&self, _: &StatusConfig) {}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::status::{ErrorCode, SubmissionStatus};
fn test_cfg() -> StatusConfig {
StatusConfig {
enabled: true,
store: "memory".into(),
ttl_seconds: 3600,
max_records: 100,
cleanup_interval_seconds: 60,
}
}
fn make_record(request_id: RequestId, key_id: &str, ttl: u64) -> SubmissionStatusRecord {
use crate::status::recipient_domains_from;
SubmissionStatusRecord::new(
request_id, key_id.into(),
recipient_domains_from(&["user@example.com".to_string()], &[]),
1, ttl,
)
}
#[test]
fn put_and_get_returns_record() {
let store = InMemoryStatusStore::new(&test_cfg());
let id = RequestId::new();
store.put(make_record(id.clone(), "key-a", 3600));
assert!(store.get(&id, "key-a").is_some());
}
#[test]
fn get_with_wrong_key_returns_none() {
let store = InMemoryStatusStore::new(&test_cfg());
let id = RequestId::new();
store.put(make_record(id.clone(), "key-a", 3600));
assert!(store.get(&id, "key-b").is_none());
}
#[test]
fn expired_record_returns_none() {
let store = InMemoryStatusStore::new(&test_cfg());
let id = RequestId::new();
store.put(make_record(id.clone(), "key-a", 0)); assert!(store.get(&id, "key-a").is_none());
}
#[test]
fn update_status_transitions_correctly() {
let store = InMemoryStatusStore::new(&test_cfg());
let id = RequestId::new();
store.put(make_record(id.clone(), "key-a", 3600));
store.update_status(&id, "key-a", StatusUpdate {
status: SubmissionStatus::SmtpAccepted,
code: None,
message: Some("accepted".into()),
});
let r = store.get(&id, "key-a").unwrap();
assert_eq!(r.status, SubmissionStatus::SmtpAccepted);
}
#[test]
fn terminal_status_is_not_updated() {
let store = InMemoryStatusStore::new(&test_cfg());
let id = RequestId::new();
store.put(make_record(id.clone(), "key-a", 3600));
store.update_status(&id, "key-a", StatusUpdate {
status: SubmissionStatus::SmtpAccepted, code: None, message: None,
});
store.update_status(&id, "key-a", StatusUpdate {
status: SubmissionStatus::SmtpFailed, code: Some(ErrorCode::SmtpUnavailable),
message: None,
});
let r = store.get(&id, "key-a").unwrap();
assert_eq!(r.status, SubmissionStatus::SmtpAccepted, "terminal status must not change");
}
#[test]
fn expire_old_records_removes_expired() {
let store = InMemoryStatusStore::new(&test_cfg());
let id1 = RequestId::new();
let id2 = RequestId::new();
store.put(make_record(id1.clone(), "key-a", 0)); store.put(make_record(id2.clone(), "key-a", 3600));
store.expire_old_records();
assert_eq!(store.record_count(), 1);
assert!(store.get(&id2, "key-a").is_some());
}
#[test]
fn max_records_evicts_oldest() {
let mut cfg = test_cfg();
cfg.max_records = 2;
let store = InMemoryStatusStore::new(&cfg);
let ids: Vec<RequestId> = (0..3).map(|_| RequestId::new()).collect();
for id in &ids {
store.put(make_record(id.clone(), "key-a", 3600));
std::thread::sleep(std::time::Duration::from_millis(1));
}
assert_eq!(store.record_count(), 2, "must be capped at max_records");
}
#[test]
fn noop_store_always_returns_none() {
let store = NoopStatusStore;
let id = RequestId::new();
store.put(make_record(id.clone(), "k", 3600));
assert!(store.get(&id, "k").is_none());
assert_eq!(store.record_count(), 0);
}
}