use crate::util::{Compact, TimeUntil};
use fnv::FnvHashMap;
use futures::future::{AbortHandle, AbortRegistration};
use std::{
collections::hash_map,
task::{Context, Poll},
time::SystemTime,
};
use tokio_util::time::delay_queue::{self, DelayQueue};
use tracing::Span;
#[derive(Debug, Default)]
pub struct InFlightRequests {
request_data: FnvHashMap<u64, RequestData>,
deadlines: DelayQueue<u64>,
}
#[derive(Debug)]
struct RequestData {
abort_handle: AbortHandle,
deadline_key: delay_queue::Key,
span: Span,
}
#[derive(Debug)]
pub struct AlreadyExistsError;
impl InFlightRequests {
pub fn len(&self) -> usize {
self.request_data.len()
}
pub fn start_request(
&mut self,
request_id: u64,
deadline: SystemTime,
span: Span,
) -> Result<AbortRegistration, AlreadyExistsError> {
match self.request_data.entry(request_id) {
hash_map::Entry::Vacant(vacant) => {
let timeout = deadline.time_until();
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let deadline_key = self.deadlines.insert(request_id, timeout);
vacant.insert(RequestData {
abort_handle,
deadline_key,
span,
});
Ok(abort_registration)
}
hash_map::Entry::Occupied(_) => Err(AlreadyExistsError),
}
}
pub fn cancel_request(&mut self, request_id: u64) -> bool {
if let Some(RequestData {
span,
abort_handle,
deadline_key,
}) = self.request_data.remove(&request_id)
{
let _entered = span.enter();
self.request_data.compact(0.1);
abort_handle.abort();
self.deadlines.remove(&deadline_key);
tracing::info!("ReceiveCancel");
true
} else {
false
}
}
pub fn remove_request(&mut self, request_id: u64) -> Option<Span> {
if let Some(request_data) = self.request_data.remove(&request_id) {
self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key);
Some(request_data.span)
} else {
None
}
}
pub fn poll_expired(&mut self, cx: &mut Context) -> Poll<Option<u64>> {
if self.deadlines.is_empty() {
return Poll::Ready(None);
}
self.deadlines.poll_expired(cx).map(|expired| {
let expired = expired?;
if let Some(RequestData {
abort_handle, span, ..
}) = self.request_data.remove(expired.get_ref())
{
let _entered = span.enter();
self.request_data.compact(0.1);
abort_handle.abort();
tracing::error!("DeadlineExceeded");
}
Some(expired.into_inner())
})
}
}
impl Drop for InFlightRequests {
fn drop(&mut self) {
self.request_data
.values()
.for_each(|request_data| request_data.abort_handle.abort())
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use futures::{
future::{pending, Abortable},
FutureExt,
};
use futures_test::task::noop_context;
#[tokio::test]
async fn start_request_increases_len() {
let mut in_flight_requests = InFlightRequests::default();
assert_eq!(in_flight_requests.len(), 0);
in_flight_requests
.start_request(0, SystemTime::now(), Span::current())
.unwrap();
assert_eq!(in_flight_requests.len(), 1);
}
#[tokio::test]
async fn polling_expired_aborts() {
let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests
.start_request(0, SystemTime::now(), Span::current())
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
tokio::time::pause();
tokio::time::advance(std::time::Duration::from_secs(1000)).await;
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Ready(Some(_))
);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Ready(Err(_))
);
assert_eq!(in_flight_requests.len(), 0);
}
#[tokio::test]
async fn cancel_request_aborts() {
let mut in_flight_requests = InFlightRequests::default();
let abort_registration = in_flight_requests
.start_request(0, SystemTime::now(), Span::current())
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
assert!(in_flight_requests.cancel_request(0));
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Ready(Err(_))
);
assert_eq!(in_flight_requests.len(), 0);
}
#[tokio::test]
async fn remove_request_doesnt_abort() {
let mut in_flight_requests = InFlightRequests::default();
assert!(in_flight_requests.deadlines.is_empty());
let abort_registration = in_flight_requests
.start_request(
0,
SystemTime::now() + std::time::Duration::from_secs(10),
Span::current(),
)
.unwrap();
let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration));
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Pending
);
assert!(!in_flight_requests.deadlines.is_empty());
assert_matches!(in_flight_requests.remove_request(0), Some(_));
assert!(in_flight_requests.deadlines.is_empty());
assert_matches!(
in_flight_requests.poll_expired(&mut noop_context()),
Poll::Ready(None)
);
assert_matches!(
abortable_future.poll_unpin(&mut noop_context()),
Poll::Pending
);
assert_eq!(in_flight_requests.len(), 0);
}
}