use std::borrow::Cow;
use std::collections::HashMap;
use chrono::{Duration, Utc};
use serde_json;
use code_grant::error::{AccessTokenError, AccessTokenErrorType};
use primitives::authorizer::Authorizer;
use primitives::issuer::{IssuedToken, Issuer};
use primitives::grant::{Extensions, Grant};
use primitives::registrar::{Registrar, RegistrarError};
pub trait Request {
fn valid(&self) -> bool;
fn code(&self) -> Option<Cow<str>>;
fn authorization(&self) -> Option<(Cow<str>, Cow<[u8]>)>;
fn client_id(&self) -> Option<Cow<str>>;
fn redirect_uri(&self) -> Option<Cow<str>>;
fn grant_type(&self) -> Option<Cow<str>>;
fn extension(&self, key: &str) -> Option<Cow<str>>;
fn allow_credentials_in_body(&self) -> bool {
false
}
}
pub trait Extension {
fn extend(&mut self, request: &dyn Request, data: Extensions) -> std::result::Result<Extensions, ()>;
}
impl Extension for () {
fn extend(&mut self, _: &dyn Request, _: Extensions) -> std::result::Result<Extensions, ()> {
Ok(Extensions::new())
}
}
pub trait Endpoint {
fn registrar(&self) -> &dyn Registrar;
fn authorizer(&mut self) -> &mut dyn Authorizer;
fn issuer(&mut self) -> &mut dyn Issuer;
fn extension(&mut self) -> & mut dyn Extension;
}
enum Credentials<'a> {
None,
Authenticated {
client_id: &'a str,
passphrase: &'a [u8],
},
Unauthenticated {
client_id: &'a str,
},
Duplicate,
}
pub fn access_token(handler: &mut dyn Endpoint, request: &dyn Request) -> Result<BearerToken> {
if !request.valid() {
return Err(Error::invalid())
}
let authorization = request.authorization();
let client_id = request.client_id();
let client_secret = request.extension("client_secret");
let mut credentials = Credentials::None;
if let Some((client_id, auth)) = &authorization {
credentials.authenticate(client_id.as_ref(), auth.as_ref());
}
if let Some(client_id) = &client_id {
match &client_secret {
Some(auth) if request.allow_credentials_in_body() => {
credentials.authenticate(client_id.as_ref(), auth.as_ref().as_bytes())
},
Some(_) | None => {
credentials.unauthenticated(client_id.as_ref())
},
}
}
let (client_id, auth) = credentials
.into_client()
.ok_or_else(Error::invalid)?;
handler.registrar()
.check(&client_id, auth)
.map_err(|err| match err {
RegistrarError::Unspecified => Error::unauthorized("basic"),
RegistrarError::PrimitiveError => Error::Primitive(PrimitiveError {
grant: None,
extensions: None,
}),
})?;
match request.grant_type() {
Some(ref cow) if cow == "authorization_code" => (),
None => return Err(Error::invalid()),
Some(_) => return Err(Error::invalid_with(AccessTokenErrorType::UnsupportedGrantType)),
};
let code = request
.code()
.ok_or_else(Error::invalid)?;
let code = code.as_ref();
let saved_params = match handler.authorizer().extract(code) {
Err(()) => return Err(Error::Primitive(PrimitiveError {
grant: None,
extensions: None,
})),
Ok(None) => return Err(Error::invalid()),
Ok(Some(v)) => v,
};
let redirect_uri = request
.redirect_uri()
.ok_or_else(Error::invalid)?;
let redirect_uri = redirect_uri
.as_ref()
.parse()
.map_err(|_| Error::invalid())?;
if (saved_params.client_id.as_ref(), &saved_params.redirect_uri) != (client_id, &redirect_uri) {
return Err(Error::invalid_with(AccessTokenErrorType::InvalidGrant))
}
if saved_params.until < Utc::now() {
return Err(Error::invalid_with(AccessTokenErrorType::InvalidGrant))
}
let code_extensions = saved_params.extensions;
let access_extensions = handler.extension().extend(request, code_extensions);
let access_extensions = match access_extensions {
Ok(extensions) => extensions,
Err(_) => return Err(Error::invalid()),
};
let token = handler.issuer().issue(Grant {
client_id: saved_params.client_id,
owner_id: saved_params.owner_id,
redirect_uri: saved_params.redirect_uri,
scope: saved_params.scope.clone(),
until: Utc::now() + Duration::hours(1),
extensions: access_extensions,
}).map_err(|()| Error::Primitive(PrimitiveError {
grant: None,
extensions: None,
}))?;
Ok(BearerToken{ 0: token, 1: saved_params.scope.to_string() })
}
impl<'a> Credentials<'a> {
pub fn authenticate(&mut self, client_id: &'a str, passphrase: &'a [u8]) {
self.add(Credentials::Authenticated { client_id, passphrase })
}
pub fn unauthenticated(&mut self, client_id: &'a str) {
self.add(Credentials::Unauthenticated { client_id })
}
pub fn into_client(self) -> Option<(&'a str, Option<&'a [u8]>)> {
match self {
Credentials::Authenticated { client_id, passphrase, }
=> Some((client_id, Some(passphrase))),
Credentials::Unauthenticated { client_id }
=> Some((client_id, None)),
_ => None,
}
}
fn add(&mut self, new: Self) {
use std::mem::replace;
let old = replace(self, Credentials::None);
let next = match old {
Credentials::None => new,
_ => Credentials::Duplicate,
};
replace(self, next);
}
}
pub enum Error {
Invalid(ErrorDescription),
Unauthorized(ErrorDescription, String),
Primitive(PrimitiveError),
}
pub struct PrimitiveError {
pub grant: Option<Grant>,
pub extensions: Option<Extensions>,
}
pub struct ErrorDescription {
error: AccessTokenError,
}
type Result<T> = std::result::Result<T, Error>;
pub struct BearerToken(IssuedToken, String);
impl Error {
fn invalid() -> Self {
Error::Invalid(ErrorDescription {
error: AccessTokenError::default()
})
}
fn invalid_with(with_type: AccessTokenErrorType) -> Self {
Error::Invalid(ErrorDescription {
error: {
let mut error = AccessTokenError::default();
error.set_type(with_type);
error
},
})
}
fn unauthorized(authtype: &str) -> Error {
Error::Unauthorized(ErrorDescription {
error: {
let mut error = AccessTokenError::default();
error.set_type(AccessTokenErrorType::InvalidClient);
error
},
},
authtype.to_string())
}
pub fn description(&mut self) -> Option<&mut AccessTokenError> {
match self {
Error::Invalid(description) => Some(description.description()),
Error::Unauthorized(description, _) => Some(description.description()),
Error::Primitive(_) => None,
}
}
}
impl ErrorDescription {
pub fn to_json(self) -> String {
let asmap = self.error.into_iter()
.map(|(k, v)| (k.to_string(), v.into_owned()))
.collect::<HashMap<String, String>>();
serde_json::to_string(&asmap).unwrap()
}
pub fn description(&mut self) -> &mut AccessTokenError {
&mut self.error
}
}
impl BearerToken {
pub fn to_json(self) -> String {
#[derive(Serialize)]
struct Serial<'a> {
access_token: &'a str,
#[serde(skip_serializing_if="Option::is_none")]
refresh_token: Option<&'a str>,
token_type: &'a str,
expires_in: String,
scope: &'a str,
}
let remaining = self.0.until.signed_duration_since(Utc::now());
let serial = Serial {
access_token: self.0.token.as_str(),
refresh_token: Some(self.0.refresh.as_str())
.filter(|_| self.0.refreshable()),
token_type: "bearer",
expires_in: remaining.num_seconds().to_string(),
scope: self.1.as_str(),
};
serde_json::to_string(&serial).unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn bearer_token_encoding() {
let token = BearerToken(IssuedToken {
token: "access".into(),
refresh: "refresh".into(),
until: Utc::now(),
}, "scope".into());
let json = token.to_json();
let mut token = serde_json::from_str::<HashMap<String, String>>(&json).unwrap();
assert_eq!(token.remove("access_token"), Some("access".to_string()));
assert_eq!(token.remove("refresh_token"), Some("refresh".to_string()));
assert_eq!(token.remove("scope"), Some("scope".to_string()));
assert_eq!(token.remove("token_type"), Some("bearer".to_string()));
assert!(token.remove("expires_in").is_some());
}
#[test]
fn no_refresh_encoding() {
let token = BearerToken(IssuedToken::without_refresh(
"access".into(),
Utc::now(),
), "scope".into());
let json = token.to_json();
let mut token = serde_json::from_str::<HashMap<String, String>>(&json).unwrap();
assert_eq!(token.remove("access_token"), Some("access".to_string()));
assert_eq!(token.remove("refresh_token"), None);
assert_eq!(token.remove("scope"), Some("scope".to_string()));
assert_eq!(token.remove("token_type"), Some("bearer".to_string()));
assert!(token.remove("expires_in").is_some());
}
}