1use std::{collections::HashMap, convert::TryFrom, fmt, time::Duration};
7
8use http::{header::AUTHORIZATION, Method};
9use oauth1_request::signature_method::HmacSha1 as DefaultSM;
10use oauth1_request::signature_method::SignatureMethod;
11use reqwest::{
12 header::HeaderMap, header::HeaderName, header::HeaderValue,
13 IntoUrl,
14};
15
16#[cfg(feature = "blocking")]
17use reqwest::blocking::{
18 RequestBuilder as ReqwestRequestBuilder, Response,
19 Client as ReqwestClient, Body
20};
21
22#[cfg(not(feature = "blocking"))]
23use reqwest::{
24 RequestBuilder as ReqwestRequestBuilder, Response,
25 Client as ReqwestClient, Body
26};
27
28#[cfg(all(feature = "multipart", feature = "blocking"))]
29use reqwest::blocking::multipart;
30
31#[cfg(all(not(feature = "blocking"), feature = "multipart"))]
32use reqwest::multipart;
33
34use serde::Serialize;
35use url::Url;
36
37use crate::{
38 Error, OAuthParameters, SecretsProvider, SignResult, Signer, OAUTH_KEY_PREFIX, REALM_KEY,
39};
40
41#[derive(Debug)]
43pub struct RequestBuilder<TSigner>
44where
45 TSigner: Clone,
46{
47 method: Method,
48 inner: ReqwestRequestBuilder,
49 signer: TSigner,
50 url: Option<Url>,
51 body: String,
52 query_oauth_parameters: HashMap<String, String>,
53 form_oauth_parameters: HashMap<String, String>,
54}
55
56impl RequestBuilder<()> {
57 pub fn sign<'a, TSecrets>(
62 self,
63 secrets: TSecrets,
64 ) -> RequestBuilder<Signer<'a, TSecrets, DefaultSM>>
65 where
66 TSecrets: SecretsProvider + Clone,
67 {
68 self.sign_with_params(secrets, OAuthParameters::new())
69 }
70
71 pub fn sign_with_params<'a, TSecrets, TSM>(
73 self,
74 secrets: TSecrets,
75 params: OAuthParameters<'a, TSM>,
76 ) -> RequestBuilder<Signer<'a, TSecrets, TSM>>
77 where
78 TSecrets: SecretsProvider + Clone,
79 TSM: SignatureMethod + Clone,
80 {
81 RequestBuilder {
82 inner: self.inner,
83 method: self.method,
84 url: self.url,
85 body: self.body,
86 signer: Signer::new(secrets.into(), params),
87 query_oauth_parameters: self.query_oauth_parameters,
88 form_oauth_parameters: self.form_oauth_parameters,
89 }
90 }
91}
92
93impl<TSecrets, TSM> RequestBuilder<Signer<'_, TSecrets, TSM>>
94where
95 TSecrets: SecretsProvider + Clone,
96 TSM: SignatureMethod + Clone,
97{
98 #[cfg(feature = "blocking")]
109 pub fn send(self) -> Result<Response, Error> {
110 Ok(self.generate_signature()?.send()?)
111 }
112
113 #[cfg(not(feature = "blocking"))]
121 pub async fn send(self) -> Result<Response, Error> {
122 Ok(self.generate_signature()?.send().await?)
123 }
124
125 pub fn generate_signature(self) -> SignResult<ReqwestRequestBuilder> {
127 if let Some(url) = self.url {
128 let (is_q, url, payload) = match url.query() {
129 None | Some("") => {
130 (false, url, self.body.as_ref())
132 }
133 Some(q) => {
134 let mut pure_url = url.clone();
136 pure_url.set_query(None);
137 (true, pure_url, q)
138 }
139 };
140 let oauth_params: HashMap<String, String> = self
141 .form_oauth_parameters
142 .into_iter()
143 .chain(self.query_oauth_parameters.into_iter())
144 .collect();
145
146 let signature = self
147 .signer
148 .override_oauth_parameter(oauth_params)
149 .generate_signature(self.method, url, payload, is_q)?;
150 Ok(self.inner.header(AUTHORIZATION, signature))
153 } else {
154 Ok(self.inner)
156 }
157 }
158}
159
160impl<TSigner> RequestBuilder<TSigner>
161where
162 TSigner: Clone,
163{
164 pub(crate) fn new<T: IntoUrl + Clone>(
165 client: &ReqwestClient,
166 method: Method,
167 url: T,
168 signer: TSigner,
169 ) -> Self {
170 match url.clone().into_url() {
171 Ok(url) => {
172 let mut query_oauth_params: HashMap<String, String> = HashMap::new();
173 let stealed_url = steal_oauth_params_from_url(url, &mut query_oauth_params);
174 RequestBuilder {
175 inner: client.request(method.clone(), stealed_url.clone()),
176 method,
177 url: Some(stealed_url),
178 body: String::new(),
179 signer: signer,
180 query_oauth_parameters: query_oauth_params,
181 form_oauth_parameters: HashMap::new(),
182 }
183 }
184 Err(_) => RequestBuilder {
185 inner: client.request(method.clone(), url),
186 method,
187 url: None,
188 body: String::new(),
189 signer: signer,
190 query_oauth_parameters: HashMap::new(),
191 form_oauth_parameters: HashMap::new(),
192 },
193 }
194 }
195
196 pub fn query<T: Serialize + ?Sized>(mut self, query: &T) -> Self {
218 let query = steal_oauth_params(query, &mut self.query_oauth_parameters);
220
221 if let Some(ref mut url) = self.url {
223 let mut pairs = url.query_pairs_mut();
224 let serializer = serde_urlencoded::Serializer::new(&mut pairs);
225
226 let _ = query.serialize(serializer);
227 }
228 if let Some(ref mut url) = self.url {
230 if let Some("") = url.query() {
231 url.set_query(None);
232 }
233 }
234 self.inner = self.inner.query(&query);
236 self
237 }
238
239 pub fn form<T: Serialize + ?Sized + Clone>(mut self, form: &T) -> Self {
241 self.form_oauth_parameters.clear();
243 let form = steal_oauth_params(form, &mut self.query_oauth_parameters);
245
246 match serde_urlencoded::to_string(form.clone()) {
247 Ok(body) => {
248 self.inner = self.inner.form(&form);
249 self.body = body;
250 self
251 }
252 Err(_) => self.pass_through(|b| b.form(&form)),
253 }
254 }
255
256 #[cfg(feature = "json")]
273 pub fn json<T: Serialize + ?Sized>(self, json: &T) -> Self {
274 self.pass_through(|b: ReqwestRequestBuilder| b.json(json))
275 }
276
277 pub fn query_without_capture<T: Serialize>(self, query: &T) -> Self {
286 self.pass_through(|b| b.query(query))
287 }
288
289 pub fn form_without_capture<T: Serialize + ?Sized>(self, form: &T) -> Self {
294 self.pass_through(|b| b.form(form))
295 }
296
297 fn pass_through<F>(self, f: F) -> Self
301 where
302 F: FnOnce(ReqwestRequestBuilder) -> ReqwestRequestBuilder,
303 {
304 RequestBuilder {
305 inner: f(self.inner),
306 ..self
307 }
308 }
309
310 pub fn header<K, V>(self, key: K, value: V) -> Self
312 where
313 HeaderName: TryFrom<K>,
314 <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
315 HeaderValue: TryFrom<V>,
316 <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
317 {
318 self.pass_through(|b| b.header(key, value))
319 }
320
321 pub fn headers(mut self, headers: HeaderMap) -> Self {
325 self.inner = self.inner.headers(headers);
326 self
327 }
328
329 pub fn basic_auth<U, P>(self, username: U, password: Option<P>) -> Self
331 where
332 U: fmt::Display,
333 P: fmt::Display,
334 {
335 self.pass_through(|b| b.basic_auth(username, password))
336 }
337
338 pub fn bearer_auth<T>(self, token: T) -> Self
340 where
341 T: fmt::Display,
342 {
343 self.pass_through(|b| b.bearer_auth(token))
344 }
345
346 pub fn body<T: Into<Body>>(mut self, body: T) -> Self {
348 self.inner = self.inner.body(body);
349 self
350 }
351
352 pub fn timeout(mut self, timeout: Duration) -> Self {
358 self.inner = self.inner.timeout(timeout);
359 self
360 }
361
362 #[cfg(feature = "multipart")]
384 pub fn multipart(self, multipart: multipart::Form) -> Self {
385 self.pass_through(|b| b.multipart(multipart))
386 }
387
388 pub fn fetch_mode_no_cors(self) -> Self {
398 self
399 }
400
401 pub fn try_clone(&self) -> Option<Self> {
406 match self.inner.try_clone() {
407 Some(inner) => Some(RequestBuilder {
408 inner,
409 method: self.method.clone(),
410 url: self.url.clone(),
411 body: self.body.clone(),
412 signer: self.signer.clone(),
413 query_oauth_parameters: self.query_oauth_parameters.clone(),
414 form_oauth_parameters: self.form_oauth_parameters.clone(),
415 }),
416 None => None,
417 }
418 }
419}
420
421fn steal_oauth_params<T>(
422 query: &T,
423 oauth_map: &mut HashMap<String, String>,
424) -> Vec<(String, String)>
425where
426 T: Serialize + ?Sized,
427{
428 let mut empty_url = Url::parse("http://example.com/")
429 .expect("failed to parse the http://example.com/, that is unexpected behavior.");
431 {
432 let mut pairs = empty_url.query_pairs_mut();
433 let serializer = serde_urlencoded::Serializer::new(&mut pairs);
434 let _ = query.serialize(serializer);
435 }
436
437 steal_oauth_params_core(&empty_url, oauth_map)
439 .into_iter()
440 .map(|(k, v)| (k.to_string(), v.to_string()))
441 .collect()
442}
443
444fn steal_oauth_params_from_url(mut url: Url, oauth_map: &mut HashMap<String, String>) -> Url {
445 let remainder = steal_oauth_params_core(&url, oauth_map);
446 url.set_query(None);
448 if remainder.len() > 0 {
449 let mut serializer = url.query_pairs_mut();
451 for (k, v) in remainder {
452 serializer.append_pair(&k, &v);
453 }
454 }
455
456 url
457}
458
459fn steal_oauth_params_core(
460 url: &Url,
461 oauth_map: &mut HashMap<String, String>,
462) -> Vec<(String, String)> {
463 url.query_pairs()
465 .into_iter()
466 .map(|(k, v)| (k.to_string(), v.to_string()))
467 .filter_map(|(k, v)| {
468 if k.starts_with(OAUTH_KEY_PREFIX) || k == REALM_KEY {
469 oauth_map.insert(k, v);
471 None
472 } else {
473 Some((k, v))
474 }
475 })
476 .collect()
477}
478
479#[cfg(test)]
480mod tests {
481 use http::header::AUTHORIZATION;
482
483 #[cfg(feature = "blocking")]
484 use reqwest::blocking::{
485 RequestBuilder as ReqwestRequestBuilder, Response,
486 Client as ReqwestClient,
487 };
488
489 #[cfg(not(feature = "blocking"))]
490 use reqwest::{
491 RequestBuilder as ReqwestRequestBuilder, Response,
492 Client as ReqwestClient,
493 };
494
495 use crate::{
496 OAuthClientProvider, OAuthParameters, Secrets, OAUTH_NONCE_KEY, OAUTH_TIMESTAMP_KEY,
497 };
498
499 fn extract_signature(auth_header: &str) -> String {
500 let content = auth_header.strip_prefix("OAuth ").unwrap();
501 let mapped_header = content
502 .split(',')
503 .map(|item| item.splitn(2, '=').collect::<Vec<&str>>())
504 .filter(|v| v.len() == 2)
505 .map(|v| (v[0], v[1]))
506 .collect::<Vec<(&str, &str)>>();
507 let sig_content = mapped_header.iter().find(|(k, _)| k == &"oauth_signature");
508 percent_encoding::percent_decode_str(sig_content.unwrap().1)
509 .decode_utf8_lossy()
510 .trim_matches('"')
511 .to_string()
512 }
513
514 #[test]
515 fn call_multiple_queries() {
516 let req = ReqwestClient::new()
517 .get("https://example.com")
518 .query(&[("a", "b")])
519 .query(&[("c", "d")])
520 .build()
521 .unwrap();
522 assert_eq!(req.url().to_string(), "https://example.com/?a=b&c=d");
524 }
525
526 #[test]
527 fn call_multiple_forms() {
528 let req = ReqwestClient::new()
529 .post("https://example.com")
530 .query(&[("this is", "query")])
531 .form(&[("a", "b")]) .form(&[("c", "d")])
533 .build()
534 .unwrap();
535 let decoded_body = String::from_utf8_lossy(req.body().unwrap().as_bytes().unwrap());
537 assert_eq!(req.url().to_string(), "https://example.com/?this+is=query");
539 assert_eq!(decoded_body, "c=d");
540 }
541
542 #[test]
543 fn capture_post_query() {
544 let endpoint = "https://photos.example.net/initiate";
545 let c_key = "dpf43f3p2l4k3l03";
546 let c_secret = "kd94hf93k423kf44";
547 let nonce = "wIjqoS";
548 let timestamp = 137_131_200u64;
549
550 let secrets = Secrets::new(c_key, c_secret);
551 let params = OAuthParameters::new()
552 .nonce(nonce)
553 .timestamp(timestamp)
554 .callback("http://printer.example.com/ready")
555 .realm("photos");
556
557 let req = ReqwestClient::new()
558 .oauth1_with_params(secrets, params)
559 .post(endpoint)
560 .form(&[("少女", "終末旅行"), ("oauth_should_be_ignored", "true")]);
561 let url = req.body;
562 assert_eq!(
564 url,
565 "%E5%B0%91%E5%A5%B3=%E7%B5%82%E6%9C%AB%E6%97%85%E8%A1%8C"
566 );
567 }
568
569 #[test]
570 fn sign_post_query() {
571 let endpoint = "https://photos.example.net/initiate";
573 let c_key = "dpf43f3p2l4k3l03";
574 let c_secret = "kd94hf93k423kf44";
575 let nonce = "wIjqoS";
576 let timestamp = 137_131_200u64;
577
578 let secrets = Secrets::new(c_key, c_secret);
579 let params = OAuthParameters::new()
580 .nonce(nonce)
581 .timestamp(timestamp)
582 .callback("http://printer.example.com/ready")
583 .realm("photos");
584
585 let req = ReqwestClient::new()
586 .oauth1_with_params(secrets, params)
587 .post(endpoint)
588 .generate_signature()
589 .unwrap()
590 .build()
591 .unwrap();
592
593 let sign = req.headers().get(AUTHORIZATION);
594 assert_eq!(
596 extract_signature(sign.unwrap().to_str().unwrap()),
597 "74KNZJeDHnMBp0EMJ9ZHt/XKycU="
598 );
599 }
600
601 #[test]
602 fn capture_get_query() {
603 let endpoint = "https://photos.example.net/photos?file=vacation.jpg&size=original&oauth_should_be_ignored=true";
605 let c_key = "dpf43f3p2l4k3l03";
606 let c_secret = "kd94hf93k423kf44";
607 let token = "nnch734d00sl2jdk";
608 let token_secret = "pfkkdhi9sl3r4s00";
609 let nonce = "wIjqoS";
610 let timestamp = 137_131_200u64;
611
612 let secrets = Secrets::new(c_key, c_secret).token(token, token_secret);
613 let params = OAuthParameters::new().nonce(nonce).timestamp(timestamp);
614
615 let req = ReqwestClient::new()
616 .oauth1_with_params(secrets, params)
617 .get(endpoint);
618 let query = req.url.unwrap().query().unwrap().to_string();
619 assert_eq!(query, "file=vacation.jpg&size=original")
621 }
622
623 #[test]
624 fn sign_get_query() {
625 let endpoint = "http://photos.example.net/photos?file=vacation.jpg&size=original";
627 let c_key = "dpf43f3p2l4k3l03";
628 let c_secret = "kd94hf93k423kf44";
629 let token = "nnch734d00sl2jdk";
630 let token_secret = "pfkkdhi9sl3r4s00";
631 let nonce = "chapoH";
632 let timestamp = 137_131_202u64;
633
634 let secrets = Secrets::new(c_key, c_secret).token(token, token_secret);
635 let params = OAuthParameters::new()
636 .nonce(nonce)
637 .timestamp(timestamp)
638 .realm("Photos");
639 let req = ReqwestClient::new()
642 .oauth1_with_params(secrets, params)
643 .get(endpoint)
644 .generate_signature()
645 .unwrap()
646 .build()
647 .unwrap();
648
649 let sign = req.headers().get(AUTHORIZATION);
650 assert_eq!(
652 extract_signature(sign.unwrap().to_str().unwrap()),
653 "MdpQcU8iPSUjWoN/UDMsK2sui9I="
654 );
655
656 }
659
660 #[test]
661 fn sign_get_query_with_query_oauth_params() {
662 let endpoint =
664 "http://photos.example.net/photos?file=vacation.jpg&size=original&realm=Photos";
665 let c_key = "dpf43f3p2l4k3l03";
666 let c_secret = "kd94hf93k423kf44";
667 let token = "nnch734d00sl2jdk";
668 let token_secret = "pfkkdhi9sl3r4s00";
669 let nonce = "chapoH";
670 let timestamp = 137_131_202u64;
671
672 let secrets = Secrets::new(c_key, c_secret).token(token, token_secret);
673 let req = ReqwestClient::new()
676 .oauth1(secrets)
677 .get(endpoint)
678 .query(&[
679 (OAUTH_NONCE_KEY, nonce),
680 (OAUTH_TIMESTAMP_KEY, &format!("{}", timestamp)),
681 ])
682 .generate_signature()
683 .unwrap()
684 .build()
685 .unwrap();
686
687 let sign = req.headers().get(AUTHORIZATION);
688 assert_eq!(
690 extract_signature(sign.unwrap().to_str().unwrap()),
691 "MdpQcU8iPSUjWoN/UDMsK2sui9I="
692 );
693
694 }
697
698 #[test]
699 fn capture_body() {
700 let endpoint = url::Url::parse("https://api.twitter.com/1.1/statuses/update.json").unwrap();
702 let c_key = "xvz1evFS4wEEPTGEFPHBog";
703 let c_secret = "kAcSOqF21Fu85e7zjz7ZN2U4ZRhfV3WpwPAoE3Z7kBw";
704 let nonce = "kYjzVBB8Y0ZFabxSWbWovY3uYSQ2pTgmZeNu2VS4cg";
705 let timestamp = 1_318_622_958u64;
706 let token = "370773112-GmHxMAgYyLbNEtIKZeRNFsMKPR9EyMZeS9weJAEb";
707 let token_secret = "LswwdoUaIvS8ltyTt5jkRh4J50vUPVVHtR2YPi5kE";
708
709 let secrets = Secrets::new(c_key, c_secret).token(token, token_secret);
710 let params = OAuthParameters::new().nonce(nonce).timestamp(timestamp);
711
712 let req = ReqwestClient::new()
713 .oauth1_with_params(secrets, params)
714 .post(endpoint)
715 .form(&[
716 ("include_entities", "true"),
717 (
718 "status",
719 "Hello Ladies + Gentlemen, a signed OAuth request!",
720 ),
721 ]);
722
723 let body = req.body;
724 assert_eq!(
726 body,
727 "include_entities=true&status=Hello+Ladies+%2B+Gentlemen%2C+a+signed+OAuth+request%21"
728 )
729 }
730
731 #[test]
732 fn sign_post_body() {
733 let endpoint = url::Url::parse("https://api.twitter.com/1.1/statuses/update.json").unwrap();
735 let c_key = "xvz1evFS4wEEPTGEFPHBog";
736 let c_secret = "kAcSOqF21Fu85e7zjz7ZN2U4ZRhfV3WpwPAoE3Z7kBw";
737 let nonce = "kYjzVBB8Y0ZFabxSWbWovY3uYSQ2pTgmZeNu2VS4cg";
738 let timestamp = 1_318_622_958u64;
739 let token = "370773112-GmHxMAgYyLbNEtIKZeRNFsMKPR9EyMZeS9weJAEb";
740 let token_secret = "LswwdoUaIvS8ltyTt5jkRh4J50vUPVVHtR2YPi5kE";
741
742 let secrets = Secrets::new(c_key, c_secret).token(token, token_secret);
743 let params = OAuthParameters::new()
744 .nonce(nonce)
745 .timestamp(timestamp)
746 .version(true);
747
748 let req = ReqwestClient::new()
749 .oauth1_with_params(secrets, params)
750 .post(endpoint)
751 .form(&[
752 ("include_entities", "true"),
753 (
754 "status",
755 "Hello Ladies + Gentlemen, a signed OAuth request!",
756 ),
757 ])
758 .generate_signature()
759 .unwrap()
760 .build()
761 .unwrap();
762
763 let sign = req.headers().get(AUTHORIZATION);
764 assert_eq!(
766 extract_signature(sign.unwrap().to_str().unwrap()),
767 "hCtSmYh+iHYCEqBWrE7C7hYmtUk="
768 );
769 }
770}
771