use crate::context;
use crate::util::{Compact, TimeUntil};
use fnv::FnvHashMap;
use std::collections::hash_map;
use std::task::{Context, Poll};
use tokio::sync::oneshot;
use tokio_util::time::delay_queue::{self, DelayQueue};
use tracing::Span;
#[derive(Debug)]
pub struct InFlightRequests<Resp> {
request_data: FnvHashMap<u64, RequestData<Resp>>,
deadlines: DelayQueue<u64>,
}
impl<Resp> Default for InFlightRequests<Resp> {
fn default() -> Self {
Self {
request_data: Default::default(),
deadlines: Default::default(),
}
}
}
#[derive(Debug)]
struct RequestData<Res> {
ctx: context::Context,
span: Span,
response_completion: oneshot::Sender<Res>,
deadline_key: delay_queue::Key,
}
#[derive(Debug)]
pub struct AlreadyExistsError;
impl<Res> InFlightRequests<Res> {
pub fn len(&self) -> usize {
self.request_data.len()
}
pub fn is_empty(&self) -> bool {
self.request_data.is_empty()
}
pub fn insert_request(&mut self, request_id: u64, ctx: context::Context, span: Span, response_completion: oneshot::Sender<Res>) -> Result<(), AlreadyExistsError> {
match self.request_data.entry(request_id) {
hash_map::Entry::Vacant(vacant) => {
let timeout = ctx.deadline.time_until();
let deadline_key = self.deadlines.insert(request_id, timeout);
vacant.insert(RequestData {
ctx,
span,
response_completion,
deadline_key,
});
Ok(())
},
hash_map::Entry::Occupied(_) => Err(AlreadyExistsError),
}
}
pub fn complete_request(&mut self, request_id: u64, result: Res) -> Option<Span> {
if let Some(request_data) = self.request_data.remove(&request_id) {
self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key);
let _ = request_data.response_completion.send(result);
return Some(request_data.span);
}
tracing::debug!("No in-flight request found for request_id = {request_id}.");
None
}
pub fn complete_all_requests<'a>(&'a mut self, mut result: impl FnMut() -> Res + 'a) -> impl Iterator<Item = Span> + 'a {
self.deadlines.clear();
self.request_data.drain().map(move |(_, request_data)| {
let _ = request_data.response_completion.send(result());
request_data.span
})
}
pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> {
if let Some(request_data) = self.request_data.remove(&request_id) {
self.request_data.compact(0.1);
self.deadlines.remove(&request_data.deadline_key);
Some((request_data.ctx, request_data.span))
} else {
None
}
}
pub fn poll_expired(&mut self, cx: &mut Context, expired_error: impl Fn() -> Res) -> Poll<Option<u64>> {
self.deadlines.poll_expired(cx).map(|expired| {
let request_id = expired?.into_inner();
if let Some(request_data) = self.request_data.remove(&request_id) {
let _entered = request_data.span.enter();
tracing::error!("DeadlineExceeded");
self.request_data.compact(0.1);
let _ = request_data.response_completion.send(expired_error());
}
Some(request_id)
})
}
}