oauth_client/
lib.rs

1// Copyright 2016 oauth-client-rs Developers
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! OAuth 1.0 client library for Rust.
9//!
10//! [Repository](https://github.com/gifnksm/oauth-client-rs)
11//!
12//! # Examples
13//!
14//! Send request for request token.
15//!
16//! ```
17//! # use oauth_client::DefaultRequestBuilder;
18//! const REQUEST_TOKEN: &str = "http://oauthbin.com/v1/request-token";
19//! let consumer = oauth_client::Token::new("key", "secret");
20//! let bytes = oauth_client::get::<DefaultRequestBuilder>(REQUEST_TOKEN, &consumer, None, None, &()).unwrap();
21//! ```
22#![warn(bad_style)]
23#![warn(missing_docs)]
24#![warn(unused)]
25#![warn(unused_extern_crates)]
26#![warn(unused_import_braces)]
27#![warn(unused_qualifications)]
28#![warn(unused_results)]
29#![allow(unused_doc_comments)]
30#![cfg_attr(docsrs, feature(doc_cfg))]
31
32use http::{
33    header::{HeaderName, AUTHORIZATION, CONTENT_TYPE},
34    HeaderValue, StatusCode,
35};
36use log::debug;
37use rand::{distributions::Alphanumeric, Rng};
38use ring::hmac;
39use std::{borrow::Cow, collections::HashMap, convert::TryFrom, io, iter, mem::MaybeUninit};
40use thiserror::Error;
41use time::OffsetDateTime;
42#[cfg(all(feature = "reqwest-blocking"))]
43use ::{
44    lazy_static::lazy_static,
45    reqwest::blocking::Client,
46    std::{io::Read, str::FromStr},
47    url::Url,
48};
49
50/// Re-exporting `reqwest` crate.
51#[cfg(feature = "client-reqwest")]
52#[cfg_attr(docsrs, doc(cfg(feature = "client-reqwest")))]
53pub use reqwest;
54use std::fmt::Debug;
55
56/// Error type.
57#[derive(Debug, Error)]
58#[non_exhaustive]
59pub enum Error {
60    /// An error happening due to a HTTP status error.
61    #[error("HTTP status error code: {0}")]
62    HttpStatus(StatusCode),
63
64    /// An error happening due to a IO error.
65    #[error("IO error: {0}")]
66    Io(#[from] io::Error),
67
68    /// An error happening due to a HTTP request error.
69    #[error("HTTP request error: {0}")]
70    HttpRequest(Box<dyn std::error::Error + Send + Sync + 'static>),
71}
72
73#[cfg(feature = "client-reqwest")]
74#[cfg_attr(docsrs, doc(cfg(feature = "client-reqwest")))]
75impl From<reqwest::Error> for Error {
76    fn from(e: reqwest::Error) -> Self {
77        Self::HttpRequest(e.into())
78    }
79}
80
81#[cfg(feature = "reqwest-blocking")]
82lazy_static! {
83    static ref CLIENT: Client = Client::new();
84}
85
86/// Token structure for the OAuth
87#[derive(Clone, Debug)]
88pub struct Token<'a> {
89    /// 'key' field of the token
90    pub key: Cow<'a, str>,
91    /// 'secret' part of the token
92    pub secret: Cow<'a, str>,
93}
94
95impl<'a> Token<'a> {
96    /// Create new token from `key` and `secret`
97    ///
98    /// # Examples
99    ///
100    /// ```
101    /// let consumer = oauth_client::Token::new("key", "secret");
102    /// ```
103    pub fn new<K, S>(key: K, secret: S) -> Token<'a>
104    where
105        K: Into<Cow<'a, str>>,
106        S: Into<Cow<'a, str>>,
107    {
108        Token {
109            key: key.into(),
110            secret: secret.into(),
111        }
112    }
113}
114
115/// Alias for `HashMap<Cow<'a, str>, Cow<'a, str>>`
116pub type ParamList<'a> = HashMap<Cow<'a, str>, Cow<'a, str>>;
117
118fn insert_param<'a, K, V>(param: &mut ParamList<'a>, key: K, value: V) -> Option<Cow<'a, str>>
119where
120    K: Into<Cow<'a, str>>,
121    V: Into<Cow<'a, str>>,
122{
123    param.insert(key.into(), value.into())
124}
125
126fn join_query(param: &ParamList<'_>) -> String {
127    let mut pairs = param
128        .iter()
129        .map(|(k, v)| format!("{}={}", encode(k), encode(v)))
130        .collect::<Vec<_>>();
131    pairs.sort();
132    pairs.join("&")
133}
134
135// Encode all but the unreserved characters defined in
136// RFC 3986, section 2.3. "Unreserved Characters"
137// https://tools.ietf.org/html/rfc3986#page-12
138//
139// This is required by
140// OAuth Core 1.0, section 5.1. "Parameter Encoding"
141// https://oauth.net/core/1.0/#encoding_parameters
142const STRICT_ENCODE_SET: percent_encoding::AsciiSet = percent_encoding::NON_ALPHANUMERIC
143    .remove(b'-')
144    .remove(b'.')
145    .remove(b'_')
146    .remove(b'~');
147
148use self::percent_encode_string as encode;
149/// Percent-encode the string in the manner defined in RFC 3986
150pub fn percent_encode_string(s: &str) -> Cow<str> {
151    percent_encoding::percent_encode(s.as_bytes(), &STRICT_ENCODE_SET).collect()
152}
153
154/// Create signature. See <https://dev.twitter.com/oauth/overview/creating-signatures>
155pub fn signature(
156    method: &str,
157    uri: &str,
158    query: &str,
159    consumer_secret: &str,
160    token_secret: Option<&str>,
161) -> String {
162    let base = format!("{}&{}&{}", encode(method), encode(uri), encode(query));
163    let key = format!(
164        "{}&{}",
165        encode(consumer_secret),
166        encode(token_secret.unwrap_or(""))
167    );
168    debug!("Signature base string: {}", base);
169    debug!("Authorization header: Authorization: {}", base);
170    let signing_key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, key.as_bytes());
171    let signature = hmac::sign(&signing_key, base.as_bytes());
172    base64::encode(signature.as_ref())
173}
174
175/// Things that can go wrong while verifying a request's signature
176#[derive(Debug, Error)]
177#[non_exhaustive]
178pub enum VerifyError {
179    /// No authorization header
180    #[error("Authorization header not found")]
181    NoAuthorizationHeader,
182
183    /// Invalid header
184    #[error("Non ASCII values in header: {0}")]
185    NonAsciiHeader(#[source] http::header::ToStrError),
186
187    /// Invalid params
188    #[error("Invalid key value pair in query params")]
189    InvalidKeyValuePair,
190}
191
192/// Generic request type. Allows you to pass any [`reqwest::Request`]-like object.
193/// You're gonna need to wrap whatever client's `Request` type you're using in your own
194/// type, as the orphan rules won't allow you to `impl` this trait.
195pub trait GenericRequest {
196    /// Headers
197    fn headers(&self) -> &http::header::HeaderMap<HeaderValue>;
198
199    /// Url
200    fn url(&self) -> &str;
201
202    /// Method.
203    fn method(&self) -> &str;
204}
205
206#[cfg(feature = "client-reqwest")]
207#[cfg_attr(docsrs, doc(cfg(feature = "client-reqwest")))]
208impl GenericRequest for reqwest::Request {
209    fn headers(&self) -> &http::header::HeaderMap<HeaderValue> {
210        self.headers()
211    }
212
213    fn url(&self) -> &str {
214        if let Some(host) = self.headers().get("host") {
215            // Host field actually contains the url used to connect, while the url provided
216            // by the request can be inaccurate
217            host.to_str().unwrap_or_else(|_e| self.url().as_str())
218        } else {
219            self.url().as_str()
220        }
221    }
222
223    fn method(&self) -> &str {
224        self.method().as_str()
225    }
226}
227
228/// Verifies that the provided request's signature is valid.
229/// The `url_middleware` argument allows you to modify the url before it's used to calculate the
230/// signature. This could be useful for tests, where there can be multiple `localhost` urls.
231///
232/// # Examples
233///
234/// ```
235/// # use std::borrow::Cow;
236/// # use oauth_client::{Error, RequestBuilder, Token};
237/// # use oauth_client::reqwest::header::{HeaderName, HeaderValue};
238/// # use std::convert::TryFrom;
239/// #[derive(Debug)]
240/// struct DummyRequestBuilder(reqwest::RequestBuilder);
241///
242/// impl RequestBuilder for DummyRequestBuilder {
243///     type ReturnValue = reqwest::Request;
244///     type ClientBuilder = reqwest::Client;
245///
246///     fn new(method: http::Method, url: &'_ str, client: &Self::ClientBuilder) -> Self {
247///         Self(client.request(method, url))
248///     }
249///     fn body(mut self, b: String) -> Self {
250///         self.0 = self.0.body(b); self
251///     }
252///     fn header<K, V>(mut self, key: K, val: V) -> Self
253///         where
254///             HeaderName: TryFrom<K>,
255///             HeaderValue: TryFrom<V>,
256///             <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
257///             <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
258///     {
259///         self.0 = self.0.header(key, val); self
260///     }
261///     fn send(self) -> Result<Self::ReturnValue, Error> {
262///         Ok(self.0.build()?)
263///     }
264/// }
265/// let client = reqwest::Client::new();
266/// let token = Token::new("key", "secret");
267/// let request = oauth_client::get::<DummyRequestBuilder>(
268///     "http://localhost/",
269///     &token,
270///     None,
271///     None,
272///     &client,
273/// ).unwrap();
274/// assert!(
275///     oauth_client::check_signature_request(request, &token.secret, None, |u| Cow::from(u)).unwrap(),
276///     "Invalid signature"
277/// );
278/// ```
279pub fn check_signature_request<R: GenericRequest>(
280    request: R,
281    consumer_secret: &str,
282    token_secret: Option<&str>,
283    url_middleware: impl for<'a> FnOnce(&'a str) -> Cow<'a, str>,
284) -> Result<bool, VerifyError> {
285    let authorization_header = request
286        .headers()
287        .get("Authorization")
288        .ok_or(VerifyError::NoAuthorizationHeader)?;
289
290    let (provided_signature, mut auth_params_without_signature): (Vec<&str>, Vec<&str>) =
291        authorization_header
292            .to_str()
293            .map_err(VerifyError::NonAsciiHeader)?
294            .split(',')
295            .map(str::trim)
296            .partition(|x| x.starts_with("oauth_signature="));
297
298    assert_eq!(
299        provided_signature.len(),
300        1,
301        "provided_signature: {:?}",
302        provided_signature
303    );
304    let provided_signature = provided_signature.first().unwrap();
305    let all_other_max = auth_params_without_signature.len() - 1;
306    let mut all_together_max = all_other_max;
307    let mut query_params = None;
308    let mut url = request.url();
309    if let Some(qm_i) = request.url().rfind('?') {
310        // Strip query params from url
311        url = &url[..qm_i];
312        let qp = request.url()[qm_i + 1..].split('&').collect::<Vec<&str>>();
313        all_together_max += qp.len();
314        query_params = Some(qp);
315    }
316    fn split_key_value_pair(qp: &str) -> Result<(&str, &str), VerifyError> {
317        qp.split_once('=')
318            .ok_or(VerifyError::InvalidKeyValuePair)
319            .map(|(k, v)| (k, v))
320    }
321    // First one starts with "OAuth oauth_callback=..."
322    auth_params_without_signature[0] = &auth_params_without_signature[0]["OAuth ".len()..];
323    let query: Result<Vec<(&str, &str)>, VerifyError> = auth_params_without_signature
324        .into_iter()
325        .map(|qp|
326            split_key_value_pair(qp).map(|(k,v)| (k, &v[1..v.len()-1]))
327        )
328        // Append the query from URL params at the end
329        .chain(
330            if let Some(query_params) = query_params.take() {
331                query_params.into_iter()
332            } else {
333                Vec::new().into_iter()
334            }.map(split_key_value_pair)
335        )
336        .collect();
337    let mut query = query?;
338    query.sort_by(|(a, _), (b, _)| a.cmp(b));
339
340    let query: String = query
341        .iter()
342        .enumerate()
343        .flat_map(|(i, (k, v))| [k, "=", v, if i == all_together_max { &"" } else { &"&" }])
344        .collect();
345
346    // Fix the url provided by reqwest::Request, e.g. being `localhost` instead of `127.0.0.1`
347    let url = url_middleware(url);
348
349    return Ok(check_signature(
350        &provided_signature["oauth_signature=\"".len()..provided_signature.len() - 1],
351        request.method(),
352        &url,
353        &query,
354        consumer_secret,
355        token_secret,
356    ));
357}
358
359/// Checks if the signature created by the given request data is the same
360/// as the provided signature.
361///
362/// See [`check_signature_request`] for a function that automatically
363/// does this with any [`GenericRequest`]
364pub fn check_signature(
365    signature_to_check: &str,
366    method: &str,
367    uri: &str,
368    query: &str,
369    consumer_secret: &str,
370    token_secret: Option<&str>,
371) -> bool {
372    let signature = signature(method, uri, query, consumer_secret, token_secret);
373    let new_encoded_signature = encode(&signature);
374
375    new_encoded_signature == signature_to_check
376}
377
378/// Construct plain-text header.
379///
380/// See https://datatracker.ietf.org/doc/html/rfc5849#section-3.5.1
381fn header(param: &ParamList<'_>) -> String {
382    let mut pairs = param
383        .iter()
384        .filter(|&(k, _)| k.starts_with("oauth_"))
385        .map(|(k, v)| format!("{}=\"{}\"", k, encode(v)))
386        .collect::<Vec<_>>();
387    pairs.sort();
388    format!("OAuth {}", pairs.join(", "))
389}
390
391/// Construct plain-text body from 'ParamList'
392fn body(param: &ParamList<'_>) -> String {
393    let mut pairs = param
394        .iter()
395        .filter(|&(k, _)| !k.starts_with("oauth_"))
396        .map(|(k, v)| format!("{}={}", k, encode(v)))
397        .collect::<Vec<_>>();
398    pairs.sort();
399    pairs.join("&")
400}
401
402/// Create header and body
403fn get_header(
404    method: &str,
405    uri: &str,
406    consumer: &Token<'_>,
407    token: Option<&Token<'_>>,
408    other_param: Option<&ParamList<'_>>,
409) -> (String, String) {
410    let mut param = HashMap::new();
411    let timestamp = format!("{}", OffsetDateTime::now_utc().unix_timestamp());
412    let mut rng = rand::thread_rng();
413    let nonce = iter::repeat(())
414        .map(|()| rng.sample(Alphanumeric))
415        .map(char::from)
416        .take(32)
417        .collect::<String>();
418
419    let _ = insert_param(&mut param, "oauth_consumer_key", consumer.key.to_string());
420    let _ = insert_param(&mut param, "oauth_nonce", nonce);
421    let _ = insert_param(&mut param, "oauth_signature_method", "HMAC-SHA1");
422    let _ = insert_param(&mut param, "oauth_timestamp", timestamp);
423    let _ = insert_param(&mut param, "oauth_version", "1.0");
424    if let Some(tk) = token {
425        let _ = insert_param(&mut param, "oauth_token", tk.key.as_ref());
426    }
427
428    if let Some(ps) = other_param {
429        for (k, v) in ps.iter() {
430            let _ = insert_param(&mut param, k.as_ref(), v.as_ref());
431        }
432    }
433
434    let sign = signature(
435        method,
436        uri,
437        join_query(&param).as_ref(),
438        consumer.secret.as_ref(),
439        token.map(|t| t.secret.as_ref()),
440    );
441    let _ = insert_param(&mut param, "oauth_signature", sign);
442
443    (header(&param), body(&param))
444}
445
446/// Create an authorization header.
447/// See <https://dev.twitter.com/oauth/overview/authorizing-requests>
448///
449/// # Examples
450///
451/// ```
452/// # extern crate oauth_client;
453/// # fn main() {
454/// const REQUEST_TOKEN: &str = "http://oauthbin.com/v1/request-token";
455/// let consumer = oauth_client::Token::new("key", "secret");
456/// let header = oauth_client::authorization_header(
457///   "GET", REQUEST_TOKEN, &consumer, None, None
458/// );
459/// # }
460/// ```
461pub fn authorization_header(
462    method: &str,
463    uri: &str,
464    consumer: &Token<'_>,
465    token: Option<&Token<'_>>,
466    other_param: Option<&ParamList<'_>>,
467) -> (String, String) {
468    get_header(method, uri, consumer, token, other_param)
469}
470
471/// Send authorized GET request to the specified URL.
472/// `consumer` is a consumer token.
473///
474/// # Examples
475///
476/// ```
477/// # use oauth_client::DefaultRequestBuilder;
478/// let REQUEST_TOKEN: &str = "http://oauthbin.com/v1/request-token";
479/// let consumer = oauth_client::Token::new("key", "secret");
480/// let resp = oauth_client::get::<DefaultRequestBuilder>(REQUEST_TOKEN, &consumer, None, None, &()).unwrap();
481/// ```
482pub fn get<RB: RequestBuilder>(
483    uri: &str,
484    consumer: &Token<'_>,
485    token: Option<&Token<'_>>,
486    other_param: Option<&ParamList<'_>>,
487    client: &RB::ClientBuilder,
488) -> Result<RB::ReturnValue, Error> {
489    let (header, body) = get_header("GET", uri, consumer, token, other_param);
490    let req_uri = if !body.is_empty() {
491        format!("{}?{}", uri, body)
492    } else {
493        uri.to_string()
494    };
495
496    let rsp = RB::new(http::Method::GET, &req_uri, client)
497        .header(AUTHORIZATION, header)
498        .send()?;
499    Ok(rsp)
500}
501
502/// Send authorized POST request to the specified URL.
503/// `consumer` is a consumer token.
504///
505/// # Examples
506///
507/// ```
508/// # use oauth_client::DefaultRequestBuilder;
509/// let request = oauth_client::Token::new("key", "secret");
510/// let ACCESS_TOKEN: &'static str = "https://oauthbin.com/v1/access-token";
511/// let consumer = oauth_client::Token::new("key", "secret");
512/// let resp = oauth_client::post::<DefaultRequestBuilder>(ACCESS_TOKEN, &consumer, Some(&request), None, &()).unwrap();
513/// ```
514pub fn post<RB: RequestBuilder>(
515    uri: &str,
516    consumer: &Token<'_>,
517    token: Option<&Token<'_>>,
518    other_param: Option<&ParamList<'_>>,
519    client: &RB::ClientBuilder,
520) -> Result<RB::ReturnValue, Error> {
521    let (header, body) = get_header("POST", uri, consumer, token, other_param);
522
523    RB::new(http::Method::POST, uri, client)
524        .body(body)
525        .header(AUTHORIZATION, header)
526        .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
527        .send()
528}
529
530/// Default one to use if you're not using a custom HTTP Client
531/// and are ok with bundling reqwest
532#[cfg(feature = "reqwest-blocking")]
533#[cfg_attr(docsrs, doc(cfg(feature = "reqwest-blocking")))]
534#[derive(Debug)]
535pub struct DefaultRequestBuilder {
536    inner: reqwest::blocking::RequestBuilder,
537}
538
539#[cfg(feature = "reqwest-blocking")]
540#[cfg_attr(docsrs, doc(cfg(feature = "reqwest-blocking")))]
541impl RequestBuilder for DefaultRequestBuilder {
542    type ReturnValue = String;
543    type ClientBuilder = ();
544    /// If the url is wrong then it will fail only during send
545    fn new(method: http::Method, url: &'_ str, _: &Self::ClientBuilder) -> Self {
546        let rb = CLIENT.request(method, Url::from_str(url).unwrap());
547        Self { inner: rb }
548    }
549
550    fn body(mut self, b: String) -> Self {
551        self.inner = self.inner.body(b);
552
553        self
554    }
555
556    fn header<K, V>(mut self, key: K, val: V) -> Self
557    where
558        HeaderName: TryFrom<K>,
559        HeaderValue: TryFrom<V>,
560        <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
561        <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
562    {
563        self.inner = self.inner.header(key, val);
564
565        self
566    }
567
568    fn send(self) -> Result<Self::ReturnValue, Error> {
569        let mut response = self.inner.send()?;
570        if response.status() != StatusCode::OK {
571            return Err(Error::HttpStatus(response.status()));
572        }
573        let mut buf = String::with_capacity(200);
574        let _ = response.read_to_string(&mut buf)?;
575        Ok(buf)
576    }
577}
578
579/// A generic request builder. Allows you to use any HTTP client.
580/// See [`DefaultRequestBuilder`] for one that uses [`reqwest::Client`].
581pub trait RequestBuilder: Debug {
582    /// Generic return value allows you to return a future, allowing the possibility
583    /// of using this library in `async` environments.
584    type ReturnValue;
585
586    /// This is useful for reusing existing connection pools.
587    type ClientBuilder;
588
589    /// Construct the request builder
590    fn new(method: http::Method, url: &'_ str, client: &Self::ClientBuilder) -> Self;
591
592    /// Set the body
593    fn body(self, b: String) -> Self;
594
595    /// Set a header
596    fn header<K, V>(self, key: K, val: V) -> Self
597    where
598        HeaderName: TryFrom<K>,
599        HeaderValue: TryFrom<V>,
600        <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
601        <HeaderValue as TryFrom<V>>::Error: Into<http::Error>;
602
603    /// A `build`-like function that also sends the request
604    fn send(self) -> Result<Self::ReturnValue, Error>
605    where
606        Self: Sized;
607}
608
609/// Errors possible while using [`parse_query_string`].
610#[derive(Debug, Error)]
611pub enum ParseQueryError {
612    /// You provided more keys than there actually were to parse.
613    /// Empty string.
614    #[error("Not enough key value pairs provided. Was: {0}")]
615    NotEnoughPairs(usize),
616
617    /// Lacks an `=`, or nothing after the `=` sign in some key value pair.
618    #[error("One of the key value pairs was invalid.")]
619    InvalidKeyValuePair,
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625    use log::LevelFilter;
626    use std::collections::HashMap;
627
628    #[test]
629    fn parse_dont_sort_doesnt_sort() {
630        let input = "b=BBB&a=AAA";
631        let [(a_key, a), (b_key, b)] = parse_query_string(input, false, &["a", "b"]).unwrap();
632        assert_eq!(a_key, "b");
633        assert_eq!(b_key, "a");
634        assert_eq!(b, "AAA");
635        assert_eq!(a, "BBB");
636    }
637
638    #[test]
639    fn parse_sort_out_of_order() {
640        let input = "b=BBB&a=AAA";
641        let [(a_key, a), (b_key, b)] = parse_query_string(input, true, &["a", "b"]).unwrap();
642        assert_eq!(a_key, "a");
643        assert_eq!(b_key, "b");
644        assert_eq!(a, "AAA");
645        assert_eq!(b, "BBB");
646    }
647
648    #[test]
649    fn parse_sort_already_sorted() {
650        let input = "a=AAA&b=BBB";
651        let [(a_key, a), (b_key, b)] = parse_query_string(input, true, &["a", "b"]).unwrap();
652        assert_eq!(a_key, "a");
653        assert_eq!(b_key, "b");
654        assert_eq!(a, "AAA");
655        assert_eq!(b, "BBB");
656    }
657
658    #[test]
659    fn parse_invalid_keys() {
660        let input = "a=AAA&b=BBB";
661        match parse_query_string(input, true, &["a", "x"]) {
662            Ok(_) => panic!("Should error"),
663            Err(e) => match e {
664                ParseQueryError::NotEnoughPairs(_) => {}
665                _ => panic!("Wrong error"),
666            },
667        }
668    }
669
670    #[test]
671    fn parse_empty_string() {
672        let input = "";
673        assert_eq!("".split('&').collect::<Vec<_>>(), [""]);
674        assert_eq!("&".split('&').collect::<Vec<_>>(), ["", ""]);
675        assert_eq!(0, "".split_terminator('&').count());
676        match parse_query_string(input, true, &["a", "b"]) {
677            Ok(_) => panic!("Should error"),
678            Err(e) => match e {
679                ParseQueryError::NotEnoughPairs(_) => {}
680                _ => panic!("Wrong error"),
681            },
682        }
683    }
684
685    #[test]
686    fn parse_invalid_format() {
687        let input = "x&";
688        match parse_query_string(input, true, &["x"]) {
689            Ok(_) => panic!("Should error"),
690            Err(e) => match e {
691                ParseQueryError::InvalidKeyValuePair => {}
692                _ => panic!("Wrong error"),
693            },
694        }
695    }
696
697    #[test]
698    fn check_signature_request_test() {
699        simple_logger::SimpleLogger::new()
700            .with_level(LevelFilter::Trace)
701            .init()
702            .unwrap();
703        #[derive(Debug)]
704        struct DummyRequestBuilder(reqwest::RequestBuilder);
705
706        impl RequestBuilder for DummyRequestBuilder {
707            type ReturnValue = reqwest::Request;
708            type ClientBuilder = reqwest::Client;
709
710            fn new(method: http::Method, url: &'_ str, client: &Self::ClientBuilder) -> Self {
711                Self(client.request(method, url))
712            }
713            fn body(mut self, b: String) -> Self {
714                self.0 = self.0.body(b);
715
716                self
717            }
718            fn header<K, V>(mut self, key: K, val: V) -> Self
719            where
720                HeaderName: TryFrom<K>,
721                HeaderValue: TryFrom<V>,
722                <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
723                <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
724            {
725                self.0 = self.0.header(key, val);
726
727                self
728            }
729            fn send(self) -> Result<Self::ReturnValue, Error> {
730                let rv = self.0.build()?;
731                Ok(rv)
732            }
733        }
734        let client = reqwest::Client::new();
735        let token = Token::new("key", "secret");
736        let param_list = {
737            let mut hm: ParamList<'_> = HashMap::with_capacity(2);
738            assert!(hm
739                .insert(
740                    "oauth_any_string".into(),
741                    "gets_put_into_auth_header".into()
742                )
743                .is_none());
744            assert!(hm
745                .insert(
746                    "doesnt_start_with_oauth".into(),
747                    "doesnt_get_put_into_auth_header".into()
748                )
749                .is_none());
750            assert!(hm
751                .insert("oauth_callback".into(), "http://xd.xy?xy=xz&xd=xx".into())
752                .is_none());
753
754            hm
755        };
756        let request = get::<DummyRequestBuilder>(
757            // FIXME: Trailing slash important, otherwise it fails, dunno how to fix
758            "http://localhost/",
759            &token,
760            None,
761            Some(&param_list),
762            &client,
763        )
764        .unwrap();
765        assert!(check_signature_request(request, &token.secret, None, |u| u.into()).unwrap());
766    }
767
768    #[test]
769    fn query() {
770        let mut map = HashMap::new();
771        let _ = map.insert("aaa".into(), "AAA".into());
772        let _ = map.insert("bbbb".into(), "BBBB".into());
773        let query = join_query(&map);
774        assert_eq!("aaa=AAA&bbbb=BBBB", query);
775    }
776
777    #[test]
778    fn test_encode() {
779        let method = "GET";
780        let uri = "http://oauthbin.com/v1/request-token";
781        let encoded_uri = "http%3A%2F%2Foauthbin.com%2Fv1%2Frequest-token";
782        let query = [
783            "oauth_consumer_key=key&",
784            "oauth_nonce=s6HGl3GhmsDsmpgeLo6lGtKs7rQEzzsA&",
785            "oauth_signature_method=HMAC-SHA1&",
786            "oauth_timestamp=1471445561&",
787            "oauth_version=1.0",
788        ]
789        .iter()
790        .cloned()
791        .collect::<String>();
792        let encoded_query = [
793            "oauth_consumer_key%3Dkey%26",
794            "oauth_nonce%3Ds6HGl3GhmsDsmpgeLo6lGtKs7rQEzzsA%26",
795            "oauth_signature_method%3DHMAC-SHA1%26",
796            "oauth_timestamp%3D1471445561%26",
797            "oauth_version%3D1.0",
798        ]
799        .iter()
800        .cloned()
801        .collect::<String>();
802
803        assert_eq!(encode(method), "GET");
804        assert_eq!(encode(uri), encoded_uri);
805        assert_eq!(encode(&query), encoded_query);
806    }
807}
808
809use log::warn;
810
811/// Utility function to parse the `Authorization` header from an HTTP request.
812///
813/// Assumptions:
814/// 1. Keys are distinct
815///
816/// Arguments:
817/// 1. Key to search
818/// 2. Whether to sort the return value (for reproducibility).
819///
820///    Set to true if in doubt. If the server changes its order of arguments you'll be fine.
821///
822/// 3. The names of the keys. (If put more than existing, or invalid then error might happen
823///    because we are looking for all provided keys.)
824pub fn parse_query_string<'q, const N: usize>(
825    query_string: &'q str,
826    sort: bool,
827    keys: &[&str; N],
828) -> Result<[(&'q str, &'q str); N], ParseQueryError> {
829    // Create an uninitialized array of `MaybeUninit`. The `assume_init` is
830    // safe because the type we are claiming to have initialized here is a
831    // bunch of `MaybeUninit`s, which do not require initialization.
832    let mut rv: [MaybeUninit<(&str, &str)>; N] = unsafe { MaybeUninit::uninit().assume_init() };
833
834    let mut num_inserted = 0;
835    for kv in query_string.split_terminator('&') {
836        let mut iter = kv.split_terminator('=');
837        let key = iter.next().ok_or(ParseQueryError::InvalidKeyValuePair)?;
838        let val = iter.next().ok_or(ParseQueryError::InvalidKeyValuePair)?;
839
840        if keys.contains(&key) {
841            // Dropping a `MaybeUninit` does nothing. Thus using element
842            // assignment instead of `ptr::write` does not cause the old
843            // uninitialized value to be dropped.
844            rv[num_inserted] = MaybeUninit::new((key, val));
845            num_inserted += 1;
846        } else {
847            warn!("Unexpected key {:?}. (value {:?})", key, val);
848        }
849    }
850
851    if num_inserted < N {
852        return Err(ParseQueryError::NotEnoughPairs(num_inserted));
853    }
854
855    // Everything is initialized. Transmute the array to the
856    // initialized type.
857    let mut rv: [(&str, &str); N] = unsafe { std::mem::transmute_copy(&rv) };
858
859    if sort {
860        // NOTE: Assumption: keys are distinct
861        rv.sort_unstable_by_key(|&(k, _v)| k);
862    }
863
864    Ok(rv)
865}