use crate::{routes::HandlerArgs, types::RequestError};
use core::fmt;
use pin_project::pin_project;
use serde_json::value::RawValue;
use std::{
convert::Infallible,
future::Future,
panic,
task::{ready, Context, Poll},
};
use tokio::task::JoinSet;
use tower::util::{BoxCloneSyncService, Oneshot};
use super::Response;
#[pin_project]
pub struct RouteFuture {
#[pin]
inner:
Oneshot<BoxCloneSyncService<HandlerArgs, Option<Box<RawValue>>, Infallible>, HandlerArgs>,
span: Option<tracing::Span>,
}
impl RouteFuture {
pub(crate) const fn new(
inner: Oneshot<
BoxCloneSyncService<HandlerArgs, Option<Box<RawValue>>, Infallible>,
HandlerArgs,
>,
) -> Self {
Self { inner, span: None }
}
pub(crate) fn with_span(self, span: tracing::Span) -> Self {
Self {
span: Some(span),
..self
}
}
}
impl fmt::Debug for RouteFuture {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RouteFuture").finish_non_exhaustive()
}
}
impl Future for RouteFuture {
type Output = Result<Option<Box<RawValue>>, Infallible>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let _enter = this.span.as_ref().map(tracing::Span::enter);
this.inner.poll(cx)
}
}
#[pin_project(project = BatchFutureInnerProj)]
enum BatchFutureInner {
Prepping(Vec<RouteFuture>),
Running(#[pin] JoinSet<Result<Option<Box<RawValue>>, Infallible>>),
}
impl BatchFutureInner {
fn len(&self) -> usize {
match self {
Self::Prepping(futs) => futs.len(),
Self::Running(futs) => futs.len(),
}
}
fn is_empty(&self) -> bool {
self.len() == 0
}
fn run(&mut self) {
if let Self::Prepping(futs) = self {
let js = futs.drain(..).collect::<JoinSet<_>>();
*self = Self::Running(js);
}
}
}
impl fmt::Debug for BatchFutureInner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut s = f.debug_struct("BatchFutureInner");
match self {
Self::Prepping(futs) => s.field("prepared", &futs.len()),
Self::Running(futs) => s.field("running", &futs.len()),
}
.finish_non_exhaustive()
}
}
#[pin_project]
pub struct BatchFuture {
#[pin]
futs: BatchFutureInner,
resps: Vec<Box<RawValue>>,
single: bool,
span: Option<tracing::Span>,
}
impl fmt::Debug for BatchFuture {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BatchFuture")
.field("state", &self.futs)
.field("responses", &self.resps.len())
.finish()
}
}
impl BatchFuture {
pub(crate) fn new_with_capacity(single: bool, capacity: usize) -> Self {
Self {
futs: BatchFutureInner::Prepping(Vec::with_capacity(capacity)),
resps: Vec::with_capacity(capacity),
single,
span: None,
}
}
pub(crate) fn with_span(self, span: tracing::Span) -> Self {
Self {
span: Some(span),
..self
}
}
pub(crate) fn push(&mut self, fut: RouteFuture) {
let BatchFutureInner::Prepping(ref mut futs) = self.futs else {
panic!("pushing into a running batch future");
};
futs.push(fut);
}
pub(crate) fn push_resp(&mut self, resp: Box<RawValue>) {
self.resps.push(resp);
}
pub(crate) fn push_parse_error(&mut self) {
self.push_resp(Response::parse_error());
}
pub(crate) fn push_parse_result(&mut self, result: Result<RouteFuture, RequestError>) {
match result {
Ok(fut) => self.push(fut),
Err(_) => self.push_parse_error(),
}
}
pub(crate) fn len(&self) -> usize {
self.futs.len()
}
}
impl std::future::Future for BatchFuture {
type Output = Option<Box<RawValue>>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
if matches!(self.futs, BatchFutureInner::Prepping(_)) {
if self.futs.is_empty() && self.resps.is_empty() {
return Poll::Ready(Some(Response::parse_error()));
}
self.futs.run();
}
let this = self.project();
let _enter = this.span.as_ref().map(tracing::Span::enter);
let BatchFutureInnerProj::Running(mut futs) = this.futs.project() else {
unreachable!()
};
loop {
match ready!(futs.poll_join_next(cx)) {
Some(Ok(resp)) => {
if let Some(resp) = unwrap_infallible!(resp) {
this.resps.push(resp);
}
}
None => {
if this.resps.is_empty() {
return Poll::Ready(None);
}
let resp = if *this.single {
this.resps.pop().unwrap_or_else(Response::parse_error)
} else {
serde_json::value::to_raw_value(&this.resps)
.unwrap_or_else(|_| Response::serialization_failure(RawValue::NULL))
};
return Poll::Ready(Some(resp));
}
Some(Err(err)) => {
tracing::error!(?err, "panic or cancel in batch future");
if let Ok(reason) = err.try_into_panic() {
panic::resume_unwind(reason);
}
}
}
}
}
}