use std::borrow::Cow;
use std::collections::HashMap;
use chrono::{Duration, Utc};
use crate::code_grant::{
accesstoken::TokenResponse,
error::{AccessTokenError, AccessTokenErrorType},
};
use crate::primitives::grant::Grant;
use crate::primitives::issuer::{RefreshedToken, Issuer};
use crate::primitives::registrar::{Registrar, RegistrarError};
pub trait Request {
fn valid(&self) -> bool;
fn refresh_token(&self) -> Option<Cow<str>>;
fn scope(&self) -> Option<Cow<str>>;
fn grant_type(&self) -> Option<Cow<str>>;
fn authorization(&self) -> Option<(Cow<str>, Cow<[u8]>)>;
fn extension(&self, key: &str) -> Option<Cow<str>>;
}
pub trait Endpoint {
fn registrar(&self) -> &dyn Registrar;
fn issuer(&mut self) -> &mut dyn Issuer;
}
#[derive(Debug)]
pub struct BearerToken(RefreshedToken, String);
#[derive(Debug)]
pub struct Refresh {
state: RefreshState,
}
#[derive(Debug)]
enum RefreshState {
Authenticating {
client: String,
passdata: Option<Vec<u8>>,
token: String,
},
Recovering {
authenticated: Option<String>,
token: String,
},
CoAuthenticating {
grant: Box<Grant>,
token: String,
},
Issuing {
grant: Box<Grant>,
token: String,
},
Err(Error),
}
#[derive(Clone)]
pub enum Input<'req> {
Authenticated {
scope: Option<Cow<'req, str>>,
},
Recovered {
scope: Option<Cow<'req, str>>,
grant: Option<Box<Grant>>,
},
Refreshed(RefreshedToken),
None,
}
#[derive(Debug)]
pub enum Output<'a> {
Unauthenticated {
client: &'a str,
pass: Option<&'a [u8]>,
},
RecoverRefresh {
token: &'a str,
},
Refresh {
token: &'a str,
grant: Box<Grant>,
},
Ok(BearerToken),
Err(Error),
}
#[derive(Clone, Debug)]
pub enum Error {
Invalid(ErrorDescription),
Unauthorized(ErrorDescription, String),
Primitive,
}
#[derive(Clone, Debug)]
pub struct ErrorDescription {
pub(crate) error: AccessTokenError,
}
type Result<T> = std::result::Result<T, Error>;
impl Refresh {
pub fn new(request: &dyn Request) -> Self {
Refresh {
state: initialize(request).unwrap_or_else(RefreshState::Err),
}
}
pub fn advance<'req>(&mut self, input: Input<'req>) -> Output<'_> {
match (self.take(), input) {
(RefreshState::Err(error), _) => {
self.state = RefreshState::Err(error.clone());
Output::Err(error)
}
(
RefreshState::Authenticating {
client,
passdata: _,
token,
},
Input::Authenticated { .. },
) => {
self.state = authenticated(client, token);
self.output()
}
(RefreshState::Recovering { authenticated, token }, Input::Recovered { scope, grant }) => {
self.state = recovered_refresh(scope, authenticated, grant, token)
.unwrap_or_else(RefreshState::Err);
self.output()
}
(RefreshState::CoAuthenticating { grant, token }, Input::Authenticated { scope }) => {
self.state = co_authenticated(scope, grant, token).unwrap_or_else(RefreshState::Err);
self.output()
}
(RefreshState::Issuing { grant, token: _ }, Input::Refreshed(token)) => {
self.state = RefreshState::Err(Error::Primitive);
Output::Ok(issued(grant, token))
}
(current, Input::None) => {
match current {
RefreshState::Authenticating { .. } => self.state = current,
RefreshState::Recovering { .. } => self.state = current,
RefreshState::CoAuthenticating { .. } => (),
RefreshState::Issuing { .. } => (),
RefreshState::Err(_) => (),
}
self.output()
}
(_, _) => {
self.state = RefreshState::Err(Error::Primitive);
self.output()
}
}
}
fn take(&mut self) -> RefreshState {
core::mem::replace(&mut self.state, RefreshState::Err(Error::Primitive))
}
fn output(&self) -> Output<'_> {
match &self.state {
RefreshState::Authenticating { client, passdata, .. } => Output::Unauthenticated {
client,
pass: passdata.as_ref().map(|vec| vec.as_slice()),
},
RefreshState::CoAuthenticating { grant, .. } => Output::Unauthenticated {
client: &grant.client_id,
pass: None,
},
RefreshState::Recovering { token, .. } => Output::RecoverRefresh { token: &token },
RefreshState::Issuing { token, grant, .. } => Output::Refresh {
token,
grant: grant.clone(),
},
RefreshState::Err(error) => Output::Err(error.clone()),
}
}
}
impl<'req> Input<'req> {
pub fn take(&mut self) -> Self {
core::mem::replace(self, Input::None)
}
}
pub fn refresh(handler: &mut dyn Endpoint, request: &dyn Request) -> Result<BearerToken> {
enum Requested {
None,
Refresh { token: String, grant: Box<Grant> },
RecoverRefresh { token: String },
Authenticate { client: String, pass: Option<Vec<u8>> },
}
let mut refresh = Refresh::new(request);
let mut requested = Requested::None;
loop {
let input = match requested {
Requested::None => Input::None,
Requested::Refresh { token, grant } => {
let refreshed = handler
.issuer()
.refresh(&token, *grant)
.map_err(|()| Error::Primitive)?;
Input::Refreshed(refreshed)
}
Requested::RecoverRefresh { token } => {
let recovered = handler
.issuer()
.recover_refresh(&token)
.map_err(|()| Error::Primitive)?;
Input::Recovered {
scope: request.scope(),
grant: recovered.map(Box::new),
}
}
Requested::Authenticate { client, pass } => {
let _: () =
handler
.registrar()
.check(&client, pass.as_deref())
.map_err(|err| match err {
RegistrarError::PrimitiveError => Error::Primitive,
RegistrarError::Unspecified => Error::unauthorized("basic"),
})?;
Input::Authenticated {
scope: request.scope(),
}
}
};
requested = match refresh.advance(input) {
Output::Err(error) => return Err(error),
Output::Ok(token) => return Ok(token),
Output::Refresh { token, grant } => Requested::Refresh {
token: token.to_string(),
grant,
},
Output::RecoverRefresh { token } => Requested::RecoverRefresh {
token: token.to_string(),
},
Output::Unauthenticated { client, pass } => Requested::Authenticate {
client: client.to_string(),
pass: pass.map(|p| p.to_vec()),
},
};
}
}
fn initialize(request: &dyn Request) -> Result<RefreshState> {
if !request.valid() {
return Err(Error::invalid(AccessTokenErrorType::InvalidRequest));
}
let token = request.refresh_token();
let token = token.ok_or_else(|| Error::invalid(AccessTokenErrorType::InvalidRequest))?;
match request.grant_type() {
Some(ref cow) if cow == "refresh_token" => (),
None => return Err(Error::invalid(AccessTokenErrorType::InvalidRequest)),
Some(_) => return Err(Error::invalid(AccessTokenErrorType::UnsupportedGrantType)),
};
match request.authorization() {
Some((client, passdata)) => Ok(RefreshState::Authenticating {
client: client.into_owned(),
passdata: Some(passdata.to_vec()),
token: token.into_owned(),
}),
None => Ok(RefreshState::Recovering {
token: token.into_owned(),
authenticated: None,
}),
}
}
fn authenticated(client: String, token: String) -> RefreshState {
RefreshState::Recovering {
token,
authenticated: Some(client),
}
}
fn recovered_refresh(
scope: Option<Cow<str>>, authenticated: Option<String>, grant: Option<Box<Grant>>, token: String,
) -> Result<RefreshState> {
let grant = grant
.ok_or_else(|| Error::invalid(AccessTokenErrorType::InvalidGrant))?;
match authenticated {
Some(client) => {
if grant.client_id.as_str() != client {
Err(Error::invalid(AccessTokenErrorType::InvalidGrant))
} else {
validate(scope, grant, token)
}
}
None => Ok(RefreshState::CoAuthenticating { grant, token }),
}
}
fn co_authenticated(scope: Option<Cow<str>>, grant: Box<Grant>, token: String) -> Result<RefreshState> {
validate(scope, grant, token)
}
fn validate(scope: Option<Cow<str>>, grant: Box<Grant>, token: String) -> Result<RefreshState> {
if grant.until <= Utc::now() {
return Err(Error::invalid(AccessTokenErrorType::InvalidGrant));
}
let scope = match scope {
Some(scope) => Some(
scope
.parse()
.map_err(|_| Error::invalid(AccessTokenErrorType::InvalidScope))?,
),
None => None,
};
let scope = match scope {
Some(scope) => {
if !grant.scope.priviledged_to(&scope) {
return Err(Error::invalid(AccessTokenErrorType::InvalidScope));
}
scope
}
None => grant.scope.clone(),
};
let mut grant = grant;
grant.scope = scope;
grant.until = Utc::now() + Duration::hours(1);
Ok(RefreshState::Issuing { grant, token })
}
fn issued(grant: Box<Grant>, token: RefreshedToken) -> BearerToken {
BearerToken(token, grant.scope.to_string())
}
impl Error {
fn invalid(kind: AccessTokenErrorType) -> Self {
Error::Invalid(ErrorDescription {
error: AccessTokenError::new(kind),
})
}
pub fn unauthorized(authtype: &str) -> Self {
Error::Unauthorized(
ErrorDescription {
error: AccessTokenError::new(AccessTokenErrorType::InvalidClient),
},
authtype.to_string(),
)
}
pub fn description(&mut self) -> Option<&mut AccessTokenError> {
match self {
Error::Invalid(description) => Some(description.description()),
Error::Unauthorized(description, _) => Some(description.description()),
Error::Primitive => None,
}
}
}
impl ErrorDescription {
pub fn description(&mut self) -> &mut AccessTokenError {
&mut self.error
}
pub fn to_json(&self) -> String {
let asmap = self
.error
.iter()
.map(|(k, v)| (k.to_string(), v.into_owned()))
.collect::<HashMap<String, String>>();
serde_json::to_string(&asmap).unwrap()
}
}
impl BearerToken {
pub fn to_json(&self) -> String {
let remaining = self.0.until.signed_duration_since(Utc::now());
let token_response = TokenResponse {
access_token: Some(self.0.token.clone()),
refresh_token: self.0.refresh.clone(),
token_type: Some("bearer".to_owned()),
expires_in: Some(remaining.num_seconds()),
scope: Some(self.1.clone()),
error: None,
};
serde_json::to_string(&token_response).unwrap()
}
}