use actr_framework::Bytes;
use actr_protocol::ActorResult;
use std::collections::HashMap;
use std::time::{Duration, Instant};
pub(crate) const DEDUP_TTL: Duration = Duration::from_secs(15);
#[derive(Clone, Debug)]
enum CachedResult {
Done(ActorResult<Bytes>),
InFlight,
}
#[derive(Debug)]
struct Entry {
received_at: Instant,
result: CachedResult,
}
#[derive(Debug, Default)]
pub(crate) struct DedupState {
entries: HashMap<String, Entry>,
ttl: Duration,
}
#[derive(Debug)]
pub(crate) enum DedupOutcome {
Fresh,
InFlight,
Duplicate(ActorResult<Bytes>),
}
impl DedupState {
pub(crate) fn new() -> Self {
Self {
entries: HashMap::new(),
ttl: DEDUP_TTL,
}
}
pub(crate) fn check_or_mark(&mut self, request_id: &str) -> DedupOutcome {
let now = Instant::now();
self.evict_expired(now);
match self.entries.get(request_id) {
None => {
self.entries.insert(
request_id.to_string(),
Entry {
received_at: now,
result: CachedResult::InFlight,
},
);
DedupOutcome::Fresh
}
Some(entry) => match &entry.result {
CachedResult::InFlight => DedupOutcome::InFlight,
CachedResult::Done(r) => DedupOutcome::Duplicate(r.clone()),
},
}
}
pub(crate) fn complete(&mut self, request_id: &str, result: ActorResult<Bytes>) {
if let Some(entry) = self.entries.get_mut(request_id) {
entry.result = CachedResult::Done(result);
}
}
fn evict_expired(&mut self, now: Instant) {
self.entries
.retain(|_, e| now.duration_since(e.received_at) < self.ttl);
}
#[cfg(test)]
pub(crate) fn len(&self) -> usize {
self.entries.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ok_bytes(s: &str) -> ActorResult<Bytes> {
Ok(Bytes::from(s.as_bytes().to_vec()))
}
#[test]
fn fresh_request_is_marked_in_flight() {
let mut d = DedupState::new();
assert!(matches!(d.check_or_mark("req-1"), DedupOutcome::Fresh));
assert_eq!(d.len(), 1);
}
#[test]
fn concurrent_duplicate_returns_in_flight() {
let mut d = DedupState::new();
d.check_or_mark("req-1"); assert!(matches!(d.check_or_mark("req-1"), DedupOutcome::InFlight));
}
#[test]
fn completed_duplicate_returns_cached_response() {
let mut d = DedupState::new();
d.check_or_mark("req-1");
d.complete("req-1", ok_bytes("hello"));
let outcome = d.check_or_mark("req-1");
assert!(
matches!(outcome, DedupOutcome::Duplicate(Ok(ref b)) if b == "hello"),
"expected cached Ok(\"hello\")"
);
}
#[test]
fn error_response_is_cached_and_returned() {
use actr_protocol::ActrError;
let mut d = DedupState::new();
d.check_or_mark("req-err");
d.complete(
"req-err",
Err(ActrError::InvalidArgument("bad input".to_string())),
);
let outcome = d.check_or_mark("req-err");
assert!(
matches!(
outcome,
DedupOutcome::Duplicate(Err(ActrError::InvalidArgument(_)))
),
"expected cached Err"
);
}
#[test]
fn expired_entry_is_evicted_and_treated_as_fresh() {
let mut d = DedupState {
ttl: Duration::from_nanos(1), ..DedupState::new()
};
d.check_or_mark("req-old");
d.complete("req-old", ok_bytes("v1"));
d.check_or_mark("req-new"); assert!(matches!(d.check_or_mark("req-old"), DedupOutcome::Fresh));
}
#[test]
fn different_request_ids_are_independent() {
let mut d = DedupState::new();
d.check_or_mark("req-a");
d.complete("req-a", ok_bytes("a"));
assert!(matches!(d.check_or_mark("req-b"), DedupOutcome::Fresh));
assert!(matches!(
d.check_or_mark("req-a"),
DedupOutcome::Duplicate(_)
));
}
}