#![warn(bad_style)]
#![warn(missing_docs)]
#![warn(unused)]
#![warn(unused_extern_crates)]
#![warn(unused_import_braces)]
#![warn(unused_qualifications)]
#![warn(unused_results)]
#![allow(unused_doc_comments)]
use lazy_static::lazy_static;
use log::debug;
use rand::{distributions::Alphanumeric, Rng};
use reqwest::{
blocking::{Client, RequestBuilder},
header::{AUTHORIZATION, CONTENT_TYPE},
StatusCode,
};
use ring::hmac;
use std::{
borrow::Cow,
collections::HashMap,
io::{self, Read},
iter,
};
use thiserror::Error;
use time::OffsetDateTime;
pub type Result<T> = std::result::Result<T, Error>;
pub use reqwest;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum Error {
#[error("HTTP status error code: {0}")]
HttpStatus(StatusCode),
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
}
lazy_static! {
static ref CLIENT: Client = Client::new();
}
#[derive(Clone, Debug)]
pub struct Token<'a> {
pub key: Cow<'a, str>,
pub secret: Cow<'a, str>,
}
impl<'a> Token<'a> {
pub fn new<K, S>(key: K, secret: S) -> Token<'a>
where
K: Into<Cow<'a, str>>,
S: Into<Cow<'a, str>>,
{
Token {
key: key.into(),
secret: secret.into(),
}
}
}
pub type ParamList<'a> = HashMap<Cow<'a, str>, Cow<'a, str>>;
fn insert_param<'a, K, V>(param: &mut ParamList<'a>, key: K, value: V) -> Option<Cow<'a, str>>
where
K: Into<Cow<'a, str>>,
V: Into<Cow<'a, str>>,
{
param.insert(key.into(), value.into())
}
fn join_query(param: &ParamList<'_>) -> String {
let mut pairs = param
.iter()
.map(|(k, v)| format!("{}={}", encode(&k), encode(&v)))
.collect::<Vec<_>>();
pairs.sort();
pairs.join("&")
}
const STRICT_ENCODE_SET: percent_encoding::AsciiSet = percent_encoding::NON_ALPHANUMERIC
.remove(b'-')
.remove(b'.')
.remove(b'_')
.remove(b'~');
fn encode(s: &str) -> String {
percent_encoding::percent_encode(s.as_bytes(), &STRICT_ENCODE_SET).collect()
}
fn signature(
method: &str,
uri: &str,
query: &str,
consumer_secret: &str,
token_secret: Option<&str>,
) -> String {
let base = format!("{}&{}&{}", encode(method), encode(uri), encode(query));
let key = format!(
"{}&{}",
encode(consumer_secret),
encode(token_secret.unwrap_or(""))
);
debug!("Signature base string: {}", base);
debug!("Authorization header: Authorization: {}", base);
let signing_key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, key.as_bytes());
let signature = hmac::sign(&signing_key, base.as_bytes());
base64::encode(signature.as_ref())
}
fn header(param: &ParamList<'_>) -> String {
let mut pairs = param
.iter()
.filter(|&(k, _)| k.starts_with("oauth_"))
.map(|(k, v)| format!("{}=\"{}\"", k, encode(&v)))
.collect::<Vec<_>>();
pairs.sort();
format!("OAuth {}", pairs.join(", "))
}
fn body(param: &ParamList<'_>) -> String {
let mut pairs = param
.iter()
.filter(|&(k, _)| !k.starts_with("oauth_"))
.map(|(k, v)| format!("{}={}", k, encode(&v)))
.collect::<Vec<_>>();
pairs.sort();
pairs.join("&")
}
fn get_header(
method: &str,
uri: &str,
consumer: &Token<'_>,
token: Option<&Token<'_>>,
other_param: Option<&ParamList<'_>>,
) -> (String, String) {
let mut param = HashMap::new();
let timestamp = format!("{}", OffsetDateTime::now_utc().unix_timestamp());
let mut rng = rand::thread_rng();
let nonce = iter::repeat(())
.map(|()| rng.sample(Alphanumeric))
.map(char::from)
.take(32)
.collect::<String>();
let _ = insert_param(&mut param, "oauth_consumer_key", consumer.key.to_string());
let _ = insert_param(&mut param, "oauth_nonce", nonce);
let _ = insert_param(&mut param, "oauth_signature_method", "HMAC-SHA1");
let _ = insert_param(&mut param, "oauth_timestamp", timestamp);
let _ = insert_param(&mut param, "oauth_version", "1.0");
if let Some(tk) = token {
let _ = insert_param(&mut param, "oauth_token", tk.key.as_ref());
}
if let Some(ps) = other_param {
for (k, v) in ps.iter() {
let _ = insert_param(&mut param, k.as_ref(), v.as_ref());
}
}
let sign = signature(
method,
uri,
join_query(¶m).as_ref(),
consumer.secret.as_ref(),
token.map(|t| t.secret.as_ref()),
);
let _ = insert_param(&mut param, "oauth_signature", sign);
(header(¶m), body(¶m))
}
pub fn authorization_header(
method: &str,
uri: &str,
consumer: &Token<'_>,
token: Option<&Token<'_>>,
other_param: Option<&ParamList<'_>>,
) -> (String, String) {
get_header(method, uri, consumer, token, other_param)
}
pub fn get(
uri: &str,
consumer: &Token<'_>,
token: Option<&Token<'_>>,
other_param: Option<&ParamList<'_>>,
) -> Result<Vec<u8>> {
let (header, body) = get_header("GET", uri, consumer, token, other_param);
let req_uri = if !body.is_empty() {
format!("{}?{}", uri, body)
} else {
uri.to_string()
};
let rsp = send(CLIENT.get(&req_uri).header(AUTHORIZATION, header))?;
Ok(rsp)
}
pub fn post(
uri: &str,
consumer: &Token<'_>,
token: Option<&Token<'_>>,
other_param: Option<&ParamList<'_>>,
) -> Result<Vec<u8>> {
let (header, body) = get_header("POST", uri, consumer, token, other_param);
let rsp = send(
CLIENT
.post(uri)
.body(body)
.header(AUTHORIZATION, header)
.header(CONTENT_TYPE, "application/x-www-form-urlencoded"),
)?;
Ok(rsp)
}
fn send(builder: RequestBuilder) -> Result<Vec<u8>> {
let mut response = builder.send()?;
if response.status() != StatusCode::OK {
return Err(Error::HttpStatus(response.status()));
}
let mut buf = vec![];
let _ = response.read_to_end(&mut buf)?;
Ok(buf)
}
#[cfg(test)]
mod tests {
use super::encode;
use std::collections::HashMap;
#[test]
fn query() {
let mut map = HashMap::new();
let _ = map.insert("aaa".into(), "AAA".into());
let _ = map.insert("bbbb".into(), "BBBB".into());
let query = super::join_query(&map);
assert_eq!("aaa=AAA&bbbb=BBBB", query);
}
#[test]
fn test_encode() {
let method = "GET";
let uri = "http://oauthbin.com/v1/request-token";
let encoded_uri = "http%3A%2F%2Foauthbin.com%2Fv1%2Frequest-token";
let query = [
"oauth_consumer_key=key&",
"oauth_nonce=s6HGl3GhmsDsmpgeLo6lGtKs7rQEzzsA&",
"oauth_signature_method=HMAC-SHA1&",
"oauth_timestamp=1471445561&",
"oauth_version=1.0",
]
.iter()
.cloned()
.collect::<String>();
let encoded_query = [
"oauth_consumer_key%3Dkey%26",
"oauth_nonce%3Ds6HGl3GhmsDsmpgeLo6lGtKs7rQEzzsA%26",
"oauth_signature_method%3DHMAC-SHA1%26",
"oauth_timestamp%3D1471445561%26",
"oauth_version%3D1.0",
]
.iter()
.cloned()
.collect::<String>();
assert_eq!(encode(method), "GET");
assert_eq!(encode(uri), encoded_uri);
assert_eq!(encode(&query), encoded_query);
}
}