1#![allow(unstable_name_collisions)]
67
68pub mod flow;
69pub mod percent_encoding;
70pub mod signature;
71
72pub use error::OAuthError;
73pub use parameter::Parameter;
74pub use signature::*;
75
76mod error;
77mod parameter;
78
79use parameter::has_unique_parameters;
80use percent_encoding::*;
81
82use std::str::FromStr;
83use std::time::SystemTime;
84
85use http::{
86 method::Method as HttpMethod,
87 uri::{Authority, PathAndQuery, Scheme},
88 HeaderMap, HeaderValue, Uri,
89};
90use itertools::Itertools;
91use uuid::Uuid;
92
93#[derive(Debug, Clone, PartialEq, Eq)]
95pub struct OAuthRequest {
96 pub uri: Uri,
98 pub headers: HeaderMap,
100 pub method: HttpMethod,
102 pub body: Option<String>,
104}
105
106#[derive(Debug, Clone, PartialEq, Eq, Hash)]
108pub struct ConsumerCredentials {
109 pub key: String,
110 pub secret: String,
111}
112
113impl ConsumerCredentials {
114 pub fn new(key: &str, secret: &str) -> Self {
115 Self {
116 key: key.into(),
117 secret: secret.into(),
118 }
119 }
120}
121
122#[derive(Debug, Clone, PartialEq, Eq, Hash)]
124pub struct AccessToken {
125 pub token: String,
126 pub secret: String,
127}
128
129impl AccessToken {
130 pub fn new(token: &str, secret: &str) -> Self {
131 Self {
132 token: token.into(),
133 secret: secret.into(),
134 }
135 }
136}
137
138#[derive(Debug, Clone, PartialEq, Eq, Hash)]
140pub enum AuthenticationLevel {
141 Consumer(ConsumerCredentials),
143 Token(ConsumerCredentials, AccessToken),
145}
146
147#[derive(Debug, Clone, PartialEq, Eq, Hash)]
149pub enum AuthorizationScheme {
150 Header,
153 Body,
155 Uri,
157}
158
159impl Default for AuthorizationScheme {
160 fn default() -> Self {
161 Self::Header
162 }
163}
164
165impl OAuthRequest {
166 pub fn builder(
168 request_uri: Uri,
169 auth_level: AuthenticationLevel,
170 signature_method: Box<dyn SignatureMethod>,
171 ) -> OAuthRequestBuilder {
172 OAuthRequestBuilder::new(request_uri, auth_level, signature_method)
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct OAuthRequestBuilder {
179 request_uri: Uri,
180 auth_level: AuthenticationLevel,
181 signature_method: Box<dyn SignatureMethod>,
182 scheme: AuthorizationScheme,
183 parameters: Vec<Parameter>,
184 extra_auth: Vec<Parameter>,
185}
186
187impl OAuthRequestBuilder {
188 pub fn new(
192 request_uri: Uri,
193 auth_level: AuthenticationLevel,
194 signature_method: Box<dyn SignatureMethod>,
195 ) -> Self {
196 OAuthRequestBuilder {
197 request_uri,
198 auth_level,
199 signature_method,
200 scheme: AuthorizationScheme::default(),
201 parameters: vec![],
202 extra_auth: vec![],
203 }
204 }
205
206 pub fn add_parameters(mut self, parameters: &[Parameter]) -> Self {
208 self.parameters.extend_from_slice(parameters);
209 self
210 }
211
212 pub fn add_auth_parameters(mut self, parameters: &[Parameter]) -> Self {
216 self.extra_auth.extend_from_slice(parameters);
217 self
218 }
219
220 pub fn scheme(mut self, scheme: AuthorizationScheme) -> Self {
222 self.scheme = scheme;
223 self
224 }
225
226 pub fn build(self) -> Result<OAuthRequest, Box<dyn std::error::Error>> {
231 let mut auth_params = self.extra_auth;
232 let mut user_params = self.parameters;
233
234 let mut realm = None;
236 auth_params = auth_params
237 .into_iter()
238 .filter_map(|p| {
239 if p.name == "realm" {
240 realm = Some(p);
241 None
242 } else {
243 Some(p)
244 }
245 })
246 .collect();
247
248 auth_params.push(Parameter {
249 name: String::from("oauth_version"),
250 value: String::from("1.0"),
251 });
252 auth_params.push(Parameter {
253 name: String::from("oauth_timestamp"),
254 value: format!(
255 "{}",
256 SystemTime::now()
257 .duration_since(SystemTime::UNIX_EPOCH)
258 .unwrap()
259 .as_secs()
260 ),
261 });
262 auth_params.push(Parameter {
263 name: String::from("oauth_nonce"),
264 value: Uuid::new_v4().hyphenated().to_string(),
265 });
266 auth_params.push(self.signature_method.as_parameter());
267 {
268 let (AuthenticationLevel::Consumer(c) | AuthenticationLevel::Token(c, _)) =
269 &self.auth_level;
270 auth_params.push(Parameter {
271 name: String::from("oauth_consumer_key"),
272 value: c.key.clone(),
273 });
274 }
275 if let AuthenticationLevel::Token(_, t) = &self.auth_level {
276 auth_params.push(Parameter {
277 name: String::from("oauth_token"),
278 value: t.token.clone(),
279 });
280 }
281
282 let mut all_params = auth_params.clone();
283 all_params.extend_from_slice(&user_params);
284 if !has_unique_parameters(all_params.iter()) {
285 return Err(Box::new(OAuthError::DuplicateParameters));
286 };
287 all_params.sort();
288
289 let base_method = match &self.scheme {
290 AuthorizationScheme::Header | AuthorizationScheme::Body => "POST",
291 AuthorizationScheme::Uri => "GET",
292 };
293
294 let base_uri = normalize_uri(self.request_uri.clone());
295
296 let base_params = all_params
297 .iter()
298 .map(|p| p.encoded())
299 .intersperse(String::from('&'))
300 .collect::<String>();
301
302 let base_string = format!(
303 "{}&{}&{}",
304 base_method,
305 encode_string(&base_uri.to_string()),
306 encode_string(&base_params)
307 );
308
309 let signature = self.signature_method.sign(&base_string, &self.auth_level)?;
310 all_params.push(signature.clone());
311 auth_params.push(signature);
312
313 if let Some(r) = realm {
314 all_params.push(r.clone());
315 auth_params.push(r);
316 }
317
318 let mut headers = HeaderMap::new();
319
320 match &self.scheme {
321 AuthorizationScheme::Header => {
322 auth_params.sort();
323 user_params.sort();
324
325 headers.insert(
326 "Authorization",
327 HeaderValue::from_str(&format!(
328 "OAuth {}",
329 &auth_params
330 .iter()
331 .map(|p| p.encoded())
332 .intersperse(String::from(','))
333 .collect::<String>(),
334 ))?,
335 );
336
337 headers.insert(
338 "Content-Type",
339 HeaderValue::from_static("application/x-www-form-urlencoded"),
340 );
341
342 let body = user_params
343 .iter()
344 .map(|p| p.encoded())
345 .intersperse(String::from('&'))
346 .collect::<String>();
347
348 Ok(OAuthRequest {
349 uri: self.request_uri,
350 headers,
351 method: HttpMethod::POST,
352 body: Some(body),
353 })
354 }
355 AuthorizationScheme::Body => {
356 auth_params.sort();
357 user_params.sort();
358
359 headers.insert(
360 "Content-Type",
361 HeaderValue::from_static("application/x-www-form-urlencoded"),
362 );
363
364 let body = all_params
365 .iter()
366 .map(|p| p.encoded())
367 .intersperse(String::from('&'))
368 .collect::<String>();
369
370 Ok(OAuthRequest {
371 uri: self.request_uri,
372 headers,
373 method: HttpMethod::POST,
374 body: Some(body),
375 })
376 }
377 AuthorizationScheme::Uri => {
378 all_params.sort();
379
380 let mut uri = self.request_uri.into_parts();
381
382 let params = all_params
383 .iter()
384 .map(|p| p.encoded())
385 .intersperse(String::from('&'))
386 .collect::<String>();
387
388 match uri.path_and_query {
389 None => {
390 uri.path_and_query =
391 Some(PathAndQuery::from_str(&(String::from("?") + ¶ms))?)
392 }
393 Some(pq) => match pq.query() {
394 None => {
395 uri.path_and_query = Some(PathAndQuery::from_str(&format!(
396 "{}?{}",
397 pq.path(),
398 params
399 ))?)
400 }
401 Some(q) => {
402 uri.path_and_query = Some(PathAndQuery::from_str(&format!(
403 "{}?{}&{}",
404 pq.path(),
405 q,
406 params
407 ))?)
408 }
409 },
410 };
411
412 Ok(OAuthRequest {
413 uri: Uri::from_parts(uri)?,
414 headers,
415 method: HttpMethod::GET,
416 body: None,
417 })
418 }
419 }
420 }
421}
422
423fn normalize_uri(original: Uri) -> Uri {
424 let port = original.port_u16();
425 let mut normalized_uri = original.into_parts();
426
427 normalized_uri.scheme = normalized_uri
430 .scheme
431 .map(|s| Scheme::from_str(&s.as_str().to_lowercase()).unwrap());
432 normalized_uri.authority = normalized_uri
433 .authority
434 .map(|a| Authority::from_str(&a.as_str().to_lowercase()).unwrap());
435
436 match &normalized_uri.scheme {
438 Some(s)
439 if s.as_str() == "http" && port == Some(80)
440 || s.as_str() == "https" && port == Some(443) =>
441 {
442 normalized_uri.authority = normalized_uri
443 .authority
444 .map(|a| Authority::from_str(a.host()).unwrap());
445 }
446 _ => {}
447 };
448
449 normalized_uri.path_and_query = normalized_uri
450 .path_and_query
451 .map(|pq| pq.path().parse().unwrap());
452
453 Uri::from_parts(normalized_uri).unwrap()
454}