awc/responses/
response_body.rs

1use std::{
2    future::Future,
3    mem,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use actix_http::{error::PayloadError, header, HttpMessage};
9use bytes::Bytes;
10use futures_core::Stream;
11use pin_project_lite::pin_project;
12
13use super::{read_body::ReadBody, ResponseTimeout, DEFAULT_BODY_LIMIT};
14use crate::ClientResponse;
15
16pin_project! {
17    /// A `Future` that reads a body stream, resolving as [`Bytes`].
18    ///
19    /// # Errors
20    /// `Future` implementation returns error if:
21    /// - content type is not `application/json`;
22    /// - content length is greater than [limit](JsonBody::limit) (default: 2 MiB).
23    pub struct ResponseBody<S> {
24        #[pin]
25        body: Option<ReadBody<S>>,
26        length: Option<usize>,
27        timeout: ResponseTimeout,
28        err: Option<PayloadError>,
29    }
30}
31
32#[deprecated(since = "3.0.0", note = "Renamed to `ResponseBody`.")]
33pub type MessageBody<B> = ResponseBody<B>;
34
35impl<S> ResponseBody<S>
36where
37    S: Stream<Item = Result<Bytes, PayloadError>>,
38{
39    /// Creates a body stream reader from a response by taking its payload.
40    pub fn new(res: &mut ClientResponse<S>) -> ResponseBody<S> {
41        let length = match res.headers().get(&header::CONTENT_LENGTH) {
42            Some(value) => {
43                let len = value.to_str().ok().and_then(|s| s.parse::<usize>().ok());
44
45                match len {
46                    None => return Self::err(PayloadError::UnknownLength),
47                    len => len,
48                }
49            }
50            None => None,
51        };
52
53        ResponseBody {
54            body: Some(ReadBody::new(res.take_payload(), DEFAULT_BODY_LIMIT)),
55            length,
56            timeout: mem::take(&mut res.timeout),
57            err: None,
58        }
59    }
60
61    /// Change max size limit of payload.
62    ///
63    /// The default limit is 2 MiB.
64    pub fn limit(mut self, limit: usize) -> Self {
65        if let Some(ref mut body) = self.body {
66            body.limit = limit;
67        }
68
69        self
70    }
71
72    fn err(err: PayloadError) -> Self {
73        ResponseBody {
74            body: None,
75            length: None,
76            timeout: ResponseTimeout::default(),
77            err: Some(err),
78        }
79    }
80}
81
82impl<S> Future for ResponseBody<S>
83where
84    S: Stream<Item = Result<Bytes, PayloadError>>,
85{
86    type Output = Result<Bytes, PayloadError>;
87
88    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
89        let this = self.project();
90
91        if let Some(err) = this.err.take() {
92            return Poll::Ready(Err(err));
93        }
94
95        if let Some(len) = this.length.take() {
96            let body = Option::as_ref(&this.body).unwrap();
97            if len > body.limit {
98                return Poll::Ready(Err(PayloadError::Overflow));
99            }
100        }
101
102        this.timeout.poll_timeout(cx)?;
103
104        this.body.as_pin_mut().unwrap().poll(cx)
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use static_assertions::assert_impl_all;
111
112    use super::*;
113    use crate::{http::header, test::TestResponse};
114
115    assert_impl_all!(ResponseBody<()>: Unpin);
116
117    #[actix_rt::test]
118    async fn read_body() {
119        let mut req = TestResponse::with_header((header::CONTENT_LENGTH, "xxxx")).finish();
120        match req.body().await.err().unwrap() {
121            PayloadError::UnknownLength => {}
122            _ => unreachable!("error"),
123        }
124
125        let mut req = TestResponse::with_header((header::CONTENT_LENGTH, "10000000")).finish();
126        match req.body().await.err().unwrap() {
127            PayloadError::Overflow => {}
128            _ => unreachable!("error"),
129        }
130
131        let mut req = TestResponse::default()
132            .set_payload(Bytes::from_static(b"test"))
133            .finish();
134        assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test"));
135
136        let mut req = TestResponse::default()
137            .set_payload(Bytes::from_static(b"11111111111111"))
138            .finish();
139        match req.body().limit(5).await.err().unwrap() {
140            PayloadError::Overflow => {}
141            _ => unreachable!("error"),
142        }
143    }
144}