use actr_framework::Bytes;
use actr_protocol::ActorResult;
use std::collections::HashMap;
use std::time::{Duration, Instant};
use tokio::sync::watch;
pub(crate) const DEDUP_TTL: Duration = Duration::from_secs(30);
pub(crate) type DedupWaiter = watch::Receiver<Option<ActorResult<Bytes>>>;
#[derive(Clone, Debug)]
enum CachedResult {
Done(ActorResult<Bytes>),
InFlight {
completion_tx: watch::Sender<Option<ActorResult<Bytes>>>,
},
}
#[derive(Debug)]
struct Entry {
received_at: Instant,
result: CachedResult,
}
#[derive(Debug, Default)]
pub(crate) struct DedupState {
entries: HashMap<String, Entry>,
ttl: Duration,
}
#[derive(Debug)]
pub(crate) enum DedupOutcome {
Fresh,
InFlight(DedupWaiter),
Duplicate(ActorResult<Bytes>),
}
impl DedupState {
pub(crate) fn new() -> Self {
Self {
entries: HashMap::new(),
ttl: DEDUP_TTL,
}
}
pub(crate) fn check_or_mark(&mut self, request_id: &str) -> DedupOutcome {
let now = Instant::now();
self.evict_expired(now);
match self.entries.get(request_id) {
None => {
let (completion_tx, _completion_rx) = watch::channel(None);
self.entries.insert(
request_id.to_string(),
Entry {
received_at: now,
result: CachedResult::InFlight { completion_tx },
},
);
DedupOutcome::Fresh
}
Some(entry) => match &entry.result {
CachedResult::InFlight { completion_tx } => {
DedupOutcome::InFlight(completion_tx.subscribe())
}
CachedResult::Done(r) => DedupOutcome::Duplicate(r.clone()),
},
}
}
pub(crate) fn complete(&mut self, request_id: &str, result: ActorResult<Bytes>) {
if let Some(entry) = self.entries.get_mut(request_id) {
if let CachedResult::InFlight { completion_tx } = &entry.result {
let _ = completion_tx.send(Some(result.clone()));
}
entry.result = CachedResult::Done(result);
}
}
fn evict_expired(&mut self, now: Instant) {
let ttl = self.ttl;
self.entries.retain(|_, e| match e.result {
CachedResult::InFlight { .. } => true,
CachedResult::Done(_) => now.duration_since(e.received_at) < ttl,
});
}
#[cfg(test)]
pub(crate) fn len(&self) -> usize {
self.entries.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ok_bytes(s: &str) -> ActorResult<Bytes> {
Ok(Bytes::from(s.as_bytes().to_vec()))
}
#[test]
fn fresh_request_is_marked_and_concurrent_duplicate_returns_waiter() {
let mut d = DedupState::new();
assert!(matches!(d.check_or_mark("req-1"), DedupOutcome::Fresh));
assert_eq!(d.len(), 1);
assert!(matches!(
d.check_or_mark("req-1"),
DedupOutcome::InFlight(_)
));
}
#[tokio::test]
async fn in_flight_duplicate_waiter_receives_original_result() {
let mut d = DedupState::new();
assert!(matches!(d.check_or_mark("req-1"), DedupOutcome::Fresh));
let mut waiter = match d.check_or_mark("req-1") {
DedupOutcome::InFlight(waiter) => waiter,
other => panic!("expected InFlight waiter, got {other:?}"),
};
d.complete("req-1", ok_bytes("hello"));
let _ = waiter.changed().await;
let result = waiter
.borrow()
.clone()
.expect("waiter should observe completed result");
assert!(
matches!(result, Ok(ref b) if b == "hello"),
"expected waiter to receive original Ok(\"hello\")"
);
}
#[test]
fn in_flight_entry_is_not_evicted_by_completed_cache_ttl() {
let mut d = DedupState {
ttl: Duration::from_nanos(1),
..DedupState::new()
};
assert!(matches!(d.check_or_mark("req-slow"), DedupOutcome::Fresh));
d.check_or_mark("req-other");
assert!(matches!(
d.check_or_mark("req-slow"),
DedupOutcome::InFlight(_)
));
}
#[test]
fn dedup_ttl_covers_reliable_rpc_retry_window() {
assert!(
DEDUP_TTL >= Duration::from_secs(20),
"dedup TTL should cover late RpcReliable retries"
);
}
#[test]
fn completed_duplicate_returns_cached_success_or_error() {
let mut d = DedupState::new();
d.check_or_mark("req-1");
d.complete("req-1", ok_bytes("hello"));
let outcome = d.check_or_mark("req-1");
assert!(
matches!(outcome, DedupOutcome::Duplicate(Ok(ref b)) if b == "hello"),
"expected cached Ok(\"hello\")"
);
use actr_protocol::ActrError;
let mut d = DedupState::new();
d.check_or_mark("req-err");
d.complete(
"req-err",
Err(ActrError::InvalidArgument("bad input".to_string())),
);
let outcome = d.check_or_mark("req-err");
assert!(
matches!(
outcome,
DedupOutcome::Duplicate(Err(ActrError::InvalidArgument(_)))
),
"expected cached Err"
);
}
#[test]
fn expired_entry_is_evicted_and_treated_as_fresh() {
let mut d = DedupState {
ttl: Duration::from_nanos(1), ..DedupState::new()
};
d.check_or_mark("req-old");
d.complete("req-old", ok_bytes("v1"));
d.check_or_mark("req-new"); assert!(matches!(d.check_or_mark("req-old"), DedupOutcome::Fresh));
}
#[test]
fn different_request_ids_are_independent() {
let mut d = DedupState::new();
d.check_or_mark("req-a");
d.complete("req-a", ok_bytes("a"));
assert!(matches!(d.check_or_mark("req-b"), DedupOutcome::Fresh));
assert!(matches!(
d.check_or_mark("req-a"),
DedupOutcome::Duplicate(_)
));
}
}