use std::{collections::HashMap, convert::TryFrom, fmt, time::Duration};
use http::{header::AUTHORIZATION, Method};
use oauth1_request::signature_method::HmacSha1 as DefaultSM;
use oauth1_request::signature_method::SignatureMethod;
use reqwest::{
header::HeaderMap, header::HeaderName, header::HeaderValue, multipart, Body,
Client as RequwestClient, IntoUrl, RequestBuilder as ReqwestRequestBuilder, Response,
};
use serde::Serialize;
use url::Url;
use crate::{
Error, OAuthParameters, SecretsProvider, SignResult, Signer, OAUTH_KEY_PREFIX, REALM_KEY,
};
#[derive(Debug)]
pub struct RequestBuilder<TSigner>
where
TSigner: Clone,
{
method: Method,
inner: ReqwestRequestBuilder,
signer: TSigner,
url: Option<Url>,
body: String,
query_oauth_parameters: HashMap<String, String>,
form_oauth_parameters: HashMap<String, String>,
}
impl RequestBuilder<()> {
pub fn sign<'a, TSecrets>(
self,
secrets: TSecrets,
) -> RequestBuilder<Signer<'a, TSecrets, DefaultSM>>
where
TSecrets: SecretsProvider + Clone,
{
self.sign_with_params(secrets, OAuthParameters::new())
}
pub fn sign_with_params<'a, TSecrets, TSM>(
self,
secrets: TSecrets,
params: OAuthParameters<'a, TSM>,
) -> RequestBuilder<Signer<'a, TSecrets, TSM>>
where
TSecrets: SecretsProvider + Clone,
TSM: SignatureMethod + Clone,
{
RequestBuilder {
inner: self.inner,
method: self.method,
url: self.url,
body: self.body,
signer: Signer::new(secrets.into(), params),
query_oauth_parameters: self.query_oauth_parameters,
form_oauth_parameters: self.form_oauth_parameters,
}
}
}
impl<TSecrets, TSM> RequestBuilder<Signer<'_, TSecrets, TSM>>
where
TSecrets: SecretsProvider + Clone,
TSM: SignatureMethod + Clone,
{
pub async fn send(self) -> Result<Response, Error> {
Ok(self.generate_signature()?.send().await?)
}
pub fn generate_signature(self) -> SignResult<ReqwestRequestBuilder> {
if let Some(url) = self.url {
let (is_q, url, payload) = match url.query() {
None | Some("") => {
(false, url, self.body.as_ref())
}
Some(q) => {
let mut pure_url = url.clone();
pure_url.set_query(None);
(true, pure_url, q)
}
};
let oauth_params: HashMap<String, String> = self
.form_oauth_parameters
.into_iter()
.chain(self.query_oauth_parameters.into_iter())
.collect();
let signature = self
.signer
.override_oauth_parameter(oauth_params)
.generate_signature(self.method, url, payload, is_q)?;
Ok(self.inner.header(AUTHORIZATION, signature))
} else {
Ok(self.inner)
}
}
}
impl<TSigner> RequestBuilder<TSigner>
where
TSigner: Clone,
{
pub(crate) fn new<T: IntoUrl + Clone>(
client: &RequwestClient,
method: Method,
url: T,
signer: TSigner,
) -> Self {
match url.clone().into_url() {
Ok(url) => {
let mut query_oauth_params: HashMap<String, String> = HashMap::new();
let stealed_url = steal_oauth_params_from_url(url, &mut query_oauth_params);
RequestBuilder {
inner: client.request(method.clone(), stealed_url.clone()),
method,
url: Some(stealed_url),
body: String::new(),
signer: signer,
query_oauth_parameters: query_oauth_params,
form_oauth_parameters: HashMap::new(),
}
}
Err(_) => RequestBuilder {
inner: client.request(method.clone(), url),
method,
url: None,
body: String::new(),
signer: signer,
query_oauth_parameters: HashMap::new(),
form_oauth_parameters: HashMap::new(),
},
}
}
pub fn query<T: Serialize + ?Sized>(mut self, query: &T) -> Self {
let query = steal_oauth_params(query, &mut self.query_oauth_parameters);
if let Some(ref mut url) = self.url {
let mut pairs = url.query_pairs_mut();
let serializer = serde_urlencoded::Serializer::new(&mut pairs);
let _ = query.serialize(serializer);
}
if let Some(ref mut url) = self.url {
if let Some("") = url.query() {
url.set_query(None);
}
}
self.inner = self.inner.query(&query);
self
}
pub fn form<T: Serialize + ?Sized + Clone>(mut self, form: &T) -> Self {
self.form_oauth_parameters.clear();
let form = steal_oauth_params(form, &mut self.query_oauth_parameters);
match serde_urlencoded::to_string(form.clone()) {
Ok(body) => {
self.inner = self.inner.form(&form);
self.body = body;
self
}
Err(_) => self.pass_through(|b| b.form(&form)),
}
}
pub fn query_without_capture<T: Serialize>(self, query: &T) -> Self {
self.pass_through(|b| b.query(query))
}
pub fn form_without_capture<T: Serialize + ?Sized>(self, form: &T) -> Self {
self.pass_through(|b| b.form(form))
}
fn pass_through<F>(self, f: F) -> Self
where
F: FnOnce(ReqwestRequestBuilder) -> ReqwestRequestBuilder,
{
RequestBuilder {
inner: f(self.inner),
..self
}
}
pub fn header<K, V>(self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.pass_through(|b| b.header(key, value))
}
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.inner = self.inner.headers(headers);
self
}
pub fn basic_auth<U, P>(self, username: U, password: Option<P>) -> Self
where
U: fmt::Display,
P: fmt::Display,
{
self.pass_through(|b| b.basic_auth(username, password))
}
pub fn bearer_auth<T>(self, token: T) -> Self
where
T: fmt::Display,
{
self.pass_through(|b| b.bearer_auth(token))
}
pub fn body<T: Into<Body>>(mut self, body: T) -> Self {
self.inner = self.inner.body(body);
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.inner = self.inner.timeout(timeout);
self
}
pub fn multipart(self, multipart: multipart::Form) -> Self {
self.pass_through(|b| b.multipart(multipart))
}
pub fn fetch_mode_no_cors(self) -> Self {
self
}
pub fn try_clone(&self) -> Option<Self> {
match self.inner.try_clone() {
Some(inner) => Some(RequestBuilder {
inner,
method: self.method.clone(),
url: self.url.clone(),
body: self.body.clone(),
signer: self.signer.clone(),
query_oauth_parameters: self.query_oauth_parameters.clone(),
form_oauth_parameters: self.form_oauth_parameters.clone(),
}),
None => None,
}
}
}
fn steal_oauth_params<T>(
query: &T,
oauth_map: &mut HashMap<String, String>,
) -> Vec<(String, String)>
where
T: Serialize + ?Sized,
{
let mut empty_url = Url::parse("http://example.com/")
.expect("failed to parse the http://example.com/, that is unexpected behavior.");
{
let mut pairs = empty_url.query_pairs_mut();
let serializer = serde_urlencoded::Serializer::new(&mut pairs);
let _ = query.serialize(serializer);
}
steal_oauth_params_core(&empty_url, oauth_map)
.into_iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect()
}
fn steal_oauth_params_from_url(mut url: Url, oauth_map: &mut HashMap<String, String>) -> Url {
let remainder = steal_oauth_params_core(&url, oauth_map);
url.set_query(None);
if remainder.len() > 0 {
let mut serializer = url.query_pairs_mut();
for (k, v) in remainder {
serializer.append_pair(&k, &v);
}
}
url
}
fn steal_oauth_params_core(
url: &Url,
oauth_map: &mut HashMap<String, String>,
) -> Vec<(String, String)> {
url.query_pairs()
.into_iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.filter_map(|(k, v)| {
if k.starts_with(OAUTH_KEY_PREFIX) || k == REALM_KEY {
oauth_map.insert(k, v);
None
} else {
Some((k, v))
}
})
.collect()
}
#[cfg(test)]
mod tests {
use http::header::AUTHORIZATION;
use crate::{
OAuthClientProvider, OAuthParameters, Secrets, OAUTH_NONCE_KEY, OAUTH_TIMESTAMP_KEY,
};
fn extract_signature(auth_header: &str) -> String {
let content = auth_header.strip_prefix("OAuth ").unwrap();
let mapped_header = content
.split(',')
.map(|item| item.splitn(2, '=').collect::<Vec<&str>>())
.filter(|v| v.len() == 2)
.map(|v| (v[0], v[1]))
.collect::<Vec<(&str, &str)>>();
let sig_content = mapped_header.iter().find(|(k, _)| k == &"oauth_signature");
percent_encoding::percent_decode_str(sig_content.unwrap().1)
.decode_utf8_lossy()
.trim_matches('"')
.to_string()
}
#[test]
fn call_multiple_queries() {
let req = reqwest::Client::new()
.get("https://example.com")
.query(&[("a", "b")])
.query(&[("c", "d")])
.build()
.unwrap();
assert_eq!(req.url().to_string(), "https://example.com/?a=b&c=d");
}
#[test]
fn call_multiple_forms() {
let req = reqwest::Client::new()
.post("https://example.com")
.query(&[("this is", "query")])
.form(&[("a", "b")])
.form(&[("c", "d")])
.build()
.unwrap();
let decoded_body = String::from_utf8_lossy(req.body().unwrap().as_bytes().unwrap());
assert_eq!(req.url().to_string(), "https://example.com/?this+is=query");
assert_eq!(decoded_body, "c=d");
}
#[test]
fn capture_post_query() {
let endpoint = "https://photos.example.net/initiate";
let c_key = "dpf43f3p2l4k3l03";
let c_secret = "kd94hf93k423kf44";
let nonce = "wIjqoS";
let timestamp = 137_131_200u64;
let secrets = Secrets::new(c_key, c_secret);
let params = OAuthParameters::new()
.nonce(nonce)
.timestamp(timestamp)
.callback("http://printer.example.com/ready")
.realm("photos");
let req = reqwest::Client::new()
.oauth1_with_params(secrets, params)
.post(endpoint)
.form(&[("少女", "終末旅行"), ("oauth_should_be_ignored", "true")]);
let url = req.body;
assert_eq!(
url,
"%E5%B0%91%E5%A5%B3=%E7%B5%82%E6%9C%AB%E6%97%85%E8%A1%8C"
);
}
#[test]
fn sign_post_query() {
let endpoint = "https://photos.example.net/initiate";
let c_key = "dpf43f3p2l4k3l03";
let c_secret = "kd94hf93k423kf44";
let nonce = "wIjqoS";
let timestamp = 137_131_200u64;
let secrets = Secrets::new(c_key, c_secret);
let params = OAuthParameters::new()
.nonce(nonce)
.timestamp(timestamp)
.callback("http://printer.example.com/ready")
.realm("photos");
let req = reqwest::Client::new()
.oauth1_with_params(secrets, params)
.post(endpoint)
.generate_signature()
.unwrap()
.build()
.unwrap();
let sign = req.headers().get(AUTHORIZATION);
assert_eq!(
extract_signature(sign.unwrap().to_str().unwrap()),
"74KNZJeDHnMBp0EMJ9ZHt/XKycU="
);
}
#[test]
fn capture_get_query() {
let endpoint = "https://photos.example.net/photos?file=vacation.jpg&size=original&oauth_should_be_ignored=true";
let c_key = "dpf43f3p2l4k3l03";
let c_secret = "kd94hf93k423kf44";
let token = "nnch734d00sl2jdk";
let token_secret = "pfkkdhi9sl3r4s00";
let nonce = "wIjqoS";
let timestamp = 137_131_200u64;
let secrets = Secrets::new(c_key, c_secret).token(token, token_secret);
let params = OAuthParameters::new().nonce(nonce).timestamp(timestamp);
let req = reqwest::Client::new()
.oauth1_with_params(secrets, params)
.get(endpoint);
let query = req.url.unwrap().query().unwrap().to_string();
assert_eq!(query, "file=vacation.jpg&size=original")
}
#[test]
fn sign_get_query() {
let endpoint = "http://photos.example.net/photos?file=vacation.jpg&size=original";
let c_key = "dpf43f3p2l4k3l03";
let c_secret = "kd94hf93k423kf44";
let token = "nnch734d00sl2jdk";
let token_secret = "pfkkdhi9sl3r4s00";
let nonce = "chapoH";
let timestamp = 137_131_202u64;
let secrets = Secrets::new(c_key, c_secret).token(token, token_secret);
let params = OAuthParameters::new()
.nonce(nonce)
.timestamp(timestamp)
.realm("Photos");
let req = reqwest::Client::new()
.oauth1_with_params(secrets, params)
.get(endpoint)
.generate_signature()
.unwrap()
.build()
.unwrap();
let sign = req.headers().get(AUTHORIZATION);
assert_eq!(
extract_signature(sign.unwrap().to_str().unwrap()),
"MdpQcU8iPSUjWoN/UDMsK2sui9I="
);
}
#[test]
fn sign_get_query_with_query_oauth_params() {
let endpoint =
"http://photos.example.net/photos?file=vacation.jpg&size=original&realm=Photos";
let c_key = "dpf43f3p2l4k3l03";
let c_secret = "kd94hf93k423kf44";
let token = "nnch734d00sl2jdk";
let token_secret = "pfkkdhi9sl3r4s00";
let nonce = "chapoH";
let timestamp = 137_131_202u64;
let secrets = Secrets::new(c_key, c_secret).token(token, token_secret);
let req = reqwest::Client::new()
.oauth1(secrets)
.get(endpoint)
.query(&[
(OAUTH_NONCE_KEY, nonce),
(OAUTH_TIMESTAMP_KEY, &format!("{}", timestamp)),
])
.generate_signature()
.unwrap()
.build()
.unwrap();
let sign = req.headers().get(AUTHORIZATION);
assert_eq!(
extract_signature(sign.unwrap().to_str().unwrap()),
"MdpQcU8iPSUjWoN/UDMsK2sui9I="
);
}
#[test]
fn capture_body() {
let endpoint = url::Url::parse("https://api.twitter.com/1.1/statuses/update.json").unwrap();
let c_key = "xvz1evFS4wEEPTGEFPHBog";
let c_secret = "kAcSOqF21Fu85e7zjz7ZN2U4ZRhfV3WpwPAoE3Z7kBw";
let nonce = "kYjzVBB8Y0ZFabxSWbWovY3uYSQ2pTgmZeNu2VS4cg";
let timestamp = 1_318_622_958u64;
let token = "370773112-GmHxMAgYyLbNEtIKZeRNFsMKPR9EyMZeS9weJAEb";
let token_secret = "LswwdoUaIvS8ltyTt5jkRh4J50vUPVVHtR2YPi5kE";
let secrets = Secrets::new(c_key, c_secret).token(token, token_secret);
let params = OAuthParameters::new().nonce(nonce).timestamp(timestamp);
let req = reqwest::Client::new()
.oauth1_with_params(secrets, params)
.post(endpoint)
.form(&[
("include_entities", "true"),
(
"status",
"Hello Ladies + Gentlemen, a signed OAuth request!",
),
]);
let body = req.body;
assert_eq!(
body,
"include_entities=true&status=Hello+Ladies+%2B+Gentlemen%2C+a+signed+OAuth+request%21"
)
}
#[test]
fn sign_post_body() {
let endpoint = url::Url::parse("https://api.twitter.com/1.1/statuses/update.json").unwrap();
let c_key = "xvz1evFS4wEEPTGEFPHBog";
let c_secret = "kAcSOqF21Fu85e7zjz7ZN2U4ZRhfV3WpwPAoE3Z7kBw";
let nonce = "kYjzVBB8Y0ZFabxSWbWovY3uYSQ2pTgmZeNu2VS4cg";
let timestamp = 1_318_622_958u64;
let token = "370773112-GmHxMAgYyLbNEtIKZeRNFsMKPR9EyMZeS9weJAEb";
let token_secret = "LswwdoUaIvS8ltyTt5jkRh4J50vUPVVHtR2YPi5kE";
let secrets = Secrets::new(c_key, c_secret).token(token, token_secret);
let params = OAuthParameters::new()
.nonce(nonce)
.timestamp(timestamp)
.version(true);
let req = reqwest::Client::new()
.oauth1_with_params(secrets, params)
.post(endpoint)
.form(&[
("include_entities", "true"),
(
"status",
"Hello Ladies + Gentlemen, a signed OAuth request!",
),
])
.generate_signature()
.unwrap()
.build()
.unwrap();
let sign = req.headers().get(AUTHORIZATION);
assert_eq!(
extract_signature(sign.unwrap().to_str().unwrap()),
"hCtSmYh+iHYCEqBWrE7C7hYmtUk="
);
}
}