1#![warn(bad_style)]
23#![warn(missing_docs)]
24#![warn(unused)]
25#![warn(unused_extern_crates)]
26#![warn(unused_import_braces)]
27#![warn(unused_qualifications)]
28#![warn(unused_results)]
29#![allow(unused_doc_comments)]
30#![cfg_attr(docsrs, feature(doc_cfg))]
31
32use http::{
33 header::{HeaderName, AUTHORIZATION, CONTENT_TYPE},
34 HeaderValue, StatusCode,
35};
36use log::debug;
37use rand::{distributions::Alphanumeric, Rng};
38use ring::hmac;
39use std::{borrow::Cow, collections::HashMap, convert::TryFrom, io, iter, mem::MaybeUninit};
40use thiserror::Error;
41use time::OffsetDateTime;
42#[cfg(all(feature = "reqwest-blocking"))]
43use ::{
44 lazy_static::lazy_static,
45 reqwest::blocking::Client,
46 std::{io::Read, str::FromStr},
47 url::Url,
48};
49
50#[cfg(feature = "client-reqwest")]
52#[cfg_attr(docsrs, doc(cfg(feature = "client-reqwest")))]
53pub use reqwest;
54use std::fmt::Debug;
55
56#[derive(Debug, Error)]
58#[non_exhaustive]
59pub enum Error {
60 #[error("HTTP status error code: {0}")]
62 HttpStatus(StatusCode),
63
64 #[error("IO error: {0}")]
66 Io(#[from] io::Error),
67
68 #[error("HTTP request error: {0}")]
70 HttpRequest(Box<dyn std::error::Error + Send + Sync + 'static>),
71}
72
73#[cfg(feature = "client-reqwest")]
74#[cfg_attr(docsrs, doc(cfg(feature = "client-reqwest")))]
75impl From<reqwest::Error> for Error {
76 fn from(e: reqwest::Error) -> Self {
77 Self::HttpRequest(e.into())
78 }
79}
80
81#[cfg(feature = "reqwest-blocking")]
82lazy_static! {
83 static ref CLIENT: Client = Client::new();
84}
85
86#[derive(Clone, Debug)]
88pub struct Token<'a> {
89 pub key: Cow<'a, str>,
91 pub secret: Cow<'a, str>,
93}
94
95impl<'a> Token<'a> {
96 pub fn new<K, S>(key: K, secret: S) -> Token<'a>
104 where
105 K: Into<Cow<'a, str>>,
106 S: Into<Cow<'a, str>>,
107 {
108 Token {
109 key: key.into(),
110 secret: secret.into(),
111 }
112 }
113}
114
115pub type ParamList<'a> = HashMap<Cow<'a, str>, Cow<'a, str>>;
117
118fn insert_param<'a, K, V>(param: &mut ParamList<'a>, key: K, value: V) -> Option<Cow<'a, str>>
119where
120 K: Into<Cow<'a, str>>,
121 V: Into<Cow<'a, str>>,
122{
123 param.insert(key.into(), value.into())
124}
125
126fn join_query(param: &ParamList<'_>) -> String {
127 let mut pairs = param
128 .iter()
129 .map(|(k, v)| format!("{}={}", encode(k), encode(v)))
130 .collect::<Vec<_>>();
131 pairs.sort();
132 pairs.join("&")
133}
134
135const STRICT_ENCODE_SET: percent_encoding::AsciiSet = percent_encoding::NON_ALPHANUMERIC
143 .remove(b'-')
144 .remove(b'.')
145 .remove(b'_')
146 .remove(b'~');
147
148use self::percent_encode_string as encode;
149pub fn percent_encode_string(s: &str) -> Cow<str> {
151 percent_encoding::percent_encode(s.as_bytes(), &STRICT_ENCODE_SET).collect()
152}
153
154pub fn signature(
156 method: &str,
157 uri: &str,
158 query: &str,
159 consumer_secret: &str,
160 token_secret: Option<&str>,
161) -> String {
162 let base = format!("{}&{}&{}", encode(method), encode(uri), encode(query));
163 let key = format!(
164 "{}&{}",
165 encode(consumer_secret),
166 encode(token_secret.unwrap_or(""))
167 );
168 debug!("Signature base string: {}", base);
169 debug!("Authorization header: Authorization: {}", base);
170 let signing_key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, key.as_bytes());
171 let signature = hmac::sign(&signing_key, base.as_bytes());
172 base64::encode(signature.as_ref())
173}
174
175#[derive(Debug, Error)]
177#[non_exhaustive]
178pub enum VerifyError {
179 #[error("Authorization header not found")]
181 NoAuthorizationHeader,
182
183 #[error("Non ASCII values in header: {0}")]
185 NonAsciiHeader(#[source] http::header::ToStrError),
186
187 #[error("Invalid key value pair in query params")]
189 InvalidKeyValuePair,
190}
191
192pub trait GenericRequest {
196 fn headers(&self) -> &http::header::HeaderMap<HeaderValue>;
198
199 fn url(&self) -> &str;
201
202 fn method(&self) -> &str;
204}
205
206#[cfg(feature = "client-reqwest")]
207#[cfg_attr(docsrs, doc(cfg(feature = "client-reqwest")))]
208impl GenericRequest for reqwest::Request {
209 fn headers(&self) -> &http::header::HeaderMap<HeaderValue> {
210 self.headers()
211 }
212
213 fn url(&self) -> &str {
214 if let Some(host) = self.headers().get("host") {
215 host.to_str().unwrap_or_else(|_e| self.url().as_str())
218 } else {
219 self.url().as_str()
220 }
221 }
222
223 fn method(&self) -> &str {
224 self.method().as_str()
225 }
226}
227
228pub fn check_signature_request<R: GenericRequest>(
280 request: R,
281 consumer_secret: &str,
282 token_secret: Option<&str>,
283 url_middleware: impl for<'a> FnOnce(&'a str) -> Cow<'a, str>,
284) -> Result<bool, VerifyError> {
285 let authorization_header = request
286 .headers()
287 .get("Authorization")
288 .ok_or(VerifyError::NoAuthorizationHeader)?;
289
290 let (provided_signature, mut auth_params_without_signature): (Vec<&str>, Vec<&str>) =
291 authorization_header
292 .to_str()
293 .map_err(VerifyError::NonAsciiHeader)?
294 .split(',')
295 .map(str::trim)
296 .partition(|x| x.starts_with("oauth_signature="));
297
298 assert_eq!(
299 provided_signature.len(),
300 1,
301 "provided_signature: {:?}",
302 provided_signature
303 );
304 let provided_signature = provided_signature.first().unwrap();
305 let all_other_max = auth_params_without_signature.len() - 1;
306 let mut all_together_max = all_other_max;
307 let mut query_params = None;
308 let mut url = request.url();
309 if let Some(qm_i) = request.url().rfind('?') {
310 url = &url[..qm_i];
312 let qp = request.url()[qm_i + 1..].split('&').collect::<Vec<&str>>();
313 all_together_max += qp.len();
314 query_params = Some(qp);
315 }
316 fn split_key_value_pair(qp: &str) -> Result<(&str, &str), VerifyError> {
317 qp.split_once('=')
318 .ok_or(VerifyError::InvalidKeyValuePair)
319 .map(|(k, v)| (k, v))
320 }
321 auth_params_without_signature[0] = &auth_params_without_signature[0]["OAuth ".len()..];
323 let query: Result<Vec<(&str, &str)>, VerifyError> = auth_params_without_signature
324 .into_iter()
325 .map(|qp|
326 split_key_value_pair(qp).map(|(k,v)| (k, &v[1..v.len()-1]))
327 )
328 .chain(
330 if let Some(query_params) = query_params.take() {
331 query_params.into_iter()
332 } else {
333 Vec::new().into_iter()
334 }.map(split_key_value_pair)
335 )
336 .collect();
337 let mut query = query?;
338 query.sort_by(|(a, _), (b, _)| a.cmp(b));
339
340 let query: String = query
341 .iter()
342 .enumerate()
343 .flat_map(|(i, (k, v))| [k, "=", v, if i == all_together_max { &"" } else { &"&" }])
344 .collect();
345
346 let url = url_middleware(url);
348
349 return Ok(check_signature(
350 &provided_signature["oauth_signature=\"".len()..provided_signature.len() - 1],
351 request.method(),
352 &url,
353 &query,
354 consumer_secret,
355 token_secret,
356 ));
357}
358
359pub fn check_signature(
365 signature_to_check: &str,
366 method: &str,
367 uri: &str,
368 query: &str,
369 consumer_secret: &str,
370 token_secret: Option<&str>,
371) -> bool {
372 let signature = signature(method, uri, query, consumer_secret, token_secret);
373 let new_encoded_signature = encode(&signature);
374
375 new_encoded_signature == signature_to_check
376}
377
378fn header(param: &ParamList<'_>) -> String {
382 let mut pairs = param
383 .iter()
384 .filter(|&(k, _)| k.starts_with("oauth_"))
385 .map(|(k, v)| format!("{}=\"{}\"", k, encode(v)))
386 .collect::<Vec<_>>();
387 pairs.sort();
388 format!("OAuth {}", pairs.join(", "))
389}
390
391fn body(param: &ParamList<'_>) -> String {
393 let mut pairs = param
394 .iter()
395 .filter(|&(k, _)| !k.starts_with("oauth_"))
396 .map(|(k, v)| format!("{}={}", k, encode(v)))
397 .collect::<Vec<_>>();
398 pairs.sort();
399 pairs.join("&")
400}
401
402fn get_header(
404 method: &str,
405 uri: &str,
406 consumer: &Token<'_>,
407 token: Option<&Token<'_>>,
408 other_param: Option<&ParamList<'_>>,
409) -> (String, String) {
410 let mut param = HashMap::new();
411 let timestamp = format!("{}", OffsetDateTime::now_utc().unix_timestamp());
412 let mut rng = rand::thread_rng();
413 let nonce = iter::repeat(())
414 .map(|()| rng.sample(Alphanumeric))
415 .map(char::from)
416 .take(32)
417 .collect::<String>();
418
419 let _ = insert_param(&mut param, "oauth_consumer_key", consumer.key.to_string());
420 let _ = insert_param(&mut param, "oauth_nonce", nonce);
421 let _ = insert_param(&mut param, "oauth_signature_method", "HMAC-SHA1");
422 let _ = insert_param(&mut param, "oauth_timestamp", timestamp);
423 let _ = insert_param(&mut param, "oauth_version", "1.0");
424 if let Some(tk) = token {
425 let _ = insert_param(&mut param, "oauth_token", tk.key.as_ref());
426 }
427
428 if let Some(ps) = other_param {
429 for (k, v) in ps.iter() {
430 let _ = insert_param(&mut param, k.as_ref(), v.as_ref());
431 }
432 }
433
434 let sign = signature(
435 method,
436 uri,
437 join_query(¶m).as_ref(),
438 consumer.secret.as_ref(),
439 token.map(|t| t.secret.as_ref()),
440 );
441 let _ = insert_param(&mut param, "oauth_signature", sign);
442
443 (header(¶m), body(¶m))
444}
445
446pub fn authorization_header(
462 method: &str,
463 uri: &str,
464 consumer: &Token<'_>,
465 token: Option<&Token<'_>>,
466 other_param: Option<&ParamList<'_>>,
467) -> (String, String) {
468 get_header(method, uri, consumer, token, other_param)
469}
470
471pub fn get<RB: RequestBuilder>(
483 uri: &str,
484 consumer: &Token<'_>,
485 token: Option<&Token<'_>>,
486 other_param: Option<&ParamList<'_>>,
487 client: &RB::ClientBuilder,
488) -> Result<RB::ReturnValue, Error> {
489 let (header, body) = get_header("GET", uri, consumer, token, other_param);
490 let req_uri = if !body.is_empty() {
491 format!("{}?{}", uri, body)
492 } else {
493 uri.to_string()
494 };
495
496 let rsp = RB::new(http::Method::GET, &req_uri, client)
497 .header(AUTHORIZATION, header)
498 .send()?;
499 Ok(rsp)
500}
501
502pub fn post<RB: RequestBuilder>(
515 uri: &str,
516 consumer: &Token<'_>,
517 token: Option<&Token<'_>>,
518 other_param: Option<&ParamList<'_>>,
519 client: &RB::ClientBuilder,
520) -> Result<RB::ReturnValue, Error> {
521 let (header, body) = get_header("POST", uri, consumer, token, other_param);
522
523 RB::new(http::Method::POST, uri, client)
524 .body(body)
525 .header(AUTHORIZATION, header)
526 .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
527 .send()
528}
529
530#[cfg(feature = "reqwest-blocking")]
533#[cfg_attr(docsrs, doc(cfg(feature = "reqwest-blocking")))]
534#[derive(Debug)]
535pub struct DefaultRequestBuilder {
536 inner: reqwest::blocking::RequestBuilder,
537}
538
539#[cfg(feature = "reqwest-blocking")]
540#[cfg_attr(docsrs, doc(cfg(feature = "reqwest-blocking")))]
541impl RequestBuilder for DefaultRequestBuilder {
542 type ReturnValue = String;
543 type ClientBuilder = ();
544 fn new(method: http::Method, url: &'_ str, _: &Self::ClientBuilder) -> Self {
546 let rb = CLIENT.request(method, Url::from_str(url).unwrap());
547 Self { inner: rb }
548 }
549
550 fn body(mut self, b: String) -> Self {
551 self.inner = self.inner.body(b);
552
553 self
554 }
555
556 fn header<K, V>(mut self, key: K, val: V) -> Self
557 where
558 HeaderName: TryFrom<K>,
559 HeaderValue: TryFrom<V>,
560 <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
561 <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
562 {
563 self.inner = self.inner.header(key, val);
564
565 self
566 }
567
568 fn send(self) -> Result<Self::ReturnValue, Error> {
569 let mut response = self.inner.send()?;
570 if response.status() != StatusCode::OK {
571 return Err(Error::HttpStatus(response.status()));
572 }
573 let mut buf = String::with_capacity(200);
574 let _ = response.read_to_string(&mut buf)?;
575 Ok(buf)
576 }
577}
578
579pub trait RequestBuilder: Debug {
582 type ReturnValue;
585
586 type ClientBuilder;
588
589 fn new(method: http::Method, url: &'_ str, client: &Self::ClientBuilder) -> Self;
591
592 fn body(self, b: String) -> Self;
594
595 fn header<K, V>(self, key: K, val: V) -> Self
597 where
598 HeaderName: TryFrom<K>,
599 HeaderValue: TryFrom<V>,
600 <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
601 <HeaderValue as TryFrom<V>>::Error: Into<http::Error>;
602
603 fn send(self) -> Result<Self::ReturnValue, Error>
605 where
606 Self: Sized;
607}
608
609#[derive(Debug, Error)]
611pub enum ParseQueryError {
612 #[error("Not enough key value pairs provided. Was: {0}")]
615 NotEnoughPairs(usize),
616
617 #[error("One of the key value pairs was invalid.")]
619 InvalidKeyValuePair,
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625 use log::LevelFilter;
626 use std::collections::HashMap;
627
628 #[test]
629 fn parse_dont_sort_doesnt_sort() {
630 let input = "b=BBB&a=AAA";
631 let [(a_key, a), (b_key, b)] = parse_query_string(input, false, &["a", "b"]).unwrap();
632 assert_eq!(a_key, "b");
633 assert_eq!(b_key, "a");
634 assert_eq!(b, "AAA");
635 assert_eq!(a, "BBB");
636 }
637
638 #[test]
639 fn parse_sort_out_of_order() {
640 let input = "b=BBB&a=AAA";
641 let [(a_key, a), (b_key, b)] = parse_query_string(input, true, &["a", "b"]).unwrap();
642 assert_eq!(a_key, "a");
643 assert_eq!(b_key, "b");
644 assert_eq!(a, "AAA");
645 assert_eq!(b, "BBB");
646 }
647
648 #[test]
649 fn parse_sort_already_sorted() {
650 let input = "a=AAA&b=BBB";
651 let [(a_key, a), (b_key, b)] = parse_query_string(input, true, &["a", "b"]).unwrap();
652 assert_eq!(a_key, "a");
653 assert_eq!(b_key, "b");
654 assert_eq!(a, "AAA");
655 assert_eq!(b, "BBB");
656 }
657
658 #[test]
659 fn parse_invalid_keys() {
660 let input = "a=AAA&b=BBB";
661 match parse_query_string(input, true, &["a", "x"]) {
662 Ok(_) => panic!("Should error"),
663 Err(e) => match e {
664 ParseQueryError::NotEnoughPairs(_) => {}
665 _ => panic!("Wrong error"),
666 },
667 }
668 }
669
670 #[test]
671 fn parse_empty_string() {
672 let input = "";
673 assert_eq!("".split('&').collect::<Vec<_>>(), [""]);
674 assert_eq!("&".split('&').collect::<Vec<_>>(), ["", ""]);
675 assert_eq!(0, "".split_terminator('&').count());
676 match parse_query_string(input, true, &["a", "b"]) {
677 Ok(_) => panic!("Should error"),
678 Err(e) => match e {
679 ParseQueryError::NotEnoughPairs(_) => {}
680 _ => panic!("Wrong error"),
681 },
682 }
683 }
684
685 #[test]
686 fn parse_invalid_format() {
687 let input = "x&";
688 match parse_query_string(input, true, &["x"]) {
689 Ok(_) => panic!("Should error"),
690 Err(e) => match e {
691 ParseQueryError::InvalidKeyValuePair => {}
692 _ => panic!("Wrong error"),
693 },
694 }
695 }
696
697 #[test]
698 fn check_signature_request_test() {
699 simple_logger::SimpleLogger::new()
700 .with_level(LevelFilter::Trace)
701 .init()
702 .unwrap();
703 #[derive(Debug)]
704 struct DummyRequestBuilder(reqwest::RequestBuilder);
705
706 impl RequestBuilder for DummyRequestBuilder {
707 type ReturnValue = reqwest::Request;
708 type ClientBuilder = reqwest::Client;
709
710 fn new(method: http::Method, url: &'_ str, client: &Self::ClientBuilder) -> Self {
711 Self(client.request(method, url))
712 }
713 fn body(mut self, b: String) -> Self {
714 self.0 = self.0.body(b);
715
716 self
717 }
718 fn header<K, V>(mut self, key: K, val: V) -> Self
719 where
720 HeaderName: TryFrom<K>,
721 HeaderValue: TryFrom<V>,
722 <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
723 <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
724 {
725 self.0 = self.0.header(key, val);
726
727 self
728 }
729 fn send(self) -> Result<Self::ReturnValue, Error> {
730 let rv = self.0.build()?;
731 Ok(rv)
732 }
733 }
734 let client = reqwest::Client::new();
735 let token = Token::new("key", "secret");
736 let param_list = {
737 let mut hm: ParamList<'_> = HashMap::with_capacity(2);
738 assert!(hm
739 .insert(
740 "oauth_any_string".into(),
741 "gets_put_into_auth_header".into()
742 )
743 .is_none());
744 assert!(hm
745 .insert(
746 "doesnt_start_with_oauth".into(),
747 "doesnt_get_put_into_auth_header".into()
748 )
749 .is_none());
750 assert!(hm
751 .insert("oauth_callback".into(), "http://xd.xy?xy=xz&xd=xx".into())
752 .is_none());
753
754 hm
755 };
756 let request = get::<DummyRequestBuilder>(
757 "http://localhost/",
759 &token,
760 None,
761 Some(¶m_list),
762 &client,
763 )
764 .unwrap();
765 assert!(check_signature_request(request, &token.secret, None, |u| u.into()).unwrap());
766 }
767
768 #[test]
769 fn query() {
770 let mut map = HashMap::new();
771 let _ = map.insert("aaa".into(), "AAA".into());
772 let _ = map.insert("bbbb".into(), "BBBB".into());
773 let query = join_query(&map);
774 assert_eq!("aaa=AAA&bbbb=BBBB", query);
775 }
776
777 #[test]
778 fn test_encode() {
779 let method = "GET";
780 let uri = "http://oauthbin.com/v1/request-token";
781 let encoded_uri = "http%3A%2F%2Foauthbin.com%2Fv1%2Frequest-token";
782 let query = [
783 "oauth_consumer_key=key&",
784 "oauth_nonce=s6HGl3GhmsDsmpgeLo6lGtKs7rQEzzsA&",
785 "oauth_signature_method=HMAC-SHA1&",
786 "oauth_timestamp=1471445561&",
787 "oauth_version=1.0",
788 ]
789 .iter()
790 .cloned()
791 .collect::<String>();
792 let encoded_query = [
793 "oauth_consumer_key%3Dkey%26",
794 "oauth_nonce%3Ds6HGl3GhmsDsmpgeLo6lGtKs7rQEzzsA%26",
795 "oauth_signature_method%3DHMAC-SHA1%26",
796 "oauth_timestamp%3D1471445561%26",
797 "oauth_version%3D1.0",
798 ]
799 .iter()
800 .cloned()
801 .collect::<String>();
802
803 assert_eq!(encode(method), "GET");
804 assert_eq!(encode(uri), encoded_uri);
805 assert_eq!(encode(&query), encoded_query);
806 }
807}
808
809use log::warn;
810
811pub fn parse_query_string<'q, const N: usize>(
825 query_string: &'q str,
826 sort: bool,
827 keys: &[&str; N],
828) -> Result<[(&'q str, &'q str); N], ParseQueryError> {
829 let mut rv: [MaybeUninit<(&str, &str)>; N] = unsafe { MaybeUninit::uninit().assume_init() };
833
834 let mut num_inserted = 0;
835 for kv in query_string.split_terminator('&') {
836 let mut iter = kv.split_terminator('=');
837 let key = iter.next().ok_or(ParseQueryError::InvalidKeyValuePair)?;
838 let val = iter.next().ok_or(ParseQueryError::InvalidKeyValuePair)?;
839
840 if keys.contains(&key) {
841 rv[num_inserted] = MaybeUninit::new((key, val));
845 num_inserted += 1;
846 } else {
847 warn!("Unexpected key {:?}. (value {:?})", key, val);
848 }
849 }
850
851 if num_inserted < N {
852 return Err(ParseQueryError::NotEnoughPairs(num_inserted));
853 }
854
855 let mut rv: [(&str, &str); N] = unsafe { std::mem::transmute_copy(&rv) };
858
859 if sort {
860 rv.sort_unstable_by_key(|&(k, _v)| k);
862 }
863
864 Ok(rv)
865}