use std::sync::Arc;
use arc_swap::ArcSwap;
use redis::Commands;
use crate::{
config::StatusConfig,
metrics::Metrics,
request_id::RequestId,
status::{StatusStore, StatusUpdate, SubmissionStatus, SubmissionStatusRecord},
};
const KEY_PREFIX: &str = "rele:s:";
pub struct RedisStatusStore {
client: redis::Client,
config: ArcSwap<StatusConfig>,
metrics: Arc<Metrics>,
}
impl RedisStatusStore {
pub fn open(
redis_url: &str,
config: &StatusConfig,
metrics: Arc<Metrics>,
) -> Result<Arc<Self>, String> {
let client = redis::Client::open(redis_url)
.map_err(|e| format!("failed to open Redis client for {redis_url}: {e}"))?;
let mut conn = client.get_connection()
.map_err(|e| format!("failed to connect to Redis at {redis_url}: {e}"))?;
let _: String = redis::cmd("PING").query(&mut conn)
.map_err(|e| format!("Redis PING failed: {e}"))?;
tracing::info!(url = redis_url, "Redis status store connected");
Ok(Arc::new(Self {
client,
config: ArcSwap::from_pointee(config.clone()),
metrics,
}))
}
fn get_conn(&self) -> Option<redis::Connection> {
match self.client.get_connection() {
Ok(c) => Some(c),
Err(e) => {
tracing::warn!(error = %e, "Redis connection failed (degraded mode)");
None
}
}
}
fn key(request_id: &RequestId) -> String {
format!("{KEY_PREFIX}{}", request_id.as_str())
}
}
impl StatusStore for RedisStatusStore {
fn put(&self, record: SubmissionStatusRecord) {
let cfg = self.config.load();
let k = Self::key(&record.request_id);
let ttl: u64 = cfg.ttl_seconds;
let json = match serde_json::to_string(&record) {
Ok(j) => j,
Err(e) => {
tracing::warn!(error = %e, "Redis put: serialisation failed");
return;
}
};
let Some(mut conn) = self.get_conn() else { return };
if let Err(e) = conn.set_ex::<_, _, ()>(&k, &json, ttl) {
tracing::warn!(error = %e, key = %k, "Redis SET EX failed");
} else {
self.metrics.status_record_created();
}
}
fn update_status(&self, request_id: &RequestId, key_id: &str, update: StatusUpdate) {
let k = Self::key(request_id);
let Some(mut conn) = self.get_conn() else { return };
let raw: Option<String> = match conn.get(&k) {
Ok(v) => v,
Err(e) => {
tracing::warn!(error = %e, "Redis GET failed in update_status");
return;
}
};
let Some(raw) = raw else { return };
let mut record: SubmissionStatusRecord = match serde_json::from_str(&raw) {
Ok(r) => r,
Err(e) => {
tracing::warn!(error = %e, "Redis update_status: deserialisation failed");
return;
}
};
if record.key_id != key_id || record.is_expired() || record.status.is_terminal() {
return;
}
let s = update.status;
let c = update.code.as_ref()
.and_then(|c| serde_json::to_value(c).ok())
.and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "none".into());
record.status = s.clone();
record.code = update.code;
if update.message.is_some() { record.message = update.message; }
record.updated_at = chrono::Utc::now();
let cfg = self.config.load();
let ttl: u64 = cfg.ttl_seconds;
match serde_json::to_string(&record) {
Ok(json) => {
if let Err(e) = conn.set_ex::<_, _, ()>(&k, &json, ttl) {
tracing::warn!(error = %e, "Redis SET EX failed in update_status");
} else {
let status_s = serde_json::to_value(&s)
.ok()
.and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "unknown".into());
self.metrics.status_transitioned(&status_s, &c);
}
}
Err(e) => {
tracing::warn!(error = %e, "Redis update_status: serialisation failed");
}
}
}
fn get(&self, request_id: &RequestId, key_id: &str) -> Option<SubmissionStatusRecord> {
let k = Self::key(request_id);
let mut conn = self.get_conn()?;
let raw: Option<String> = match conn.get(&k) {
Ok(v) => v,
Err(e) => {
tracing::warn!(error = %e, "Redis GET failed");
return None;
}
};
let raw = raw?;
let record: SubmissionStatusRecord = match serde_json::from_str(&raw) {
Ok(r) => r,
Err(e) => {
tracing::warn!(error = %e, "Redis get: deserialisation failed");
return None;
}
};
if record.key_id != key_id { return None; }
if record.is_expired() {
let _: () = conn.del(&k).unwrap_or(());
self.metrics.status_record_expired_one();
return None;
}
Some(record)
}
fn expire_old_records(&self) {
}
fn record_count(&self) -> usize {
let Some(mut conn) = self.get_conn() else { return 0 };
let pattern = format!("{KEY_PREFIX}*");
let mut cursor: u64 = 0;
let mut total = 0usize;
loop {
let result: redis::RedisResult<(u64, Vec<String>)> = redis::cmd("SCAN")
.arg(cursor).arg("MATCH").arg(&pattern).arg("COUNT").arg(100)
.query(&mut conn);
match result {
Ok((new_cursor, keys)) => {
total += keys.len();
cursor = new_cursor;
if cursor == 0 { break; }
}
Err(e) => {
tracing::warn!(error = %e, "Redis SCAN failed in record_count");
break;
}
}
}
self.metrics.status_set_current(total);
total
}
fn reload_config(&self, config: &StatusConfig) {
self.config.store(Arc::new(config.clone()));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::status::recipient_domains_from;
fn redis_url() -> Option<String> {
std::env::var("REDIS_TEST_URL").ok()
}
fn test_cfg() -> StatusConfig {
StatusConfig {
enabled: true,
store: "redis".into(),
ttl_seconds: 60,
max_records: 1000,
cleanup_interval_seconds: 60,
db_path: None,
redis_url: None,
}
}
fn make_record(id: &RequestId, key: &str, ttl: u64) -> SubmissionStatusRecord {
SubmissionStatusRecord::new(
id.clone(), key.into(),
recipient_domains_from(&["user@example.com".to_string()], &[]),
1, ttl,
)
}
#[test]
fn record_serialises_and_deserialises() {
let id = RequestId::new();
let r = make_record(&id, "k", 3600);
let json = serde_json::to_string(&r).unwrap();
let r2: SubmissionStatusRecord = serde_json::from_str(&json).unwrap();
assert_eq!(r2.request_id.as_str(), id.as_str());
assert_eq!(r2.key_id, "k");
}
#[test]
fn key_prefix_correct() {
let id = RequestId::new();
let key = RedisStatusStore::key(&id);
assert!(key.starts_with("rele:s:req_"), "key: {key}");
}
#[test]
fn redis_put_and_get() {
let Some(url) = redis_url() else { return };
let store = RedisStatusStore::open(
&url, &test_cfg(), Arc::new(Metrics::new())
).expect("Redis connect");
let id = RequestId::new();
store.put(make_record(&id, "key-a", 60));
let r = store.get(&id, "key-a");
assert!(r.is_some(), "record must be retrievable");
}
#[test]
fn redis_wrong_key_returns_none() {
let Some(url) = redis_url() else { return };
let store = RedisStatusStore::open(
&url, &test_cfg(), Arc::new(Metrics::new())
).unwrap();
let id = RequestId::new();
store.put(make_record(&id, "key-a", 60));
assert!(store.get(&id, "key-b").is_none());
}
#[test]
fn redis_update_transitions_status() {
let Some(url) = redis_url() else { return };
let store = RedisStatusStore::open(
&url, &test_cfg(), Arc::new(Metrics::new())
).unwrap();
let id = RequestId::new();
store.put(make_record(&id, "key-a", 60));
store.update_status(&id, "key-a", StatusUpdate {
status: SubmissionStatus::SmtpAccepted,
code: None,
message: Some("ok".into()),
});
let r = store.get(&id, "key-a").unwrap();
assert_eq!(r.status, SubmissionStatus::SmtpAccepted);
}
#[test]
fn redis_terminal_not_overwritten() {
let Some(url) = redis_url() else { return };
let store = RedisStatusStore::open(
&url, &test_cfg(), Arc::new(Metrics::new())
).unwrap();
let id = RequestId::new();
store.put(make_record(&id, "k", 60));
store.update_status(&id, "k", StatusUpdate {
status: SubmissionStatus::SmtpAccepted, code: None, message: None,
});
store.update_status(&id, "k", StatusUpdate {
status: SubmissionStatus::SmtpFailed,
code: Some(crate::status::ErrorCode::SmtpUnavailable),
message: None,
});
let r = store.get(&id, "k").unwrap();
assert_eq!(r.status, SubmissionStatus::SmtpAccepted, "terminal must not change");
}
#[test]
fn redis_expired_record_returns_none() {
let Some(url) = redis_url() else { return };
let store = RedisStatusStore::open(
&url, &test_cfg(), Arc::new(Metrics::new())
).unwrap();
let id = RequestId::new();
store.put(make_record(&id, "k", 0));
assert!(store.get(&id, "k").is_none(), "expired record must return None");
}
}