ajj/routes/
future.rs

1use crate::{routes::HandlerArgs, types::RequestError};
2use core::fmt;
3use pin_project::pin_project;
4use serde_json::value::RawValue;
5use std::{
6    convert::Infallible,
7    future::Future,
8    panic,
9    task::{ready, Context, Poll},
10};
11use tokio::task::JoinSet;
12use tower::util::{BoxCloneSyncService, Oneshot};
13
14use super::Response;
15
16/// A future produced by the [`Router`].
17///
18/// [`Router`]: crate::Router
19#[pin_project]
20pub struct RouteFuture {
21    /// The inner [`Route`] future.
22    ///
23    /// [`Route`]: crate::routes::Route
24    #[pin]
25    inner:
26        Oneshot<BoxCloneSyncService<HandlerArgs, Option<Box<RawValue>>, Infallible>, HandlerArgs>,
27    /// The span (if any).
28    span: Option<tracing::Span>,
29}
30
31impl RouteFuture {
32    /// Create a new route future.
33    pub(crate) const fn new(
34        inner: Oneshot<
35            BoxCloneSyncService<HandlerArgs, Option<Box<RawValue>>, Infallible>,
36            HandlerArgs,
37        >,
38    ) -> Self {
39        Self { inner, span: None }
40    }
41
42    /// Set the span for the future.
43    pub(crate) fn with_span(self, span: tracing::Span) -> Self {
44        Self {
45            span: Some(span),
46            ..self
47        }
48    }
49}
50
51impl fmt::Debug for RouteFuture {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        f.debug_struct("RouteFuture").finish_non_exhaustive()
54    }
55}
56
57impl Future for RouteFuture {
58    type Output = Result<Option<Box<RawValue>>, Infallible>;
59
60    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
61        let this = self.project();
62        let _enter = this.span.as_ref().map(tracing::Span::enter);
63
64        this.inner.poll(cx)
65    }
66}
67
68#[pin_project(project = BatchFutureInnerProj)]
69enum BatchFutureInner {
70    Prepping(Vec<RouteFuture>),
71    Running(#[pin] JoinSet<Result<Option<Box<RawValue>>, Infallible>>),
72}
73
74impl BatchFutureInner {
75    fn len(&self) -> usize {
76        match self {
77            Self::Prepping(futs) => futs.len(),
78            Self::Running(futs) => futs.len(),
79        }
80    }
81
82    fn is_empty(&self) -> bool {
83        self.len() == 0
84    }
85
86    fn run(&mut self) {
87        if let Self::Prepping(futs) = self {
88            let js = futs.drain(..).collect::<JoinSet<_>>();
89            *self = Self::Running(js);
90        }
91    }
92}
93
94impl fmt::Debug for BatchFutureInner {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        let mut s = f.debug_struct("BatchFutureInner");
97
98        match self {
99            Self::Prepping(futs) => s.field("prepared", &futs.len()),
100            Self::Running(futs) => s.field("running", &futs.len()),
101        }
102        .finish_non_exhaustive()
103    }
104}
105
106/// A collection of [`RouteFuture`]s that are executed concurrently.
107///
108/// This is the type returned by [`Router::call_batch_with_state`], and should
109/// only be instantiated by that method.
110///
111/// [`Router::call_batch_with_state`]: crate::Router::call_batch_with_state
112#[pin_project]
113pub struct BatchFuture {
114    /// The futures, either in the prepping or running state.
115    #[pin]
116    futs: BatchFutureInner,
117    /// The responses collected so far.
118    resps: Vec<Box<RawValue>>,
119    /// Whether the batch was a single request.
120    single: bool,
121
122    /// The service name, for tracing and metrics.
123    service_name: &'static str,
124
125    /// The span (if any).
126    span: Option<tracing::Span>,
127}
128
129impl fmt::Debug for BatchFuture {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        f.debug_struct("BatchFuture")
132            .field("state", &self.futs)
133            .field("responses", &self.resps.len())
134            .finish()
135    }
136}
137
138impl BatchFuture {
139    /// Create a new batch future with a capacity.
140    pub(crate) fn new_with_capacity(
141        single: bool,
142        service_name: &'static str,
143        capacity: usize,
144    ) -> Self {
145        Self {
146            futs: BatchFutureInner::Prepping(Vec::with_capacity(capacity)),
147            resps: Vec::with_capacity(capacity),
148            single,
149            service_name,
150            span: None,
151        }
152    }
153
154    /// Set the span for the future.
155    pub(crate) fn with_span(self, span: tracing::Span) -> Self {
156        Self {
157            span: Some(span),
158            ..self
159        }
160    }
161
162    /// Spawn a future into the batch.
163    pub(crate) fn push(&mut self, fut: RouteFuture) {
164        let BatchFutureInner::Prepping(ref mut futs) = self.futs else {
165            panic!("pushing into a running batch future");
166        };
167        futs.push(fut);
168    }
169
170    /// Push a response into the batch.
171    pub(crate) fn push_resp(&mut self, resp: Box<RawValue>) {
172        self.resps.push(resp);
173    }
174
175    /// Push a parse error into the batch.
176    pub(crate) fn push_parse_error(&mut self) {
177        crate::metrics::record_parse_error(self.service_name);
178        self.push_resp(Response::parse_error());
179    }
180
181    /// Push a parse result into the batch. Convenience function to simplify
182    /// [`Router`] logic.
183    ///
184    /// [`Router`]: crate::Router
185    pub(crate) fn push_parse_result(&mut self, result: Result<RouteFuture, RequestError>) {
186        match result {
187            Ok(fut) => self.push(fut),
188            Err(_) => self.push_parse_error(),
189        }
190    }
191
192    /// Get the number of futures in the batch.
193    pub(crate) fn len(&self) -> usize {
194        self.futs.len()
195    }
196}
197
198impl std::future::Future for BatchFuture {
199    type Output = Option<Box<RawValue>>;
200
201    fn poll(
202        mut self: std::pin::Pin<&mut Self>,
203        cx: &mut std::task::Context<'_>,
204    ) -> std::task::Poll<Self::Output> {
205        if matches!(self.futs, BatchFutureInner::Prepping(_)) {
206            // SPEC: empty arrays are invalid
207            if self.futs.is_empty() && self.resps.is_empty() {
208                crate::metrics::record_parse_error(self.service_name);
209                return Poll::Ready(Some(Response::parse_error()));
210            }
211            self.futs.run();
212        }
213
214        let this = self.project();
215        let _enter = this.span.as_ref().map(tracing::Span::enter);
216
217        let BatchFutureInnerProj::Running(mut futs) = this.futs.project() else {
218            unreachable!()
219        };
220
221        loop {
222            match ready!(futs.poll_join_next(cx)) {
223                Some(Ok(resp)) => {
224                    // SPEC: notifications receive no response.
225                    if let Some(resp) = unwrap_infallible!(resp) {
226                        this.resps.push(resp);
227                    }
228                }
229
230                // join set is drained, return the response(s)
231                None => {
232                    // SPEC: batches that contain only notifications receive no response.
233                    if this.resps.is_empty() {
234                        return Poll::Ready(None);
235                    }
236
237                    // SPEC: single requests return a single response
238                    // Batch requests return an array of responses
239                    let resp = if *this.single {
240                        this.resps.pop().unwrap_or_else(|| {
241                            // this should never happen, but just in case...
242                            crate::metrics::record_parse_error(this.service_name);
243                            Response::parse_error()
244                        })
245                    } else {
246                        // otherwise, we have an array of responses
247                        serde_json::value::to_raw_value(&this.resps)
248                            .unwrap_or_else(|_| Response::serialization_failure(RawValue::NULL))
249                    };
250
251                    return Poll::Ready(Some(resp));
252                }
253                // panic in a future, propagate it
254                Some(Err(err)) => {
255                    tracing::error!(?err, "panic or cancel in batch future");
256                    // propagate panics
257                    if let Ok(reason) = err.try_into_panic() {
258                        panic::resume_unwind(reason);
259                    }
260                }
261            }
262        }
263    }
264}
265
266// Some code is this file is reproduced under the terms of the MIT license. It
267// originates from the `axum` crate. The original source code can be found at
268// the following URL, and the original license is included below.
269//
270// https://github.com/tokio-rs/axum/
271//
272// The MIT License (MIT)
273//
274// Copyright (c) 2019 Axum Contributors
275//
276// Permission is hereby granted, free of charge, to any
277// person obtaining a copy of this software and associated
278// documentation files (the "Software"), to deal in the
279// Software without restriction, including without
280// limitation the rights to use, copy, modify, merge,
281// publish, distribute, sublicense, and/or sell copies of
282// the Software, and to permit persons to whom the Software
283// is furnished to do so, subject to the following
284// conditions:
285//
286// The above copyright notice and this permission notice
287// shall be included in all copies or substantial portions
288// of the Software.
289//
290// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
291// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
292// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
293// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
294// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
295// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
296// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
297// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
298// DEALINGS IN THE SOFTWARE.