#![warn(bad_style)]
#![warn(missing_docs)]
#![warn(unused)]
#![warn(unused_extern_crates)]
#![warn(unused_import_braces)]
#![warn(unused_qualifications)]
#![warn(unused_results)]
#![allow(unused_doc_comments)]
#![cfg_attr(docsrs, feature(doc_cfg))]
use http::{
header::{HeaderName, AUTHORIZATION, CONTENT_TYPE},
HeaderValue, StatusCode,
};
use log::debug;
use rand::{distributions::Alphanumeric, Rng};
use ring::hmac;
use std::{borrow::Cow, collections::HashMap, convert::TryFrom, io, iter, mem::MaybeUninit};
use thiserror::Error;
use time::OffsetDateTime;
#[cfg(all(feature = "reqwest-blocking"))]
use ::{
lazy_static::lazy_static,
reqwest::blocking::Client,
std::{io::Read, str::FromStr},
url::Url,
};
#[cfg(feature = "client-reqwest")]
#[cfg_attr(docsrs, doc(cfg(feature = "client-reqwest")))]
pub use reqwest;
use std::fmt::Debug;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum Error {
#[error("HTTP status error code: {0}")]
HttpStatus(StatusCode),
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("HTTP request error: {0}")]
HttpRequest(Box<dyn std::error::Error + Send + Sync + 'static>),
}
#[cfg(feature = "client-reqwest")]
#[cfg_attr(docsrs, doc(cfg(feature = "client-reqwest")))]
impl From<reqwest::Error> for Error {
fn from(e: reqwest::Error) -> Self {
Self::HttpRequest(e.into())
}
}
#[cfg(feature = "reqwest-blocking")]
lazy_static! {
static ref CLIENT: Client = Client::new();
}
#[derive(Clone, Debug)]
pub struct Token<'a> {
pub key: Cow<'a, str>,
pub secret: Cow<'a, str>,
}
impl<'a> Token<'a> {
pub fn new<K, S>(key: K, secret: S) -> Token<'a>
where
K: Into<Cow<'a, str>>,
S: Into<Cow<'a, str>>,
{
Token {
key: key.into(),
secret: secret.into(),
}
}
}
pub type ParamList<'a> = HashMap<Cow<'a, str>, Cow<'a, str>>;
fn insert_param<'a, K, V>(param: &mut ParamList<'a>, key: K, value: V) -> Option<Cow<'a, str>>
where
K: Into<Cow<'a, str>>,
V: Into<Cow<'a, str>>,
{
param.insert(key.into(), value.into())
}
fn join_query(param: &ParamList<'_>) -> String {
let mut pairs = param
.iter()
.map(|(k, v)| format!("{}={}", encode(k), encode(v)))
.collect::<Vec<_>>();
pairs.sort();
pairs.join("&")
}
const STRICT_ENCODE_SET: percent_encoding::AsciiSet = percent_encoding::NON_ALPHANUMERIC
.remove(b'-')
.remove(b'.')
.remove(b'_')
.remove(b'~');
use self::percent_encode_string as encode;
pub fn percent_encode_string(s: &str) -> Cow<str> {
percent_encoding::percent_encode(s.as_bytes(), &STRICT_ENCODE_SET).collect()
}
pub fn signature(
method: &str,
uri: &str,
query: &str,
consumer_secret: &str,
token_secret: Option<&str>,
) -> String {
let base = format!("{}&{}&{}", encode(method), encode(uri), encode(query));
let key = format!(
"{}&{}",
encode(consumer_secret),
encode(token_secret.unwrap_or(""))
);
debug!("Signature base string: {}", base);
debug!("Authorization header: Authorization: {}", base);
let signing_key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, key.as_bytes());
let signature = hmac::sign(&signing_key, base.as_bytes());
base64::encode(signature.as_ref())
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum VerifyError {
#[error("Authorization header not found")]
NoAuthorizationHeader,
#[error("Non ASCII values in header: {0}")]
NonAsciiHeader(#[source] http::header::ToStrError),
#[error("Invalid key value pair in query params")]
InvalidKeyValuePair,
}
pub trait GenericRequest {
fn headers(&self) -> &http::header::HeaderMap<HeaderValue>;
fn url(&self) -> &str;
fn method(&self) -> &str;
}
#[cfg(feature = "client-reqwest")]
#[cfg_attr(docsrs, doc(cfg(feature = "client-reqwest")))]
impl GenericRequest for reqwest::Request {
fn headers(&self) -> &http::header::HeaderMap<HeaderValue> {
self.headers()
}
fn url(&self) -> &str {
if let Some(host) = self.headers().get("host") {
host.to_str().unwrap_or_else(|_e| self.url().as_str())
} else {
self.url().as_str()
}
}
fn method(&self) -> &str {
self.method().as_str()
}
}
pub fn check_signature_request<R: GenericRequest>(
request: R,
consumer_secret: &str,
token_secret: Option<&str>,
url_middleware: impl for<'a> FnOnce(&'a str) -> Cow<'a, str>,
) -> Result<bool, VerifyError> {
let authorization_header = request
.headers()
.get("Authorization")
.ok_or(VerifyError::NoAuthorizationHeader)?;
let (provided_signature, mut auth_params_without_signature): (Vec<&str>, Vec<&str>) =
authorization_header
.to_str()
.map_err(VerifyError::NonAsciiHeader)?
.split(',')
.map(str::trim)
.partition(|x| x.starts_with("oauth_signature="));
assert_eq!(
provided_signature.len(),
1,
"provided_signature: {:?}",
provided_signature
);
let provided_signature = provided_signature.first().unwrap();
let all_other_max = auth_params_without_signature.len() - 1;
let mut all_together_max = all_other_max;
let mut query_params = None;
let mut url = request.url();
if let Some(qm_i) = request.url().rfind('?') {
url = &url[..qm_i];
let qp = request.url()[qm_i + 1..].split('&').collect::<Vec<&str>>();
all_together_max += qp.len();
query_params = Some(qp);
}
fn split_key_value_pair(qp: &str) -> Result<(&str, &str), VerifyError> {
qp.split_once('=')
.ok_or(VerifyError::InvalidKeyValuePair)
.map(|(k, v)| (k, v))
}
auth_params_without_signature[0] = &auth_params_without_signature[0]["OAuth ".len()..];
let query: Result<Vec<(&str, &str)>, VerifyError> = auth_params_without_signature
.into_iter()
.map(|qp|
split_key_value_pair(qp).map(|(k,v)| (k, &v[1..v.len()-1]))
)
.chain(
if let Some(query_params) = query_params.take() {
query_params.into_iter()
} else {
Vec::new().into_iter()
}.map(split_key_value_pair)
)
.collect();
let mut query = query?;
query.sort_by(|(a, _), (b, _)| a.cmp(b));
let query: String = query
.iter()
.enumerate()
.flat_map(|(i, (k, v))| [k, "=", v, if i == all_together_max { &"" } else { &"&" }])
.collect();
let url = url_middleware(url);
return Ok(check_signature(
&provided_signature["oauth_signature=\"".len()..provided_signature.len() - 1],
request.method(),
&url,
&query,
consumer_secret,
token_secret,
));
}
pub fn check_signature(
signature_to_check: &str,
method: &str,
uri: &str,
query: &str,
consumer_secret: &str,
token_secret: Option<&str>,
) -> bool {
let signature = signature(method, uri, query, consumer_secret, token_secret);
let new_encoded_signature = encode(&signature);
new_encoded_signature == signature_to_check
}
fn header(param: &ParamList<'_>) -> String {
let mut pairs = param
.iter()
.filter(|&(k, _)| k.starts_with("oauth_"))
.map(|(k, v)| format!("{}=\"{}\"", k, encode(v)))
.collect::<Vec<_>>();
pairs.sort();
format!("OAuth {}", pairs.join(", "))
}
fn body(param: &ParamList<'_>) -> String {
let mut pairs = param
.iter()
.filter(|&(k, _)| !k.starts_with("oauth_"))
.map(|(k, v)| format!("{}={}", k, encode(v)))
.collect::<Vec<_>>();
pairs.sort();
pairs.join("&")
}
fn get_header(
method: &str,
uri: &str,
consumer: &Token<'_>,
token: Option<&Token<'_>>,
other_param: Option<&ParamList<'_>>,
) -> (String, String) {
let mut param = HashMap::new();
let timestamp = format!("{}", OffsetDateTime::now_utc().unix_timestamp());
let mut rng = rand::thread_rng();
let nonce = iter::repeat(())
.map(|()| rng.sample(Alphanumeric))
.map(char::from)
.take(32)
.collect::<String>();
let _ = insert_param(&mut param, "oauth_consumer_key", consumer.key.to_string());
let _ = insert_param(&mut param, "oauth_nonce", nonce);
let _ = insert_param(&mut param, "oauth_signature_method", "HMAC-SHA1");
let _ = insert_param(&mut param, "oauth_timestamp", timestamp);
let _ = insert_param(&mut param, "oauth_version", "1.0");
if let Some(tk) = token {
let _ = insert_param(&mut param, "oauth_token", tk.key.as_ref());
}
if let Some(ps) = other_param {
for (k, v) in ps.iter() {
let _ = insert_param(&mut param, k.as_ref(), v.as_ref());
}
}
let sign = signature(
method,
uri,
join_query(¶m).as_ref(),
consumer.secret.as_ref(),
token.map(|t| t.secret.as_ref()),
);
let _ = insert_param(&mut param, "oauth_signature", sign);
(header(¶m), body(¶m))
}
pub fn authorization_header(
method: &str,
uri: &str,
consumer: &Token<'_>,
token: Option<&Token<'_>>,
other_param: Option<&ParamList<'_>>,
) -> (String, String) {
get_header(method, uri, consumer, token, other_param)
}
pub fn get<RB: RequestBuilder>(
uri: &str,
consumer: &Token<'_>,
token: Option<&Token<'_>>,
other_param: Option<&ParamList<'_>>,
client: &RB::ClientBuilder,
) -> Result<RB::ReturnValue, Error> {
let (header, body) = get_header("GET", uri, consumer, token, other_param);
let req_uri = if !body.is_empty() {
format!("{}?{}", uri, body)
} else {
uri.to_string()
};
let rsp = RB::new(http::Method::GET, &req_uri, client)
.header(AUTHORIZATION, header)
.send()?;
Ok(rsp)
}
pub fn post<RB: RequestBuilder>(
uri: &str,
consumer: &Token<'_>,
token: Option<&Token<'_>>,
other_param: Option<&ParamList<'_>>,
client: &RB::ClientBuilder,
) -> Result<RB::ReturnValue, Error> {
let (header, body) = get_header("POST", uri, consumer, token, other_param);
RB::new(http::Method::POST, uri, client)
.body(body)
.header(AUTHORIZATION, header)
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.send()
}
#[cfg(feature = "reqwest-blocking")]
#[cfg_attr(docsrs, doc(cfg(feature = "reqwest-blocking")))]
#[derive(Debug)]
pub struct DefaultRequestBuilder {
inner: reqwest::blocking::RequestBuilder,
}
#[cfg(feature = "reqwest-blocking")]
#[cfg_attr(docsrs, doc(cfg(feature = "reqwest-blocking")))]
impl RequestBuilder for DefaultRequestBuilder {
type ReturnValue = String;
type ClientBuilder = ();
fn new(method: http::Method, url: &'_ str, _: &Self::ClientBuilder) -> Self {
let rb = CLIENT.request(method, Url::from_str(url).unwrap());
Self { inner: rb }
}
fn body(mut self, b: String) -> Self {
self.inner = self.inner.body(b);
self
}
fn header<K, V>(mut self, key: K, val: V) -> Self
where
HeaderName: TryFrom<K>,
HeaderValue: TryFrom<V>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.inner = self.inner.header(key, val);
self
}
fn send(self) -> Result<Self::ReturnValue, Error> {
let mut response = self.inner.send()?;
if response.status() != StatusCode::OK {
return Err(Error::HttpStatus(response.status()));
}
let mut buf = String::with_capacity(200);
let _ = response.read_to_string(&mut buf)?;
Ok(buf)
}
}
pub trait RequestBuilder: Debug {
type ReturnValue;
type ClientBuilder;
fn new(method: http::Method, url: &'_ str, client: &Self::ClientBuilder) -> Self;
fn body(self, b: String) -> Self;
fn header<K, V>(self, key: K, val: V) -> Self
where
HeaderName: TryFrom<K>,
HeaderValue: TryFrom<V>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>;
fn send(self) -> Result<Self::ReturnValue, Error>
where
Self: Sized;
}
#[derive(Debug, Error)]
pub enum ParseQueryError {
#[error("Not enough key value pairs provided. Was: {0}")]
NotEnoughPairs(usize),
#[error("One of the key value pairs was invalid.")]
InvalidKeyValuePair,
}
#[cfg(test)]
mod tests {
use super::*;
use log::LevelFilter;
use std::collections::HashMap;
#[test]
fn parse_dont_sort_doesnt_sort() {
let input = "b=BBB&a=AAA";
let [(a_key, a), (b_key, b)] = parse_query_string(input, false, &["a", "b"]).unwrap();
assert_eq!(a_key, "b");
assert_eq!(b_key, "a");
assert_eq!(b, "AAA");
assert_eq!(a, "BBB");
}
#[test]
fn parse_sort_out_of_order() {
let input = "b=BBB&a=AAA";
let [(a_key, a), (b_key, b)] = parse_query_string(input, true, &["a", "b"]).unwrap();
assert_eq!(a_key, "a");
assert_eq!(b_key, "b");
assert_eq!(a, "AAA");
assert_eq!(b, "BBB");
}
#[test]
fn parse_sort_already_sorted() {
let input = "a=AAA&b=BBB";
let [(a_key, a), (b_key, b)] = parse_query_string(input, true, &["a", "b"]).unwrap();
assert_eq!(a_key, "a");
assert_eq!(b_key, "b");
assert_eq!(a, "AAA");
assert_eq!(b, "BBB");
}
#[test]
fn parse_invalid_keys() {
let input = "a=AAA&b=BBB";
match parse_query_string(input, true, &["a", "x"]) {
Ok(_) => panic!("Should error"),
Err(e) => match e {
ParseQueryError::NotEnoughPairs(_) => {}
_ => panic!("Wrong error"),
},
}
}
#[test]
fn parse_empty_string() {
let input = "";
assert_eq!("".split('&').collect::<Vec<_>>(), [""]);
assert_eq!("&".split('&').collect::<Vec<_>>(), ["", ""]);
assert_eq!(0, "".split_terminator('&').count());
match parse_query_string(input, true, &["a", "b"]) {
Ok(_) => panic!("Should error"),
Err(e) => match e {
ParseQueryError::NotEnoughPairs(_) => {}
_ => panic!("Wrong error"),
},
}
}
#[test]
fn parse_invalid_format() {
let input = "x&";
match parse_query_string(input, true, &["x"]) {
Ok(_) => panic!("Should error"),
Err(e) => match e {
ParseQueryError::InvalidKeyValuePair => {}
_ => panic!("Wrong error"),
},
}
}
#[test]
fn check_signature_request_test() {
simple_logger::SimpleLogger::new()
.with_level(LevelFilter::Trace)
.init()
.unwrap();
#[derive(Debug)]
struct DummyRequestBuilder(reqwest::RequestBuilder);
impl RequestBuilder for DummyRequestBuilder {
type ReturnValue = reqwest::Request;
type ClientBuilder = reqwest::Client;
fn new(method: http::Method, url: &'_ str, client: &Self::ClientBuilder) -> Self {
Self(client.request(method, url))
}
fn body(mut self, b: String) -> Self {
self.0 = self.0.body(b);
self
}
fn header<K, V>(mut self, key: K, val: V) -> Self
where
HeaderName: TryFrom<K>,
HeaderValue: TryFrom<V>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.0 = self.0.header(key, val);
self
}
fn send(self) -> Result<Self::ReturnValue, Error> {
let rv = self.0.build()?;
Ok(rv)
}
}
let client = reqwest::Client::new();
let token = Token::new("key", "secret");
let param_list = {
let mut hm: ParamList<'_> = HashMap::with_capacity(2);
assert!(hm
.insert(
"oauth_any_string".into(),
"gets_put_into_auth_header".into()
)
.is_none());
assert!(hm
.insert(
"doesnt_start_with_oauth".into(),
"doesnt_get_put_into_auth_header".into()
)
.is_none());
assert!(hm
.insert("oauth_callback".into(), "http://xd.xy?xy=xz&xd=xx".into())
.is_none());
hm
};
let request = get::<DummyRequestBuilder>(
"http://localhost/",
&token,
None,
Some(¶m_list),
&client,
)
.unwrap();
assert!(check_signature_request(request, &token.secret, None, |u| u.into()).unwrap());
}
#[test]
fn query() {
let mut map = HashMap::new();
let _ = map.insert("aaa".into(), "AAA".into());
let _ = map.insert("bbbb".into(), "BBBB".into());
let query = join_query(&map);
assert_eq!("aaa=AAA&bbbb=BBBB", query);
}
#[test]
fn test_encode() {
let method = "GET";
let uri = "http://oauthbin.com/v1/request-token";
let encoded_uri = "http%3A%2F%2Foauthbin.com%2Fv1%2Frequest-token";
let query = [
"oauth_consumer_key=key&",
"oauth_nonce=s6HGl3GhmsDsmpgeLo6lGtKs7rQEzzsA&",
"oauth_signature_method=HMAC-SHA1&",
"oauth_timestamp=1471445561&",
"oauth_version=1.0",
]
.iter()
.cloned()
.collect::<String>();
let encoded_query = [
"oauth_consumer_key%3Dkey%26",
"oauth_nonce%3Ds6HGl3GhmsDsmpgeLo6lGtKs7rQEzzsA%26",
"oauth_signature_method%3DHMAC-SHA1%26",
"oauth_timestamp%3D1471445561%26",
"oauth_version%3D1.0",
]
.iter()
.cloned()
.collect::<String>();
assert_eq!(encode(method), "GET");
assert_eq!(encode(uri), encoded_uri);
assert_eq!(encode(&query), encoded_query);
}
}
use log::warn;
pub fn parse_query_string<'q, const N: usize>(
query_string: &'q str,
sort: bool,
keys: &[&str; N],
) -> Result<[(&'q str, &'q str); N], ParseQueryError> {
let mut rv: [MaybeUninit<(&str, &str)>; N] = unsafe { MaybeUninit::uninit().assume_init() };
let mut num_inserted = 0;
for kv in query_string.split_terminator('&') {
let mut iter = kv.split_terminator('=');
let key = iter.next().ok_or(ParseQueryError::InvalidKeyValuePair)?;
let val = iter.next().ok_or(ParseQueryError::InvalidKeyValuePair)?;
if keys.contains(&key) {
rv[num_inserted] = MaybeUninit::new((key, val));
num_inserted += 1;
} else {
warn!("Unexpected key {:?}. (value {:?})", key, val);
}
}
if num_inserted < N {
return Err(ParseQueryError::NotEnoughPairs(num_inserted));
}
let mut rv: [(&str, &str); N] = unsafe { std::mem::transmute_copy(&rv) };
if sort {
rv.sort_unstable_by_key(|&(k, _v)| k);
}
Ok(rv)
}