use crate::types::Response;
use crate::types::mrtr::state::now_secs;
use futures_util::future::BoxFuture;
use std::sync::Arc;
pub trait RequestStateStore: Send + Sync {
fn get<'a>(&'a self, tag: &'a str) -> BoxFuture<'a, Option<Response>>;
fn put<'a>(&'a self, tag: &'a str, response: Response, exp: u64) -> BoxFuture<'a, ()>;
fn reserve<'a>(&'a self, _tag: &'a str) -> BoxFuture<'a, Box<dyn Send>> {
Box::pin(async { Box::new(()) as Box<dyn Send> })
}
}
#[derive(Debug, Default)]
pub struct InMemoryStateStore {
entries: dashmap::DashMap<String, (Response, u64)>,
locks: dashmap::DashMap<String, Arc<tokio::sync::Mutex<()>>>,
}
impl InMemoryStateStore {
pub fn new() -> Self {
Self::default()
}
}
impl RequestStateStore for InMemoryStateStore {
fn get<'a>(&'a self, tag: &'a str) -> BoxFuture<'a, Option<Response>> {
Box::pin(async move {
let now = now_secs();
let hit = match self.entries.get(tag) {
Some(entry) if entry.1 > now => Some(entry.0.clone()),
Some(_) => None, None => return None,
};
if hit.is_none() {
self.entries.remove(tag);
}
hit
})
}
fn put<'a>(&'a self, tag: &'a str, response: Response, exp: u64) -> BoxFuture<'a, ()> {
Box::pin(async move {
let now = now_secs();
self.entries.retain(|_, (_, e)| *e > now);
self.locks.retain(|_, m| Arc::strong_count(m) > 1);
self.entries.insert(tag.to_owned(), (response, exp));
})
}
fn reserve<'a>(&'a self, tag: &'a str) -> BoxFuture<'a, Box<dyn Send>> {
Box::pin(async move {
let mutex = self
.locks
.entry(tag.to_owned())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone();
let guard = mutex.lock_owned().await;
Box::new(guard) as Box<dyn Send>
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::RequestId;
fn resp(id: i64) -> Response {
Response::success(RequestId::Number(id), serde_json::json!({ "ok": true }))
}
#[tokio::test]
async fn put_then_get_returns_the_cached_response() {
let store = InMemoryStateStore::new();
assert!(store.get("tag").await.is_none());
store.put("tag", resp(1), now_secs() + 300).await;
let got = store.get("tag").await.expect("cached response");
assert_eq!(*got.id(), RequestId::Number(1));
}
#[tokio::test]
async fn expired_entries_are_not_returned() {
let store = InMemoryStateStore::new();
store
.put("tag", resp(1), now_secs().saturating_sub(1))
.await;
assert!(store.get("tag").await.is_none());
}
#[tokio::test]
async fn put_evicts_expired_entries() {
let store = InMemoryStateStore::new();
store
.put("old", resp(1), now_secs().saturating_sub(1))
.await;
store.put("new", resp(2), now_secs() + 300).await;
assert_eq!(store.entries.len(), 1);
assert!(store.entries.contains_key("new"));
}
#[tokio::test]
async fn reserve_serialises_concurrent_final_rounds() {
use std::sync::atomic::{AtomicUsize, Ordering};
let store = Arc::new(InMemoryStateStore::new());
let runs = Arc::new(AtomicUsize::new(0));
async fn round(store: Arc<InMemoryStateStore>, runs: Arc<AtomicUsize>) {
let _guard = store.reserve("tag").await;
if store.get("tag").await.is_none() {
runs.fetch_add(1, Ordering::SeqCst);
tokio::task::yield_now().await;
store.put("tag", resp(1), now_secs() + 300).await;
}
}
let a = tokio::spawn(round(store.clone(), runs.clone()));
let b = tokio::spawn(round(store.clone(), runs.clone()));
a.await.expect("task a");
b.await.expect("task b");
assert_eq!(
runs.load(Ordering::SeqCst),
1,
"the final-round handler must run exactly once across identical retries"
);
assert!(store.get("tag").await.is_some());
}
#[tokio::test]
async fn put_sweeps_released_reservation_locks() {
let store = InMemoryStateStore::new();
{
let _guard = store.reserve("tag").await;
assert_eq!(store.locks.len(), 1);
}
store.put("other", resp(1), now_secs() + 300).await;
assert!(store.locks.is_empty());
}
}