1use std::{convert::TryFrom, fmt, net, rc::Rc, time::Duration};
2
3use bytes::Bytes;
4use futures_core::Stream;
5use serde::Serialize;
6
7use actix_http::{
8 body::MessageBody,
9 error::HttpError,
10 header::{self, HeaderMap, HeaderValue, TryIntoHeaderPair},
11 ConnectionType, Method, RequestHead, Uri, Version,
12};
13
14use crate::{
15 client::ClientConfig,
16 error::{FreezeRequestError, InvalidUrl},
17 frozen::FrozenClientRequest,
18 sender::{PrepForSendingError, RequestSender, SendClientRequest},
19 BoxError,
20};
21
22#[cfg(feature = "cookies")]
23use crate::cookie::{Cookie, CookieJar};
24
25pub struct ClientRequest {
46 pub(crate) head: RequestHead,
47 err: Option<HttpError>,
48 addr: Option<net::SocketAddr>,
49 response_decompress: bool,
50 timeout: Option<Duration>,
51 config: ClientConfig,
52
53 #[cfg(feature = "cookies")]
54 cookies: Option<CookieJar>,
55}
56
57impl ClientRequest {
58 pub(crate) fn new<U>(method: Method, uri: U, config: ClientConfig) -> Self
60 where
61 Uri: TryFrom<U>,
62 <Uri as TryFrom<U>>::Error: Into<HttpError>,
63 {
64 ClientRequest {
65 config,
66 head: RequestHead::default(),
67 err: None,
68 addr: None,
69 #[cfg(feature = "cookies")]
70 cookies: None,
71 timeout: None,
72 response_decompress: true,
73 }
74 .method(method)
75 .uri(uri)
76 }
77
78 #[inline]
80 pub fn uri<U>(mut self, uri: U) -> Self
81 where
82 Uri: TryFrom<U>,
83 <Uri as TryFrom<U>>::Error: Into<HttpError>,
84 {
85 match Uri::try_from(uri) {
86 Ok(uri) => self.head.uri = uri,
87 Err(e) => self.err = Some(e.into()),
88 }
89 self
90 }
91
92 pub fn get_uri(&self) -> &Uri {
94 &self.head.uri
95 }
96
97 pub fn address(mut self, addr: net::SocketAddr) -> Self {
102 self.addr = Some(addr);
103 self
104 }
105
106 #[inline]
108 pub fn method(mut self, method: Method) -> Self {
109 self.head.method = method;
110 self
111 }
112
113 pub fn get_method(&self) -> &Method {
115 &self.head.method
116 }
117
118 #[doc(hidden)]
122 #[inline]
123 pub fn version(mut self, version: Version) -> Self {
124 self.head.version = version;
125 self
126 }
127
128 pub fn get_version(&self) -> &Version {
130 &self.head.version
131 }
132
133 pub fn get_peer_addr(&self) -> &Option<net::SocketAddr> {
135 &self.head.peer_addr
136 }
137
138 #[inline]
140 pub fn headers(&self) -> &HeaderMap {
141 &self.head.headers
142 }
143
144 #[inline]
146 pub fn headers_mut(&mut self) -> &mut HeaderMap {
147 &mut self.head.headers
148 }
149
150 pub fn insert_header(mut self, header: impl TryIntoHeaderPair) -> Self {
152 match header.try_into_pair() {
153 Ok((key, value)) => {
154 self.head.headers.insert(key, value);
155 }
156 Err(e) => self.err = Some(e.into()),
157 };
158
159 self
160 }
161
162 pub fn insert_header_if_none(mut self, header: impl TryIntoHeaderPair) -> Self {
164 match header.try_into_pair() {
165 Ok((key, value)) => {
166 if !self.head.headers.contains_key(&key) {
167 self.head.headers.insert(key, value);
168 }
169 }
170 Err(e) => self.err = Some(e.into()),
171 };
172
173 self
174 }
175
176 pub fn append_header(mut self, header: impl TryIntoHeaderPair) -> Self {
187 match header.try_into_pair() {
188 Ok((key, value)) => self.head.headers.append(key, value),
189 Err(e) => self.err = Some(e.into()),
190 };
191
192 self
193 }
194
195 #[inline]
197 pub fn camel_case(mut self) -> Self {
198 self.head.set_camel_case_headers(true);
199 self
200 }
201
202 #[inline]
205 pub fn force_close(mut self) -> Self {
206 self.head.set_connection_type(ConnectionType::Close);
207 self
208 }
209
210 #[inline]
212 pub fn content_type<V>(mut self, value: V) -> Self
213 where
214 HeaderValue: TryFrom<V>,
215 <HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
216 {
217 match HeaderValue::try_from(value) {
218 Ok(value) => {
219 self.head.headers.insert(header::CONTENT_TYPE, value);
220 }
221 Err(e) => self.err = Some(e.into()),
222 }
223 self
224 }
225
226 #[inline]
228 pub fn content_length(self, len: u64) -> Self {
229 let mut buf = itoa::Buffer::new();
230 self.insert_header((header::CONTENT_LENGTH, buf.format(len)))
231 }
232
233 pub fn basic_auth(self, username: impl fmt::Display, password: impl fmt::Display) -> Self {
237 let auth = format!("{}:{}", username, password);
238
239 self.insert_header((
240 header::AUTHORIZATION,
241 format!("Basic {}", base64::encode(&auth)),
242 ))
243 }
244
245 pub fn bearer_auth(self, token: impl fmt::Display) -> Self {
247 self.insert_header((header::AUTHORIZATION, format!("Bearer {}", token)))
248 }
249
250 #[cfg(feature = "cookies")]
266 pub fn cookie(mut self, cookie: Cookie<'_>) -> Self {
267 if self.cookies.is_none() {
268 let mut jar = CookieJar::new();
269 jar.add(cookie.into_owned());
270 self.cookies = Some(jar)
271 } else {
272 self.cookies.as_mut().unwrap().add(cookie.into_owned());
273 }
274 self
275 }
276
277 pub fn no_decompress(mut self) -> Self {
279 self.response_decompress = false;
280 self
281 }
282
283 pub fn timeout(mut self, timeout: Duration) -> Self {
288 self.timeout = Some(timeout);
289 self
290 }
291
292 pub fn query<T: Serialize>(
294 mut self,
295 query: &T,
296 ) -> Result<Self, serde_urlencoded::ser::Error> {
297 let mut parts = self.head.uri.clone().into_parts();
298
299 if let Some(path_and_query) = parts.path_and_query {
300 let query = serde_urlencoded::to_string(query)?;
301 let path = path_and_query.path();
302 parts.path_and_query = format!("{}?{}", path, query).parse().ok();
303
304 match Uri::from_parts(parts) {
305 Ok(uri) => self.head.uri = uri,
306 Err(e) => self.err = Some(e.into()),
307 }
308 }
309
310 Ok(self)
311 }
312
313 pub fn freeze(self) -> Result<FrozenClientRequest, FreezeRequestError> {
316 let slf = match self.prep_for_sending() {
317 Ok(slf) => slf,
318 Err(e) => return Err(e.into()),
319 };
320
321 let request = FrozenClientRequest {
322 head: Rc::new(slf.head),
323 addr: slf.addr,
324 response_decompress: slf.response_decompress,
325 timeout: slf.timeout,
326 config: slf.config,
327 };
328
329 Ok(request)
330 }
331
332 pub fn send_body<B>(self, body: B) -> SendClientRequest
334 where
335 B: MessageBody + 'static,
336 {
337 let slf = match self.prep_for_sending() {
338 Ok(slf) => slf,
339 Err(e) => return e.into(),
340 };
341
342 RequestSender::Owned(slf.head).send_body(
343 slf.addr,
344 slf.response_decompress,
345 slf.timeout,
346 &slf.config,
347 body,
348 )
349 }
350
351 pub fn send_json<T: Serialize>(self, value: &T) -> SendClientRequest {
353 let slf = match self.prep_for_sending() {
354 Ok(slf) => slf,
355 Err(e) => return e.into(),
356 };
357
358 RequestSender::Owned(slf.head).send_json(
359 slf.addr,
360 slf.response_decompress,
361 slf.timeout,
362 &slf.config,
363 value,
364 )
365 }
366
367 pub fn send_form<T: Serialize>(self, value: &T) -> SendClientRequest {
371 let slf = match self.prep_for_sending() {
372 Ok(slf) => slf,
373 Err(e) => return e.into(),
374 };
375
376 RequestSender::Owned(slf.head).send_form(
377 slf.addr,
378 slf.response_decompress,
379 slf.timeout,
380 &slf.config,
381 value,
382 )
383 }
384
385 pub fn send_stream<S, E>(self, stream: S) -> SendClientRequest
387 where
388 S: Stream<Item = Result<Bytes, E>> + 'static,
389 E: Into<BoxError> + 'static,
390 {
391 let slf = match self.prep_for_sending() {
392 Ok(slf) => slf,
393 Err(e) => return e.into(),
394 };
395
396 RequestSender::Owned(slf.head).send_stream(
397 slf.addr,
398 slf.response_decompress,
399 slf.timeout,
400 &slf.config,
401 stream,
402 )
403 }
404
405 pub fn send(self) -> SendClientRequest {
407 let slf = match self.prep_for_sending() {
408 Ok(slf) => slf,
409 Err(e) => return e.into(),
410 };
411
412 RequestSender::Owned(slf.head).send(
413 slf.addr,
414 slf.response_decompress,
415 slf.timeout,
416 &slf.config,
417 )
418 }
419
420 fn prep_for_sending(#[allow(unused_mut)] mut self) -> Result<Self, PrepForSendingError> {
422 if let Some(e) = self.err {
423 return Err(e.into());
424 }
425
426 let uri = &self.head.uri;
428 if uri.host().is_none() {
429 return Err(InvalidUrl::MissingHost.into());
430 } else if uri.scheme().is_none() {
431 return Err(InvalidUrl::MissingScheme.into());
432 } else if let Some(scheme) = uri.scheme() {
433 match scheme.as_str() {
434 "http" | "ws" | "https" | "wss" => {}
435 _ => return Err(InvalidUrl::UnknownScheme.into()),
436 }
437 } else {
438 return Err(InvalidUrl::UnknownScheme.into());
439 }
440
441 #[cfg(feature = "cookies")]
443 if let Some(ref mut jar) = self.cookies {
444 let cookie: String = jar
445 .delta()
446 .map(|c| c.stripped().encoded().to_string())
448 .collect::<Vec<_>>()
449 .join("; ");
450
451 if !cookie.is_empty() {
452 self.head
453 .headers
454 .insert(header::COOKIE, HeaderValue::from_str(&cookie).unwrap());
455 }
456 }
457
458 let mut slf = self;
459
460 if slf.response_decompress {
464 #[allow(clippy::vec_init_then_push)]
466 #[cfg(feature = "__compress")]
467 let accept_encoding = {
468 let mut encoding = vec![];
469
470 #[cfg(feature = "compress-brotli")]
471 {
472 encoding.push("br");
473 }
474
475 #[cfg(feature = "compress-gzip")]
476 {
477 encoding.push("gzip");
478 encoding.push("deflate");
479 }
480
481 #[cfg(feature = "compress-zstd")]
482 encoding.push("zstd");
483
484 assert!(
485 !encoding.is_empty(),
486 "encoding can not be empty unless __compress feature has been explicitly enabled"
487 );
488
489 encoding.join(", ")
490 };
491
492 #[cfg(not(feature = "__compress"))]
495 let accept_encoding = "identity";
496
497 slf = slf.insert_header_if_none((header::ACCEPT_ENCODING, accept_encoding));
498 }
499
500 Ok(slf)
501 }
502}
503
504impl fmt::Debug for ClientRequest {
505 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
506 writeln!(
507 f,
508 "\nClientRequest {:?} {} {}",
509 self.head.version, self.head.method, self.head.uri
510 )?;
511 writeln!(f, " headers:")?;
512 for (key, val) in self.head.headers.iter() {
513 writeln!(f, " {:?}: {:?}", key, val)?;
514 }
515 Ok(())
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use std::time::SystemTime;
522
523 use actix_http::header::HttpDate;
524
525 use super::*;
526 use crate::Client;
527
528 #[actix_rt::test]
529 async fn test_debug() {
530 let request = Client::new().get("/").append_header(("x-test", "111"));
531 let repr = format!("{:?}", request);
532 assert!(repr.contains("ClientRequest"));
533 assert!(repr.contains("x-test"));
534 }
535
536 #[actix_rt::test]
537 async fn test_basics() {
538 let req = Client::new()
539 .put("/")
540 .version(Version::HTTP_2)
541 .insert_header((header::DATE, HttpDate::from(SystemTime::now())))
542 .content_type("plain/text")
543 .append_header((header::SERVER, "awc"));
544
545 let req = if let Some(val) = Some("server") {
546 req.append_header((header::USER_AGENT, val))
547 } else {
548 req
549 };
550
551 let req = if let Some(_val) = Option::<&str>::None {
552 req.append_header((header::ALLOW, "1"))
553 } else {
554 req
555 };
556
557 let mut req = req.content_length(100);
558
559 assert!(req.headers().contains_key(header::CONTENT_TYPE));
560 assert!(req.headers().contains_key(header::DATE));
561 assert!(req.headers().contains_key(header::SERVER));
562 assert!(req.headers().contains_key(header::USER_AGENT));
563 assert!(!req.headers().contains_key(header::ALLOW));
564 assert!(!req.headers().contains_key(header::EXPECT));
565 assert_eq!(req.head.version, Version::HTTP_2);
566
567 let _ = req.headers_mut();
568 let _ = req.send_body("");
569 }
570
571 #[actix_rt::test]
572 async fn test_client_header() {
573 let req = Client::builder()
574 .add_default_header((header::CONTENT_TYPE, "111"))
575 .finish()
576 .get("/");
577
578 assert_eq!(
579 req.head
580 .headers
581 .get(header::CONTENT_TYPE)
582 .unwrap()
583 .to_str()
584 .unwrap(),
585 "111"
586 );
587 }
588
589 #[actix_rt::test]
590 async fn test_client_header_override() {
591 let req = Client::builder()
592 .add_default_header((header::CONTENT_TYPE, "111"))
593 .finish()
594 .get("/")
595 .insert_header((header::CONTENT_TYPE, "222"));
596
597 assert_eq!(
598 req.head
599 .headers
600 .get(header::CONTENT_TYPE)
601 .unwrap()
602 .to_str()
603 .unwrap(),
604 "222"
605 );
606 }
607
608 #[actix_rt::test]
609 async fn client_basic_auth() {
610 let req = Client::new().get("/").basic_auth("username", "password");
611 assert_eq!(
612 req.head
613 .headers
614 .get(header::AUTHORIZATION)
615 .unwrap()
616 .to_str()
617 .unwrap(),
618 "Basic dXNlcm5hbWU6cGFzc3dvcmQ="
619 );
620
621 let req = Client::new().get("/").basic_auth("username", "");
622 assert_eq!(
623 req.head
624 .headers
625 .get(header::AUTHORIZATION)
626 .unwrap()
627 .to_str()
628 .unwrap(),
629 "Basic dXNlcm5hbWU6"
630 );
631 }
632
633 #[actix_rt::test]
634 async fn client_bearer_auth() {
635 let req = Client::new().get("/").bearer_auth("someS3cr3tAutht0k3n");
636 assert_eq!(
637 req.head
638 .headers
639 .get(header::AUTHORIZATION)
640 .unwrap()
641 .to_str()
642 .unwrap(),
643 "Bearer someS3cr3tAutht0k3n"
644 );
645 }
646
647 #[actix_rt::test]
648 async fn client_query() {
649 let req = Client::new()
650 .get("/")
651 .query(&[("key1", "val1"), ("key2", "val2")])
652 .unwrap();
653 assert_eq!(req.get_uri().query().unwrap(), "key1=val1&key2=val2");
654 }
655}