1use std::mem;
3use std::borrow::Cow;
4use std::collections::HashMap;
5
6use chrono::Utc;
7use serde::{Deserialize, Serialize};
8use serde_json;
9
10use crate::code_grant::error::{AccessTokenError, AccessTokenErrorType};
11use crate::primitives::authorizer::Authorizer;
12use crate::primitives::issuer::{IssuedToken, Issuer};
13use crate::primitives::grant::{Extensions, Grant};
14use crate::primitives::registrar::{Registrar, RegistrarError};
15use crate::primitives::scope::Scope;
16
17#[derive(Deserialize, Serialize)]
19pub struct TokenResponse {
20 #[serde(skip_serializing_if = "Option::is_none")]
22 pub access_token: Option<String>,
23
24 #[serde(skip_serializing_if = "Option::is_none")]
26 pub refresh_token: Option<String>,
27
28 #[serde(skip_serializing_if = "Option::is_none")]
30 pub token_type: Option<String>,
31
32 #[serde(skip_serializing_if = "Option::is_none")]
34 pub expires_in: Option<i64>,
35
36 #[serde(skip_serializing_if = "Option::is_none")]
38 pub scope: Option<String>,
39
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub error: Option<String>,
43}
44
45#[non_exhaustive]
47pub enum Authorization<'a> {
48 None,
50 Username(Cow<'a, str>),
52 UsernamePassword(Cow<'a, str>, Cow<'a, [u8]>),
54}
55
56pub trait Request {
58 fn valid(&self) -> bool;
63
64 fn code(&self) -> Option<Cow<str>>;
66
67 fn authorization(&self) -> Authorization;
69
70 fn client_id(&self) -> Option<Cow<str>>;
72
73 fn redirect_uri(&self) -> Option<Cow<str>>;
75
76 fn grant_type(&self) -> Option<Cow<str>>;
78
79 fn extension(&self, key: &str) -> Option<Cow<str>>;
81
82 fn allow_credentials_in_body(&self) -> bool {
90 false
91 }
92}
93
94pub trait Extension {
98 fn extend(&mut self, request: &dyn Request, data: Extensions)
103 -> std::result::Result<Extensions, ()>;
104}
105
106impl Extension for () {
107 fn extend(&mut self, _: &dyn Request, _: Extensions) -> std::result::Result<Extensions, ()> {
108 Ok(Extensions::new())
109 }
110}
111
112pub trait Endpoint {
118 fn registrar(&self) -> &dyn Registrar;
120
121 fn authorizer(&mut self) -> &mut dyn Authorizer;
123
124 fn issuer(&mut self) -> &mut dyn Issuer;
126
127 fn extension(&mut self) -> &mut dyn Extension;
131}
132
133enum Credentials<'a> {
134 None,
136 Authenticated {
138 client_id: &'a str,
139 passphrase: &'a [u8],
140 },
141 Unauthenticated { client_id: &'a str },
146 Duplicate,
150}
151
152pub struct AccessToken {
180 state: AccessTokenState,
181}
182
183enum AccessTokenState {
185 Authenticate {
187 client: String,
188 passdata: Option<Vec<u8>>,
189 code: String,
190 redirect_uri: url::Url,
192 },
193 Recover {
194 client: String,
195 code: String,
196 redirect_uri: url::Url,
197 },
198 Extend {
199 saved_params: Box<Grant>,
200 extensions: Extensions,
201 },
202 Issue {
203 grant: Box<Grant>,
204 },
205 Err(Error),
206}
207
208pub enum Input<'req> {
210 Request(&'req dyn Request),
212 Authenticated,
214 Recovered(Option<Box<Grant>>),
216 Extended {
218 access_extensions: Extensions,
220 },
221 Issued(IssuedToken),
223 None,
225}
226
227pub enum Output<'machine> {
233 Authenticate {
238 client: &'machine str,
240 passdata: Option<&'machine [u8]>,
242 },
243 Recover {
247 code: &'machine str,
249 },
250 Extend {
254 extensions: &'machine mut Extensions,
256 },
257 Issue {
261 grant: &'machine Grant,
263 },
264 Ok(BearerToken),
269 Err(Box<Error>),
273}
274
275impl AccessToken {
276 pub fn new(request: &dyn Request) -> Self {
278 AccessToken {
279 state: Self::validate(request).unwrap_or_else(AccessTokenState::Err),
280 }
281 }
282
283 pub fn advance(&mut self, input: Input) -> Output<'_> {
285 self.state = match (self.take(), input) {
286 (current, Input::None) => current,
287 (
288 AccessTokenState::Authenticate {
289 client,
290 code,
291 redirect_uri,
292 ..
293 },
294 Input::Authenticated,
295 ) => Self::authenticated(client, code, redirect_uri),
296 (
297 AccessTokenState::Recover {
298 client, redirect_uri, ..
299 },
300 Input::Recovered(grant),
301 ) => Self::recovered(client, redirect_uri, grant).unwrap_or_else(AccessTokenState::Err),
302 (AccessTokenState::Extend { saved_params, .. }, Input::Extended { access_extensions }) => {
303 Self::issue(saved_params, access_extensions)
304 }
305 (AccessTokenState::Issue { grant }, Input::Issued(token)) => {
306 return Output::Ok(Self::finish(grant, token));
307 }
308 (AccessTokenState::Err(err), _) => AccessTokenState::Err(err),
309 (_, _) => AccessTokenState::Err(Error::Primitive(Box::new(PrimitiveError::empty()))),
310 };
311
312 self.output()
313 }
314
315 fn output(&mut self) -> Output<'_> {
316 match &mut self.state {
317 AccessTokenState::Err(err) => Output::Err(Box::new(err.clone())),
318 AccessTokenState::Authenticate { client, passdata, .. } => Output::Authenticate {
319 client,
320 passdata: passdata.as_ref().map(Vec::as_slice),
321 },
322 AccessTokenState::Recover { code, .. } => Output::Recover { code },
323 AccessTokenState::Extend { extensions, .. } => Output::Extend { extensions },
324 AccessTokenState::Issue { grant } => Output::Issue { grant },
325 }
326 }
327
328 fn take(&mut self) -> AccessTokenState {
329 mem::replace(
330 &mut self.state,
331 AccessTokenState::Err(Error::Primitive(Box::new(PrimitiveError::empty()))),
332 )
333 }
334
335 fn validate(request: &dyn Request) -> Result<AccessTokenState> {
336 if !request.valid() {
337 return Err(Error::invalid());
338 }
339
340 let authorization = request.authorization();
341 let client_id = request.client_id();
342 let client_secret = request.extension("client_secret");
343
344 let mut credentials = Credentials::None;
345
346 match &authorization {
347 Authorization::None => {}
348 Authorization::Username(username) => credentials.unauthenticated(&username),
349 Authorization::UsernamePassword(username, password) => {
350 credentials.authenticate(&username, &password)
351 }
352 }
353
354 if let Some(client_id) = &client_id {
355 match &client_secret {
356 Some(auth) if request.allow_credentials_in_body() => {
357 credentials.authenticate(client_id.as_ref(), auth.as_ref().as_bytes())
358 }
359 Some(_) | None => credentials.unauthenticated(client_id.as_ref()),
361 }
362 }
363
364 match request.grant_type() {
365 Some(ref cow) if cow == "authorization_code" => (),
366 None => return Err(Error::invalid()),
367 Some(_) => return Err(Error::invalid_with(AccessTokenErrorType::UnsupportedGrantType)),
368 };
369
370 let (client_id, passdata) = credentials.into_client().ok_or_else(Error::invalid)?;
371
372 let redirect_uri = request
373 .redirect_uri()
374 .ok_or_else(Error::invalid)?
375 .parse()
376 .map_err(|_| Error::invalid())?;
377
378 let code = request.code().ok_or_else(Error::invalid)?;
379
380 Ok(AccessTokenState::Authenticate {
381 client: client_id.to_string(),
382 passdata: passdata.map(Vec::from),
383 redirect_uri,
384 code: code.into_owned(),
385 })
386 }
387
388 fn authenticated(client: String, code: String, redirect_uri: url::Url) -> AccessTokenState {
389 AccessTokenState::Recover {
390 client,
391 code,
392 redirect_uri,
393 }
394 }
395
396 fn recovered(
397 client_id: String, redirect_uri: url::Url, grant: Option<Box<Grant>>,
398 ) -> Result<AccessTokenState> {
399 let mut saved_params = match grant {
400 None => return Err(Error::invalid()),
401 Some(v) => v,
402 };
403
404 if (saved_params.client_id.as_str(), &saved_params.redirect_uri) != (&client_id, &redirect_uri) {
405 return Err(Error::invalid_with(AccessTokenErrorType::InvalidGrant));
406 }
407
408 if saved_params.until < Utc::now() {
409 return Err(Error::invalid_with(AccessTokenErrorType::InvalidGrant));
410 }
411
412 let extensions = mem::take(&mut saved_params.extensions);
413 Ok(AccessTokenState::Extend {
414 saved_params,
415 extensions,
416 })
417 }
418
419 fn issue(grant: Box<Grant>, extensions: Extensions) -> AccessTokenState {
420 AccessTokenState::Issue {
421 grant: Box::new(Grant { extensions, ..*grant }),
422 }
423 }
424
425 fn finish(grant: Box<Grant>, token: IssuedToken) -> BearerToken {
426 BearerToken(token, grant.scope.clone())
427 }
428}
429
430pub fn access_token(handler: &mut dyn Endpoint, request: &dyn Request) -> Result<BearerToken> {
433 enum Requested<'a> {
434 None,
435 Authenticate {
436 client: &'a str,
437 passdata: Option<&'a [u8]>,
438 },
439 Recover(&'a str),
440 Extend {
441 extensions: &'a mut Extensions,
442 },
443 Issue {
444 grant: &'a Grant,
445 },
446 }
447
448 let mut access_token = AccessToken::new(request);
449 let mut requested = Requested::None;
450
451 loop {
452 let input = match requested {
453 Requested::None => Input::None,
454 Requested::Authenticate { client, passdata } => {
455 handler
456 .registrar()
457 .check(client, passdata)
458 .map_err(|err| match err {
459 RegistrarError::Unspecified => Error::unauthorized("basic"),
460 RegistrarError::PrimitiveError => Error::Primitive(Box::new(PrimitiveError {
461 grant: None,
462 extensions: None,
463 })),
464 })?;
465 Input::Authenticated
466 }
467 Requested::Recover(code) => {
468 let opt_grant = handler.authorizer().extract(code).map_err(|_| {
469 Error::Primitive(Box::new(PrimitiveError {
470 grant: None,
471 extensions: None,
472 }))
473 })?;
474 Input::Recovered(opt_grant.map(Box::new))
475 }
476 Requested::Extend { extensions } => {
477 let access_extensions = handler
478 .extension()
479 .extend(request, extensions.clone())
480 .map_err(|_| Error::invalid())?;
481 Input::Extended { access_extensions }
482 }
483 Requested::Issue { grant } => {
484 let token = handler.issuer().issue(grant.clone()).map_err(|_| {
485 Error::Primitive(Box::new(PrimitiveError {
486 grant: None,
488 extensions: None,
489 }))
490 })?;
491 Input::Issued(token)
492 }
493 };
494
495 requested = match access_token.advance(input) {
496 Output::Authenticate { client, passdata } => Requested::Authenticate { client, passdata },
497 Output::Recover { code } => Requested::Recover(code),
498 Output::Extend { extensions } => Requested::Extend { extensions },
499 Output::Issue { grant } => Requested::Issue { grant },
500 Output::Ok(token) => return Ok(token),
501 Output::Err(e) => return Err(*e),
502 };
503 }
504}
505
506impl<'a> Credentials<'a> {
507 pub fn authenticate(&mut self, client_id: &'a str, passphrase: &'a [u8]) {
508 self.add(Credentials::Authenticated {
509 client_id,
510 passphrase,
511 })
512 }
513
514 pub fn unauthenticated(&mut self, client_id: &'a str) {
515 self.add(Credentials::Unauthenticated { client_id })
516 }
517
518 pub fn into_client(self) -> Option<(&'a str, Option<&'a [u8]>)> {
519 match self {
520 Credentials::Authenticated {
521 client_id,
522 passphrase,
523 } => Some((client_id, Some(passphrase))),
524 Credentials::Unauthenticated { client_id } => Some((client_id, None)),
525 _ => None,
526 }
527 }
528
529 fn add(&mut self, new: Self) {
530 *self = match self {
531 Credentials::None => new,
532 _ => Credentials::Duplicate,
533 };
534 }
535}
536
537#[derive(Clone)]
539pub enum Error {
540 Invalid(ErrorDescription),
542
543 Unauthorized(ErrorDescription, String),
545
546 Primitive(Box<PrimitiveError>),
551}
552
553#[derive(Clone)]
568pub struct PrimitiveError {
569 pub grant: Option<Grant>,
574
575 pub extensions: Option<Extensions>,
577}
578
579#[derive(Clone)]
582pub struct ErrorDescription {
583 pub(crate) error: AccessTokenError,
584}
585
586type Result<T> = std::result::Result<T, Error>;
587
588pub struct BearerToken(pub(crate) IssuedToken, pub(crate) Scope);
590
591impl Error {
592 pub fn invalid() -> Self {
594 Error::Invalid(ErrorDescription {
595 error: AccessTokenError::default(),
596 })
597 }
598
599 pub(crate) fn invalid_with(with_type: AccessTokenErrorType) -> Self {
600 Error::Invalid(ErrorDescription {
601 error: {
602 let mut error = AccessTokenError::default();
603 error.set_type(with_type);
604 error
605 },
606 })
607 }
608
609 pub fn unauthorized(authtype: &str) -> Error {
611 Error::Unauthorized(
612 ErrorDescription {
613 error: {
614 let mut error = AccessTokenError::default();
615 error.set_type(AccessTokenErrorType::InvalidClient);
616 error
617 },
618 },
619 authtype.to_string(),
620 )
621 }
622
623 pub fn description(&mut self) -> Option<&mut AccessTokenError> {
628 match self {
629 Error::Invalid(description) => Some(description.description()),
630 Error::Unauthorized(description, _) => Some(description.description()),
631 Error::Primitive(_) => None,
632 }
633 }
634}
635
636impl PrimitiveError {
637 pub fn empty() -> Self {
639 PrimitiveError {
640 grant: None,
641 extensions: None,
642 }
643 }
644}
645
646impl ErrorDescription {
647 pub fn new(error: AccessTokenError) -> Self {
649 Self { error }
650 }
651
652 pub fn to_json(&self) -> String {
655 let asmap = self
656 .error
657 .iter()
658 .map(|(k, v)| (k.to_string(), v.into_owned()))
659 .collect::<HashMap<String, String>>();
660 serde_json::to_string(&asmap).unwrap()
661 }
662
663 pub fn description(&mut self) -> &mut AccessTokenError {
665 &mut self.error
666 }
667}
668
669impl BearerToken {
670 pub fn to_json(&self) -> String {
673 let remaining = self.0.until.signed_duration_since(Utc::now());
674 let token_response = TokenResponse {
675 access_token: Some(self.0.token.clone()),
676 refresh_token: self.0.refresh.clone(),
677 token_type: Some("bearer".to_owned()),
678 expires_in: Some(remaining.num_seconds()),
679 scope: Some(self.1.to_string()),
680 error: None,
681 };
682
683 serde_json::to_string(&token_response).unwrap()
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use crate::primitives::issuer::TokenType;
691
692 #[test]
693 fn bearer_token_encoding() {
694 let token = BearerToken(
695 IssuedToken {
696 token: "access".into(),
697 refresh: Some("refresh".into()),
698 until: Utc::now(),
699 token_type: TokenType::Bearer,
700 },
701 "scope".parse().unwrap(),
702 );
703
704 let json = token.to_json();
705 let token = serde_json::from_str::<TokenResponse>(&json).unwrap();
706
707 assert_eq!(token.access_token, Some("access".to_owned()));
708 assert_eq!(token.refresh_token, Some("refresh".to_owned()));
709 assert_eq!(token.scope, Some("scope".to_owned()));
710 assert_eq!(token.token_type, Some("bearer".to_owned()));
711 assert!(token.expires_in.is_some());
712 }
713
714 #[test]
715 fn no_refresh_encoding() {
716 let token = BearerToken(
717 IssuedToken::without_refresh("access".into(), Utc::now()),
718 "scope".parse().unwrap(),
719 );
720
721 let json = token.to_json();
722 let token = serde_json::from_str::<TokenResponse>(&json).unwrap();
723
724 assert_eq!(token.access_token, Some("access".to_owned()));
725 assert_eq!(token.refresh_token, None);
726 assert_eq!(token.scope, Some("scope".to_owned()));
727 assert_eq!(token.token_type, Some("bearer".to_owned()));
728 assert!(token.expires_in.is_some());
729 }
730}