Skip to main content

ntex/client/
response.rs

1use std::cell::{Ref, RefMut};
2use std::task::{Context, Poll};
3use std::{cell::Cell, fmt, future::Future, marker::PhantomData, pin::Pin, rc::Rc};
4
5use serde::de::DeserializeOwned;
6
7#[cfg(feature = "cookie")]
8use coo_kie::{Cookie, ParseError as CookieParseError};
9
10use crate::http::error::PayloadError;
11use crate::http::header::{AsName, CONTENT_LENGTH, HeaderValue};
12use crate::http::{HeaderMap, HttpMessage, Payload, ResponseHead, StatusCode, Version};
13use crate::time::{Deadline, Millis};
14use crate::util::{Bytes, BytesMut, Extensions, Stream};
15
16use super::{ClientConfig, error::JsonPayloadError};
17
18/// Client Response
19pub struct ClientResponse {
20    pub(crate) head: ResponseHead,
21    pub(crate) payload: Cell<Option<Payload>>,
22    config: Rc<ClientConfig>,
23}
24
25impl HttpMessage for ClientResponse {
26    fn message_headers(&self) -> &HeaderMap {
27        &self.head.headers
28    }
29
30    fn message_extensions(&self) -> Ref<'_, Extensions> {
31        self.head.extensions()
32    }
33
34    fn message_extensions_mut(&self) -> RefMut<'_, Extensions> {
35        self.head.extensions_mut()
36    }
37
38    #[cfg(feature = "cookie")]
39    /// Load request cookies.
40    fn cookies(&self) -> Result<Ref<'_, Vec<Cookie<'static>>>, CookieParseError> {
41        use crate::http::header::SET_COOKIE;
42
43        struct Cookies(Vec<Cookie<'static>>);
44
45        if self.message_extensions().get::<Cookies>().is_none() {
46            let mut cookies = Vec::new();
47            for hdr in self.message_headers().get_all(&SET_COOKIE) {
48                let s =
49                    std::str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?;
50                cookies.push(Cookie::parse_encoded(s)?.into_owned());
51            }
52            self.message_extensions_mut().insert(Cookies(cookies));
53        }
54        Ok(Ref::map(self.message_extensions(), |ext| {
55            &ext.get::<Cookies>().unwrap().0
56        }))
57    }
58}
59
60impl ClientResponse {
61    /// Create new client response instance
62    #[doc(hidden)]
63    pub fn new(head: ResponseHead, payload: Payload, config: Rc<ClientConfig>) -> Self {
64        ClientResponse {
65            head,
66            config,
67            payload: Cell::new(Some(payload)),
68        }
69    }
70
71    #[cfg(feature = "ws")]
72    pub(crate) fn with_empty_payload(head: ResponseHead, config: Rc<ClientConfig>) -> Self {
73        ClientResponse::new(head, Payload::None, config)
74    }
75
76    #[inline]
77    pub(crate) fn head(&self) -> &ResponseHead {
78        &self.head
79    }
80
81    #[inline]
82    pub(crate) fn head_mut(&mut self) -> &mut ResponseHead {
83        &mut self.head
84    }
85
86    /// Read the Request Version.
87    #[inline]
88    pub fn version(&self) -> Version {
89        self.head().version
90    }
91
92    /// Get the status from the server.
93    #[inline]
94    pub fn status(&self) -> StatusCode {
95        self.head().status
96    }
97
98    #[inline]
99    /// Returns a reference to the header value.
100    pub fn header<N: AsName>(&self, name: N) -> Option<&HeaderValue> {
101        self.head().headers.get(name)
102    }
103
104    #[inline]
105    /// Returns response's headers.
106    pub fn headers(&self) -> &HeaderMap {
107        &self.head().headers
108    }
109
110    #[inline]
111    /// Returns mutable response's headers.
112    pub fn headers_mut(&mut self) -> &mut HeaderMap {
113        &mut self.head_mut().headers
114    }
115
116    /// Set a body and return previous body value
117    pub fn set_payload(&self, payload: Payload) {
118        self.payload.set(Some(payload));
119    }
120
121    /// Get response's payload
122    pub fn take_payload(&self) -> Payload {
123        if let Some(pl) = self.payload.take() {
124            pl
125        } else {
126            Payload::None
127        }
128    }
129
130    /// Request extensions
131    #[inline]
132    pub fn extensions(&self) -> Ref<'_, Extensions> {
133        self.head().extensions()
134    }
135
136    /// Mutable reference to a the request's extensions
137    #[inline]
138    pub fn extensions_mut(&self) -> RefMut<'_, Extensions> {
139        self.head().extensions_mut()
140    }
141}
142
143impl ClientResponse {
144    /// Loads http response's body.
145    pub fn body(&self) -> MessageBody {
146        MessageBody::new(self)
147    }
148
149    /// Loads and parse `application/json` encoded body.
150    /// Return `JsonBody<T>` future. It resolves to a `T` value.
151    ///
152    /// Returns error:
153    ///
154    /// * content type is not `application/json`
155    /// * content length is greater than 256k
156    pub fn json<T: DeserializeOwned>(&self) -> JsonBody<T> {
157        JsonBody::new(self)
158    }
159}
160
161impl Stream for ClientResponse {
162    type Item = Result<Bytes, PayloadError>;
163
164    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
165        if let Some(mut pl) = self.payload.take() {
166            let result = Pin::new(&mut pl).poll_next(cx);
167            self.payload.set(Some(pl));
168            result
169        } else {
170            Poll::Ready(None)
171        }
172    }
173}
174
175impl fmt::Debug for ClientResponse {
176    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177        writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?;
178        writeln!(f, "  headers:")?;
179        for (key, val) in self.headers() {
180            writeln!(f, "    {key:?}: {val:?}")?;
181        }
182        Ok(())
183    }
184}
185
186#[derive(Debug)]
187/// Future that resolves to a complete http message body.
188pub struct MessageBody {
189    length: Option<usize>,
190    err: Option<PayloadError>,
191    fut: Option<ReadBody>,
192}
193
194impl MessageBody {
195    /// Create `MessageBody` for request.
196    pub fn new(res: &ClientResponse) -> MessageBody {
197        let mut len = None;
198        if let Some(l) = res.headers().get(&CONTENT_LENGTH) {
199            if let Ok(s) = l.to_str() {
200                if let Ok(l) = s.parse::<usize>() {
201                    len = Some(l);
202                } else {
203                    return Self::err(PayloadError::UnknownLength);
204                }
205            } else {
206                return Self::err(PayloadError::UnknownLength);
207            }
208        }
209
210        MessageBody {
211            length: len,
212            err: None,
213            fut: Some(ReadBody::new(
214                res.take_payload(),
215                res.config.response_pl_limit,
216                res.config.response_pl_timeout,
217            )),
218        }
219    }
220
221    /// Change max size of payload. By default max size is 256Kb
222    pub fn limit(mut self, limit: usize) -> Self {
223        if let Some(ref mut fut) = self.fut {
224            fut.limit = limit;
225        }
226        self
227    }
228
229    /// Set operation timeout.
230    ///
231    /// By default timeout is set to 10 seconds. Set 0 millis to disable
232    /// timeout.
233    pub fn timeout(mut self, to: Millis) -> Self {
234        if let Some(ref mut fut) = self.fut {
235            fut.timeout.reset(to);
236        }
237        self
238    }
239
240    fn err(e: PayloadError) -> Self {
241        MessageBody {
242            fut: None,
243            err: Some(e),
244            length: None,
245        }
246    }
247}
248
249impl Future for MessageBody {
250    type Output = Result<Bytes, PayloadError>;
251
252    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
253        let this = self.get_mut();
254
255        if let Some(err) = this.err.take() {
256            return Poll::Ready(Err(err));
257        }
258
259        if let Some(len) = this.length.take() {
260            let limit = this.fut.as_ref().unwrap().limit;
261            if limit > 0 && len > limit {
262                return Poll::Ready(Err(PayloadError::Overflow));
263            }
264        }
265
266        Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx)
267    }
268}
269
270#[derive(Debug)]
271/// Response's payload json parser, it resolves to a deserialized `T` value.
272///
273/// Returns error:
274///
275/// * content type is not `application/json`
276/// * content length is greater than 64k
277pub struct JsonBody<U> {
278    length: Option<usize>,
279    err: Option<JsonPayloadError>,
280    fut: Option<ReadBody>,
281    _t: PhantomData<U>,
282}
283
284impl<U> JsonBody<U>
285where
286    U: DeserializeOwned,
287{
288    /// Create `JsonBody` for request.
289    pub fn new(res: &ClientResponse) -> Self {
290        // check content-type
291        let json = if let Ok(Some(mime)) = res.mime_type() {
292            mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
293        } else {
294            false
295        };
296        if !json {
297            return JsonBody {
298                length: None,
299                fut: None,
300                err: Some(JsonPayloadError::ContentType),
301                _t: PhantomData,
302            };
303        }
304
305        let mut len = None;
306        if let Some(l) = res.headers().get(&CONTENT_LENGTH)
307            && let Ok(s) = l.to_str()
308            && let Ok(l) = s.parse::<usize>()
309        {
310            len = Some(l);
311        }
312
313        JsonBody {
314            length: len,
315            err: None,
316            fut: Some(ReadBody::new(
317                res.take_payload(),
318                res.config.response_pl_limit,
319                res.config.response_pl_timeout,
320            )),
321            _t: PhantomData,
322        }
323    }
324
325    /// Change max size of payload. By default max size is 64Kb
326    pub fn limit(mut self, limit: usize) -> Self {
327        if let Some(ref mut fut) = self.fut {
328            fut.limit = limit;
329        }
330        self
331    }
332
333    /// Set operation timeout.
334    ///
335    /// By default timeout is set to 10 seconds. Set 0 millis to disable
336    /// timeout.
337    pub fn timeout(mut self, to: Millis) -> Self {
338        if let Some(ref mut fut) = self.fut {
339            fut.timeout.reset(to);
340        }
341        self
342    }
343}
344
345impl<U> Unpin for JsonBody<U> where U: DeserializeOwned {}
346
347impl<U> Future for JsonBody<U>
348where
349    U: DeserializeOwned,
350{
351    type Output = Result<U, JsonPayloadError>;
352
353    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
354        if let Some(err) = self.err.take() {
355            return Poll::Ready(Err(err));
356        }
357
358        if let Some(len) = self.length.take() {
359            let limit = self.fut.as_ref().unwrap().limit;
360            if limit > 0 && len > limit {
361                return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Overflow)));
362            }
363        }
364
365        let body = match Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx) {
366            Poll::Ready(result) => result?,
367            Poll::Pending => return Poll::Pending,
368        };
369        Poll::Ready(serde_json::from_slice::<U>(&body).map_err(JsonPayloadError::from))
370    }
371}
372
373#[derive(Debug)]
374struct ReadBody {
375    stream: Payload,
376    buf: BytesMut,
377    limit: usize,
378    timeout: Deadline,
379}
380
381impl ReadBody {
382    fn new(stream: Payload, limit: usize, timeout: Millis) -> Self {
383        Self {
384            stream,
385            limit,
386            buf: BytesMut::with_capacity(std::cmp::min(limit, 32768)),
387            timeout: Deadline::new(timeout),
388        }
389    }
390}
391
392impl Future for ReadBody {
393    type Output = Result<Bytes, PayloadError>;
394
395    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
396        let this = self.get_mut();
397
398        loop {
399            return match Pin::new(&mut this.stream).poll_next(cx) {
400                Poll::Ready(Some(Ok(chunk))) => {
401                    if this.limit > 0 && (this.buf.len() + chunk.len()) > this.limit {
402                        Poll::Ready(Err(PayloadError::Overflow))
403                    } else {
404                        this.buf.extend_from_slice(&chunk);
405                        continue;
406                    }
407                }
408                Poll::Ready(None) => Poll::Ready(Ok(this.buf.take())),
409                Poll::Ready(Some(Err(err))) => Poll::Ready(Err(err)),
410                Poll::Pending => {
411                    if this.timeout.poll_elapsed(cx).is_ready() {
412                        Poll::Ready(Err(PayloadError::Incomplete(Some(
413                            std::io::Error::new(
414                                std::io::ErrorKind::TimedOut,
415                                "Operation timed out",
416                            ),
417                        ))))
418                    } else {
419                        Poll::Pending
420                    }
421                }
422            };
423        }
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430    use serde::{Deserialize, Serialize};
431
432    use crate::{client::test::TestResponse, http::header};
433
434    #[crate::rt_test]
435    async fn test_body() {
436        let req = TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish();
437        match req.body().await.err().unwrap() {
438            PayloadError::UnknownLength => (),
439            _ => unreachable!("error"),
440        }
441
442        let req = TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish();
443        match req.body().await.err().unwrap() {
444            PayloadError::Overflow => (),
445            _ => unreachable!("error"),
446        }
447
448        let req = TestResponse::default()
449            .set_payload(Bytes::from_static(b"test"))
450            .finish();
451        assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test"));
452
453        let req = TestResponse::default()
454            .set_payload(Bytes::from_static(b"11111111111111"))
455            .finish();
456        match req.body().limit(5).await.err().unwrap() {
457            PayloadError::Overflow => (),
458            _ => unreachable!("error"),
459        }
460    }
461
462    #[derive(Serialize, Deserialize, PartialEq, Debug)]
463    struct MyObject {
464        name: String,
465    }
466
467    fn json_eq(err: &JsonPayloadError, other: &JsonPayloadError) -> bool {
468        match err {
469            JsonPayloadError::Payload(PayloadError::Overflow) => {
470                matches!(other, JsonPayloadError::Payload(PayloadError::Overflow))
471            }
472            JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType),
473            _ => false,
474        }
475    }
476
477    #[crate::rt_test]
478    async fn test_json_body() {
479        let req = TestResponse::default().finish();
480        let json = JsonBody::<MyObject>::new(&req).await;
481        assert!(json_eq(
482            &json.err().unwrap(),
483            &JsonPayloadError::ContentType
484        ));
485
486        let req = TestResponse::default()
487            .header(
488                header::CONTENT_TYPE,
489                header::HeaderValue::from_static("application/text"),
490            )
491            .finish();
492        let json = JsonBody::<MyObject>::new(&req).await;
493        assert!(json_eq(
494            &json.err().unwrap(),
495            &JsonPayloadError::ContentType
496        ));
497
498        let req = TestResponse::default()
499            .header(
500                header::CONTENT_TYPE,
501                header::HeaderValue::from_static("application/json"),
502            )
503            .header(
504                header::CONTENT_LENGTH,
505                header::HeaderValue::from_static("10000"),
506            )
507            .finish();
508
509        let json = JsonBody::<MyObject>::new(&req).limit(100).await;
510        assert!(json_eq(
511            &json.err().unwrap(),
512            &JsonPayloadError::Payload(PayloadError::Overflow)
513        ));
514
515        let req = TestResponse::default()
516            .header(
517                header::CONTENT_TYPE,
518                header::HeaderValue::from_static("application/json"),
519            )
520            .header(
521                header::CONTENT_LENGTH,
522                header::HeaderValue::from_static("16"),
523            )
524            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
525            .finish();
526
527        let json = JsonBody::<MyObject>::new(&req).await;
528        assert_eq!(
529            json.ok().unwrap(),
530            MyObject {
531                name: "test".to_owned()
532            }
533        );
534    }
535}