rocket_community/local/asynchronous/
response.rs

1use std::future::Future;
2use std::io;
3use std::{
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use tokio::io::{AsyncRead, ReadBuf};
9
10use crate::http::CookieJar;
11use crate::{Request, Response};
12
13/// An `async` response from a dispatched [`LocalRequest`](super::LocalRequest).
14///
15/// This `LocalResponse` implements [`tokio::io::AsyncRead`]. As such, if
16/// [`into_string()`](LocalResponse::into_string()) and
17/// [`into_bytes()`](LocalResponse::into_bytes()) do not suffice, the response's
18/// body can be read directly:
19///
20/// ```rust
21/// # #[macro_use] extern crate rocket_community as rocket;
22/// use std::io;
23///
24/// use rocket::local::asynchronous::Client;
25/// use rocket::tokio::io::AsyncReadExt;
26/// use rocket::http::Status;
27///
28/// #[get("/")]
29/// fn hello_world() -> &'static str {
30///     "Hello, world!"
31/// }
32///
33/// #[launch]
34/// fn rocket() -> _ {
35///     rocket::build().mount("/", routes![hello_world])
36///     #    .reconfigure(rocket::Config::debug_default())
37/// }
38///
39/// # async fn read_body_manually() -> io::Result<()> {
40/// // Dispatch a `GET /` request.
41/// let client = Client::tracked(rocket()).await.expect("valid rocket");
42/// let mut response = client.get("/").dispatch().await;
43///
44/// // Check metadata validity.
45/// assert_eq!(response.status(), Status::Ok);
46/// assert_eq!(response.body().preset_size(), Some(13));
47///
48/// // Read 10 bytes of the body. Note: in reality, we'd use `into_string()`.
49/// let mut buffer = [0; 10];
50/// response.read(&mut buffer).await?;
51/// assert_eq!(buffer, "Hello, wor".as_bytes());
52/// # Ok(())
53/// # }
54/// # rocket::async_test(read_body_manually()).expect("read okay");
55/// ```
56///
57/// For more, see [the top-level documentation](../index.html#localresponse).
58pub struct LocalResponse<'c> {
59    // XXX: SAFETY: This (dependent) field must come first due to drop order!
60    response: Response<'c>,
61    cookies: CookieJar<'c>,
62    _request: Box<Request<'c>>,
63}
64
65impl Drop for LocalResponse<'_> {
66    fn drop(&mut self) {}
67}
68
69impl<'c> LocalResponse<'c> {
70    pub(crate) fn new<F, O>(req: Request<'c>, f: F) -> impl Future<Output = LocalResponse<'c>>
71    where
72        F: FnOnce(&'c Request<'c>) -> O + Send,
73        O: Future<Output = Response<'c>> + Send,
74    {
75        // `LocalResponse` is a self-referential structure. In particular,
76        // `response` and `cookies` can refer to `_request` and its contents. As
77        // such, we must
78        //   1) Ensure `Request` has a stable address.
79        //
80        //      This is done by `Box`ing the `Request`, using only the stable
81        //      address thereafter.
82        //
83        //   2) Ensure no refs to `Request` or its contents leak with a lifetime
84        //      extending beyond that of `&self`.
85        //
86        //      We have no methods that return an `&Request`. However, we must
87        //      also ensure that `Response` doesn't leak any such references. To
88        //      do so, we don't expose the `Response` directly in any way;
89        //      otherwise, methods like `.headers()` could, in conjunction with
90        //      particular crafted `Responder`s, potentially be used to obtain a
91        //      reference to contents of `Request`. All methods, instead, return
92        //      references bounded by `self`. This is easily verified by noting
93        //      that 1) `LocalResponse` fields are private, and 2) all `impl`s
94        //      of `LocalResponse` aside from this method abstract the lifetime
95        //      away as `'_`, ensuring it is not used for any output value.
96        let boxed_req = Box::new(req);
97        let request: &'c Request<'c> = unsafe { &*(&*boxed_req as *const _) };
98
99        async move {
100            // NOTE: The cookie jar `secure` state will not reflect the last
101            // known value in `request.cookies()`. This is okay: new cookies
102            // should never be added to the resulting jar which is the only time
103            // the value is used to set cookie defaults.
104            let response: Response<'c> = f(request).await;
105            let mut cookies = CookieJar::new(None, request.rocket());
106            for cookie in response.cookies() {
107                cookies.add_original(cookie.into_owned());
108            }
109
110            LocalResponse {
111                _request: boxed_req,
112                cookies,
113                response,
114            }
115        }
116    }
117}
118
119impl LocalResponse<'_> {
120    pub(crate) fn _response(&self) -> &Response<'_> {
121        &self.response
122    }
123
124    pub(crate) fn _cookies(&self) -> &CookieJar<'_> {
125        &self.cookies
126    }
127
128    pub(crate) async fn _into_string(mut self) -> io::Result<String> {
129        self.response.body_mut().to_string().await
130    }
131
132    pub(crate) async fn _into_bytes(mut self) -> io::Result<Vec<u8>> {
133        self.response.body_mut().to_bytes().await
134    }
135
136    #[cfg(feature = "json")]
137    async fn _into_json<T>(self) -> Option<T>
138    where
139        T: Send + serde::de::DeserializeOwned + 'static,
140    {
141        self.blocking_read(|r| serde_json::from_reader(r))
142            .await?
143            .ok()
144    }
145
146    #[cfg(feature = "msgpack")]
147    async fn _into_msgpack<T>(self) -> Option<T>
148    where
149        T: Send + serde::de::DeserializeOwned + 'static,
150    {
151        self.blocking_read(|r| rmp_serde::from_read(r)).await?.ok()
152    }
153
154    #[cfg(any(feature = "json", feature = "msgpack"))]
155    async fn blocking_read<T, F>(mut self, f: F) -> Option<T>
156    where
157        T: Send + 'static,
158        F: FnOnce(&mut dyn io::Read) -> T + Send + 'static,
159    {
160        use tokio::io::AsyncReadExt;
161        use tokio::sync::mpsc;
162
163        struct ChanReader {
164            last: Option<io::Cursor<Vec<u8>>>,
165            rx: mpsc::Receiver<io::Result<Vec<u8>>>,
166        }
167
168        impl std::io::Read for ChanReader {
169            fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
170                loop {
171                    if let Some(ref mut cursor) = self.last {
172                        if cursor.position() < cursor.get_ref().len() as u64 {
173                            return std::io::Read::read(cursor, buf);
174                        }
175                    }
176
177                    if let Some(buf) = self.rx.blocking_recv() {
178                        self.last = Some(io::Cursor::new(buf?));
179                    } else {
180                        return Ok(0);
181                    }
182                }
183            }
184        }
185
186        let (tx, rx) = mpsc::channel(2);
187        let reader = tokio::task::spawn_blocking(move || {
188            let mut reader = ChanReader { last: None, rx };
189            f(&mut reader)
190        });
191
192        loop {
193            // TODO: Try to fill as much as the buffer before send it off?
194            let mut buf = Vec::with_capacity(1024);
195            match self.read_buf(&mut buf).await {
196                Ok(0) => break,
197                Ok(_) => tx.send(Ok(buf)).await.ok()?,
198                Err(e) => {
199                    tx.send(Err(e)).await.ok()?;
200                    break;
201                }
202            }
203        }
204
205        // NOTE: We _must_ drop tx now to prevent a deadlock!
206        drop(tx);
207
208        reader.await.ok()
209    }
210
211    // Generates the public API methods, which call the private methods above.
212    pub_response_impl!("# use rocket::local::asynchronous::Client;\n\
213        use rocket::local::asynchronous::LocalResponse;" async await);
214}
215
216impl AsyncRead for LocalResponse<'_> {
217    fn poll_read(
218        mut self: Pin<&mut Self>,
219        cx: &mut Context<'_>,
220        buf: &mut ReadBuf<'_>,
221    ) -> Poll<io::Result<()>> {
222        Pin::new(self.response.body_mut()).poll_read(cx, buf)
223    }
224}
225
226impl std::fmt::Debug for LocalResponse<'_> {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        self._response().fmt(f)
229    }
230}