use std::borrow::Cow;
use std::collections::HashMap;
use chrono::{Duration, Utc};
use code_grant::error::{AccessTokenError, AccessTokenErrorType};
use primitives::issuer::{RefreshedToken, Issuer};
use primitives::registrar::{Registrar, RegistrarError};
pub trait Request {
fn valid(&self) -> bool;
fn refresh_token(&self) -> Option<Cow<str>>;
fn scope(&self) -> Option<Cow<str>>;
fn grant_type(&self) -> Option<Cow<str>>;
fn authorization(&self) -> Option<(Cow<str>, Cow<[u8]>)>;
fn extension(&self, key: &str) -> Option<Cow<str>>;
}
pub trait Endpoint {
fn registrar(&self) -> &dyn Registrar;
fn issuer(&mut self) -> &mut dyn Issuer;
}
pub struct BearerToken(RefreshedToken, String);
#[derive(Debug)]
pub enum Error {
Invalid(ErrorDescription),
Unauthorized(ErrorDescription, String),
Primitive,
}
#[derive(Debug)]
pub struct ErrorDescription {
error: AccessTokenError,
}
type Result<T> = std::result::Result<T, Error>;
pub fn refresh(handler: &mut dyn Endpoint, request: &dyn Request)
-> Result<BearerToken>
{
if !request.valid() {
return Err(Error::invalid(AccessTokenErrorType::InvalidRequest))
}
let token = request.refresh_token();
let token = token.ok_or(Error::invalid(AccessTokenErrorType::InvalidRequest))?;
match request.grant_type() {
Some(ref cow) if cow == "refresh_token" => (),
None => return Err(Error::invalid(AccessTokenErrorType::InvalidRequest)),
Some(_) => return Err(Error::invalid(AccessTokenErrorType::UnsupportedGrantType)),
};
let authenticated = match request.authorization() {
Some((client, passdata)) => {
handler
.registrar()
.check(&client, Some(&passdata))
.map_err(|err| match err {
RegistrarError::PrimitiveError => Error::Primitive,
RegistrarError::Unspecified => Error::unauthorized("basic"),
})?;
Some(client)
},
None => None,
};
let grant = handler
.issuer()
.recover_refresh(&token)
.map_err(|()| Error::Primitive)?;
let grant = grant
.ok_or_else(|| Error::invalid(AccessTokenErrorType::InvalidGrant))?;
match authenticated {
Some(client) => {
if grant.client_id.as_str() != client {
return Err(Error::invalid(AccessTokenErrorType::InvalidGrant))
}
},
None => {
handler
.registrar()
.check(&grant.client_id, None)
.map_err(|err| match err {
RegistrarError::PrimitiveError => Error::Primitive,
RegistrarError::Unspecified => Error::unauthorized("basic"),
})?;;
}
}
if grant.until <= Utc::now() {
return Err(Error::invalid(AccessTokenErrorType::InvalidGrant));
}
let scope = match request.scope() {
Some(scope) => Some(scope.parse().map_err(|_| Error::invalid(AccessTokenErrorType::InvalidScope))?),
None => None,
};
let scope = match scope {
Some(scope) => {
if !(&scope <= &grant.scope) {
return Err(Error::invalid(AccessTokenErrorType::InvalidScope))
}
scope
},
None => grant.scope.clone(),
};
let str_scope = scope.to_string();
let mut grant = grant;
grant.scope = scope;
grant.until = Utc::now() + Duration::hours(1);
let token = handler
.issuer()
.refresh(&token, grant)
.map_err(|()| Error::Primitive)?;
Ok(BearerToken { 0: token, 1: str_scope })
}
impl Error {
fn invalid(kind: AccessTokenErrorType) -> Self {
Error::Invalid(ErrorDescription {
error: AccessTokenError::new(kind),
})
}
fn unauthorized(authtype: &str) -> Self {
Error::Unauthorized(ErrorDescription {
error: AccessTokenError::new(AccessTokenErrorType::InvalidClient),
},
authtype.to_string())
}
pub fn description(&mut self) -> Option<&mut AccessTokenError> {
match self {
Error::Invalid(description) => Some(description.description()),
Error::Unauthorized(description, _) => Some(description.description()),
Error::Primitive => None,
}
}
}
impl ErrorDescription {
pub fn description(&mut self) -> &mut AccessTokenError {
&mut self.error
}
pub fn to_json(self) -> String {
let asmap = self.error.into_iter()
.map(|(k, v)| (k.to_string(), v.into_owned()))
.collect::<HashMap<String, String>>();
serde_json::to_string(&asmap).unwrap()
}
}
impl BearerToken {
pub fn to_json(self) -> String {
let remaining = self.0.until.signed_duration_since(Utc::now());
let mut kvmap: HashMap<_, _> = vec![
("access_token", self.0.token),
("token_type", "bearer".to_string()),
("expires_in", remaining.num_seconds().to_string()),
("scope", self.1)].into_iter().collect();
if let Some(refresh) = self.0.refresh {
kvmap.insert("refresh_token", refresh);
}
serde_json::to_string(&kvmap).unwrap()
}
}