integrationos_domain/algebra/
oauth.rs

1use anyhow::{Context, Result};
2use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3use base64::prelude::*;
4use core::str;
5use hmac::{Hmac, Mac};
6use http::Method;
7use indexmap::IndexMap;
8use percent_encoding::{AsciiSet, PercentEncode, NON_ALPHANUMERIC};
9use rand::{thread_rng, RngCore};
10use reqwest::Url;
11use sha1::Sha1;
12use sha2::{Sha256, Sha512};
13use std::{
14    borrow::Cow,
15    fmt,
16    time::{SystemTime, UNIX_EPOCH},
17};
18
19const NONCE_LEN: usize = 12;
20const OAUTH_CALLBACK: &str = "oauth_callback";
21const OAUTH_VERIFIER: &str = "oauth_verifier";
22const OAUTH_CONSUMER_KEY: &str = "oauth_consumer_key";
23const OAUTH_NONCE: &str = "oauth_nonce";
24const OAUTH_SIGNATURE: &str = "oauth_signature";
25const OAUTH_SIGNATURE_METHOD: &str = "oauth_signature_method";
26const OAUTH_TIMESTAMP: &str = "oauth_timestamp";
27const OAUTH_TOKEN: &str = "oauth_token";
28const OAUTH_VERSION: &str = "oauth_version";
29const HMAC_LENGTH_ERROR: &str = "HMAC has no key length restrictions";
30
31const EXCLUDE: &AsciiSet = &NON_ALPHANUMERIC
32    .remove(b'-')
33    .remove(b'.')
34    .remove(b'_')
35    .remove(b'~');
36
37fn percent_encode<T: ?Sized + AsRef<[u8]>>(data: &T) -> PercentEncode<'_> {
38    percent_encoding::percent_encode(data.as_ref(), EXCLUDE)
39}
40
41#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
42pub enum SignatureMethod {
43    HmacSha1,
44    HmacSha256,
45    HmacSha512,
46    PlainText,
47}
48
49impl fmt::Display for SignatureMethod {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        match self {
52            Self::HmacSha1 => write!(f, "HMAC-SHA1"),
53            Self::HmacSha256 => write!(f, "HMAC-SHA256"),
54            Self::HmacSha512 => write!(f, "HMAC-SHA512"),
55            Self::PlainText => write!(f, "PLAINTEXT"),
56        }
57    }
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub struct SignableRequest {
62    pub method: Method,
63    pub uri: Url,
64    pub parameters: IndexMap<String, String>,
65}
66
67impl SignableRequest {
68    fn as_bytes(&self) -> Result<Cow<[u8]>> {
69        let method = self.method.to_string();
70        let normalized_uri = {
71            let mut url = self.uri.clone();
72            if let Some(host) = url.host_str() {
73                url.set_host(Some(&host.to_lowercase()))
74                    .context("OAuth 1.0 URI lowercasing shouldn't change host validity")?;
75            }
76            url.set_fragment(None);
77            url.set_query(None);
78            url
79        };
80
81        let encoded_url = percent_encode(normalized_uri.as_str());
82        let encoded_url_params = encode_url_parameters(&self.parameters);
83        let encoded_params = percent_encode(&encoded_url_params);
84
85        let result = format!("{}&{}&{}", method, encoded_url, encoded_params);
86
87        Ok(Cow::Owned(result.into_bytes()))
88    }
89
90    fn sorted_parameters(&self) -> SignableRequest {
91        let mut params = self.parameters.clone();
92        params.sort_keys();
93        SignableRequest {
94            method: self.method.clone(),
95            uri: self.uri.clone(),
96            parameters: params,
97        }
98    }
99}
100
101#[derive(Debug, Clone)]
102pub struct SigningKey {
103    pub client_secret: String,
104    pub token_secret: Option<String>,
105}
106
107impl fmt::Display for SigningKey {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        if let Some(token_secret) = &self.token_secret {
110            write!(
111                f,
112                "{}&{}",
113                percent_encode(&self.client_secret),
114                percent_encode(&token_secret)
115            )
116        } else {
117            write!(f, "{}&", percent_encode(&self.client_secret))
118        }
119    }
120}
121
122impl SignatureMethod {
123    pub fn sign(self, data: &SignableRequest, key: &SigningKey) -> Result<String> {
124        let key = key.to_string();
125
126        match self {
127            Self::HmacSha1 => {
128                let mut mac =
129                    Hmac::<Sha1>::new_from_slice(key.as_bytes()).context(HMAC_LENGTH_ERROR)?;
130                mac.update(&data.sorted_parameters().as_bytes()?);
131                let result = mac.finalize().into_bytes();
132                Ok(BASE64_STANDARD.encode(result))
133            }
134            Self::HmacSha256 => {
135                let mut mac =
136                    Hmac::<Sha256>::new_from_slice(key.as_bytes()).context(HMAC_LENGTH_ERROR)?;
137                mac.update(&data.sorted_parameters().as_bytes()?);
138                let result = mac.finalize().into_bytes();
139                Ok(BASE64_STANDARD.encode(result))
140            }
141            Self::HmacSha512 => {
142                let mut mac =
143                    Hmac::<Sha512>::new_from_slice(key.as_bytes()).context(HMAC_LENGTH_ERROR)?;
144                mac.update(&data.sorted_parameters().as_bytes()?);
145                let result = mac.finalize().into_bytes();
146                Ok(BASE64_STANDARD.encode(result))
147            }
148            Self::PlainText => Ok(key),
149        }
150    }
151}
152
153#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
154pub struct Nonce(pub String);
155impl Nonce {
156    pub fn generate() -> Result<Nonce> {
157        let mut rng = thread_rng();
158        let mut rand = [0_u8; NONCE_LEN * 3 / 4];
159        rng.fill_bytes(&mut rand);
160
161        let i = rand.iter().position(|&b| b != 0).unwrap_or(rand.len());
162        let rand = &rand[i..];
163
164        let mut buf = [0u8; NONCE_LEN];
165        let len = URL_SAFE_NO_PAD
166            .encode_slice(rand, &mut buf)
167            .context("Failed to encode nonce to Base64")?;
168
169        let nonce_str = str::from_utf8(&buf[..len])
170            .context("Failed to convert nonce bytes to UTF-8")?
171            .to_string();
172
173        Ok(Nonce(nonce_str))
174    }
175}
176
177#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
178pub struct OAuthData {
179    pub client_id: String,
180    pub token: Option<String>,
181    pub signature_method: SignatureMethod,
182    pub nonce: Nonce,
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
186pub enum AuthorizationType {
187    RequestToken { callback: String },
188    AccessToken { verifier: String },
189    Request,
190}
191
192fn timestamp() -> Result<u64> {
193    Ok(SystemTime::now()
194        .duration_since(UNIX_EPOCH)
195        .context("Bad system time!")?
196        .as_secs())
197}
198
199impl OAuthData {
200    pub fn authorization(
201        &self,
202        mut req: SignableRequest,
203        typ: AuthorizationType,
204        key: &SigningKey,
205        realm: Option<String>,
206    ) -> Result<String> {
207        req.parameters.extend(self.parameters()?);
208
209        match typ {
210            AuthorizationType::RequestToken { callback } => {
211                req.parameters.insert(OAUTH_CALLBACK.into(), callback);
212            }
213            AuthorizationType::AccessToken { verifier } => {
214                req.parameters.insert(OAUTH_VERIFIER.into(), verifier);
215            }
216            AuthorizationType::Request => {}
217        }
218
219        let signature = self.signature_method.sign(&req, key)?;
220        req.parameters.insert(OAUTH_SIGNATURE.into(), signature);
221
222        // Only include OAuth parameters in the Authorization header
223        let oauth_params: IndexMap<_, _> = req
224            .parameters
225            .iter()
226            .filter(|(k, _)| k.starts_with("oauth_"))
227            .map(|(k, v)| (k.clone(), v.clone()))
228            .collect();
229
230        Ok(match realm {
231            Some(realm) => format!(
232                "OAuth realm=\"{}\",{}",
233                realm,
234                encode_auth_header(&oauth_params)
235            ),
236            None => format!("OAuth {}", encode_auth_header(&oauth_params)),
237        })
238    }
239
240    pub fn parameters(&self) -> Result<IndexMap<String, String>> {
241        let mut params = IndexMap::new();
242
243        params.insert(OAUTH_CONSUMER_KEY.into(), self.client_id.clone());
244        if let Some(token) = &self.token {
245            params.insert(OAUTH_TOKEN.into(), token.clone());
246        }
247        params.insert(
248            OAUTH_SIGNATURE_METHOD.into(),
249            self.signature_method.to_string(),
250        );
251        params.insert(OAUTH_TIMESTAMP.into(), timestamp()?.to_string());
252        params.insert(OAUTH_NONCE.into(), self.nonce.0.clone());
253        params.insert(OAUTH_VERSION.into(), "1.0".into());
254        Ok(params)
255    }
256}
257fn encode_auth_header(req: &IndexMap<String, String>) -> String {
258    req.iter()
259        .map(|(k, v)| format!(r#"{}="{}""#, percent_encode(k), percent_encode(v)))
260        .collect::<Vec<String>>()
261        .join(",")
262}
263
264fn encode_url_parameters(params: &IndexMap<String, String>) -> String {
265    params
266        .iter()
267        .map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(v)))
268        .collect::<Vec<String>>()
269        .join("&")
270}