Skip to main content

trillium_client/
response_body.rs

1use crate::{Error, Pool, pool::PoolEntry};
2use encoding_rs::Encoding;
3use futures_lite::{AsyncRead, AsyncReadExt, AsyncWriteExt};
4use std::{
5    fmt::{self, Debug, Formatter},
6    io, mem,
7    pin::Pin,
8    task::{Context, Poll, ready},
9};
10use trillium_http::{
11    Body, BodySource, Headers, HttpConfig, MutCow, ReceivedBody, ReceivedBodyState,
12};
13use trillium_server_common::{Runtime, Transport, url::Origin};
14
15/// A response body received from a server.
16///
17/// Most of the time this represents a body that will be read from the underlying transport, but it
18/// can also wrap an override body installed by middleware via [`ConnExt::set_response_body`]
19/// — e.g. cache hits, mocked responses, or circuit-breaker short-circuits. Reads, encoding
20/// handling, and `max_len` enforcement work transparently across both cases.
21///
22/// [`ConnExt::set_response_body`]: crate::ConnExt::set_response_body
23///
24/// ```rust
25/// use trillium_client::Client;
26/// use trillium_testing::{client_config, with_server};
27///
28/// with_server("hello from trillium", |url| async move {
29///     let client = Client::new(client_config());
30///     let mut conn = client.get(url).await?;
31///     let body = conn.response_body(); //<-
32///     assert_eq!(Some(19), body.content_length());
33///     assert_eq!("hello from trillium", body.read_string().await?);
34///     Ok(())
35/// });
36/// ```
37///
38/// ## Bounds checking
39///
40/// Every `ResponseBody` has a maximum length beyond which it will return an error, expressed as a
41/// u64. To override this on the specific `ResponseBody`, use [`ResponseBody::with_max_len`] or
42/// [`ResponseBody::set_max_len`]. The bound is enforced on override bodies as well as
43/// transport-backed ones, so a user-set memory cap holds even when middleware has replaced the
44/// body with externally-sourced bytes.
45pub struct ResponseBody<'a> {
46    inner: ResponseBodyInner<'a>,
47    /// Set on `'static` Received bodies built via
48    /// [`Conn::take_response_body`][crate::Conn::take_response_body]. `recycle` and `Drop`
49    /// consult it to decide whether to drain (keepalive) or close the underlying transport.
50    /// `None` for borrowed bodies (the conn handles their cleanup) and for Override bodies (no
51    /// transport is attached at this layer — `take_response_body` already evicted any leftover
52    /// transport before returning).
53    cleanup: Option<CleanupContext>,
54    /// Trailers harvested off the inner [`ReceivedBody`] when it reaches `End`. The
55    /// EOF-driven recycle in `poll_read` moves the `ReceivedBody` out before the caller can
56    /// observe its trailers, so they're captured here to outlive it and surfaced through
57    /// [`BodySource::trailers`].
58    trailers: Option<Headers>,
59}
60
61#[allow(clippy::large_enum_variant)]
62enum ResponseBodyInner<'a> {
63    Received(ReceivedBody<'a, Box<dyn Transport>>),
64    Override(OverrideBody<'a>),
65    Closing(Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>),
66    Closed,
67}
68
69type H1Pool = Pool<Origin, Box<dyn Transport>>;
70
71/// Carries everything `Drop for ResponseBody` and [`ResponseBody::recycle`] need to release
72/// a transport without re-deriving from a [`crate::Conn`] (which is gone by then).
73///
74/// `pool_origin: Some` means "keepalive transport, pool is configured — insert here on
75/// completion." `None` means "close on completion" (non-keepalive *or* no pool). The same
76/// instance is cloned into the body's `on_completion` callback in
77/// [`Conn::take_received_body`][crate::Conn::take_received_body], so the user-driven
78/// drain/read paths and the Drop/recycle paths share one source of truth for what to do
79/// with the transport when the body finishes.
80#[derive(Clone)]
81pub(crate) struct CleanupContext {
82    pub(crate) runtime: Runtime,
83    pub(crate) h1_pool_origin: Option<(H1Pool, Origin)>,
84}
85
86impl CleanupContext {
87    /// Hand a freshly-released transport off to its destination — pool insert (sync) or
88    /// spawn close.
89    pub(crate) fn handoff(&self, mut transport: Box<dyn Transport>) {
90        match &self.h1_pool_origin {
91            Some((pool, origin)) => {
92                log::trace!("body transferred, returning to pool");
93                pool.insert(origin.clone(), PoolEntry::new(transport, None));
94            }
95            None => {
96                self.runtime.clone().spawn(async move {
97                    let _ = transport.close().await;
98                });
99            }
100        }
101    }
102}
103
104pub(crate) struct OverrideBody<'a> {
105    body: MutCow<'a, Body>,
106    encoding: &'static Encoding,
107    max_len: u64,
108    initial_len: usize,
109    max_preallocate: usize,
110}
111
112impl AsyncRead for OverrideBody<'_> {
113    fn poll_read(
114        mut self: Pin<&mut Self>,
115        cx: &mut Context<'_>,
116        buf: &mut [u8],
117    ) -> Poll<io::Result<usize>> {
118        let remaining = self.max_len.saturating_sub(self.body.bytes_read());
119        if remaining == 0 && !buf.is_empty() {
120            return Poll::Ready(Err(io::Error::other(Error::ReceivedBodyTooLong(
121                self.max_len,
122            ))));
123        }
124        let cap = remaining.min(buf.len() as u64) as usize;
125        Pin::new(&mut *self.body).poll_read(cx, &mut buf[..cap])
126    }
127}
128
129impl<'a> OverrideBody<'a> {
130    pub(crate) fn new(
131        body: impl Into<MutCow<'a, Body>>,
132        encoding: &'static Encoding,
133        http_config: &HttpConfig,
134    ) -> Self {
135        Self {
136            body: body.into(),
137            encoding,
138            max_len: http_config.received_body_max_len(),
139            max_preallocate: http_config.received_body_max_preallocate(),
140            initial_len: http_config.received_body_initial_len(),
141        }
142    }
143}
144
145impl Debug for ResponseBody<'_> {
146    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
147        match &self.inner {
148            ResponseBodyInner::Received(rb) => f.debug_tuple("ResponseBody").field(rb).finish(),
149            ResponseBodyInner::Override(o) => f
150                .debug_struct("ResponseBody (override)")
151                .field("body", &*o.body)
152                .field("encoding", &o.encoding.name())
153                .field("max_len", &o.max_len)
154                .finish(),
155            ResponseBodyInner::Closing(_) => f.debug_tuple("ResponseBody (closing)").finish(),
156            ResponseBodyInner::Closed => f.debug_tuple("ResponseBody (closed)").finish(),
157        }
158    }
159}
160
161impl AsyncRead for ResponseBody<'_> {
162    fn poll_read(
163        mut self: Pin<&mut Self>,
164        cx: &mut Context<'_>,
165        buf: &mut [u8],
166    ) -> Poll<io::Result<usize>> {
167        let mut bytes = 0;
168        loop {
169            match &mut self.inner {
170                ResponseBodyInner::Received(rb) => bytes = ready!(Pin::new(rb).poll_read(cx, buf))?,
171                ResponseBodyInner::Override(o) => bytes = ready!(Pin::new(o).poll_read(cx, buf))?,
172                ResponseBodyInner::Closing(fut) => {
173                    ready!(fut.as_mut().poll(cx));
174                    self.inner = ResponseBodyInner::Closed;
175                    break;
176                }
177
178                ResponseBodyInner::Closed => break,
179            };
180
181            // Inline transport settlement — see take_received_body's `cleanup` param for
182            // why this isn't done via on_completion.
183            if bytes == 0
184                && let Some((mut rb, cleanup)) = self.prepare_for_recycle()
185                && rb.state() == ReceivedBodyState::End
186                && let Some(mut transport) = rb.take_transport()
187            {
188                self.trailers = Pin::new(&mut rb).trailers();
189                if let Some((pool, origin)) = cleanup.h1_pool_origin {
190                    pool.insert(origin, PoolEntry::new(transport, None));
191                } else {
192                    self.inner = ResponseBodyInner::Closing(Box::pin(async move {
193                        if let Err(e) = transport.close().await {
194                            log::warn!("transport close failed during ResponseBody EOF: {e}");
195                        }
196                    }));
197                }
198            } else {
199                break;
200            }
201        }
202
203        Poll::Ready(Ok(bytes))
204    }
205}
206
207impl ResponseBody<'_> {
208    fn take_inner(&mut self) -> ResponseBodyInner<'_> {
209        mem::replace(&mut self.inner, ResponseBodyInner::Closed)
210    }
211
212    fn max_preallocate(&self) -> usize {
213        match &self.inner {
214            ResponseBodyInner::Received(rb) => rb.max_preallocate(),
215            ResponseBodyInner::Override(override_body) => override_body.max_preallocate,
216            _ => 0,
217        }
218    }
219
220    fn max_len(&self) -> u64 {
221        match &self.inner {
222            ResponseBodyInner::Received(rb) => rb.max_len(),
223            ResponseBodyInner::Override(override_body) => override_body.max_len,
224            _ => 0,
225        }
226    }
227
228    fn initial_len(&self) -> usize {
229        match &self.inner {
230            ResponseBodyInner::Received(rb) => rb.initial_len(),
231            ResponseBodyInner::Override(override_body) => override_body.initial_len,
232            _ => 0,
233        }
234    }
235
236    fn encoding(&self) -> &'static Encoding {
237        match &self.inner {
238            ResponseBodyInner::Received(rb) => rb.encoding(),
239            ResponseBodyInner::Override(override_body) => override_body.encoding,
240            _ => encoding_rs::WINDOWS_1252,
241        }
242    }
243
244    /// Similar to [`ResponseBody::read_string`], but returns the raw bytes. This is useful for
245    /// bodies that are not text.
246    ///
247    /// You can use this in conjunction with `encoding` if you need different handling of malformed
248    /// character encoding than the lossy conversion provided by [`ResponseBody::read_string`].
249    ///
250    /// An empty or nonexistent body will yield an empty Vec, not an error.
251    ///
252    /// # Errors
253    ///
254    /// This will return an error if there is an IO error on the underlying transport such as a
255    /// disconnect.
256    ///
257    /// This will also return an error if the length exceeds the maximum length. To configure the
258    /// value on this specific request body, use [`ResponseBody::with_max_len`] or
259    /// [`ResponseBody::set_max_len`]
260    pub async fn read_bytes(mut self) -> Result<Vec<u8>, Error> {
261        let mut vec = if let Some(len) = self.content_length() {
262            if len > self.max_len() {
263                return Err(Error::ReceivedBodyTooLong(self.max_len()));
264            }
265
266            let len =
267                usize::try_from(len).map_err(|_| Error::ReceivedBodyTooLong(self.max_len()))?;
268
269            Vec::with_capacity(len.min(self.max_preallocate()))
270        } else {
271            Vec::with_capacity(self.initial_len())
272        };
273
274        self.read_to_end(&mut vec).await?;
275
276        Ok(vec)
277    }
278
279    /// Reads the entire body to a `String`.
280    ///
281    /// Uses the encoding determined by the content-type (mime) charset. If an encoding problem
282    /// is encountered, the returned `String` will contain utf8 replacement characters.
283    ///
284    /// Note that this can only be performed once per Conn, as the underlying data is not cached
285    /// anywhere. This is the only copy of the body contents.
286    ///
287    /// An empty or nonexistent body will yield an empty String, not an error
288    ///
289    /// # Errors
290    ///
291    /// This will return an error if there is an IO error on the
292    /// underlying transport such as a disconnect
293    ///
294    ///
295    /// This will also return an error if the length exceeds the maximum length. To configure the
296    /// value on this specific response body, use [`ResponseBody::with_max_len`] or
297    /// [`ResponseBody::set_max_len`].
298    pub async fn read_string(self) -> Result<String, Error> {
299        let encoding = self.encoding();
300        let bytes = self.read_bytes().await?;
301        let (s, _, _) = encoding.decode(&bytes);
302        Ok(s.to_string())
303    }
304
305    /// Set the maximum content length to read, returning self
306    ///
307    /// This protects against a memory-use denial-of-service attack wherein an untrusted peer sends
308    /// an unbounded request body. This is especially important when using
309    /// [`ResponseBody::read_string`] and [`ResponseBody::read_bytes`] instead of streaming with
310    /// `AsyncRead`.
311    ///
312    /// The default value can be found documented [in the trillium-http
313    /// crate](https://docs.trillium.rs/trillium_http/struct.httpconfig#received_body_max_len)
314    #[must_use]
315    pub fn with_max_len(mut self, max_len: u64) -> Self {
316        self.set_max_len(max_len);
317        self
318    }
319
320    /// Set the maximum content length to read
321    ///
322    /// This protects against a memory-use denial-of-service attack wherein an untrusted peer sends
323    /// an unbounded request body. This is especially important when using
324    /// [`ResponseBody::read_string`] and [`ResponseBody::read_bytes`] instead of streaming with
325    /// `AsyncRead`.
326    ///
327    /// The default value can be found documented [in the trillium-http
328    /// crate](https://docs.trillium.rs/trillium_http/struct.httpconfig#received_body_max_len)
329    pub fn set_max_len(&mut self, max_len: u64) -> &mut Self {
330        match &mut self.inner {
331            ResponseBodyInner::Received(rb) => {
332                rb.set_max_len(max_len);
333            }
334            ResponseBodyInner::Override(o) => {
335                o.max_len = max_len;
336            }
337            _ => {}
338        }
339        self
340    }
341
342    /// The trailers received after the response body, if any.
343    ///
344    /// Returns `None` until the body has been read to end-of-stream, and only on protocols
345    /// that delivered a trailer section (HTTP/1.1 chunked with trailers, HTTP/2, HTTP/3).
346    /// Reading the body via [`read_string`](Self::read_string)/[`read_bytes`](Self::read_bytes)
347    /// consumes it, so to observe trailers drive the body to completion through its
348    /// [`AsyncRead`](futures_lite::AsyncRead) interface and then call this.
349    pub fn trailers(&self) -> Option<&Headers> {
350        match &self.inner {
351            ResponseBodyInner::Received(rb) => rb.trailers_ref(),
352            // Captured off the inner ReceivedBody when it was recycled at end-of-stream.
353            _ => self.trailers.as_ref(),
354        }
355    }
356
357    /// The content-length of this body, if available.
358    ///
359    /// Usually derived from the content-length header. If the response uses
360    /// transfer-encoding chunked, this will be `None`.
361    pub fn content_length(&self) -> Option<u64> {
362        match &self.inner {
363            ResponseBodyInner::Received(rb) => rb.content_length(),
364            ResponseBodyInner::Override(o) => o.body.len(),
365            _ => None,
366        }
367    }
368
369    fn prepare_for_recycle(
370        &mut self,
371    ) -> Option<(
372        ReceivedBody<'static, Box<dyn Transport + 'static>>,
373        CleanupContext,
374    )> {
375        let cleanup = self.cleanup.take()?;
376
377        let ResponseBodyInner::Received(rb) = self.take_inner() else {
378            return None;
379        };
380
381        let rb = rb.try_into_owned()?;
382
383        Some((rb, cleanup))
384    }
385}
386
387async fn drain(rb: &mut ReceivedBody<'static, Box<dyn Transport + 'static>>) -> io::Result<u64> {
388    let copy_loops_per_yield = rb.copy_loops_per_yield();
389    trillium_http::copy(rb, futures_lite::io::sink(), copy_loops_per_yield).await
390}
391
392async fn recycle(
393    mut rb: ReceivedBody<'static, Box<dyn Transport + 'static>>,
394    h1_pool_origin: Option<(H1Pool, Origin)>,
395) {
396    if let Some((pool, origin)) = h1_pool_origin {
397        match drain(&mut rb).await {
398            Ok(drained) => {
399                if rb.state() == ReceivedBodyState::End
400                    && let Some(transport) = rb.take_transport()
401                {
402                    log::trace!(
403                        "drained {drained} bytes, returning transport to pool for {origin:?}"
404                    );
405                    pool.insert(origin, PoolEntry::new(transport, None));
406                    return;
407                }
408            }
409            Err(e) => log::warn!("drain failed during recycle: {e}"),
410        }
411    }
412
413    if let Some(mut transport) = rb.take_transport()
414        && let Err(e) = transport.close().await
415    {
416        log::warn!("transport close failed during recycle: {e}");
417    }
418}
419
420impl Drop for ResponseBody<'_> {
421    fn drop(&mut self) {
422        let Some((mut rb, cleanup)) = self.prepare_for_recycle() else {
423            return;
424        };
425
426        // fast sync path for reclaiming an owned http/1.1 received body that's at End
427        if rb.state() == ReceivedBodyState::End
428            && cleanup.h1_pool_origin.is_some()
429            && let Some(transport) = rb.take_transport()
430            && let Some((pool, origin)) = cleanup.h1_pool_origin
431        {
432            pool.insert(origin, PoolEntry::new(transport, None));
433        } else {
434            cleanup.runtime.spawn(recycle(rb, cleanup.h1_pool_origin));
435        }
436    }
437}
438
439impl BodySource for ResponseBody<'static> {
440    fn trailers(self: Pin<&mut Self>) -> Option<Headers> {
441        let this = self.get_mut();
442        match &mut this.inner {
443            ResponseBodyInner::Received(rb) => Pin::new(rb).trailers(),
444            ResponseBodyInner::Override(o) => o.body.trailers(),
445            // Recycled at EOF — trailers were captured off the ReceivedBody before it was
446            // moved out. See `ResponseBody::trailers`.
447            _ => this.trailers.take(),
448        }
449    }
450}
451
452impl<'a> From<ReceivedBody<'a, Box<dyn Transport>>> for ResponseBody<'a> {
453    fn from(received_body: ReceivedBody<'a, Box<dyn Transport>>) -> Self {
454        Self {
455            inner: ResponseBodyInner::Received(received_body),
456            cleanup: None,
457            trailers: None,
458        }
459    }
460}
461
462impl<'a> From<OverrideBody<'a>> for ResponseBody<'a> {
463    fn from(o: OverrideBody<'a>) -> Self {
464        Self {
465            inner: ResponseBodyInner::Override(o),
466            cleanup: None,
467            trailers: None,
468        }
469    }
470}
471
472impl ResponseBody<'static> {
473    pub(crate) fn received_owned(
474        body: ReceivedBody<'static, Box<dyn Transport>>,
475        cleanup: CleanupContext,
476    ) -> Self {
477        Self {
478            inner: ResponseBodyInner::Received(body),
479            cleanup: Some(cleanup),
480            trailers: None,
481        }
482    }
483
484    /// Drains and pools the underlying transport when worthwhile, closes it otherwise.
485    ///
486    /// Use this to release a keepalive transport synchronously before reissuing a request on
487    /// the same client — the redirect/retry handler pattern. For an h1.1 keepalive transport
488    /// this drives the body to EOF and returns the transport to the pool. For a non-keepalive
489    /// transport this calls `transport.close()` directly without draining (since draining
490    /// would just waste bytes on a connection we're about to close).
491    ///
492    /// For an Override body (cache hit, mocked response, tee), this is a no-op — the body's
493    /// own components handle their own teardown when dropped.
494    pub async fn recycle(mut self) {
495        let Some((rb, cleanup)) = self.prepare_for_recycle() else {
496            return;
497        };
498
499        recycle(rb, cleanup.h1_pool_origin).await;
500    }
501}
502
503impl<'a> IntoFuture for ResponseBody<'a> {
504    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
505    type Output = trillium_http::Result<String>;
506
507    fn into_future(self) -> Self::IntoFuture {
508        Box::pin(async move { self.read_string().await })
509    }
510}
511
512const _: fn() = || {
513    fn assert_send_sync<T: Send + Sync + ?Sized>() {}
514    assert_send_sync::<ResponseBody<'static>>();
515    assert_send_sync::<ResponseBody<'_>>();
516};