1use hmac::{Hmac, Mac};
7use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
8use sha1::Sha1;
9use std::collections::BTreeMap;
10use std::time::{SystemTime, UNIX_EPOCH};
11use url::Url;
12
13const ENCODE_SET: &AsciiSet = &CONTROLS
15 .add(b' ')
16 .add(b'"')
17 .add(b'#')
18 .add(b'$')
19 .add(b'%')
20 .add(b'&')
21 .add(b'+')
22 .add(b',')
23 .add(b'/')
24 .add(b':')
25 .add(b';')
26 .add(b'<')
27 .add(b'=')
28 .add(b'>')
29 .add(b'?')
30 .add(b'@')
31 .add(b'[')
32 .add(b'\\')
33 .add(b']')
34 .add(b'^')
35 .add(b'`')
36 .add(b'{')
37 .add(b'|')
38 .add(b'}');
39
40#[derive(Debug, Clone)]
42pub struct OAuthConsumer {
43 pub key: String,
44 pub secret: String,
45}
46
47#[derive(Debug, Clone, Default)]
49pub struct OAuthToken {
50 pub token: String,
51 pub secret: String,
52}
53
54pub struct OAuth1Signer {
56 consumer: OAuthConsumer,
57 token: Option<OAuthToken>,
58}
59
60impl OAuth1Signer {
61 pub fn new(consumer: OAuthConsumer) -> Self {
63 Self {
64 consumer,
65 token: None,
66 }
67 }
68
69 pub fn with_token(mut self, token: OAuthToken) -> Self {
71 self.token = Some(token);
72 self
73 }
74
75 pub fn sign(&self, method: &str, url: &str, extra_params: &[(String, String)]) -> String {
80 let timestamp = SystemTime::now()
81 .duration_since(UNIX_EPOCH)
82 .unwrap()
83 .as_secs()
84 .to_string();
85
86 let nonce = generate_nonce();
87
88 self.sign_with_timestamp_nonce(method, url, extra_params, ×tamp, &nonce)
89 }
90
91 pub fn sign_with_timestamp_nonce(
93 &self,
94 method: &str,
95 url: &str,
96 extra_params: &[(String, String)],
97 timestamp: &str,
98 nonce: &str,
99 ) -> String {
100 let parsed_url = Url::parse(url).expect("Invalid URL");
102 let base_url = format!(
103 "{}://{}{}",
104 parsed_url.scheme(),
105 parsed_url.host_str().unwrap_or(""),
106 parsed_url.path()
107 );
108
109 let url_params: Vec<(String, String)> = parsed_url
111 .query_pairs()
112 .map(|(k, v)| (k.to_string(), v.to_string()))
113 .collect();
114
115 let mut oauth_params: BTreeMap<String, String> = BTreeMap::new();
116 oauth_params.insert("oauth_consumer_key".to_string(), self.consumer.key.clone());
117 oauth_params.insert("oauth_nonce".to_string(), nonce.to_string());
118 oauth_params.insert(
119 "oauth_signature_method".to_string(),
120 "HMAC-SHA1".to_string(),
121 );
122 oauth_params.insert("oauth_timestamp".to_string(), timestamp.to_string());
123 oauth_params.insert("oauth_version".to_string(), "1.0".to_string());
124
125 if let Some(ref token) = self.token {
126 oauth_params.insert("oauth_token".to_string(), token.token.clone());
127 }
128
129 let signature =
132 self.calculate_signature(method, &base_url, &url_params, extra_params, &oauth_params);
133 oauth_params.insert("oauth_signature".to_string(), signature);
134
135 let auth_params: Vec<String> = oauth_params
137 .iter()
138 .map(|(k, v)| format!("{}=\"{}\"", k, percent_encode(v)))
139 .collect();
140
141 format!("OAuth {}", auth_params.join(", "))
142 }
143
144 fn calculate_signature(
145 &self,
146 method: &str,
147 base_url: &str,
148 url_params: &[(String, String)],
149 extra_params: &[(String, String)],
150 oauth_params: &BTreeMap<String, String>,
151 ) -> String {
152 let mut all_params: BTreeMap<String, String> = oauth_params.clone();
154 for (k, v) in url_params {
155 all_params.insert(k.clone(), v.clone());
156 }
157 for (k, v) in extra_params {
158 all_params.insert(k.clone(), v.clone());
159 }
160
161 let param_string: String = all_params
163 .iter()
164 .map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(v)))
165 .collect::<Vec<_>>()
166 .join("&");
167
168 let base_string = format!(
170 "{}&{}&{}",
171 method.to_uppercase(),
172 percent_encode(base_url),
173 percent_encode(¶m_string)
174 );
175
176 let token_secret = self.token.as_ref().map(|t| t.secret.as_str()).unwrap_or("");
178 let signing_key = format!(
179 "{}&{}",
180 percent_encode(&self.consumer.secret),
181 percent_encode(token_secret)
182 );
183
184 let mut mac = Hmac::<Sha1>::new_from_slice(signing_key.as_bytes())
186 .expect("HMAC can take key of any size");
187 mac.update(base_string.as_bytes());
188 let result = mac.finalize();
189
190 base64::Engine::encode(
192 &base64::engine::general_purpose::STANDARD,
193 result.into_bytes(),
194 )
195 }
196}
197
198fn percent_encode(s: &str) -> String {
200 utf8_percent_encode(s, ENCODE_SET).to_string()
201}
202
203fn generate_nonce() -> String {
205 let mut rng = rand::thread_rng();
206 let bytes: [u8; 16] = rand::Rng::gen(&mut rng);
207 bytes.iter().map(|b| format!("{:02x}", b)).collect()
208}
209
210pub fn parse_oauth_response(response: &str) -> BTreeMap<String, String> {
212 response
213 .split('&')
214 .filter_map(|pair| {
215 let mut parts = pair.splitn(2, '=');
216 match (parts.next(), parts.next()) {
217 (Some(key), Some(value)) => Some((
218 urlencoding::decode(key).unwrap_or_default().into_owned(),
219 urlencoding::decode(value).unwrap_or_default().into_owned(),
220 )),
221 _ => None,
222 }
223 })
224 .collect()
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_oauth_consumer_creation() {
233 let consumer = OAuthConsumer {
234 key: "test_key".to_string(),
235 secret: "test_secret".to_string(),
236 };
237 assert_eq!(consumer.key, "test_key");
238 assert_eq!(consumer.secret, "test_secret");
239 }
240
241 #[test]
242 fn test_oauth1_signer_creation() {
243 let consumer = OAuthConsumer {
244 key: "test_key".to_string(),
245 secret: "test_secret".to_string(),
246 };
247 let signer = OAuth1Signer::new(consumer);
248 assert!(signer.token.is_none());
249 }
250
251 #[test]
252 fn test_oauth1_signer_with_token() {
253 let consumer = OAuthConsumer {
254 key: "test_key".to_string(),
255 secret: "test_secret".to_string(),
256 };
257 let token = OAuthToken {
258 token: "token123".to_string(),
259 secret: "tokensecret".to_string(),
260 };
261 let signer = OAuth1Signer::new(consumer).with_token(token);
262 assert!(signer.token.is_some());
263 assert_eq!(signer.token.as_ref().unwrap().token, "token123");
264 }
265
266 #[test]
267 fn test_sign_generates_authorization_header() {
268 let consumer = OAuthConsumer {
269 key: "test_consumer_key".to_string(),
270 secret: "test_consumer_secret".to_string(),
271 };
272 let signer = OAuth1Signer::new(consumer);
273
274 let auth_header = signer.sign_with_timestamp_nonce(
275 "GET",
276 "https://example.com/api/test",
277 &[],
278 "1234567890",
279 "abc123nonce",
280 );
281
282 assert!(auth_header.starts_with("OAuth "));
283 assert!(auth_header.contains("oauth_consumer_key=\"test_consumer_key\""));
284 assert!(auth_header.contains("oauth_signature_method=\"HMAC-SHA1\""));
285 assert!(auth_header.contains("oauth_timestamp=\"1234567890\""));
286 assert!(auth_header.contains("oauth_nonce=\"abc123nonce\""));
287 assert!(auth_header.contains("oauth_version=\"1.0\""));
288 assert!(auth_header.contains("oauth_signature="));
289 }
290
291 #[test]
292 fn test_sign_with_token_includes_oauth_token() {
293 let consumer = OAuthConsumer {
294 key: "consumer_key".to_string(),
295 secret: "consumer_secret".to_string(),
296 };
297 let token = OAuthToken {
298 token: "user_token".to_string(),
299 secret: "user_secret".to_string(),
300 };
301 let signer = OAuth1Signer::new(consumer).with_token(token);
302
303 let auth_header = signer.sign_with_timestamp_nonce(
304 "GET",
305 "https://example.com/api",
306 &[],
307 "1234567890",
308 "nonce123",
309 );
310
311 assert!(auth_header.contains("oauth_token=\"user_token\""));
312 }
313
314 #[test]
315 fn test_percent_encode() {
316 assert_eq!(percent_encode("hello world"), "hello%20world");
317 assert_eq!(percent_encode("foo=bar&baz"), "foo%3Dbar%26baz");
318 assert_eq!(percent_encode("simple"), "simple");
319 }
320
321 #[test]
322 fn test_parse_oauth_response() {
323 let response = "oauth_token=abc123&oauth_token_secret=xyz789&mfa_token=mfa456";
324 let parsed = parse_oauth_response(response);
325
326 assert_eq!(parsed.get("oauth_token"), Some(&"abc123".to_string()));
327 assert_eq!(
328 parsed.get("oauth_token_secret"),
329 Some(&"xyz789".to_string())
330 );
331 assert_eq!(parsed.get("mfa_token"), Some(&"mfa456".to_string()));
332 }
333
334 #[test]
335 fn test_parse_oauth_response_with_encoded_values() {
336 let response = "key=value%20with%20spaces&other=normal";
337 let parsed = parse_oauth_response(response);
338
339 assert_eq!(parsed.get("key"), Some(&"value with spaces".to_string()));
340 assert_eq!(parsed.get("other"), Some(&"normal".to_string()));
341 }
342
343 #[test]
344 fn test_nonce_generation() {
345 let nonce1 = generate_nonce();
346 let nonce2 = generate_nonce();
347
348 assert_eq!(nonce1.len(), 32); assert_ne!(nonce1, nonce2); }
351
352 #[test]
353 fn test_sign_with_url_query_params() {
354 let consumer = OAuthConsumer {
356 key: "dpf43f3p2l4k3l03".to_string(),
357 secret: "kd94hf93k423kf44".to_string(),
358 };
359 let token = OAuthToken {
360 token: "nnch734d00sl2jdk".to_string(),
361 secret: "pfkkdhi9sl3r4s00".to_string(),
362 };
363 let signer = OAuth1Signer::new(consumer).with_token(token);
364
365 let url = "http://photos.example.net/photos?file=vacation.jpg&size=original";
367
368 let auth_header =
369 signer.sign_with_timestamp_nonce("GET", url, &[], "1191242096", "kllo9940pd9333jh");
370
371 assert!(auth_header.contains("oauth_signature="));
373 assert!(auth_header.starts_with("OAuth "));
374 assert!(
376 auth_header.contains("oauth_signature=\"tR3%2BTy81lMeYAr%2FFid0kMTYa%2FWM%3D\""),
377 "Expected signature tR3+Ty81lMeYAr/Fid0kMTYa/WM= (url-encoded), got: {}",
378 auth_header
379 );
380 }
381
382 #[test]
383 fn test_sign_url_with_query_params_extracts_base_url() {
384 let consumer = OAuthConsumer {
385 key: "consumer".to_string(),
386 secret: "secret".to_string(),
387 };
388 let signer = OAuth1Signer::new(consumer);
389
390 let url =
392 "https://api.example.com/path?ticket=ST-123&login-url=https://sso.example.com/embed";
393
394 let auth_header =
395 signer.sign_with_timestamp_nonce("GET", url, &[], "1234567890", "testnonce");
396
397 assert!(auth_header.starts_with("OAuth "));
399 assert!(auth_header.contains("oauth_consumer_key=\"consumer\""));
400 }
401}