1#![warn(missing_docs)]
5
6use hmac::{Hmac, Mac};
7use http_types::{Method, Url};
8use percent_encoding::{AsciiSet, PercentEncode, NON_ALPHANUMERIC};
9use rand::distributions::Alphanumeric;
10use rand::prelude::*;
11use serde::{Deserialize, Serialize};
12use sha1::Sha1;
13use std::borrow::Cow;
14use std::collections::BTreeMap;
15use std::fmt;
16use std::time::{SystemTime, UNIX_EPOCH};
17
18const PERCENT_ENCODING_SET: &AsciiSet = &NON_ALPHANUMERIC
19 .remove(b'-')
20 .remove(b'.')
21 .remove(b'_')
22 .remove(b'~');
23
24fn percent_encode<T: ?Sized + AsRef<[u8]>>(data: &T) -> PercentEncode<'_> {
25 percent_encoding::percent_encode(data.as_ref(), PERCENT_ENCODING_SET)
26}
27
28#[must_use]
30pub fn encode_auth_parameters(params: &BTreeMap<String, String>) -> String {
31 let mut out = String::new();
32 let params: BTreeMap<String, String> = params
33 .iter()
34 .map(|(x, y)| (percent_encode(x).collect(), percent_encode(y).collect()))
35 .collect();
36 let mut params = params.iter();
37 if let Some((k, v)) = params.next() {
38 out.push_str(k);
39 out.push_str("=\"");
40 out.push_str(v);
41 out.push('"');
42 }
43 for (k, v) in params {
44 out.push_str(", ");
45 out.push_str(k);
46 out.push_str("=\"");
47 out.push_str(v);
48 out.push('"');
49 }
50 out
51}
52
53fn encode_url_parameters(params: &BTreeMap<String, String>) -> String {
54 let mut out = String::new();
55 let params: BTreeMap<String, String> = params
56 .iter()
57 .map(|(x, y)| (percent_encode(x).collect(), percent_encode(y).collect()))
58 .collect();
59 let mut params = params.iter();
60 if let Some((k, v)) = params.next() {
61 out.push_str(k);
62 out.push('=');
63 out.push_str(v);
64 }
65 for (k, v) in params {
66 out.push('&');
67 out.push_str(k);
68 out.push('=');
69 out.push_str(v);
70 }
71 out
72}
73
74#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
76#[serde(transparent)]
77pub struct Token(pub String);
78
79#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
81#[serde(transparent)]
82pub struct ClientId(pub String);
83
84#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
86#[serde(transparent)]
87pub struct ClientSecret(pub String);
88
89#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
91#[serde(transparent)]
92pub struct TokenSecret(pub String);
93
94#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
96pub struct SigningKey {
97 pub client_secret: ClientSecret,
99 pub token_secret: Option<TokenSecret>,
101}
102
103impl SigningKey {
104 #[must_use]
106 pub fn with_token(client_secret: ClientSecret, token_secret: TokenSecret) -> Self {
107 Self {
108 client_secret,
109 token_secret: Some(token_secret),
110 }
111 }
112
113 #[must_use]
115 pub fn without_token(client_secret: ClientSecret) -> Self {
116 Self {
117 client_secret,
118 token_secret: None,
119 }
120 }
121}
122
123impl fmt::Display for SigningKey {
124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125 if let Some(token_secret) = &self.token_secret {
126 write!(f, "{}&{}", self.client_secret.0, token_secret.0)
127 } else {
128 write!(f, "{}&", self.client_secret.0)
129 }
130 }
131}
132
133fn normalize_url(mut url: Url) -> Url {
134 if let Some(host) = url.host_str() {
135 let host = host.to_lowercase();
136 url.set_host(Some(&host))
137 .expect("lowercasing shouldn't change host validity");
138 }
139 url.set_fragment(None);
140 url.set_query(None);
141 url
142}
143
144#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)]
146pub struct SignableRequest {
147 #[serde(with = "serde_with::rust::display_fromstr")]
149 pub method: Method,
150 normalized_url: Url,
151 pub parameters: BTreeMap<String, String>,
153}
154
155impl SignableRequest {
156 #[must_use]
158 pub fn new(method: Method, url: Url, parameters: BTreeMap<String, String>) -> Self {
159 let normalized_url = normalize_url(url);
160 Self {
161 method,
162 normalized_url,
163 parameters,
164 }
165 }
166
167 #[must_use]
169 pub fn url(&self) -> &Url {
170 &self.normalized_url
171 }
172}
173
174pub trait Signable {
176 fn to_bytes(&self) -> Cow<'_, [u8]>;
178}
179
180impl Signable for String {
181 fn to_bytes(&self) -> Cow<'_, [u8]> {
182 Cow::Borrowed(self.as_bytes())
183 }
184}
185
186impl Signable for &str {
187 fn to_bytes(&self) -> Cow<'_, [u8]> {
188 Cow::Borrowed(self.as_bytes())
189 }
190}
191
192impl Signable for SignableRequest {
193 fn to_bytes(&self) -> Cow<'_, [u8]> {
194 let method = self.method.to_string().into_bytes();
195 let url = percent_encode(self.url().as_str());
196 let parameters = encode_url_parameters(&self.parameters).into_bytes();
197 let mut vec =
198 Vec::with_capacity(method.len() + self.url().as_str().len() + parameters.len() + 10);
199 vec.extend_from_slice(&method);
200 vec.push(b'&');
201 for x in url {
202 vec.extend_from_slice(x.as_bytes());
203 }
204 vec.push(b'&');
205 for x in percent_encode(¶meters) {
206 vec.extend_from_slice(x.as_bytes());
207 }
208 Cow::Owned(vec)
209 }
210}
211
212#[derive(Debug, Serialize, Deserialize, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
214pub enum SignatureMethod {
215 HmacSha1,
217 Plaintext,
219}
220
221impl SignatureMethod {
222 pub fn sign(self, data: &impl Signable, key: &SigningKey) -> String {
224 let key = key.to_string();
225 match self {
226 Self::HmacSha1 => {
227 let data = data.to_bytes();
228 let mut mac = Hmac::<Sha1>::new_varkey(key.as_bytes())
229 .expect("HMAC has no key length restrictions");
230 mac.input(&data);
231 base64::encode(&mac.result().code())
232 }
233 Self::Plaintext => key,
234 }
235 }
236}
237
238impl fmt::Display for SignatureMethod {
239 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240 let string = match self {
241 Self::HmacSha1 => "HMAC-SHA1",
242 Self::Plaintext => "PLAINTEXT",
243 };
244 write!(f, "{}", string)
245 }
246}
247
248#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
250#[serde(transparent)]
251pub struct Nonce(String);
252
253impl Nonce {
254 #[must_use]
256 pub fn generate() -> Self {
257 Self(thread_rng().sample_iter(Alphanumeric).take(16).collect())
258 }
259}
260
261fn timestamp() -> u64 {
262 SystemTime::now()
263 .duration_since(UNIX_EPOCH)
264 .expect("Bad system time!")
265 .as_secs()
266}
267
268#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
270pub struct OAuthData {
271 pub client_id: ClientId,
273 pub token: Option<Token>,
275 pub signature_method: SignatureMethod,
277 pub nonce: Nonce,
279}
280
281#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
283pub enum AuthorizationType {
284 RequestToken {
286 callback: String,
288 },
289 AccessToken {
291 verifier: String,
293 },
294 Request,
296}
297
298impl OAuthData {
299 #[must_use]
301 pub fn authorization(
302 &self,
303 mut req: SignableRequest,
304 typ: AuthorizationType,
305 key: &SigningKey,
306 ) -> String {
307 req.parameters.extend(self.parameters());
308 match typ {
309 AuthorizationType::RequestToken { callback } => {
310 req.parameters.insert("oauth_callback".into(), callback);
311 }
312 AuthorizationType::AccessToken { verifier } => {
313 req.parameters.insert("oauth_verifier".into(), verifier);
314 }
315 AuthorizationType::Request => {}
316 }
317 let signature = self.signature_method.sign(&req, key);
318 req.parameters.insert("oauth_signature".into(), signature);
319 format!("OAuth {}", encode_auth_parameters(&req.parameters))
320 }
321
322 #[must_use]
324 pub fn parameters(&self) -> BTreeMap<String, String> {
325 let mut params = BTreeMap::new();
326 params.insert("oauth_consumer_key".into(), self.client_id.0.clone());
327 if let Some(token) = &self.token {
328 params.insert("oauth_token".into(), token.0.clone());
329 }
330 params.insert(
331 "oauth_signature_method".into(),
332 self.signature_method.to_string(),
333 );
334 params.insert("oauth_timestamp".into(), timestamp().to_string());
335 params.insert("oauth_nonce".into(), self.nonce.0.clone());
336 params
337 }
338
339 pub fn regen_nonce(&mut self) {
342 self.nonce = Nonce::generate();
343 }
344}
345
346pub fn receive_token<'a>(
352 data: &'a mut OAuthData,
353 key: &mut SigningKey,
354 resp: &str,
355) -> Result<&'a Token, serde_urlencoded::de::Error> {
356 #[derive(Deserialize)]
357 struct Response {
358 pub oauth_token: Token,
359 pub oauth_token_secret: TokenSecret,
360 }
361
362 let resp: Response = serde_urlencoded::from_str(resp)?;
363 let _ = data.token.take();
364 let token = &*data.token.get_or_insert(resp.oauth_token);
365 key.token_secret = Some(resp.oauth_token_secret);
366 Ok(token)
367}
368
369pub fn get_verifier(callback: &Url) -> Result<String, serde_urlencoded::de::Error> {
374 #[derive(Deserialize)]
375 struct Response {
376 pub oauth_token: Token,
377 pub oauth_verifier: String,
378 }
379
380 let query = callback.query().unwrap_or("");
381 let resp: Response = serde_urlencoded::from_str(query)?;
382 Ok(resp.oauth_verifier)
383}
384
385#[cfg(test)]
386mod tests {
387 use std::collections::BTreeMap;
388
389 #[test]
390 fn encode_auth_parameters() {
391 let mut params = BTreeMap::new();
392 params.insert("hello".into(), "World!".into());
393 params.insert("abc".into(), "def".into());
394 params.insert("zzz".into(), "aaa".into());
395 assert_eq!(
396 super::encode_auth_parameters(¶ms),
397 r#"abc="def", hello="World%21", zzz="aaa""#
398 );
399 }
400
401 #[test]
402 fn encode_url_parameters() {
403 let mut params = BTreeMap::new();
405 params.insert("b5".into(), "=%3D".into());
406 params.insert("a3".into(), "a".into());
407 params.insert("c@".into(), "".into());
408 params.insert("a2".into(), "r b".into());
409 params.insert("oauth_consumer_key".into(), "9djdj82h48djs9d2".into());
410 params.insert("oauth_token".into(), "kkk9d7dh3k39sjv7".into());
411 params.insert("oauth_signature_method".into(), "HMAC-SHA1".into());
412 params.insert("oauth_timestamp".into(), "137131201".into());
413 params.insert("oauth_nonce".into(), "7d8f3e4a".into());
414 params.insert("c2".into(), "".into());
415 assert_eq!(
416 super::encode_url_parameters(¶ms),
417 r#"a2=r%20b&a3=a&b5=%3D%253D&c%40=&c2=&oauth_consumer_key=9djdj82h48djs9d2&oauth_nonce=7d8f3e4a&oauth_signature_method=HMAC-SHA1&oauth_timestamp=137131201&oauth_token=kkk9d7dh3k39sjv7"#
418 );
419 }
420
421 #[test]
422 fn encode_request() {
423 use super::Signable;
425 use http_types::{Method, Url};
426 let mut params = BTreeMap::new();
427 params.insert("b5".into(), "=%3D".into());
428 params.insert("a3".into(), "a".into());
429 params.insert("c@".into(), "".into());
430 params.insert("a2".into(), "r b".into());
431 params.insert("oauth_consumer_key".into(), "9djdj82h48djs9d2".into());
432 params.insert("oauth_token".into(), "kkk9d7dh3k39sjv7".into());
433 params.insert("oauth_signature_method".into(), "HMAC-SHA1".into());
434 params.insert("oauth_timestamp".into(), "137131201".into());
435 params.insert("oauth_nonce".into(), "7d8f3e4a".into());
436 params.insert("c2".into(), "".into());
437 let url = Url::parse("http://example.com/request?b5=%3D%253D&a3=a&c%40=&a2=r%20b").unwrap();
438 let req = super::SignableRequest::new(Method::Post, url, params);
439 assert_eq!(
440 std::str::from_utf8(&*req.to_bytes()).unwrap(),
441 r#"POST&http%3A%2F%2Fexample.com%2Frequest&a2%3Dr%2520b%26a3%3Da%26b5%3D%253D%25253D%26c%2540%3D%26c2%3D%26oauth_consumer_key%3D9djdj82h48djs9d2%26oauth_nonce%3D7d8f3e4a%26oauth_signature_method%3DHMAC-SHA1%26oauth_timestamp%3D137131201%26oauth_token%3Dkkk9d7dh3k39sjv7"#
442 );
443 }
444
445 #[test]
446 fn nonce() {
447 for _ in 0..20 {
448 let nonce = super::Nonce::generate();
449 assert_eq!(nonce.0.len(), 16);
450 assert!(!nonce.0.chars().any(|x| !x.is_ascii_alphanumeric()));
451 }
452 }
453
454 #[test]
455 fn sign_plaintext() {
456 use super::*;
457 let client_secret = ClientSecret("client".into());
458 let token_secret = TokenSecret("token".into());
459 let without_token = SigningKey::without_token(client_secret.clone());
460 let with_token = SigningKey::with_token(client_secret, token_secret);
461 let data = "";
462 let sig_without = SignatureMethod::Plaintext.sign(&data, &without_token);
463 let sig_with = SignatureMethod::Plaintext.sign(&data, &with_token);
464 assert_eq!(&sig_without, "client&");
465 assert_eq!(&sig_with, "client&token");
466 }
467
468 #[test]
469 fn sign_hmac() {
470 use super::*;
471 let client_secret = ClientSecret("client".into());
472 let token_secret = TokenSecret("token".into());
473 let without_token = SigningKey::without_token(client_secret.clone());
474 let with_token = SigningKey::with_token(client_secret, token_secret);
475 let data = "Hello, world!";
476 let sig_without = SignatureMethod::HmacSha1.sign(&data, &without_token);
477 let sig_with = SignatureMethod::HmacSha1.sign(&data, &with_token);
478 assert_eq!(&sig_without, "QtZYxkuvnXbp2Pj0dE4nqYXdR5A=");
479 assert_eq!(&sig_with, "4e3uNt5iHa7cMOSKMeY6mil2jew=");
480 }
481}