1use crate::{routes::HandlerArgs, types::RequestError};
2use core::fmt;
3use pin_project::pin_project;
4use serde_json::value::RawValue;
5use std::{
6 convert::Infallible,
7 future::Future,
8 panic,
9 task::{ready, Context, Poll},
10};
11use tokio::task::JoinSet;
12use tower::util::{BoxCloneSyncService, Oneshot};
13
14use super::Response;
15
16#[pin_project]
20pub struct RouteFuture {
21 #[pin]
25 inner:
26 Oneshot<BoxCloneSyncService<HandlerArgs, Option<Box<RawValue>>, Infallible>, HandlerArgs>,
27 span: Option<tracing::Span>,
29}
30
31impl RouteFuture {
32 pub(crate) const fn new(
34 inner: Oneshot<
35 BoxCloneSyncService<HandlerArgs, Option<Box<RawValue>>, Infallible>,
36 HandlerArgs,
37 >,
38 ) -> Self {
39 Self { inner, span: None }
40 }
41
42 pub(crate) fn with_span(self, span: tracing::Span) -> Self {
44 Self {
45 span: Some(span),
46 ..self
47 }
48 }
49}
50
51impl fmt::Debug for RouteFuture {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 f.debug_struct("RouteFuture").finish_non_exhaustive()
54 }
55}
56
57impl Future for RouteFuture {
58 type Output = Result<Option<Box<RawValue>>, Infallible>;
59
60 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
61 let this = self.project();
62 let _enter = this.span.as_ref().map(tracing::Span::enter);
63
64 this.inner.poll(cx)
65 }
66}
67
68#[pin_project(project = BatchFutureInnerProj)]
69enum BatchFutureInner {
70 Prepping(Vec<RouteFuture>),
71 Running(#[pin] JoinSet<Result<Option<Box<RawValue>>, Infallible>>),
72}
73
74impl BatchFutureInner {
75 fn len(&self) -> usize {
76 match self {
77 Self::Prepping(futs) => futs.len(),
78 Self::Running(futs) => futs.len(),
79 }
80 }
81
82 fn is_empty(&self) -> bool {
83 self.len() == 0
84 }
85
86 fn run(&mut self) {
87 if let Self::Prepping(futs) = self {
88 let js = futs.drain(..).collect::<JoinSet<_>>();
89 *self = Self::Running(js);
90 }
91 }
92}
93
94impl fmt::Debug for BatchFutureInner {
95 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96 let mut s = f.debug_struct("BatchFutureInner");
97
98 match self {
99 Self::Prepping(futs) => s.field("prepared", &futs.len()),
100 Self::Running(futs) => s.field("running", &futs.len()),
101 }
102 .finish_non_exhaustive()
103 }
104}
105
106#[pin_project]
113pub struct BatchFuture {
114 #[pin]
116 futs: BatchFutureInner,
117 resps: Vec<Box<RawValue>>,
119 single: bool,
121
122 service_name: &'static str,
124
125 span: Option<tracing::Span>,
127}
128
129impl fmt::Debug for BatchFuture {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 f.debug_struct("BatchFuture")
132 .field("state", &self.futs)
133 .field("responses", &self.resps.len())
134 .finish()
135 }
136}
137
138impl BatchFuture {
139 pub(crate) fn new_with_capacity(
141 single: bool,
142 service_name: &'static str,
143 capacity: usize,
144 ) -> Self {
145 Self {
146 futs: BatchFutureInner::Prepping(Vec::with_capacity(capacity)),
147 resps: Vec::with_capacity(capacity),
148 single,
149 service_name,
150 span: None,
151 }
152 }
153
154 pub(crate) fn with_span(self, span: tracing::Span) -> Self {
156 Self {
157 span: Some(span),
158 ..self
159 }
160 }
161
162 pub(crate) fn push(&mut self, fut: RouteFuture) {
164 let BatchFutureInner::Prepping(ref mut futs) = self.futs else {
165 panic!("pushing into a running batch future");
166 };
167 futs.push(fut);
168 }
169
170 pub(crate) fn push_resp(&mut self, resp: Box<RawValue>) {
172 self.resps.push(resp);
173 }
174
175 pub(crate) fn push_parse_error(&mut self) {
177 crate::metrics::record_parse_error(self.service_name);
178 self.push_resp(Response::parse_error());
179 }
180
181 pub(crate) fn push_parse_result(&mut self, result: Result<RouteFuture, RequestError>) {
186 match result {
187 Ok(fut) => self.push(fut),
188 Err(_) => self.push_parse_error(),
189 }
190 }
191
192 pub(crate) fn len(&self) -> usize {
194 self.futs.len()
195 }
196}
197
198impl std::future::Future for BatchFuture {
199 type Output = Option<Box<RawValue>>;
200
201 fn poll(
202 mut self: std::pin::Pin<&mut Self>,
203 cx: &mut std::task::Context<'_>,
204 ) -> std::task::Poll<Self::Output> {
205 if matches!(self.futs, BatchFutureInner::Prepping(_)) {
206 if self.futs.is_empty() && self.resps.is_empty() {
208 crate::metrics::record_parse_error(self.service_name);
209 return Poll::Ready(Some(Response::parse_error()));
210 }
211 self.futs.run();
212 }
213
214 let this = self.project();
215 let _enter = this.span.as_ref().map(tracing::Span::enter);
216
217 let BatchFutureInnerProj::Running(mut futs) = this.futs.project() else {
218 unreachable!()
219 };
220
221 loop {
222 match ready!(futs.poll_join_next(cx)) {
223 Some(Ok(resp)) => {
224 if let Some(resp) = unwrap_infallible!(resp) {
226 this.resps.push(resp);
227 }
228 }
229
230 None => {
232 if this.resps.is_empty() {
234 return Poll::Ready(None);
235 }
236
237 let resp = if *this.single {
240 this.resps.pop().unwrap_or_else(|| {
241 crate::metrics::record_parse_error(this.service_name);
243 Response::parse_error()
244 })
245 } else {
246 serde_json::value::to_raw_value(&this.resps)
248 .unwrap_or_else(|_| Response::serialization_failure(RawValue::NULL))
249 };
250
251 return Poll::Ready(Some(resp));
252 }
253 Some(Err(err)) => {
255 tracing::error!(?err, "panic or cancel in batch future");
256 if let Ok(reason) = err.try_into_panic() {
258 panic::resume_unwind(reason);
259 }
260 }
261 }
262 }
263 }
264}
265
266