1use std::cell::{Ref, RefMut};
2use std::task::{Context, Poll};
3use std::{cell::Cell, fmt, future::Future, marker::PhantomData, pin::Pin, rc::Rc};
4
5use serde::de::DeserializeOwned;
6
7#[cfg(feature = "cookie")]
8use coo_kie::{Cookie, ParseError as CookieParseError};
9
10use crate::http::error::PayloadError;
11use crate::http::header::{AsName, CONTENT_LENGTH, HeaderValue};
12use crate::http::{HeaderMap, HttpMessage, Payload, ResponseHead, StatusCode, Version};
13use crate::time::{Deadline, Millis};
14use crate::util::{Bytes, BytesMut, Extensions, Stream};
15
16use super::{ClientConfig, error::JsonPayloadError};
17
18pub struct ClientResponse {
20 pub(crate) head: ResponseHead,
21 pub(crate) payload: Cell<Option<Payload>>,
22 config: Rc<ClientConfig>,
23}
24
25impl HttpMessage for ClientResponse {
26 fn message_headers(&self) -> &HeaderMap {
27 &self.head.headers
28 }
29
30 fn message_extensions(&self) -> Ref<'_, Extensions> {
31 self.head.extensions()
32 }
33
34 fn message_extensions_mut(&self) -> RefMut<'_, Extensions> {
35 self.head.extensions_mut()
36 }
37
38 #[cfg(feature = "cookie")]
39 fn cookies(&self) -> Result<Ref<'_, Vec<Cookie<'static>>>, CookieParseError> {
41 use crate::http::header::SET_COOKIE;
42
43 struct Cookies(Vec<Cookie<'static>>);
44
45 if self.message_extensions().get::<Cookies>().is_none() {
46 let mut cookies = Vec::new();
47 for hdr in self.message_headers().get_all(&SET_COOKIE) {
48 let s =
49 std::str::from_utf8(hdr.as_bytes()).map_err(CookieParseError::from)?;
50 cookies.push(Cookie::parse_encoded(s)?.into_owned());
51 }
52 self.message_extensions_mut().insert(Cookies(cookies));
53 }
54 Ok(Ref::map(self.message_extensions(), |ext| {
55 &ext.get::<Cookies>().unwrap().0
56 }))
57 }
58}
59
60impl ClientResponse {
61 #[doc(hidden)]
63 pub fn new(head: ResponseHead, payload: Payload, config: Rc<ClientConfig>) -> Self {
64 ClientResponse {
65 head,
66 config,
67 payload: Cell::new(Some(payload)),
68 }
69 }
70
71 #[cfg(feature = "ws")]
72 pub(crate) fn with_empty_payload(head: ResponseHead, config: Rc<ClientConfig>) -> Self {
73 ClientResponse::new(head, Payload::None, config)
74 }
75
76 #[inline]
77 pub(crate) fn head(&self) -> &ResponseHead {
78 &self.head
79 }
80
81 #[inline]
82 pub(crate) fn head_mut(&mut self) -> &mut ResponseHead {
83 &mut self.head
84 }
85
86 #[inline]
88 pub fn version(&self) -> Version {
89 self.head().version
90 }
91
92 #[inline]
94 pub fn status(&self) -> StatusCode {
95 self.head().status
96 }
97
98 #[inline]
99 pub fn header<N: AsName>(&self, name: N) -> Option<&HeaderValue> {
101 self.head().headers.get(name)
102 }
103
104 #[inline]
105 pub fn headers(&self) -> &HeaderMap {
107 &self.head().headers
108 }
109
110 #[inline]
111 pub fn headers_mut(&mut self) -> &mut HeaderMap {
113 &mut self.head_mut().headers
114 }
115
116 pub fn set_payload(&self, payload: Payload) {
118 self.payload.set(Some(payload));
119 }
120
121 pub fn take_payload(&self) -> Payload {
123 if let Some(pl) = self.payload.take() {
124 pl
125 } else {
126 Payload::None
127 }
128 }
129
130 #[inline]
132 pub fn extensions(&self) -> Ref<'_, Extensions> {
133 self.head().extensions()
134 }
135
136 #[inline]
138 pub fn extensions_mut(&self) -> RefMut<'_, Extensions> {
139 self.head().extensions_mut()
140 }
141}
142
143impl ClientResponse {
144 pub fn body(&self) -> MessageBody {
146 MessageBody::new(self)
147 }
148
149 pub fn json<T: DeserializeOwned>(&self) -> JsonBody<T> {
157 JsonBody::new(self)
158 }
159}
160
161impl Stream for ClientResponse {
162 type Item = Result<Bytes, PayloadError>;
163
164 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
165 if let Some(mut pl) = self.payload.take() {
166 let result = Pin::new(&mut pl).poll_next(cx);
167 self.payload.set(Some(pl));
168 result
169 } else {
170 Poll::Ready(None)
171 }
172 }
173}
174
175impl fmt::Debug for ClientResponse {
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?;
178 writeln!(f, " headers:")?;
179 for (key, val) in self.headers() {
180 writeln!(f, " {key:?}: {val:?}")?;
181 }
182 Ok(())
183 }
184}
185
186#[derive(Debug)]
187pub struct MessageBody {
189 length: Option<usize>,
190 err: Option<PayloadError>,
191 fut: Option<ReadBody>,
192}
193
194impl MessageBody {
195 pub fn new(res: &ClientResponse) -> MessageBody {
197 let mut len = None;
198 if let Some(l) = res.headers().get(&CONTENT_LENGTH) {
199 if let Ok(s) = l.to_str() {
200 if let Ok(l) = s.parse::<usize>() {
201 len = Some(l);
202 } else {
203 return Self::err(PayloadError::UnknownLength);
204 }
205 } else {
206 return Self::err(PayloadError::UnknownLength);
207 }
208 }
209
210 MessageBody {
211 length: len,
212 err: None,
213 fut: Some(ReadBody::new(
214 res.take_payload(),
215 res.config.response_pl_limit,
216 res.config.response_pl_timeout,
217 )),
218 }
219 }
220
221 pub fn limit(mut self, limit: usize) -> Self {
223 if let Some(ref mut fut) = self.fut {
224 fut.limit = limit;
225 }
226 self
227 }
228
229 pub fn timeout(mut self, to: Millis) -> Self {
234 if let Some(ref mut fut) = self.fut {
235 fut.timeout.reset(to);
236 }
237 self
238 }
239
240 fn err(e: PayloadError) -> Self {
241 MessageBody {
242 fut: None,
243 err: Some(e),
244 length: None,
245 }
246 }
247}
248
249impl Future for MessageBody {
250 type Output = Result<Bytes, PayloadError>;
251
252 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
253 let this = self.get_mut();
254
255 if let Some(err) = this.err.take() {
256 return Poll::Ready(Err(err));
257 }
258
259 if let Some(len) = this.length.take() {
260 let limit = this.fut.as_ref().unwrap().limit;
261 if limit > 0 && len > limit {
262 return Poll::Ready(Err(PayloadError::Overflow));
263 }
264 }
265
266 Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx)
267 }
268}
269
270#[derive(Debug)]
271pub struct JsonBody<U> {
278 length: Option<usize>,
279 err: Option<JsonPayloadError>,
280 fut: Option<ReadBody>,
281 _t: PhantomData<U>,
282}
283
284impl<U> JsonBody<U>
285where
286 U: DeserializeOwned,
287{
288 pub fn new(res: &ClientResponse) -> Self {
290 let json = if let Ok(Some(mime)) = res.mime_type() {
292 mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
293 } else {
294 false
295 };
296 if !json {
297 return JsonBody {
298 length: None,
299 fut: None,
300 err: Some(JsonPayloadError::ContentType),
301 _t: PhantomData,
302 };
303 }
304
305 let mut len = None;
306 if let Some(l) = res.headers().get(&CONTENT_LENGTH)
307 && let Ok(s) = l.to_str()
308 && let Ok(l) = s.parse::<usize>()
309 {
310 len = Some(l);
311 }
312
313 JsonBody {
314 length: len,
315 err: None,
316 fut: Some(ReadBody::new(
317 res.take_payload(),
318 res.config.response_pl_limit,
319 res.config.response_pl_timeout,
320 )),
321 _t: PhantomData,
322 }
323 }
324
325 pub fn limit(mut self, limit: usize) -> Self {
327 if let Some(ref mut fut) = self.fut {
328 fut.limit = limit;
329 }
330 self
331 }
332
333 pub fn timeout(mut self, to: Millis) -> Self {
338 if let Some(ref mut fut) = self.fut {
339 fut.timeout.reset(to);
340 }
341 self
342 }
343}
344
345impl<U> Unpin for JsonBody<U> where U: DeserializeOwned {}
346
347impl<U> Future for JsonBody<U>
348where
349 U: DeserializeOwned,
350{
351 type Output = Result<U, JsonPayloadError>;
352
353 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
354 if let Some(err) = self.err.take() {
355 return Poll::Ready(Err(err));
356 }
357
358 if let Some(len) = self.length.take() {
359 let limit = self.fut.as_ref().unwrap().limit;
360 if limit > 0 && len > limit {
361 return Poll::Ready(Err(JsonPayloadError::Payload(PayloadError::Overflow)));
362 }
363 }
364
365 let body = match Pin::new(&mut self.get_mut().fut.as_mut().unwrap()).poll(cx) {
366 Poll::Ready(result) => result?,
367 Poll::Pending => return Poll::Pending,
368 };
369 Poll::Ready(serde_json::from_slice::<U>(&body).map_err(JsonPayloadError::from))
370 }
371}
372
373#[derive(Debug)]
374struct ReadBody {
375 stream: Payload,
376 buf: BytesMut,
377 limit: usize,
378 timeout: Deadline,
379}
380
381impl ReadBody {
382 fn new(stream: Payload, limit: usize, timeout: Millis) -> Self {
383 Self {
384 stream,
385 limit,
386 buf: BytesMut::with_capacity(std::cmp::min(limit, 32768)),
387 timeout: Deadline::new(timeout),
388 }
389 }
390}
391
392impl Future for ReadBody {
393 type Output = Result<Bytes, PayloadError>;
394
395 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
396 let this = self.get_mut();
397
398 loop {
399 return match Pin::new(&mut this.stream).poll_next(cx) {
400 Poll::Ready(Some(Ok(chunk))) => {
401 if this.limit > 0 && (this.buf.len() + chunk.len()) > this.limit {
402 Poll::Ready(Err(PayloadError::Overflow))
403 } else {
404 this.buf.extend_from_slice(&chunk);
405 continue;
406 }
407 }
408 Poll::Ready(None) => Poll::Ready(Ok(this.buf.take())),
409 Poll::Ready(Some(Err(err))) => Poll::Ready(Err(err)),
410 Poll::Pending => {
411 if this.timeout.poll_elapsed(cx).is_ready() {
412 Poll::Ready(Err(PayloadError::Incomplete(Some(
413 std::io::Error::new(
414 std::io::ErrorKind::TimedOut,
415 "Operation timed out",
416 ),
417 ))))
418 } else {
419 Poll::Pending
420 }
421 }
422 };
423 }
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430 use serde::{Deserialize, Serialize};
431
432 use crate::{client::test::TestResponse, http::header};
433
434 #[crate::rt_test]
435 async fn test_body() {
436 let req = TestResponse::with_header(header::CONTENT_LENGTH, "xxxx").finish();
437 match req.body().await.err().unwrap() {
438 PayloadError::UnknownLength => (),
439 _ => unreachable!("error"),
440 }
441
442 let req = TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish();
443 match req.body().await.err().unwrap() {
444 PayloadError::Overflow => (),
445 _ => unreachable!("error"),
446 }
447
448 let req = TestResponse::default()
449 .set_payload(Bytes::from_static(b"test"))
450 .finish();
451 assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test"));
452
453 let req = TestResponse::default()
454 .set_payload(Bytes::from_static(b"11111111111111"))
455 .finish();
456 match req.body().limit(5).await.err().unwrap() {
457 PayloadError::Overflow => (),
458 _ => unreachable!("error"),
459 }
460 }
461
462 #[derive(Serialize, Deserialize, PartialEq, Debug)]
463 struct MyObject {
464 name: String,
465 }
466
467 fn json_eq(err: &JsonPayloadError, other: &JsonPayloadError) -> bool {
468 match err {
469 JsonPayloadError::Payload(PayloadError::Overflow) => {
470 matches!(other, JsonPayloadError::Payload(PayloadError::Overflow))
471 }
472 JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType),
473 _ => false,
474 }
475 }
476
477 #[crate::rt_test]
478 async fn test_json_body() {
479 let req = TestResponse::default().finish();
480 let json = JsonBody::<MyObject>::new(&req).await;
481 assert!(json_eq(
482 &json.err().unwrap(),
483 &JsonPayloadError::ContentType
484 ));
485
486 let req = TestResponse::default()
487 .header(
488 header::CONTENT_TYPE,
489 header::HeaderValue::from_static("application/text"),
490 )
491 .finish();
492 let json = JsonBody::<MyObject>::new(&req).await;
493 assert!(json_eq(
494 &json.err().unwrap(),
495 &JsonPayloadError::ContentType
496 ));
497
498 let req = TestResponse::default()
499 .header(
500 header::CONTENT_TYPE,
501 header::HeaderValue::from_static("application/json"),
502 )
503 .header(
504 header::CONTENT_LENGTH,
505 header::HeaderValue::from_static("10000"),
506 )
507 .finish();
508
509 let json = JsonBody::<MyObject>::new(&req).limit(100).await;
510 assert!(json_eq(
511 &json.err().unwrap(),
512 &JsonPayloadError::Payload(PayloadError::Overflow)
513 ));
514
515 let req = TestResponse::default()
516 .header(
517 header::CONTENT_TYPE,
518 header::HeaderValue::from_static("application/json"),
519 )
520 .header(
521 header::CONTENT_LENGTH,
522 header::HeaderValue::from_static("16"),
523 )
524 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
525 .finish();
526
527 let json = JsonBody::<MyObject>::new(&req).await;
528 assert_eq!(
529 json.ok().unwrap(),
530 MyObject {
531 name: "test".to_owned()
532 }
533 );
534 }
535}