auth_framework/protocols/
oauth1.rs1use crate::errors::{AuthError, Result};
7use base64::Engine;
8use ring::hmac;
9use serde::{Deserialize, Serialize};
10use std::collections::BTreeMap;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct OAuthConsumer {
16 pub key: String,
17 pub secret: String,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct OAuthToken {
23 pub token: String,
24 pub secret: String,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
29pub enum SignatureMethod {
30 HmacSha1,
31 HmacSha256,
32 Plaintext,
33}
34
35impl SignatureMethod {
36 pub fn as_str(&self) -> &'static str {
37 match self {
38 Self::HmacSha1 => "HMAC-SHA1",
39 Self::HmacSha256 => "HMAC-SHA256",
40 Self::Plaintext => "PLAINTEXT",
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct OAuthSignedRequest {
48 pub authorization_header: String,
50 pub signature_base_string: String,
52 pub signature: String,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct RequestTokenResponse {
59 pub oauth_token: String,
60 pub oauth_token_secret: String,
61 pub oauth_callback_confirmed: bool,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct AccessTokenResponse {
67 pub oauth_token: String,
68 pub oauth_token_secret: String,
69}
70
71pub struct OAuth1Client {
73 consumer: OAuthConsumer,
74 signature_method: SignatureMethod,
75}
76
77impl OAuth1Client {
78 pub fn new(consumer: OAuthConsumer, signature_method: SignatureMethod) -> Result<Self> {
80 if consumer.key.is_empty() || consumer.secret.is_empty() {
81 return Err(AuthError::validation(
82 "Consumer key and secret must not be empty",
83 ));
84 }
85 Ok(Self {
86 consumer,
87 signature_method,
88 })
89 }
90
91 pub fn sign_request(
95 &self,
96 method: &str,
97 url: &str,
98 token: Option<&OAuthToken>,
99 extra_params: Option<&BTreeMap<String, String>>,
100 ) -> Result<OAuthSignedRequest> {
101 let nonce = generate_nonce()?;
102 let timestamp = SystemTime::now()
103 .duration_since(UNIX_EPOCH)
104 .unwrap_or_default()
105 .as_secs()
106 .to_string();
107
108 let mut params = BTreeMap::new();
110 params.insert("oauth_consumer_key".to_string(), self.consumer.key.clone());
111 params.insert("oauth_nonce".to_string(), nonce);
112 params.insert(
113 "oauth_signature_method".to_string(),
114 self.signature_method.as_str().to_string(),
115 );
116 params.insert("oauth_timestamp".to_string(), timestamp);
117 params.insert("oauth_version".to_string(), "1.0".to_string());
118
119 if let Some(t) = token {
120 params.insert("oauth_token".to_string(), t.token.clone());
121 }
122
123 if let Some(extra) = extra_params {
124 for (k, v) in extra {
125 params.insert(k.clone(), v.clone());
126 }
127 }
128
129 let param_string: String = params
131 .iter()
132 .map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(v)))
133 .collect::<Vec<_>>()
134 .join("&");
135
136 let base_string = format!(
137 "{}&{}&{}",
138 method.to_uppercase(),
139 percent_encode(url),
140 percent_encode(¶m_string)
141 );
142
143 let token_secret = token.map(|t| t.secret.as_str()).unwrap_or("");
145 let signing_key = format!(
146 "{}&{}",
147 percent_encode(&self.consumer.secret),
148 percent_encode(token_secret)
149 );
150
151 let signature = match self.signature_method {
152 SignatureMethod::HmacSha1 => {
153 let key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, signing_key.as_bytes());
154 let tag = hmac::sign(&key, base_string.as_bytes());
155 base64::engine::general_purpose::STANDARD.encode(tag.as_ref())
156 }
157 SignatureMethod::HmacSha256 => {
158 let key = hmac::Key::new(hmac::HMAC_SHA256, signing_key.as_bytes());
159 let tag = hmac::sign(&key, base_string.as_bytes());
160 base64::engine::general_purpose::STANDARD.encode(tag.as_ref())
161 }
162 SignatureMethod::Plaintext => signing_key.clone(),
163 };
164
165 params.insert("oauth_signature".to_string(), signature.clone());
167
168 let auth_header = format!(
169 "OAuth {}",
170 params
171 .iter()
172 .filter(|(k, _)| k.starts_with("oauth_"))
173 .map(|(k, v)| format!("{}=\"{}\"", percent_encode(k), percent_encode(v)))
174 .collect::<Vec<_>>()
175 .join(", ")
176 );
177
178 Ok(OAuthSignedRequest {
179 authorization_header: auth_header,
180 signature_base_string: base_string,
181 signature,
182 })
183 }
184
185 pub fn build_authorize_url(&self, base_url: &str, request_token: &str) -> String {
187 format!(
188 "{}?oauth_token={}",
189 base_url,
190 percent_encode(request_token)
191 )
192 }
193
194 pub fn parse_request_token_response(body: &str) -> Result<RequestTokenResponse> {
196 let params = parse_form_body(body);
197 let token = params
198 .get("oauth_token")
199 .ok_or_else(|| AuthError::validation("Missing oauth_token"))?
200 .clone();
201 let secret = params
202 .get("oauth_token_secret")
203 .ok_or_else(|| AuthError::validation("Missing oauth_token_secret"))?
204 .clone();
205 let confirmed = params
206 .get("oauth_callback_confirmed")
207 .map(|v| v == "true")
208 .unwrap_or(false);
209
210 Ok(RequestTokenResponse {
211 oauth_token: token,
212 oauth_token_secret: secret,
213 oauth_callback_confirmed: confirmed,
214 })
215 }
216
217 pub fn parse_access_token_response(body: &str) -> Result<AccessTokenResponse> {
219 let params = parse_form_body(body);
220 let token = params
221 .get("oauth_token")
222 .ok_or_else(|| AuthError::validation("Missing oauth_token"))?
223 .clone();
224 let secret = params
225 .get("oauth_token_secret")
226 .ok_or_else(|| AuthError::validation("Missing oauth_token_secret"))?
227 .clone();
228
229 Ok(AccessTokenResponse {
230 oauth_token: token,
231 oauth_token_secret: secret,
232 })
233 }
234}
235
236fn percent_encode(s: &str) -> String {
238 let mut encoded = String::with_capacity(s.len());
239 for byte in s.bytes() {
240 match byte {
241 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
242 encoded.push(byte as char);
243 }
244 _ => {
245 encoded.push_str(&format!("%{:02X}", byte));
246 }
247 }
248 }
249 encoded
250}
251
252fn parse_form_body(body: &str) -> BTreeMap<String, String> {
254 body.split('&')
255 .filter_map(|pair| {
256 let mut parts = pair.splitn(2, '=');
257 let key = parts.next()?;
258 let value = parts.next().unwrap_or("");
259 Some((key.to_string(), value.to_string()))
260 })
261 .collect()
262}
263
264fn generate_nonce() -> Result<String> {
266 use ring::rand::{SecureRandom, SystemRandom};
267 let rng = SystemRandom::new();
268 let mut buf = [0u8; 16];
269 rng.fill(&mut buf)
270 .map_err(|_| AuthError::crypto("Failed to generate nonce".to_string()))?;
271 Ok(hex::encode(buf))
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 fn test_consumer() -> OAuthConsumer {
279 OAuthConsumer {
280 key: "dpf43f3p2l4k3l03".to_string(),
281 secret: "kd94hf93k423kf44".to_string(),
282 }
283 }
284
285 #[test]
286 fn test_create_client() {
287 let client = OAuth1Client::new(test_consumer(), SignatureMethod::HmacSha1).unwrap();
288 assert_eq!(client.consumer.key, "dpf43f3p2l4k3l03");
289 }
290
291 #[test]
292 fn test_empty_consumer_rejected() {
293 let consumer = OAuthConsumer {
294 key: String::new(),
295 secret: "secret".to_string(),
296 };
297 assert!(OAuth1Client::new(consumer, SignatureMethod::HmacSha1).is_err());
298 }
299
300 #[test]
301 fn test_sign_request_hmac_sha1() {
302 let client = OAuth1Client::new(test_consumer(), SignatureMethod::HmacSha1).unwrap();
303 let signed = client
304 .sign_request("GET", "https://api.example.com/resource", None, None)
305 .unwrap();
306
307 assert!(signed.authorization_header.starts_with("OAuth "));
308 assert!(signed.authorization_header.contains("oauth_consumer_key="));
309 assert!(signed.authorization_header.contains("oauth_signature="));
310 assert!(signed.authorization_header.contains("oauth_nonce="));
311 assert!(!signed.signature.is_empty());
312 }
313
314 #[test]
315 fn test_sign_request_with_token() {
316 let client = OAuth1Client::new(test_consumer(), SignatureMethod::HmacSha1).unwrap();
317 let token = OAuthToken {
318 token: "nnch734d00sl2jdk".to_string(),
319 secret: "pfkkdhi9sl3r4s00".to_string(),
320 };
321 let signed = client
322 .sign_request("POST", "https://api.example.com/post", Some(&token), None)
323 .unwrap();
324
325 assert!(signed.authorization_header.contains("oauth_token="));
326 assert!(!signed.signature.is_empty());
327 }
328
329 #[test]
330 fn test_sign_request_hmac_sha256() {
331 let client = OAuth1Client::new(test_consumer(), SignatureMethod::HmacSha256).unwrap();
332 let signed = client
333 .sign_request("GET", "https://api.example.com/resource", None, None)
334 .unwrap();
335 assert!(signed.authorization_header.contains("HMAC-SHA256"));
336 }
337
338 #[test]
339 fn test_sign_request_plaintext() {
340 let client = OAuth1Client::new(test_consumer(), SignatureMethod::Plaintext).unwrap();
341 let signed = client
342 .sign_request("GET", "https://api.example.com/resource", None, None)
343 .unwrap();
344 assert!(signed.signature.contains("kd94hf93k423kf44"));
346 }
347
348 #[test]
349 fn test_percent_encode() {
350 assert_eq!(percent_encode("hello"), "hello");
351 assert_eq!(percent_encode("hello world"), "hello%20world");
352 assert_eq!(percent_encode("a&b=c"), "a%26b%3Dc");
353 assert_eq!(percent_encode("~.-_"), "~.-_");
354 }
355
356 #[test]
357 fn test_signature_base_string_format() {
358 let client = OAuth1Client::new(test_consumer(), SignatureMethod::HmacSha1).unwrap();
359 let signed = client
360 .sign_request("GET", "https://api.example.com/1/resource", None, None)
361 .unwrap();
362 assert!(signed.signature_base_string.starts_with("GET&"));
363 assert!(signed
364 .signature_base_string
365 .contains("https%3A%2F%2Fapi.example.com%2F1%2Fresource"));
366 }
367
368 #[test]
369 fn test_build_authorize_url() {
370 let client = OAuth1Client::new(test_consumer(), SignatureMethod::HmacSha1).unwrap();
371 let url = client.build_authorize_url(
372 "https://api.example.com/authorize",
373 "hh5s93j4hdidpola",
374 );
375 assert_eq!(
376 url,
377 "https://api.example.com/authorize?oauth_token=hh5s93j4hdidpola"
378 );
379 }
380
381 #[test]
382 fn test_parse_request_token_response() {
383 let body = "oauth_token=hh5s93j4hdidpola&oauth_token_secret=hdhd0244k9j7ao03&oauth_callback_confirmed=true";
384 let resp = OAuth1Client::parse_request_token_response(body).unwrap();
385 assert_eq!(resp.oauth_token, "hh5s93j4hdidpola");
386 assert_eq!(resp.oauth_token_secret, "hdhd0244k9j7ao03");
387 assert!(resp.oauth_callback_confirmed);
388 }
389
390 #[test]
391 fn test_parse_request_token_missing_field() {
392 let body = "oauth_token=xyz";
393 assert!(OAuth1Client::parse_request_token_response(body).is_err());
394 }
395
396 #[test]
397 fn test_parse_access_token_response() {
398 let body = "oauth_token=nnch734d00sl2jdk&oauth_token_secret=pfkkdhi9sl3r4s00";
399 let resp = OAuth1Client::parse_access_token_response(body).unwrap();
400 assert_eq!(resp.oauth_token, "nnch734d00sl2jdk");
401 assert_eq!(resp.oauth_token_secret, "pfkkdhi9sl3r4s00");
402 }
403
404 #[test]
405 fn test_different_consumers_different_signatures() {
406 let c1 = OAuth1Client::new(
407 OAuthConsumer {
408 key: "key1".to_string(),
409 secret: "secret1".to_string(),
410 },
411 SignatureMethod::HmacSha1,
412 )
413 .unwrap();
414 let c2 = OAuth1Client::new(
415 OAuthConsumer {
416 key: "key2".to_string(),
417 secret: "secret2".to_string(),
418 },
419 SignatureMethod::HmacSha1,
420 )
421 .unwrap();
422
423 let s1 = c1
424 .sign_request("GET", "https://example.com", None, None)
425 .unwrap();
426 let s2 = c2
427 .sign_request("GET", "https://example.com", None, None)
428 .unwrap();
429 assert_ne!(s1.signature, s2.signature);
430 }
431
432 #[test]
433 fn test_signature_method_as_str() {
434 assert_eq!(SignatureMethod::HmacSha1.as_str(), "HMAC-SHA1");
435 assert_eq!(SignatureMethod::HmacSha256.as_str(), "HMAC-SHA256");
436 assert_eq!(SignatureMethod::Plaintext.as_str(), "PLAINTEXT");
437 }
438}