use indexmap::{IndexMap, map::Entry};
use ruma::time::{Duration, Instant};
use tracing::warn;
use uuid::Uuid;
#[derive(Clone, Debug)]
pub(crate) struct RequestLimits {
pub(crate) max_pending_requests: usize,
pub(crate) response_timeout: Duration,
}
pub(super) struct PendingRequests<T> {
requests: IndexMap<Uuid, Expirable<T>>,
limits: RequestLimits,
}
impl<T> PendingRequests<T> {
pub(super) fn new(limits: RequestLimits) -> Self {
Self { requests: IndexMap::with_capacity(limits.max_pending_requests), limits }
}
pub(super) fn insert(&mut self, key: Uuid, value: T) -> Option<&mut T> {
if self.requests.len() >= self.limits.max_pending_requests {
return None;
}
let Entry::Vacant(entry) = self.requests.entry(key) else {
panic!("uuid collision");
};
let expirable = Expirable::new(value, Instant::now() + self.limits.response_timeout);
let inserted = entry.insert(expirable);
Some(&mut inserted.value)
}
pub(super) fn extract(&mut self, key: &Uuid) -> Result<T, &'static str> {
let value =
self.requests.swap_remove(key).ok_or("Received response for an unknown request")?;
value.value().ok_or("Dropping response for an expired request")
}
pub(super) fn remove_expired(&mut self) {
self.requests.retain(|id, req| {
let expired = req.expired();
if expired {
warn!(?id, "Dropping response for an expired request");
}
!expired
});
}
}
struct Expirable<T> {
value: T,
expires_at: Instant,
}
impl<T> Expirable<T> {
fn new(value: T, expires_at: Instant) -> Self {
Self { value, expires_at }
}
fn value(self) -> Option<T> {
(!self.expired()).then_some(self.value)
}
fn expired(&self) -> bool {
Instant::now() >= self.expires_at
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use uuid::Uuid;
use super::{PendingRequests, RequestLimits};
struct Dummy;
#[test]
fn insertion_limits_for_pending_requests_work() {
let mut pending: PendingRequests<Dummy> = PendingRequests::new(RequestLimits {
max_pending_requests: 1,
response_timeout: Duration::from_secs(10),
});
let first = Uuid::new_v4();
assert!(pending.insert(first, Dummy).is_some());
assert!(!pending.requests.is_empty());
let second = Uuid::new_v4();
assert!(pending.insert(second, Dummy).is_none());
assert!(pending.extract(&first).is_ok());
assert!(pending.extract(&second).is_err());
assert!(pending.insert(second, Dummy).is_some());
assert!(pending.extract(&second).is_ok());
assert!(pending.requests.is_empty());
}
#[test]
fn time_limits_for_pending_requests_work() {
let mut pending: PendingRequests<Dummy> = PendingRequests::new(RequestLimits {
max_pending_requests: 10,
response_timeout: Duration::from_secs(1),
});
let key = Uuid::new_v4();
assert!(pending.insert(key, Dummy).is_some());
std::thread::sleep(Duration::from_secs(2));
assert!(pending.extract(&key).is_err());
assert!(pending.insert(Uuid::new_v4(), Dummy).is_some());
assert!(pending.insert(Uuid::new_v4(), Dummy).is_some());
std::thread::sleep(Duration::from_millis(500));
pending.remove_expired();
let key = Uuid::new_v4();
assert!(pending.insert(key, Dummy).is_some());
assert!(pending.requests.len() == 3);
std::thread::sleep(Duration::from_millis(500));
pending.remove_expired();
assert!(pending.requests.len() == 1);
assert!(pending.extract(&key).is_ok());
assert!(pending.requests.is_empty());
}
}