use std::{
collections::HashMap,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use parking_lot::RwLock;
use arc_swap::ArcSwap;
use chrono::Utc;
use crate::{
config::StatusConfig,
metrics::Metrics,
request_id::RequestId,
status::{Domain, StatusStore, StatusStoreError, StatusUpdate, SubmissionStatusRecord},
};
pub struct InMemoryStatusStore {
records: RwLock<HashMap<String, SubmissionStatusRecord>>,
config: ArcSwap<StatusConfig>,
expired_total: AtomicU64,
metrics: Arc<Metrics>,
}
impl InMemoryStatusStore {
pub fn new(config: &StatusConfig, metrics: Arc<Metrics>) -> Arc<Self> {
Arc::new(Self {
records: RwLock::new(HashMap::new()),
config: ArcSwap::from_pointee(config.clone()),
expired_total: AtomicU64::new(0),
metrics,
})
}
pub fn expired_total(&self) -> u64 {
self.expired_total.load(Ordering::Relaxed)
}
}
impl StatusStore for InMemoryStatusStore {
fn put_received(&self, record: SubmissionStatusRecord) -> Result<(), StatusStoreError> {
let key = record.request_id.as_str().to_string();
let mut map = self.records.write();
let cfg = self.config.load();
if map.contains_key(&key) {
return Ok(());
}
if map.len() >= cfg.max_records {
let to_remove = map.iter().min_by_key(|(_, r)| r.created_at).map(|(k, _)| k.clone());
if let Some(k) = to_remove { map.remove(&k); }
}
map.insert(key, record);
self.metrics.status_record_created();
Ok(())
}
fn set_recipient_metadata(
&self,
request_id: &RequestId,
key_id: &str,
recipient_domains: Vec<Domain>,
recipient_count: u32,
) -> Result<(), StatusStoreError> {
let mut map = self.records.write();
if let Some(record) = map.get_mut(request_id.as_str()) {
if record.key_id == key_id && !record.is_expired() {
record.recipient_domains = recipient_domains;
record.recipient_count = recipient_count;
record.updated_at = Utc::now();
}
}
Ok(())
}
fn update_status(&self, request_id: &RequestId, key_id: &str, update: StatusUpdate) -> Result<(), StatusStoreError> {
let mut map = self.records.write();
if let Some(record) = map.get_mut(request_id.as_str()) {
if record.key_id != key_id || record.is_expired() || record.status.is_terminal() {
return Ok(());
}
let now = Utc::now();
record.status = update.status;
record.code = update.code;
if update.message.is_some() {
record.message = update.message;
}
record.updated_at = now;
let status_str = serde_json::to_value(&record.status)
.ok().and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "unknown".into());
let code_str = record.code.as_ref()
.and_then(|c| serde_json::to_value(c).ok())
.and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "none".into());
self.metrics.status_transitioned(&status_str, &code_str);
}
Ok(())
}
fn get(&self, request_id: &RequestId, key_id: &str) -> Result<Option<SubmissionStatusRecord>, StatusStoreError> {
{
let map = self.records.read();
let Some(record) = map.get(request_id.as_str()) else { return Ok(None); };
if record.key_id != key_id {
return Ok(None);
}
if !record.is_expired() {
return Ok(Some(record.clone()));
}
}
let mut map = self.records.write();
if let Some(record) = map.get(request_id.as_str()) {
if record.is_expired() {
map.remove(request_id.as_str());
self.expired_total.fetch_add(1, Ordering::Relaxed);
self.metrics.status_record_expired_one();
}
}
Ok(None)
}
fn expire_old_records(&self) {
let now = Utc::now();
let mut map = self.records.write();
let cfg = self.config.load();
let before = map.len();
map.retain(|_, r| r.expires_at > now);
let removed = before - map.len();
if removed > 0 {
self.expired_total.fetch_add(removed as u64, Ordering::Relaxed);
self.metrics.status_records_expired(removed);
}
if map.len() > cfg.max_records {
let excess = map.len() - cfg.max_records;
let mut keys_by_age: Vec<(String, chrono::DateTime<Utc>)> = map
.iter()
.map(|(k, r)| (k.clone(), r.created_at))
.collect();
keys_by_age.sort_by_key(|(_, t)| *t);
for (k, _) in keys_by_age.into_iter().take(excess) {
map.remove(&k);
}
}
}
fn record_count(&self) -> usize {
let count = self.records.read().len();
self.metrics.status_set_current(count);
count
}
fn reload_config(&self, config: &StatusConfig) {
self.config.store(Arc::new(config.clone()));
}
}
pub struct NoopStatusStore;
impl StatusStore for NoopStatusStore {
fn put_received(&self, _: SubmissionStatusRecord) -> Result<(), StatusStoreError> { Ok(()) }
fn set_recipient_metadata(&self, _: &RequestId, _: &str, _: Vec<Domain>, _: u32) -> Result<(), StatusStoreError> { Ok(()) }
fn update_status(&self, _: &RequestId, _: &str, _: StatusUpdate) -> Result<(), StatusStoreError> { Ok(()) }
fn get(&self, _: &RequestId, _: &str) -> Result<Option<SubmissionStatusRecord>, StatusStoreError> { Ok(None) }
fn expire_old_records(&self) {}
fn record_count(&self) -> usize { 0 }
fn reload_config(&self, _: &StatusConfig) {}
}
#[cfg(test)]
mod tests {
use crate::status::SubmissionStatus;
use super::*;
fn test_cfg() -> StatusConfig {
StatusConfig {
enabled: true,
store: "memory".into(),
ttl_seconds: 3600,
max_records: 100,
cleanup_interval_seconds: 60,
db_path: None,
redis_url: None,
}
}
fn make_record(request_id: RequestId, key_id: &str, ttl: u64) -> SubmissionStatusRecord {
SubmissionStatusRecord::new_received(request_id, key_id.into(), ttl)
}
#[test]
fn put_and_get_returns_record() {
let store = InMemoryStatusStore::new(&test_cfg(), Arc::new(Metrics::new()));
let id = RequestId::new();
let _ = store.put_received(make_record(id.clone(), "key-a", 3600));
assert!(store.get(&id, "key-a").unwrap().is_some());
}
#[test]
fn get_with_wrong_key_returns_none() {
let store = InMemoryStatusStore::new(&test_cfg(), Arc::new(Metrics::new()));
let id = RequestId::new();
let _ = store.put_received(make_record(id.clone(), "key-a", 3600));
assert!(store.get(&id, "key-b").unwrap().is_none());
}
#[test]
fn expired_record_returns_none() {
let store = InMemoryStatusStore::new(&test_cfg(), Arc::new(Metrics::new()));
let id = RequestId::new();
let _ = store.put_received(make_record(id.clone(), "key-a", 0)); assert!(store.get(&id, "key-a").unwrap().is_none());
}
#[test]
fn update_status_transitions_correctly() {
let store = InMemoryStatusStore::new(&test_cfg(), Arc::new(Metrics::new()));
let id = RequestId::new();
let _ = store.put_received(make_record(id.clone(), "key-a", 3600));
let _ = store.update_status(&id, "key-a", StatusUpdate {
status: SubmissionStatus::SmtpAccepted,
code: None,
message: Some("accepted".into()),
}).unwrap();
let r = store.get(&id, "key-a").unwrap().unwrap();
assert_eq!(r.status, SubmissionStatus::SmtpAccepted);
}
#[test]
fn terminal_status_is_not_updated() {
let store = InMemoryStatusStore::new(&test_cfg(), Arc::new(Metrics::new()));
let id = RequestId::new();
let _ = store.put_received(make_record(id.clone(), "key-a", 3600));
let _ = store.update_status(&id, "key-a", StatusUpdate {
status: SubmissionStatus::SmtpAccepted, code: None, message: None,
}).unwrap();
let _ = store.update_status(&id, "key-a", StatusUpdate {
status: SubmissionStatus::SmtpFailed, code: Some(crate::error::ErrorCode::SmtpUnavailable),
message: None,
}).unwrap();
let r = store.get(&id, "key-a").unwrap().unwrap();
assert_eq!(r.status, SubmissionStatus::SmtpAccepted, "terminal status must not change");
}
#[test]
fn expire_old_records_removes_expired() {
let store = InMemoryStatusStore::new(&test_cfg(), Arc::new(Metrics::new()));
let id1 = RequestId::new();
let id2 = RequestId::new();
let _ = store.put_received(make_record(id1.clone(), "key-a", 0)); let _ = store.put_received(make_record(id2.clone(), "key-a", 3600));
store.expire_old_records();
assert_eq!(store.record_count(), 1);
assert!(store.get(&id2, "key-a").unwrap().is_some());
}
#[test]
fn max_records_evicts_oldest() {
let mut cfg = test_cfg();
cfg.max_records = 2;
let store = InMemoryStatusStore::new(&cfg, Arc::new(Metrics::new()));
let ids: Vec<RequestId> = (0..3).map(|_| RequestId::new()).collect();
for id in &ids {
let _ = store.put_received(make_record(id.clone(), "key-a", 3600));
std::thread::sleep(std::time::Duration::from_millis(1));
}
assert_eq!(store.record_count(), 2, "must be capped at max_records");
}
#[test]
fn noop_store_always_returns_none() {
let store = NoopStatusStore;
let id = RequestId::new();
let _ = store.put_received(make_record(id.clone(), "k", 3600));
assert!(store.get(&id, "k").unwrap().is_none());
assert_eq!(store.record_count(), 0);
}
}