use std::mem;
use std::borrow::Cow;
use std::collections::HashMap;
use chrono::Utc;
use serde::{Deserialize, Serialize};
use serde_json;
use crate::code_grant::error::{AccessTokenError, AccessTokenErrorType};
use crate::primitives::authorizer::Authorizer;
use crate::primitives::issuer::{IssuedToken, Issuer};
use crate::primitives::grant::{Extensions, Grant};
use crate::primitives::registrar::{Registrar, RegistrarError};
use crate::primitives::scope::Scope;
#[derive(Deserialize, Serialize)]
pub struct TokenResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub access_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_in: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
#[non_exhaustive]
pub enum Authorization<'a> {
None,
Username(Cow<'a, str>),
UsernamePassword(Cow<'a, str>, Cow<'a, [u8]>),
}
pub trait Request {
fn valid(&self) -> bool;
fn code(&self) -> Option<Cow<str>>;
fn authorization(&self) -> Authorization;
fn client_id(&self) -> Option<Cow<str>>;
fn redirect_uri(&self) -> Option<Cow<str>>;
fn grant_type(&self) -> Option<Cow<str>>;
fn extension(&self, key: &str) -> Option<Cow<str>>;
fn allow_credentials_in_body(&self) -> bool {
false
}
}
pub trait Extension {
fn extend(&mut self, request: &dyn Request, data: Extensions)
-> std::result::Result<Extensions, ()>;
}
impl Extension for () {
fn extend(&mut self, _: &dyn Request, _: Extensions) -> std::result::Result<Extensions, ()> {
Ok(Extensions::new())
}
}
pub trait Endpoint {
fn registrar(&self) -> &dyn Registrar;
fn authorizer(&mut self) -> &mut dyn Authorizer;
fn issuer(&mut self) -> &mut dyn Issuer;
fn extension(&mut self) -> &mut dyn Extension;
}
enum Credentials<'a> {
None,
Authenticated {
client_id: &'a str,
passphrase: &'a [u8],
},
Unauthenticated { client_id: &'a str },
Duplicate,
}
pub struct AccessToken {
state: AccessTokenState,
}
enum AccessTokenState {
Authenticate {
client: String,
passdata: Option<Vec<u8>>,
code: String,
redirect_uri: url::Url,
},
Recover {
client: String,
code: String,
redirect_uri: url::Url,
},
Extend {
saved_params: Box<Grant>,
extensions: Extensions,
},
Issue {
grant: Box<Grant>,
},
Err(Error),
}
pub enum Input<'req> {
Request(&'req dyn Request),
Authenticated,
Recovered(Option<Box<Grant>>),
Extended {
access_extensions: Extensions,
},
Issued(IssuedToken),
None,
}
pub enum Output<'machine> {
Authenticate {
client: &'machine str,
passdata: Option<&'machine [u8]>,
},
Recover {
code: &'machine str,
},
Extend {
extensions: &'machine mut Extensions,
},
Issue {
grant: &'machine Grant,
},
Ok(BearerToken),
Err(Box<Error>),
}
impl AccessToken {
pub fn new(request: &dyn Request) -> Self {
AccessToken {
state: Self::validate(request).unwrap_or_else(AccessTokenState::Err),
}
}
pub fn advance(&mut self, input: Input) -> Output<'_> {
self.state = match (self.take(), input) {
(current, Input::None) => current,
(
AccessTokenState::Authenticate {
client,
code,
redirect_uri,
..
},
Input::Authenticated,
) => Self::authenticated(client, code, redirect_uri),
(
AccessTokenState::Recover {
client, redirect_uri, ..
},
Input::Recovered(grant),
) => Self::recovered(client, redirect_uri, grant).unwrap_or_else(AccessTokenState::Err),
(AccessTokenState::Extend { saved_params, .. }, Input::Extended { access_extensions }) => {
Self::issue(saved_params, access_extensions)
}
(AccessTokenState::Issue { grant }, Input::Issued(token)) => {
return Output::Ok(Self::finish(grant, token));
}
(AccessTokenState::Err(err), _) => AccessTokenState::Err(err),
(_, _) => AccessTokenState::Err(Error::Primitive(Box::new(PrimitiveError::empty()))),
};
self.output()
}
fn output(&mut self) -> Output<'_> {
match &mut self.state {
AccessTokenState::Err(err) => Output::Err(Box::new(err.clone())),
AccessTokenState::Authenticate { client, passdata, .. } => Output::Authenticate {
client,
passdata: passdata.as_ref().map(Vec::as_slice),
},
AccessTokenState::Recover { code, .. } => Output::Recover { code },
AccessTokenState::Extend { extensions, .. } => Output::Extend { extensions },
AccessTokenState::Issue { grant } => Output::Issue { grant },
}
}
fn take(&mut self) -> AccessTokenState {
mem::replace(
&mut self.state,
AccessTokenState::Err(Error::Primitive(Box::new(PrimitiveError::empty()))),
)
}
fn validate(request: &dyn Request) -> Result<AccessTokenState> {
if !request.valid() {
return Err(Error::invalid());
}
let authorization = request.authorization();
let client_id = request.client_id();
let client_secret = request.extension("client_secret");
let mut credentials = Credentials::None;
match &authorization {
Authorization::None => {}
Authorization::Username(username) => credentials.unauthenticated(&username),
Authorization::UsernamePassword(username, password) => {
credentials.authenticate(&username, &password)
}
}
if let Some(client_id) = &client_id {
match &client_secret {
Some(auth) if request.allow_credentials_in_body() => {
credentials.authenticate(client_id.as_ref(), auth.as_ref().as_bytes())
}
Some(_) | None => credentials.unauthenticated(client_id.as_ref()),
}
}
match request.grant_type() {
Some(ref cow) if cow == "authorization_code" => (),
None => return Err(Error::invalid()),
Some(_) => return Err(Error::invalid_with(AccessTokenErrorType::UnsupportedGrantType)),
};
let (client_id, passdata) = credentials.into_client().ok_or_else(Error::invalid)?;
let redirect_uri = request
.redirect_uri()
.ok_or_else(Error::invalid)?
.parse()
.map_err(|_| Error::invalid())?;
let code = request.code().ok_or_else(Error::invalid)?;
Ok(AccessTokenState::Authenticate {
client: client_id.to_string(),
passdata: passdata.map(Vec::from),
redirect_uri,
code: code.into_owned(),
})
}
fn authenticated(client: String, code: String, redirect_uri: url::Url) -> AccessTokenState {
AccessTokenState::Recover {
client,
code,
redirect_uri,
}
}
fn recovered(
client_id: String, redirect_uri: url::Url, grant: Option<Box<Grant>>,
) -> Result<AccessTokenState> {
let mut saved_params = match grant {
None => return Err(Error::invalid()),
Some(v) => v,
};
if (saved_params.client_id.as_str(), &saved_params.redirect_uri) != (&client_id, &redirect_uri) {
return Err(Error::invalid_with(AccessTokenErrorType::InvalidGrant));
}
if saved_params.until < Utc::now() {
return Err(Error::invalid_with(AccessTokenErrorType::InvalidGrant));
}
let extensions = mem::take(&mut saved_params.extensions);
Ok(AccessTokenState::Extend {
saved_params,
extensions,
})
}
fn issue(grant: Box<Grant>, extensions: Extensions) -> AccessTokenState {
AccessTokenState::Issue {
grant: Box::new(Grant { extensions, ..*grant }),
}
}
fn finish(grant: Box<Grant>, token: IssuedToken) -> BearerToken {
BearerToken(token, grant.scope.clone())
}
}
pub fn access_token(handler: &mut dyn Endpoint, request: &dyn Request) -> Result<BearerToken> {
enum Requested<'a> {
None,
Authenticate {
client: &'a str,
passdata: Option<&'a [u8]>,
},
Recover(&'a str),
Extend {
extensions: &'a mut Extensions,
},
Issue {
grant: &'a Grant,
},
}
let mut access_token = AccessToken::new(request);
let mut requested = Requested::None;
loop {
let input = match requested {
Requested::None => Input::None,
Requested::Authenticate { client, passdata } => {
handler
.registrar()
.check(client, passdata)
.map_err(|err| match err {
RegistrarError::Unspecified => Error::unauthorized("basic"),
RegistrarError::PrimitiveError => Error::Primitive(Box::new(PrimitiveError {
grant: None,
extensions: None,
})),
})?;
Input::Authenticated
}
Requested::Recover(code) => {
let opt_grant = handler.authorizer().extract(code).map_err(|_| {
Error::Primitive(Box::new(PrimitiveError {
grant: None,
extensions: None,
}))
})?;
Input::Recovered(opt_grant.map(Box::new))
}
Requested::Extend { extensions } => {
let access_extensions = handler
.extension()
.extend(request, extensions.clone())
.map_err(|_| Error::invalid())?;
Input::Extended { access_extensions }
}
Requested::Issue { grant } => {
let token = handler.issuer().issue(grant.clone()).map_err(|_| {
Error::Primitive(Box::new(PrimitiveError {
grant: None,
extensions: None,
}))
})?;
Input::Issued(token)
}
};
requested = match access_token.advance(input) {
Output::Authenticate { client, passdata } => Requested::Authenticate { client, passdata },
Output::Recover { code } => Requested::Recover(code),
Output::Extend { extensions } => Requested::Extend { extensions },
Output::Issue { grant } => Requested::Issue { grant },
Output::Ok(token) => return Ok(token),
Output::Err(e) => return Err(*e),
};
}
}
impl<'a> Credentials<'a> {
pub fn authenticate(&mut self, client_id: &'a str, passphrase: &'a [u8]) {
self.add(Credentials::Authenticated {
client_id,
passphrase,
})
}
pub fn unauthenticated(&mut self, client_id: &'a str) {
self.add(Credentials::Unauthenticated { client_id })
}
pub fn into_client(self) -> Option<(&'a str, Option<&'a [u8]>)> {
match self {
Credentials::Authenticated {
client_id,
passphrase,
} => Some((client_id, Some(passphrase))),
Credentials::Unauthenticated { client_id } => Some((client_id, None)),
_ => None,
}
}
fn add(&mut self, new: Self) {
*self = match self {
Credentials::None => new,
_ => Credentials::Duplicate,
};
}
}
#[derive(Clone)]
pub enum Error {
Invalid(ErrorDescription),
Unauthorized(ErrorDescription, String),
Primitive(Box<PrimitiveError>),
}
#[derive(Clone)]
pub struct PrimitiveError {
pub grant: Option<Grant>,
pub extensions: Option<Extensions>,
}
#[derive(Clone)]
pub struct ErrorDescription {
pub(crate) error: AccessTokenError,
}
type Result<T> = std::result::Result<T, Error>;
pub struct BearerToken(pub(crate) IssuedToken, pub(crate) Scope);
impl Error {
pub fn invalid() -> Self {
Error::Invalid(ErrorDescription {
error: AccessTokenError::default(),
})
}
pub(crate) fn invalid_with(with_type: AccessTokenErrorType) -> Self {
Error::Invalid(ErrorDescription {
error: {
let mut error = AccessTokenError::default();
error.set_type(with_type);
error
},
})
}
pub fn unauthorized(authtype: &str) -> Error {
Error::Unauthorized(
ErrorDescription {
error: {
let mut error = AccessTokenError::default();
error.set_type(AccessTokenErrorType::InvalidClient);
error
},
},
authtype.to_string(),
)
}
pub fn description(&mut self) -> Option<&mut AccessTokenError> {
match self {
Error::Invalid(description) => Some(description.description()),
Error::Unauthorized(description, _) => Some(description.description()),
Error::Primitive(_) => None,
}
}
}
impl PrimitiveError {
pub fn empty() -> Self {
PrimitiveError {
grant: None,
extensions: None,
}
}
}
impl ErrorDescription {
pub fn new(error: AccessTokenError) -> Self {
Self { error }
}
pub fn to_json(&self) -> String {
let asmap = self
.error
.iter()
.map(|(k, v)| (k.to_string(), v.into_owned()))
.collect::<HashMap<String, String>>();
serde_json::to_string(&asmap).unwrap()
}
pub fn description(&mut self) -> &mut AccessTokenError {
&mut self.error
}
}
impl BearerToken {
pub fn to_json(&self) -> String {
let remaining = self.0.until.signed_duration_since(Utc::now());
let token_response = TokenResponse {
access_token: Some(self.0.token.clone()),
refresh_token: self.0.refresh.clone(),
token_type: Some("bearer".to_owned()),
expires_in: Some(remaining.num_seconds()),
scope: Some(self.1.to_string()),
error: None,
};
serde_json::to_string(&token_response).unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::primitives::issuer::TokenType;
#[test]
fn bearer_token_encoding() {
let token = BearerToken(
IssuedToken {
token: "access".into(),
refresh: Some("refresh".into()),
until: Utc::now(),
token_type: TokenType::Bearer,
},
"scope".parse().unwrap(),
);
let json = token.to_json();
let token = serde_json::from_str::<TokenResponse>(&json).unwrap();
assert_eq!(token.access_token, Some("access".to_owned()));
assert_eq!(token.refresh_token, Some("refresh".to_owned()));
assert_eq!(token.scope, Some("scope".to_owned()));
assert_eq!(token.token_type, Some("bearer".to_owned()));
assert!(token.expires_in.is_some());
}
#[test]
fn no_refresh_encoding() {
let token = BearerToken(
IssuedToken::without_refresh("access".into(), Utc::now()),
"scope".parse().unwrap(),
);
let json = token.to_json();
let token = serde_json::from_str::<TokenResponse>(&json).unwrap();
assert_eq!(token.access_token, Some("access".to_owned()));
assert_eq!(token.refresh_token, None);
assert_eq!(token.scope, Some("scope".to_owned()));
assert_eq!(token.token_type, Some("bearer".to_owned()));
assert!(token.expires_in.is_some());
}
}