use crate::types::{RequestId, Response};
use dashmap::DashMap;
use std::{
cmp::Ordering,
collections::BinaryHeap,
sync::{
Arc, Mutex,
atomic::{AtomicU64, Ordering as AtomicOrdering},
},
time::{Duration, Instant},
};
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
const DEFAULT_REQUEST_TTL: Duration = Duration::from_secs(10);
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
pub(crate) enum PendingResponse {
Response(Response),
Timeout,
}
impl PendingResponse {
#[inline]
#[cfg(test)]
fn matches_timeout(&self) -> bool {
matches!(self, Self::Timeout)
}
}
pub(crate) struct RequestHandle {
sender: oneshot::Sender<PendingResponse>,
_cancellation_token: CancellationToken,
expires_at: Option<Instant>,
}
#[derive(Clone)]
pub(crate) struct RequestQueue {
pending: Arc<DashMap<RequestId, RequestHandle>>,
expirations: Arc<Mutex<BinaryHeap<RequestExpiry>>>,
next_expiry_seq: Arc<AtomicU64>,
ttl: Duration,
}
struct RequestExpiry {
expires_at: Instant,
sequence: u64,
id: RequestId,
}
impl PartialEq for RequestExpiry {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.expires_at == other.expires_at && self.sequence == other.sequence
}
}
impl Eq for RequestExpiry {}
impl PartialOrd for RequestExpiry {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for RequestExpiry {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
other
.expires_at
.cmp(&self.expires_at)
.then_with(|| other.sequence.cmp(&self.sequence))
}
}
impl RequestHandle {
pub(super) fn new(sender: oneshot::Sender<PendingResponse>, ttl: Duration) -> Self {
Self {
sender,
_cancellation_token: CancellationToken::new(),
expires_at: (!ttl.is_zero()).then_some(Instant::now() + ttl),
}
}
pub(crate) fn send(self, resp: Response) {
match self.sender.send(PendingResponse::Response(resp)) {
Ok(_) => (),
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::error!(
logger = "neva",
"Request handler failed to send response: {:?}",
_err
);
}
};
}
#[inline]
pub(crate) fn send_timeout(self) {
match self.sender.send(PendingResponse::Timeout) {
Ok(_) => (),
Err(_err) => {
#[cfg(feature = "tracing")]
tracing::error!(
logger = "neva",
"Request handler failed to send timeout response: {:?}",
_err
);
}
};
}
}
impl RequestQueue {
#[inline]
pub(crate) fn new(ttl: Duration) -> Self {
Self {
pending: Arc::new(DashMap::new()),
expirations: Arc::new(Mutex::new(BinaryHeap::new())),
next_expiry_seq: Arc::new(AtomicU64::new(0)),
ttl,
}
}
#[inline]
pub(crate) fn push(&self, id: &RequestId) -> oneshot::Receiver<PendingResponse> {
let (sender, receiver) = oneshot::channel();
let mut handle = RequestHandle::new(sender, self.ttl);
handle.expires_at = None;
self.pending.insert(id.clone(), handle);
receiver
}
#[inline]
pub(crate) fn activate(&self, id: &RequestId) {
if let Some(mut handle) = self.pending.get_mut(id) {
let Some(expires_at) = (!self.ttl.is_zero()).then_some(Instant::now() + self.ttl)
else {
handle.expires_at = None;
return;
};
handle.expires_at = Some(expires_at);
drop(handle);
let sequence = self.next_expiry_seq.fetch_add(1, AtomicOrdering::Relaxed);
if let Ok(mut expirations) = self.expirations.lock() {
expirations.push(RequestExpiry {
expires_at,
sequence,
id: id.clone(),
});
}
}
self.cleanup_expired();
}
#[inline]
pub(crate) fn pop(&self, id: &RequestId) -> Option<RequestHandle> {
if self.is_expired(id) {
if let Some((_, handle)) = self.pending.remove(id) {
handle.send_timeout();
}
return None;
}
self.pending.remove(id).map(|(_, handle)| handle)
}
#[inline]
pub(crate) fn complete(&self, resp: Response) {
self.cleanup_expired();
if let Some(sender) = self.pop(&resp.full_id()) {
sender.send(resp)
}
}
#[inline]
fn cleanup_expired(&self) {
let now = Instant::now();
let mut expired = Vec::new();
if let Ok(mut expirations) = self.expirations.lock() {
while expirations
.peek()
.is_some_and(|entry| entry.expires_at <= now)
{
let entry = expirations.pop().expect("peeked entry must exist");
expired.push((entry.id, entry.expires_at));
}
}
for (id, expires_at) in expired {
let should_remove = self
.pending
.get(&id)
.is_some_and(|handle| handle.expires_at == Some(expires_at));
if should_remove && let Some((_, handle)) = self.pending.remove(&id) {
handle.send_timeout();
}
}
}
#[inline]
fn is_expired(&self, id: &RequestId) -> bool {
self.pending
.get(id)
.and_then(|handle| handle.expires_at)
.is_some_and(|expires_at| expires_at <= Instant::now())
}
}
impl Default for RequestQueue {
#[inline]
fn default() -> Self {
Self::new(DEFAULT_REQUEST_TTL)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use tokio::time::{Duration, timeout};
#[test]
fn it_pushes_and_pops_request() {
let queue = RequestQueue::default();
let id = RequestId::Number(1);
let receiver = queue.push(&id);
let handle = queue.pop(&id);
assert!(handle.is_some(), "Expected handle to exist");
assert!(
queue.pop(&id).is_none(),
"Handle should be removed after pop"
);
drop(receiver); }
#[tokio::test]
async fn it_sends_and_receives() {
let queue = RequestQueue::default();
let id = RequestId::Number(1);
let receiver = queue.push(&id);
let handle = queue.pop(&id).expect("Should have handle");
let expected = Response::success(id, json!({ "content": "done" }));
handle.send(expected.clone());
let Response::Ok(expected) = expected else {
unreachable!()
};
let PendingResponse::Response(Response::Ok(actual)) =
timeout(Duration::from_secs(1), receiver)
.await
.expect("Receiver should complete")
.expect("Sender should send")
else {
unreachable!()
};
assert_eq!(actual.result, expected.result);
assert_eq!(actual.id, expected.id);
}
#[tokio::test]
async fn it_sends_response_if_pending() {
let queue = RequestQueue::default();
let id = RequestId::Number(1);
let receiver = queue.push(&id);
let response = Response::success(id, json!({ "content": "done" }));
queue.complete(response.clone());
let Response::Ok(response) = response else {
unreachable!()
};
let PendingResponse::Response(Response::Ok(actual)) =
timeout(Duration::from_secs(1), receiver)
.await
.expect("Should receive within timeout")
.expect("Should receive response")
else {
unreachable!()
};
assert_eq!(actual.result, response.result);
}
#[test]
fn it_does_nothing_if_not_pending() {
let queue = RequestQueue::default();
let id = RequestId::Number(1);
let response = Response::success(id, json!({ "content": "done" }));
queue.complete(response);
}
#[test]
fn it_does_remove_expired_pending_requests() {
let queue = RequestQueue::new(Duration::from_millis(1));
let id = RequestId::Number(1);
let _receiver = queue.push(&id);
queue.activate(&id);
std::thread::sleep(Duration::from_millis(10));
assert!(queue.pop(&id).is_none());
}
#[tokio::test]
async fn pop_does_not_close_non_target_receivers() {
let queue = RequestQueue::new(Duration::from_millis(5));
let expired_id = RequestId::Number(1);
let live_id = RequestId::Number(2);
let _expired = queue.push(&expired_id);
let live = queue.push(&live_id);
queue.activate(&expired_id);
std::thread::sleep(Duration::from_millis(10));
assert!(queue.pop(&expired_id).is_none());
let response = Response::success(live_id, json!({ "content": "done" }));
queue.complete(response);
assert!(
timeout(Duration::from_secs(1), live).await.is_ok(),
"non-target receiver should remain open"
);
}
#[tokio::test]
async fn pop_sends_timeout_response_for_expired_request() {
let queue = RequestQueue::new(Duration::from_millis(5));
let id = RequestId::Number(1);
let receiver = queue.push(&id);
queue.activate(&id);
std::thread::sleep(Duration::from_millis(10));
assert!(queue.pop(&id).is_none());
assert!(
receiver
.await
.expect("expired request should resolve")
.matches_timeout(),
"expired request should resolve as timeout"
);
}
#[tokio::test]
async fn cleanup_sends_timeout_response_for_expired_requests() {
let queue = RequestQueue::new(Duration::from_millis(5));
let expired_id = RequestId::Number(1);
let live_id = RequestId::Number(2);
let expired = queue.push(&expired_id);
let _live = queue.push(&live_id);
queue.activate(&expired_id);
std::thread::sleep(Duration::from_millis(10));
let response = Response::success(live_id, json!({ "content": "done" }));
queue.complete(response);
assert!(
expired
.await
.expect("expired request should resolve")
.matches_timeout(),
"expired request should resolve as timeout"
);
}
#[test]
fn push_does_not_start_ttl_until_activated() {
let queue = RequestQueue::new(Duration::from_millis(1));
let id = RequestId::Number(1);
let _receiver = queue.push(&id);
std::thread::sleep(Duration::from_millis(10));
assert!(queue.pop(&id).is_some());
}
}