1use hmac::{Hmac, Mac};
7use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
8use sha1::Sha1;
9use std::collections::BTreeMap;
10use std::time::{SystemTime, UNIX_EPOCH};
11use url::Url;
12
13const ENCODE_SET: &AsciiSet = &CONTROLS
15 .add(b' ')
16 .add(b'"')
17 .add(b'#')
18 .add(b'$')
19 .add(b'%')
20 .add(b'&')
21 .add(b'+')
22 .add(b',')
23 .add(b'/')
24 .add(b':')
25 .add(b';')
26 .add(b'<')
27 .add(b'=')
28 .add(b'>')
29 .add(b'?')
30 .add(b'@')
31 .add(b'[')
32 .add(b'\\')
33 .add(b']')
34 .add(b'^')
35 .add(b'`')
36 .add(b'{')
37 .add(b'|')
38 .add(b'}');
39
40#[derive(Debug, Clone)]
42pub struct OAuthConsumer {
43 pub key: String,
44 pub secret: String,
45}
46
47#[derive(Debug, Clone, Default)]
49pub struct OAuthToken {
50 pub token: String,
51 pub secret: String,
52}
53
54pub struct OAuth1Signer {
56 consumer: OAuthConsumer,
57 token: Option<OAuthToken>,
58}
59
60impl OAuth1Signer {
61 pub fn new(consumer: OAuthConsumer) -> Self {
63 Self {
64 consumer,
65 token: None,
66 }
67 }
68
69 pub fn with_token(mut self, token: OAuthToken) -> Self {
71 self.token = Some(token);
72 self
73 }
74
75 pub fn sign(
80 &self,
81 method: &str,
82 url: &str,
83 extra_params: &[(String, String)],
84 ) -> String {
85 let timestamp = SystemTime::now()
86 .duration_since(UNIX_EPOCH)
87 .unwrap()
88 .as_secs()
89 .to_string();
90
91 let nonce = generate_nonce();
92
93 self.sign_with_timestamp_nonce(method, url, extra_params, ×tamp, &nonce)
94 }
95
96 pub fn sign_with_timestamp_nonce(
98 &self,
99 method: &str,
100 url: &str,
101 extra_params: &[(String, String)],
102 timestamp: &str,
103 nonce: &str,
104 ) -> String {
105 let parsed_url = Url::parse(url).expect("Invalid URL");
107 let base_url = format!(
108 "{}://{}{}",
109 parsed_url.scheme(),
110 parsed_url.host_str().unwrap_or(""),
111 parsed_url.path()
112 );
113
114 let url_params: Vec<(String, String)> = parsed_url
116 .query_pairs()
117 .map(|(k, v)| (k.to_string(), v.to_string()))
118 .collect();
119
120 let mut oauth_params: BTreeMap<String, String> = BTreeMap::new();
121 oauth_params.insert("oauth_consumer_key".to_string(), self.consumer.key.clone());
122 oauth_params.insert("oauth_nonce".to_string(), nonce.to_string());
123 oauth_params.insert("oauth_signature_method".to_string(), "HMAC-SHA1".to_string());
124 oauth_params.insert("oauth_timestamp".to_string(), timestamp.to_string());
125 oauth_params.insert("oauth_version".to_string(), "1.0".to_string());
126
127 if let Some(ref token) = self.token {
128 oauth_params.insert("oauth_token".to_string(), token.token.clone());
129 }
130
131 let signature = self.calculate_signature(method, &base_url, &url_params, extra_params, &oauth_params);
134 oauth_params.insert("oauth_signature".to_string(), signature);
135
136 let auth_params: Vec<String> = oauth_params
138 .iter()
139 .map(|(k, v)| format!("{}=\"{}\"", k, percent_encode(v)))
140 .collect();
141
142 format!("OAuth {}", auth_params.join(", "))
143 }
144
145 fn calculate_signature(
146 &self,
147 method: &str,
148 base_url: &str,
149 url_params: &[(String, String)],
150 extra_params: &[(String, String)],
151 oauth_params: &BTreeMap<String, String>,
152 ) -> String {
153 let mut all_params: BTreeMap<String, String> = oauth_params.clone();
155 for (k, v) in url_params {
156 all_params.insert(k.clone(), v.clone());
157 }
158 for (k, v) in extra_params {
159 all_params.insert(k.clone(), v.clone());
160 }
161
162 let param_string: String = all_params
164 .iter()
165 .map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(v)))
166 .collect::<Vec<_>>()
167 .join("&");
168
169 let base_string = format!(
171 "{}&{}&{}",
172 method.to_uppercase(),
173 percent_encode(base_url),
174 percent_encode(¶m_string)
175 );
176
177 let token_secret = self
179 .token
180 .as_ref()
181 .map(|t| t.secret.as_str())
182 .unwrap_or("");
183 let signing_key = format!("{}&{}", percent_encode(&self.consumer.secret), percent_encode(token_secret));
184
185 let mut mac = Hmac::<Sha1>::new_from_slice(signing_key.as_bytes())
187 .expect("HMAC can take key of any size");
188 mac.update(base_string.as_bytes());
189 let result = mac.finalize();
190
191 base64::Engine::encode(&base64::engine::general_purpose::STANDARD, result.into_bytes())
193 }
194}
195
196fn percent_encode(s: &str) -> String {
198 utf8_percent_encode(s, ENCODE_SET).to_string()
199}
200
201fn generate_nonce() -> String {
203 let mut rng = rand::thread_rng();
204 let bytes: [u8; 16] = rand::Rng::gen(&mut rng);
205 bytes.iter().map(|b| format!("{:02x}", b)).collect()
206}
207
208pub fn parse_oauth_response(response: &str) -> BTreeMap<String, String> {
210 response
211 .split('&')
212 .filter_map(|pair| {
213 let mut parts = pair.splitn(2, '=');
214 match (parts.next(), parts.next()) {
215 (Some(key), Some(value)) => Some((
216 urlencoding::decode(key).unwrap_or_default().into_owned(),
217 urlencoding::decode(value).unwrap_or_default().into_owned(),
218 )),
219 _ => None,
220 }
221 })
222 .collect()
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_oauth_consumer_creation() {
231 let consumer = OAuthConsumer {
232 key: "test_key".to_string(),
233 secret: "test_secret".to_string(),
234 };
235 assert_eq!(consumer.key, "test_key");
236 assert_eq!(consumer.secret, "test_secret");
237 }
238
239 #[test]
240 fn test_oauth1_signer_creation() {
241 let consumer = OAuthConsumer {
242 key: "test_key".to_string(),
243 secret: "test_secret".to_string(),
244 };
245 let signer = OAuth1Signer::new(consumer);
246 assert!(signer.token.is_none());
247 }
248
249 #[test]
250 fn test_oauth1_signer_with_token() {
251 let consumer = OAuthConsumer {
252 key: "test_key".to_string(),
253 secret: "test_secret".to_string(),
254 };
255 let token = OAuthToken {
256 token: "token123".to_string(),
257 secret: "tokensecret".to_string(),
258 };
259 let signer = OAuth1Signer::new(consumer).with_token(token);
260 assert!(signer.token.is_some());
261 assert_eq!(signer.token.as_ref().unwrap().token, "token123");
262 }
263
264 #[test]
265 fn test_sign_generates_authorization_header() {
266 let consumer = OAuthConsumer {
267 key: "test_consumer_key".to_string(),
268 secret: "test_consumer_secret".to_string(),
269 };
270 let signer = OAuth1Signer::new(consumer);
271
272 let auth_header = signer.sign_with_timestamp_nonce(
273 "GET",
274 "https://example.com/api/test",
275 &[],
276 "1234567890",
277 "abc123nonce",
278 );
279
280 assert!(auth_header.starts_with("OAuth "));
281 assert!(auth_header.contains("oauth_consumer_key=\"test_consumer_key\""));
282 assert!(auth_header.contains("oauth_signature_method=\"HMAC-SHA1\""));
283 assert!(auth_header.contains("oauth_timestamp=\"1234567890\""));
284 assert!(auth_header.contains("oauth_nonce=\"abc123nonce\""));
285 assert!(auth_header.contains("oauth_version=\"1.0\""));
286 assert!(auth_header.contains("oauth_signature="));
287 }
288
289 #[test]
290 fn test_sign_with_token_includes_oauth_token() {
291 let consumer = OAuthConsumer {
292 key: "consumer_key".to_string(),
293 secret: "consumer_secret".to_string(),
294 };
295 let token = OAuthToken {
296 token: "user_token".to_string(),
297 secret: "user_secret".to_string(),
298 };
299 let signer = OAuth1Signer::new(consumer).with_token(token);
300
301 let auth_header = signer.sign_with_timestamp_nonce(
302 "GET",
303 "https://example.com/api",
304 &[],
305 "1234567890",
306 "nonce123",
307 );
308
309 assert!(auth_header.contains("oauth_token=\"user_token\""));
310 }
311
312 #[test]
313 fn test_percent_encode() {
314 assert_eq!(percent_encode("hello world"), "hello%20world");
315 assert_eq!(percent_encode("foo=bar&baz"), "foo%3Dbar%26baz");
316 assert_eq!(percent_encode("simple"), "simple");
317 }
318
319 #[test]
320 fn test_parse_oauth_response() {
321 let response = "oauth_token=abc123&oauth_token_secret=xyz789&mfa_token=mfa456";
322 let parsed = parse_oauth_response(response);
323
324 assert_eq!(parsed.get("oauth_token"), Some(&"abc123".to_string()));
325 assert_eq!(parsed.get("oauth_token_secret"), Some(&"xyz789".to_string()));
326 assert_eq!(parsed.get("mfa_token"), Some(&"mfa456".to_string()));
327 }
328
329 #[test]
330 fn test_parse_oauth_response_with_encoded_values() {
331 let response = "key=value%20with%20spaces&other=normal";
332 let parsed = parse_oauth_response(response);
333
334 assert_eq!(parsed.get("key"), Some(&"value with spaces".to_string()));
335 assert_eq!(parsed.get("other"), Some(&"normal".to_string()));
336 }
337
338 #[test]
339 fn test_nonce_generation() {
340 let nonce1 = generate_nonce();
341 let nonce2 = generate_nonce();
342
343 assert_eq!(nonce1.len(), 32); assert_ne!(nonce1, nonce2); }
346
347 #[test]
348 fn test_sign_with_url_query_params() {
349 let consumer = OAuthConsumer {
351 key: "dpf43f3p2l4k3l03".to_string(),
352 secret: "kd94hf93k423kf44".to_string(),
353 };
354 let token = OAuthToken {
355 token: "nnch734d00sl2jdk".to_string(),
356 secret: "pfkkdhi9sl3r4s00".to_string(),
357 };
358 let signer = OAuth1Signer::new(consumer).with_token(token);
359
360 let url = "http://photos.example.net/photos?file=vacation.jpg&size=original";
362
363 let auth_header = signer.sign_with_timestamp_nonce(
364 "GET",
365 url,
366 &[],
367 "1191242096",
368 "kllo9940pd9333jh",
369 );
370
371 assert!(auth_header.contains("oauth_signature="));
373 assert!(auth_header.starts_with("OAuth "));
374 assert!(
376 auth_header.contains("oauth_signature=\"tR3%2BTy81lMeYAr%2FFid0kMTYa%2FWM%3D\""),
377 "Expected signature tR3+Ty81lMeYAr/Fid0kMTYa/WM= (url-encoded), got: {}",
378 auth_header
379 );
380 }
381
382 #[test]
383 fn test_sign_url_with_query_params_extracts_base_url() {
384 let consumer = OAuthConsumer {
385 key: "consumer".to_string(),
386 secret: "secret".to_string(),
387 };
388 let signer = OAuth1Signer::new(consumer);
389
390 let url = "https://api.example.com/path?ticket=ST-123&login-url=https://sso.example.com/embed";
392
393 let auth_header = signer.sign_with_timestamp_nonce(
394 "GET",
395 url,
396 &[],
397 "1234567890",
398 "testnonce",
399 );
400
401 assert!(auth_header.starts_with("OAuth "));
403 assert!(auth_header.contains("oauth_consumer_key=\"consumer\""));
404 }
405}