Skip to main content

auth_framework/protocols/
oauth1.rs

1//! OAuth 1.0a protocol support (RFC 5849).
2//!
3//! Provides HMAC-SHA1 signature generation, authorization header construction,
4//! and the three-legged OAuth 1.0a flow data structures.
5
6use crate::errors::{AuthError, Result};
7use base64::Engine;
8use ring::hmac;
9use serde::{Deserialize, Serialize};
10use std::collections::BTreeMap;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13/// OAuth 1.0a consumer credentials (application).
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct OAuthConsumer {
16    pub key: String,
17    pub secret: String,
18}
19
20/// OAuth 1.0a token credentials (user-authorized).
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct OAuthToken {
23    pub token: String,
24    pub secret: String,
25}
26
27/// OAuth 1.0a signature method.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
29pub enum SignatureMethod {
30    HmacSha1,
31    HmacSha256,
32    Plaintext,
33}
34
35impl SignatureMethod {
36    pub fn as_str(&self) -> &'static str {
37        match self {
38            Self::HmacSha1 => "HMAC-SHA1",
39            Self::HmacSha256 => "HMAC-SHA256",
40            Self::Plaintext => "PLAINTEXT",
41        }
42    }
43}
44
45/// A signed OAuth 1.0a request.
46#[derive(Debug, Clone)]
47pub struct OAuthSignedRequest {
48    /// The Authorization header value.
49    pub authorization_header: String,
50    /// The signature base string (useful for debugging).
51    pub signature_base_string: String,
52    /// The computed signature.
53    pub signature: String,
54}
55
56/// OAuth 1.0a request token response (temporary credentials).
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct RequestTokenResponse {
59    pub oauth_token: String,
60    pub oauth_token_secret: String,
61    pub oauth_callback_confirmed: bool,
62}
63
64/// OAuth 1.0a access token response.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct AccessTokenResponse {
67    pub oauth_token: String,
68    pub oauth_token_secret: String,
69}
70
71/// OAuth 1.0a client for constructing signed requests.
72pub struct OAuth1Client {
73    consumer: OAuthConsumer,
74    signature_method: SignatureMethod,
75}
76
77impl OAuth1Client {
78    /// Create a new OAuth 1.0a client.
79    pub fn new(consumer: OAuthConsumer, signature_method: SignatureMethod) -> Result<Self> {
80        if consumer.key.is_empty() || consumer.secret.is_empty() {
81            return Err(AuthError::validation(
82                "Consumer key and secret must not be empty",
83            ));
84        }
85        Ok(Self {
86            consumer,
87            signature_method,
88        })
89    }
90
91    /// Sign an HTTP request using OAuth 1.0a.
92    ///
93    /// Returns the signed request with the Authorization header value.
94    pub fn sign_request(
95        &self,
96        method: &str,
97        url: &str,
98        token: Option<&OAuthToken>,
99        extra_params: Option<&BTreeMap<String, String>>,
100    ) -> Result<OAuthSignedRequest> {
101        let nonce = generate_nonce()?;
102        let timestamp = SystemTime::now()
103            .duration_since(UNIX_EPOCH)
104            .unwrap_or_default()
105            .as_secs()
106            .to_string();
107
108        // Collect OAuth parameters
109        let mut params = BTreeMap::new();
110        params.insert("oauth_consumer_key".to_string(), self.consumer.key.clone());
111        params.insert("oauth_nonce".to_string(), nonce);
112        params.insert(
113            "oauth_signature_method".to_string(),
114            self.signature_method.as_str().to_string(),
115        );
116        params.insert("oauth_timestamp".to_string(), timestamp);
117        params.insert("oauth_version".to_string(), "1.0".to_string());
118
119        if let Some(t) = token {
120            params.insert("oauth_token".to_string(), t.token.clone());
121        }
122
123        if let Some(extra) = extra_params {
124            for (k, v) in extra {
125                params.insert(k.clone(), v.clone());
126            }
127        }
128
129        // Build signature base string
130        let param_string: String = params
131            .iter()
132            .map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(v)))
133            .collect::<Vec<_>>()
134            .join("&");
135
136        let base_string = format!(
137            "{}&{}&{}",
138            method.to_uppercase(),
139            percent_encode(url),
140            percent_encode(&param_string)
141        );
142
143        // Compute signature
144        let token_secret = token.map(|t| t.secret.as_str()).unwrap_or("");
145        let signing_key = format!(
146            "{}&{}",
147            percent_encode(&self.consumer.secret),
148            percent_encode(token_secret)
149        );
150
151        let signature = match self.signature_method {
152            SignatureMethod::HmacSha1 => {
153                let key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, signing_key.as_bytes());
154                let tag = hmac::sign(&key, base_string.as_bytes());
155                base64::engine::general_purpose::STANDARD.encode(tag.as_ref())
156            }
157            SignatureMethod::HmacSha256 => {
158                let key = hmac::Key::new(hmac::HMAC_SHA256, signing_key.as_bytes());
159                let tag = hmac::sign(&key, base_string.as_bytes());
160                base64::engine::general_purpose::STANDARD.encode(tag.as_ref())
161            }
162            SignatureMethod::Plaintext => signing_key.clone(),
163        };
164
165        // Build Authorization header
166        params.insert("oauth_signature".to_string(), signature.clone());
167
168        let auth_header = format!(
169            "OAuth {}",
170            params
171                .iter()
172                .filter(|(k, _)| k.starts_with("oauth_"))
173                .map(|(k, v)| format!("{}=\"{}\"", percent_encode(k), percent_encode(v)))
174                .collect::<Vec<_>>()
175                .join(", ")
176        );
177
178        Ok(OAuthSignedRequest {
179            authorization_header: auth_header,
180            signature_base_string: base_string,
181            signature,
182        })
183    }
184
185    /// Build the authorization URL for the user to visit.
186    pub fn build_authorize_url(&self, base_url: &str, request_token: &str) -> String {
187        format!(
188            "{}?oauth_token={}",
189            base_url,
190            percent_encode(request_token)
191        )
192    }
193
194    /// Parse a request token response body (form-encoded).
195    pub fn parse_request_token_response(body: &str) -> Result<RequestTokenResponse> {
196        let params = parse_form_body(body);
197        let token = params
198            .get("oauth_token")
199            .ok_or_else(|| AuthError::validation("Missing oauth_token"))?
200            .clone();
201        let secret = params
202            .get("oauth_token_secret")
203            .ok_or_else(|| AuthError::validation("Missing oauth_token_secret"))?
204            .clone();
205        let confirmed = params
206            .get("oauth_callback_confirmed")
207            .map(|v| v == "true")
208            .unwrap_or(false);
209
210        Ok(RequestTokenResponse {
211            oauth_token: token,
212            oauth_token_secret: secret,
213            oauth_callback_confirmed: confirmed,
214        })
215    }
216
217    /// Parse an access token response body (form-encoded).
218    pub fn parse_access_token_response(body: &str) -> Result<AccessTokenResponse> {
219        let params = parse_form_body(body);
220        let token = params
221            .get("oauth_token")
222            .ok_or_else(|| AuthError::validation("Missing oauth_token"))?
223            .clone();
224        let secret = params
225            .get("oauth_token_secret")
226            .ok_or_else(|| AuthError::validation("Missing oauth_token_secret"))?
227            .clone();
228
229        Ok(AccessTokenResponse {
230            oauth_token: token,
231            oauth_token_secret: secret,
232        })
233    }
234}
235
236/// RFC 3986 percent-encoding.
237fn percent_encode(s: &str) -> String {
238    let mut encoded = String::with_capacity(s.len());
239    for byte in s.bytes() {
240        match byte {
241            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
242                encoded.push(byte as char);
243            }
244            _ => {
245                encoded.push_str(&format!("%{:02X}", byte));
246            }
247        }
248    }
249    encoded
250}
251
252/// Parse application/x-www-form-urlencoded response body.
253fn parse_form_body(body: &str) -> BTreeMap<String, String> {
254    body.split('&')
255        .filter_map(|pair| {
256            let mut parts = pair.splitn(2, '=');
257            let key = parts.next()?;
258            let value = parts.next().unwrap_or("");
259            Some((key.to_string(), value.to_string()))
260        })
261        .collect()
262}
263
264/// Generate a cryptographically random nonce.
265fn generate_nonce() -> Result<String> {
266    use ring::rand::{SecureRandom, SystemRandom};
267    let rng = SystemRandom::new();
268    let mut buf = [0u8; 16];
269    rng.fill(&mut buf)
270        .map_err(|_| AuthError::crypto("Failed to generate nonce".to_string()))?;
271    Ok(hex::encode(buf))
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    fn test_consumer() -> OAuthConsumer {
279        OAuthConsumer {
280            key: "dpf43f3p2l4k3l03".to_string(),
281            secret: "kd94hf93k423kf44".to_string(),
282        }
283    }
284
285    #[test]
286    fn test_create_client() {
287        let client = OAuth1Client::new(test_consumer(), SignatureMethod::HmacSha1).unwrap();
288        assert_eq!(client.consumer.key, "dpf43f3p2l4k3l03");
289    }
290
291    #[test]
292    fn test_empty_consumer_rejected() {
293        let consumer = OAuthConsumer {
294            key: String::new(),
295            secret: "secret".to_string(),
296        };
297        assert!(OAuth1Client::new(consumer, SignatureMethod::HmacSha1).is_err());
298    }
299
300    #[test]
301    fn test_sign_request_hmac_sha1() {
302        let client = OAuth1Client::new(test_consumer(), SignatureMethod::HmacSha1).unwrap();
303        let signed = client
304            .sign_request("GET", "https://api.example.com/resource", None, None)
305            .unwrap();
306
307        assert!(signed.authorization_header.starts_with("OAuth "));
308        assert!(signed.authorization_header.contains("oauth_consumer_key="));
309        assert!(signed.authorization_header.contains("oauth_signature="));
310        assert!(signed.authorization_header.contains("oauth_nonce="));
311        assert!(!signed.signature.is_empty());
312    }
313
314    #[test]
315    fn test_sign_request_with_token() {
316        let client = OAuth1Client::new(test_consumer(), SignatureMethod::HmacSha1).unwrap();
317        let token = OAuthToken {
318            token: "nnch734d00sl2jdk".to_string(),
319            secret: "pfkkdhi9sl3r4s00".to_string(),
320        };
321        let signed = client
322            .sign_request("POST", "https://api.example.com/post", Some(&token), None)
323            .unwrap();
324
325        assert!(signed.authorization_header.contains("oauth_token="));
326        assert!(!signed.signature.is_empty());
327    }
328
329    #[test]
330    fn test_sign_request_hmac_sha256() {
331        let client = OAuth1Client::new(test_consumer(), SignatureMethod::HmacSha256).unwrap();
332        let signed = client
333            .sign_request("GET", "https://api.example.com/resource", None, None)
334            .unwrap();
335        assert!(signed.authorization_header.contains("HMAC-SHA256"));
336    }
337
338    #[test]
339    fn test_sign_request_plaintext() {
340        let client = OAuth1Client::new(test_consumer(), SignatureMethod::Plaintext).unwrap();
341        let signed = client
342            .sign_request("GET", "https://api.example.com/resource", None, None)
343            .unwrap();
344        // Plaintext signature = consumer_secret&token_secret
345        assert!(signed.signature.contains("kd94hf93k423kf44"));
346    }
347
348    #[test]
349    fn test_percent_encode() {
350        assert_eq!(percent_encode("hello"), "hello");
351        assert_eq!(percent_encode("hello world"), "hello%20world");
352        assert_eq!(percent_encode("a&b=c"), "a%26b%3Dc");
353        assert_eq!(percent_encode("~.-_"), "~.-_");
354    }
355
356    #[test]
357    fn test_signature_base_string_format() {
358        let client = OAuth1Client::new(test_consumer(), SignatureMethod::HmacSha1).unwrap();
359        let signed = client
360            .sign_request("GET", "https://api.example.com/1/resource", None, None)
361            .unwrap();
362        assert!(signed.signature_base_string.starts_with("GET&"));
363        assert!(signed
364            .signature_base_string
365            .contains("https%3A%2F%2Fapi.example.com%2F1%2Fresource"));
366    }
367
368    #[test]
369    fn test_build_authorize_url() {
370        let client = OAuth1Client::new(test_consumer(), SignatureMethod::HmacSha1).unwrap();
371        let url = client.build_authorize_url(
372            "https://api.example.com/authorize",
373            "hh5s93j4hdidpola",
374        );
375        assert_eq!(
376            url,
377            "https://api.example.com/authorize?oauth_token=hh5s93j4hdidpola"
378        );
379    }
380
381    #[test]
382    fn test_parse_request_token_response() {
383        let body = "oauth_token=hh5s93j4hdidpola&oauth_token_secret=hdhd0244k9j7ao03&oauth_callback_confirmed=true";
384        let resp = OAuth1Client::parse_request_token_response(body).unwrap();
385        assert_eq!(resp.oauth_token, "hh5s93j4hdidpola");
386        assert_eq!(resp.oauth_token_secret, "hdhd0244k9j7ao03");
387        assert!(resp.oauth_callback_confirmed);
388    }
389
390    #[test]
391    fn test_parse_request_token_missing_field() {
392        let body = "oauth_token=xyz";
393        assert!(OAuth1Client::parse_request_token_response(body).is_err());
394    }
395
396    #[test]
397    fn test_parse_access_token_response() {
398        let body = "oauth_token=nnch734d00sl2jdk&oauth_token_secret=pfkkdhi9sl3r4s00";
399        let resp = OAuth1Client::parse_access_token_response(body).unwrap();
400        assert_eq!(resp.oauth_token, "nnch734d00sl2jdk");
401        assert_eq!(resp.oauth_token_secret, "pfkkdhi9sl3r4s00");
402    }
403
404    #[test]
405    fn test_different_consumers_different_signatures() {
406        let c1 = OAuth1Client::new(
407            OAuthConsumer {
408                key: "key1".to_string(),
409                secret: "secret1".to_string(),
410            },
411            SignatureMethod::HmacSha1,
412        )
413        .unwrap();
414        let c2 = OAuth1Client::new(
415            OAuthConsumer {
416                key: "key2".to_string(),
417                secret: "secret2".to_string(),
418            },
419            SignatureMethod::HmacSha1,
420        )
421        .unwrap();
422
423        let s1 = c1
424            .sign_request("GET", "https://example.com", None, None)
425            .unwrap();
426        let s2 = c2
427            .sign_request("GET", "https://example.com", None, None)
428            .unwrap();
429        assert_ne!(s1.signature, s2.signature);
430    }
431
432    #[test]
433    fn test_signature_method_as_str() {
434        assert_eq!(SignatureMethod::HmacSha1.as_str(), "HMAC-SHA1");
435        assert_eq!(SignatureMethod::HmacSha256.as_str(), "HMAC-SHA256");
436        assert_eq!(SignatureMethod::Plaintext.as_str(), "PLAINTEXT");
437    }
438}