use std::borrow::Borrow;
use std::collections::HashSet;
use std::error;
use std::fmt;
use std::fs::File;
use std::io::{self, Cursor, Read};
use std::ops::Deref;
use std::str::FromStr;
use std::time::Duration;
use crate::cors;
use crate::jwt::{self, jwa, jwk, jws};
use chrono::{self, DateTime, Utc};
use rocket::http::{ContentType, Method, Status};
use rocket::response::{Responder, Response};
use rocket::Request;
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json;
use uuid::Uuid;
use crate::{ByteSequence, JsonValue};
#[derive(Debug)]
pub enum Error {
TokenAlreadyEncoded,
TokenAlreadyDecoded,
TokenNotEncoded,
TokenNotDecoded,
NoRefreshToken,
RefreshTokenAlreadyEncrypted,
RefreshTokenAlreadyDecrypted,
RefreshTokenNotDecrypted,
RefreshTokenNotEncrypted,
InvalidService,
InvalidIssuer,
InvalidAudience,
GenericError(String),
IOError(io::Error),
JWTError(Box<jwt::errors::Error>),
TokenSerializationError(serde_json::Error),
}
impl_from_error!(io::Error, Error::IOError);
impl_from_error!(serde_json::Error, Error::TokenSerializationError);
impl_from_error!(String, Error::GenericError);
impl From<jwt::errors::Error> for Error {
fn from(jwt: jwt::errors::Error) -> Self {
Error::JWTError(Box::new(jwt))
}
}
impl<'a> From<&'a str> for Error {
fn from(s: &'a str) -> Error {
Error::GenericError(s.to_string())
}
}
impl error::Error for Error {
fn description(&self) -> &str {
match *self {
Error::TokenAlreadyEncoded => "Token is already encoded",
Error::TokenAlreadyDecoded => "Token is already decoded",
Error::TokenNotEncoded => "Token is not encoded and cannot be used in this context",
Error::TokenNotDecoded => "Token is not decoded and cannot be used in this context",
Error::NoRefreshToken => "Refresh token is not present",
Error::RefreshTokenAlreadyEncrypted => "Refresh token is already encrypted and signed",
Error::RefreshTokenAlreadyDecrypted => {
"Refresh token is already decrypted and verified"
}
Error::RefreshTokenNotDecrypted => {
"Refresh token is not decrypted and cannot be used in this context"
}
Error::RefreshTokenNotEncrypted => {
"Refresh token is not encrypted and cannot be used in this context"
}
Error::InvalidService => "Service requested is not in the list of intended audiences",
Error::InvalidIssuer => "The token has an invalid issuer",
Error::InvalidAudience => "The token has invalid audience",
Error::JWTError(ref e) => e.description(),
Error::IOError(ref e) => e.description(),
Error::TokenSerializationError(ref e) => e.description(),
Error::GenericError(ref e) => e,
}
}
fn cause(&self) -> Option<&dyn error::Error> {
match *self {
Error::JWTError(ref e) => Some(e),
Error::IOError(ref e) => Some(e),
Error::TokenSerializationError(ref e) => Some(e),
_ => Some(self),
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Error::JWTError(ref e) => fmt::Display::fmt(e, f),
Error::IOError(ref e) => fmt::Display::fmt(e, f),
Error::TokenSerializationError(ref e) => fmt::Display::fmt(e, f),
Error::GenericError(ref e) => fmt::Display::fmt(e, f),
_ => write!(f, "{}", error::Error::description(self)),
}
}
}
impl<'r> Responder<'r> for Error {
fn respond_to(self, _: &Request<'_>) -> Result<Response<'r>, Status> {
error_!("Token Error: {:?}", self);
match self {
Error::InvalidService | Error::InvalidIssuer | Error::InvalidAudience => {
Err(Status::Forbidden)
}
Error::JWTError(ref e) => {
use crate::jwt::errors::Error::*;
let status = match **e {
ValidationError(_)
| JsonError(_)
| DecodeBase64(_)
| Utf8(_)
| UnspecifiedCryptographicError => Status::Unauthorized,
_ => Status::InternalServerError,
};
Err(status)
}
_ => Err(Status::InternalServerError),
}
}
}
fn make_uuid() -> Result<Uuid, Error> {
use crate::jwt::jwa::SecureRandom;
use std::error::Error;
let mut bytes = vec![0; 16];
crate::rng()
.fill(&mut bytes)
.map_err(|_| "Unable to generate UUID")?;
Ok(Uuid::from_bytes(&bytes).map_err(|e| e.description().to_string())?)
}
fn make_header(signature_algorithm: Option<jwa::SignatureAlgorithm>) -> jws::Header<jwt::Empty> {
let registered = jws::RegisteredHeader {
algorithm: signature_algorithm.unwrap_or_else(|| jwa::SignatureAlgorithm::None),
..Default::default()
};
jws::Header::from_registered_header(registered)
}
fn make_registered_claims(
subject: &str,
now: DateTime<Utc>,
expiry_duration: Duration,
issuer: &jwt::StringOrUri,
audience: &jwt::SingleOrMultiple<jwt::StringOrUri>,
) -> Result<jwt::RegisteredClaims, crate::Error> {
let expiry_duration = chrono::Duration::from_std(expiry_duration).map_err(|e| e.to_string())?;
Ok(jwt::RegisteredClaims {
issuer: Some(issuer.clone()),
subject: Some(FromStr::from_str(subject).map_err(|e| Error::JWTError(Box::new(e)))?),
audience: Some(audience.clone()),
issued_at: Some(now.into()),
not_before: Some(now.into()),
expiry: Some((now + expiry_duration).into()),
id: Some(make_uuid()?.urn().to_string()),
})
}
#[cfg_attr(feature = "clippy_lints", allow(too_many_arguments))] fn make_token<P: Serialize + DeserializeOwned + 'static>(
subject: &str,
issuer: &jwt::StringOrUri,
audience: &jwt::SingleOrMultiple<jwt::StringOrUri>,
expiry_duration: Duration,
private_claims: P,
signature_algorithm: Option<jwa::SignatureAlgorithm>,
now: DateTime<Utc>,
) -> Result<jwt::JWT<P, jwt::Empty>, crate::Error> {
let header = make_header(signature_algorithm);
let registered_claims =
make_registered_claims(subject, now, expiry_duration, issuer, audience)?;
Ok(jwt::JWT::new_decoded(
header,
jwt::ClaimsSet::<P> {
private: private_claims,
registered: registered_claims,
},
))
}
fn verify_service(config: &Configuration, service: &str) -> Result<(), Error> {
if !config.audience.contains(&FromStr::from_str(service)?) {
Err(Error::InvalidService)
} else {
Ok(())
}
}
fn verify_issuer(config: &Configuration, issuer: &jwt::StringOrUri) -> Result<(), Error> {
if *issuer == config.issuer {
Ok(())
} else {
Err(Error::InvalidIssuer)
}
}
fn verify_audience(
config: &Configuration,
audience: &jwt::SingleOrMultiple<jwt::StringOrUri>,
) -> Result<(), Error> {
let allowed_audience: HashSet<jwt::StringOrUri> = config.audience.iter().cloned().collect();
let audience: HashSet<jwt::StringOrUri> = audience.iter().cloned().collect();
if audience.is_subset(&allowed_audience) {
Ok(())
} else {
Err(Error::InvalidAudience)
}
}
pub type TokenGetterCorsOptions = cors::Cors;
const TOKEN_GETTER_METHODS: &[Method] = &[Method::Get];
const TOKEN_GETTER_HEADERS: &[&str] = &[
"Authorization",
"Accept",
"Accept-Language",
"Content-Language",
"Content-Type",
"Origin",
];
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Configuration {
pub issuer: jwt::StringOrUri,
pub allowed_origins: cors::AllOrSome<HashSet<cors::headers::Url>>,
pub audience: jwt::SingleOrMultiple<jwt::StringOrUri>,
#[serde(skip_serializing_if = "Option::is_none")]
pub signature_algorithm: Option<jwa::SignatureAlgorithm>,
#[serde(default)]
pub secret: Secret,
#[serde(
with = "crate::serde_custom::duration",
default = "Configuration::default_expiry_duration"
)]
pub expiry_duration: Duration,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub refresh_token: Option<RefreshTokenConfiguration>,
}
const DEFAULT_EXPIRY_DURATION: u64 = 86400;
impl Configuration {
fn default_expiry_duration() -> Duration {
Duration::from_secs(DEFAULT_EXPIRY_DURATION)
}
pub(crate) fn cors_option(&self) -> TokenGetterCorsOptions {
cors::Cors {
allowed_origins: self.allowed_origins.clone(),
allowed_methods: TOKEN_GETTER_METHODS
.iter()
.cloned()
.map(From::from)
.collect(),
allowed_headers: cors::AllOrSome::Some(
TOKEN_GETTER_HEADERS
.iter()
.map(|s| s.to_string().into())
.collect(),
),
allow_credentials: true,
..Default::default()
}
}
pub fn refresh_token_enabled(&self) -> bool {
self.refresh_token.is_some()
}
pub fn refresh_token(&self) -> &RefreshTokenConfiguration {
self.refresh_token.as_ref().unwrap()
}
pub fn keys(&self) -> Result<Keys, Error> {
let (encryption, decryption) = if self.refresh_token_enabled() {
let key = &self.refresh_token().key;
(Some(key.for_encryption()?), Some(key.for_decryption()?))
} else {
(None, None)
};
Ok(Keys {
signing: self.secret.for_signing()?,
signature_verification: self.secret.for_verification()?,
encryption: encryption,
decryption: decryption,
})
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct RefreshTokenConfiguration {
pub cek_algorithm: jwa::KeyManagementAlgorithm,
pub enc_algorithm: jwa::ContentEncryptionAlgorithm,
pub key: Secret,
#[serde(
with = "crate::serde_custom::duration",
default = "Configuration::default_expiry_duration"
)]
pub expiry_duration: Duration,
}
pub type PrivateClaim = JsonValue;
pub type RefreshTokenPayload = jwt::JWT<JsonValue, jwt::Empty>;
pub type RefreshTokenJWE = jwt::jwe::Compact<RefreshTokenPayload, jwt::Empty>;
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
pub struct RefreshToken(RefreshTokenJWE);
impl RefreshToken {
#[cfg_attr(feature = "clippy_lints", allow(too_many_arguments))] fn new_decrypted(
subject: &str,
issuer: &jwt::StringOrUri,
audience: &jwt::SingleOrMultiple<jwt::StringOrUri>,
expiry_duration: Duration,
payload: &JsonValue,
signature_algorithm: Option<jwa::SignatureAlgorithm>,
cek_algorithm: jwa::KeyManagementAlgorithm,
enc_algorithm: jwa::ContentEncryptionAlgorithm,
now: DateTime<Utc>,
) -> Result<Self, crate::Error> {
let token = make_token(
subject,
issuer,
audience,
expiry_duration,
payload.clone(),
signature_algorithm,
now,
)?;
let jwe = jwt::JWE::new_decrypted(
From::from(jwt::jwe::RegisteredHeader {
cek_algorithm: cek_algorithm,
enc_algorithm: enc_algorithm,
media_type: Some("JOSE".to_string()),
content_type: Some("JOSE".to_string()),
..Default::default()
}),
token,
);
Ok(RefreshToken(jwe))
}
pub fn new_encrypted(token: &str) -> Self {
RefreshToken(jwt::JWE::new_encrypted(token))
}
pub fn unwrap(self) -> RefreshTokenJWE {
self.0
}
pub fn encrypted(&self) -> bool {
match *self.borrow() {
jwt::jwe::Compact::Decrypted { .. } => false,
jwt::jwe::Compact::Encrypted(_) => true,
}
}
pub fn decrypted(&self) -> bool {
!self.encrypted()
}
fn encryption_option(&self) -> Result<jwa::EncryptionOptions, Error> {
let headers = &self.0.header()?.registered;
let need_nonce = if let jwa::KeyManagementAlgorithm::A128GCMKW
| jwa::KeyManagementAlgorithm::A192GCMKW
| jwa::KeyManagementAlgorithm::A256GCMKW = headers.cek_algorithm
{
true
} else if let jwa::ContentEncryptionAlgorithm::A128GCM
| jwa::ContentEncryptionAlgorithm::A192GCM
| jwa::ContentEncryptionAlgorithm::A256GCM = headers.enc_algorithm
{
true
} else {
false
};
if need_nonce {
let nonce = crate::auth::util::generate_salt(96 / 8)
.map_err(|_| Error::GenericError("An unknown error".to_string()))?;
Ok(jwa::EncryptionOptions::AES_GCM { nonce: nonce })
} else {
Ok(jwa::EncryptionOptions::None)
}
}
pub fn encrypt(self, secret: &jws::Secret, key: &jwk::JWK<jwt::Empty>) -> Result<Self, Error> {
if self.encrypted() {
Err(Error::RefreshTokenAlreadyEncrypted)?
}
let options = self.encryption_option()?;
let (header, jws) = self.unwrap().unwrap_decrypted();
let jws = jws.into_encoded(secret)?;
let jwe = jwt::JWE::new_decrypted(header, jws);
let jwe = jwe.into_encrypted(key, &options)?;
Ok(From::from(jwe))
}
pub fn decrypt(
self,
secret: &jws::Secret,
key: &jwk::JWK<jwt::Empty>,
signing_algorithm: jwa::SignatureAlgorithm,
cek_algorithm: jwa::KeyManagementAlgorithm,
enc_algorithm: jwa::ContentEncryptionAlgorithm,
) -> Result<Self, Error> {
if self.decrypted() {
Err(Error::RefreshTokenAlreadyDecrypted)?
}
let jwe = self.unwrap();
let jwe = jwe.into_decrypted(key, cek_algorithm, enc_algorithm)?;
let (header, jws) = jwe.unwrap_decrypted();
let jws = jws.into_decoded(secret, signing_algorithm)?;
let jwe = jwt::JWE::new_decrypted(header, jws);
Ok(From::from(jwe))
}
fn claims_set(&self) -> Result<&jwt::ClaimsSet<JsonValue>, Error> {
if !self.decrypted() {
Err(Error::RefreshTokenNotDecrypted)?;
}
Ok(self.0.payload()?.payload()?)
}
pub fn payload(&self) -> Result<&JsonValue, Error> {
Ok(&self.claims_set()?.private)
}
pub fn validate(
&self,
service: &str,
config: &Configuration,
options: Option<jwt::ValidationOptions>,
) -> Result<(), Error> {
use std::str::FromStr;
let options = options.unwrap_or_else(|| jwt::ValidationOptions {
claim_presence_options: jwt::ClaimPresenceOptions {
issued_at: jwt::Presence::Required,
not_before: jwt::Presence::Required,
expiry: jwt::Presence::Required,
..Default::default()
},
..Default::default()
});
let claims_set = self.claims_set()?;
let issuer = claims_set
.registered
.issuer
.as_ref()
.ok_or_else(|| Error::InvalidIssuer)?;
let audience = claims_set
.registered
.audience
.as_ref()
.ok_or_else(|| Error::InvalidAudience)?;
verify_service(config, service)
.and_then(|_| {
if audience.contains(&FromStr::from_str(service)?) {
Ok(())
} else {
Err(Error::InvalidAudience)
}
})
.and_then(|_| verify_audience(config, audience))
.and_then(|_| verify_issuer(config, issuer))
.and_then(|_| {
claims_set
.registered
.validate(options)
.map_err(|e| Error::JWTError(Box::new(jwt::errors::Error::ValidationError(e))))
})
}
pub fn to_string(&self) -> Result<String, Error> {
Ok(self
.0
.encrypted()
.map_err(|_| Error::RefreshTokenNotEncrypted)?
.to_string())
}
}
impl Borrow<RefreshTokenJWE> for RefreshToken {
fn borrow(&self) -> &RefreshTokenJWE {
&self.0
}
}
impl From<RefreshTokenJWE> for RefreshToken {
fn from(value: RefreshTokenJWE) -> Self {
RefreshToken(value)
}
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct Token<T> {
pub token: jwt::JWT<T, jwt::Empty>,
#[serde(with = "crate::serde_custom::duration")]
pub expires_in: Duration,
pub issued_at: DateTime<Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<RefreshToken>,
}
impl<T> Clone for Token<T>
where
T: Serialize + DeserializeOwned + Clone,
{
fn clone(&self) -> Self {
Token {
token: self.token.clone(),
expires_in: self.expires_in,
issued_at: self.issued_at,
refresh_token: self.refresh_token.clone(),
}
}
}
impl<T: Serialize + DeserializeOwned + 'static> Token<T> {
fn with_configuration_and_time(
config: &Configuration,
subject: &str,
service: &str,
private_claims: T,
refresh_token_payload: Option<&JsonValue>,
now: DateTime<Utc>,
) -> Result<Self, crate::Error> {
verify_service(config, service)?;
let access_token = make_token(
subject,
&config.issuer,
&config.audience,
config.expiry_duration,
private_claims,
config.signature_algorithm,
now,
)?;
let refresh_token = match config.refresh_token {
None => None,
Some(ref refresh_token_config) => match refresh_token_payload {
Some(payload) => Some(RefreshToken::new_decrypted(
subject,
&config.issuer,
&config.audience,
refresh_token_config.expiry_duration,
payload,
config.signature_algorithm,
refresh_token_config.cek_algorithm,
refresh_token_config.enc_algorithm,
now,
)?),
None => None,
},
};
let issued_at = access_token
.payload()
.unwrap()
.registered
.issued_at
.unwrap();
let token = Token::<T> {
token: access_token,
expires_in: config.expiry_duration,
issued_at: *issued_at.deref(),
refresh_token: refresh_token,
};
Ok(token)
}
pub fn with_configuration(
config: &Configuration,
subject: &str,
service: &str,
private_claims: T,
refresh_token_payload: Option<&JsonValue>,
) -> Result<Self, crate::Error> {
Self::with_configuration_and_time(
config,
subject,
service,
private_claims,
refresh_token_payload,
Utc::now(),
)
}
pub fn encode(mut self, secret: &jws::Secret) -> Result<Self, Error> {
match self.token {
jwt::jws::Compact::Encoded(_) => Err(Error::TokenAlreadyEncoded),
jwt @ jwt::jws::Compact::Decoded { .. } => {
self.token = jwt.into_encoded(secret)?;
Ok(self)
}
}
}
pub fn decode(
mut self,
secret: &jws::Secret,
algorithm: jwa::SignatureAlgorithm,
) -> Result<Self, Error> {
match self.token {
jwt @ jwt::jws::Compact::Encoded(_) => {
self.token = jwt.into_decoded(secret, algorithm)?;
Ok(self)
}
jwt::jws::Compact::Decoded { .. } => Err(Error::TokenAlreadyDecoded),
}
}
fn serialize(self) -> Result<String, Error> {
if self.is_decoded() {
Err(Error::TokenNotEncoded)?
}
let serialized = serde_json::to_string(&self)?;
Ok(serialized)
}
fn respond<'r>(self) -> Result<Response<'r>, Error> {
let serialized = self.serialize()?;
Response::build()
.header(ContentType::JSON)
.sized_body(Cursor::new(serialized))
.ok()
}
pub fn is_decoded(&self) -> bool {
match self.token {
jwt::jws::Compact::Encoded(_) => false,
jwt::jws::Compact::Decoded { .. } => true,
}
}
pub fn is_encoded(&self) -> bool {
!self.is_decoded()
}
pub fn registered_claims(&self) -> Result<&jwt::RegisteredClaims, crate::Error> {
match self.token {
jwt::jws::Compact::Encoded(_) => Err(Error::TokenNotDecoded)?,
ref jwt @ jwt::jws::Compact::Decoded { .. } => Ok(match_extract!(*jwt,
jwt::jws::Compact::Decoded {
payload: jwt::ClaimsSet { ref registered, .. },
..
},
registered)?),
}
}
pub fn private_claims(&self) -> Result<&T, crate::Error> {
match self.token {
jwt::jws::Compact::Encoded(_) => Err(Error::TokenNotDecoded)?,
ref jwt @ jwt::jws::Compact::Decoded { .. } => Ok(match_extract!(*jwt,
jwt::jws::Compact::Decoded {
payload: jwt::ClaimsSet { ref private, .. },
..
},
private)?),
}
}
pub fn header(&self) -> Result<&jwt::jws::Header<jwt::Empty>, crate::Error> {
match self.token {
jwt::jws::Compact::Encoded(_) => Err(Error::TokenNotDecoded)?,
ref jwt @ jwt::jws::Compact::Decoded { .. } => Ok(match_extract!(*jwt,
jwt::jws::Compact::Decoded {
ref header,
..
},
header)?),
}
}
pub fn encoded_token(&self) -> Result<String, crate::Error> {
Ok(self
.token
.encoded()
.map_err(|e| Error::JWTError(Box::new(e)))?
.to_string())
}
pub fn refresh_token(&self) -> Option<&RefreshToken> {
self.refresh_token.as_ref()
}
pub fn encrypt_refresh_token(
mut self,
secret: &jws::Secret,
key: &jwk::JWK<jwt::Empty>,
) -> Result<Self, Error> {
let refresh_token = self.refresh_token.ok_or_else(|| Error::NoRefreshToken)?;
let refresh_token = refresh_token.encrypt(secret, key)?;
self.refresh_token = Some(refresh_token);
Ok(self)
}
pub fn decrypt_refresh_token(
mut self,
secret: &jws::Secret,
key: &jwk::JWK<jwt::Empty>,
signing_algorithm: jwa::SignatureAlgorithm,
cek_algorithm: jwa::KeyManagementAlgorithm,
enc_algorithm: jwa::ContentEncryptionAlgorithm,
) -> Result<Self, Error> {
let refresh_token = self.refresh_token.ok_or_else(|| Error::NoRefreshToken)?;
let refresh_token =
refresh_token.decrypt(secret, key, signing_algorithm, cek_algorithm, enc_algorithm)?;
self.refresh_token = Some(refresh_token);
Ok(self)
}
pub fn has_refresh_token(&self) -> bool {
self.refresh_token.is_some()
}
}
impl<'r, T: Serialize + DeserializeOwned + 'static> Responder<'r> for Token<T> {
fn respond_to(self, request: &Request<'_>) -> Result<Response<'r>, Status> {
match self.respond() {
Ok(r) => Ok(r),
Err(e) => Err::<String, Error>(e).respond_to(request),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum Secret {
None,
ByteSequence(ByteSequence),
Bytes {
path: String,
},
RSAKeyPair {
rsa_private: String,
rsa_public: String,
},
}
impl Default for Secret {
fn default() -> Self {
Secret::None
}
}
impl Secret {
pub(super) fn for_signing(&self) -> Result<jws::Secret, Error> {
match *self {
Secret::None => Ok(jws::Secret::None),
Secret::ByteSequence(ref bytes) => Ok(jws::Secret::Bytes(bytes.as_bytes())),
Secret::Bytes { ref path } => Ok(jws::Secret::Bytes(Self::read_file_to_bytes(path)?)),
Secret::RSAKeyPair {
ref rsa_private, ..
} => Ok(jws::Secret::rsa_keypair_from_file(rsa_private)?),
}
}
pub(super) fn for_verification(&self) -> Result<jws::Secret, Error> {
match *self {
Secret::None => Ok(jws::Secret::None),
Secret::ByteSequence(ref bytes) => Ok(jws::Secret::Bytes(bytes.as_bytes())),
Secret::Bytes { ref path } => Ok(jws::Secret::Bytes(Self::read_file_to_bytes(path)?)),
Secret::RSAKeyPair { ref rsa_public, .. } => {
Ok(jws::Secret::public_key_from_file(rsa_public)?)
}
}
}
pub(super) fn for_encryption(&self) -> Result<jwk::JWK<jwt::Empty>, Error> {
match *self {
Secret::None => Err(Error::GenericError(
"A key is required for encryption".to_string(),
)),
Secret::ByteSequence(ref bytes) => Ok(jwk::JWK::new_octect_key(
&bytes.as_bytes(),
Default::default(),
)),
Secret::Bytes { ref path } => Ok(jwk::JWK::new_octect_key(
&Self::read_file_to_bytes(path)?,
Default::default(),
)),
Secret::RSAKeyPair { .. } => Err(Error::GenericError("Not supported yet".to_string())),
}
}
pub(super) fn for_decryption(&self) -> Result<jwk::JWK<jwt::Empty>, Error> {
self.for_encryption()
}
fn read_file_to_bytes(path: &str) -> Result<Vec<u8>, Error> {
let mut file = File::open(path)?;
let mut bytes = Vec::<u8>::new();
let _ = file.read_to_end(&mut bytes)?;
Ok(bytes)
}
}
pub struct Keys {
pub signing: jws::Secret,
pub signature_verification: jws::Secret,
pub encryption: Option<jwk::JWK<jwt::Empty>>,
pub decryption: Option<jwk::JWK<jwt::Empty>>,
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use std::time::Duration;
use chrono::{DateTime, NaiveDateTime, Utc};
use serde_json;
use super::*;
use crate::jwt;
use crate::{JsonMap, JsonValue};
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
struct TestClaims {
company: String,
department: String,
}
impl Default for TestClaims {
fn default() -> Self {
TestClaims {
company: "ACME".to_string(),
department: "Toilet Cleaning".to_string(),
}
}
}
fn make_config(refresh_token: bool) -> Configuration {
let refresh_token = if refresh_token {
Some(RefreshTokenConfiguration {
cek_algorithm: jwt::jwa::KeyManagementAlgorithm::A256GCMKW,
enc_algorithm: jwt::jwa::ContentEncryptionAlgorithm::A256GCM,
key: Secret::ByteSequence(ByteSequence::Bytes(vec![0; 256 / 8])),
expiry_duration: Duration::from_secs(86400),
})
} else {
None
};
let allowed_origins = ["https://www.example.com"];
let (allowed_origins, _) = crate::cors::AllowedOrigins::some(&allowed_origins);
Configuration {
issuer: FromStr::from_str("https://www.acme.com").unwrap(),
allowed_origins: allowed_origins,
audience: jwt::SingleOrMultiple::Single(
FromStr::from_str("https://www.example.com/").unwrap(),
),
signature_algorithm: Some(jwt::jwa::SignatureAlgorithm::HS512),
secret: Secret::ByteSequence(ByteSequence::String("secret".to_string())),
expiry_duration: Duration::from_secs(120),
refresh_token: refresh_token,
}
}
fn refresh_token_payload() -> JsonValue {
let mut map = JsonMap::with_capacity(1);
let _ = map.insert("test".to_string(), From::from("foobar"));
JsonValue::Object(map)
}
fn make_refresh_token() -> RefreshToken {
RefreshToken::new_decrypted(
"foobar",
&FromStr::from_str("https://www.acme.com").unwrap(),
&jwt::SingleOrMultiple::Single(FromStr::from_str("https://www.example.com").unwrap()),
Duration::from_secs(120),
&refresh_token_payload(),
Some(Default::default()),
jwt::jwa::KeyManagementAlgorithm::A256GCMKW,
jwt::jwa::ContentEncryptionAlgorithm::A256GCM,
Utc::now(),
)
.unwrap()
}
fn make_token(refresh_token: bool) -> Token<TestClaims> {
let refresh_token = if refresh_token {
Some(make_refresh_token())
} else {
None
};
Token {
token: jwt::JWT::new_decoded(
jwt::jws::Header::default(),
jwt::ClaimsSet {
private: Default::default(),
registered: Default::default(),
},
),
expires_in: Duration::from_secs(120),
issued_at: Utc::now(),
refresh_token: refresh_token,
}
}
#[test]
fn refresh_token_encryption_round_trip() {
let key = jwt::jwk::JWK::new_octect_key(&[0; 256 / 8], Default::default());
let signing_secret = jwt::jws::Secret::bytes_from_str("secret");
let refresh_token = make_refresh_token();
assert!(refresh_token.decrypted());
let encrypted_refresh_token =
not_err!(refresh_token.clone().encrypt(&signing_secret, &key));
assert!(encrypted_refresh_token.encrypted());
let decrypted_refresh_token = not_err!(encrypted_refresh_token.decrypt(
&signing_secret,
&key,
Default::default(),
jwt::jwa::KeyManagementAlgorithm::A256GCMKW,
jwt::jwa::ContentEncryptionAlgorithm::A256GCM
));
assert!(decrypted_refresh_token.decrypted());
let actual_refresh_token_payload: &JsonValue = decrypted_refresh_token.payload().unwrap();
let map = actual_refresh_token_payload.as_object().unwrap();
assert_eq!(map.get("test").unwrap().as_str().unwrap(), "foobar");
}
#[test]
fn serializing_and_deserializing_round_trip() {
let key = jwt::jwk::JWK::new_octect_key(&[0; 256 / 8], Default::default());
let signing_secret = jwt::jws::Secret::bytes_from_str("secret");
let token = make_token(true);
let token = not_err!(token.encode(&signing_secret));
assert!(token.is_encoded());
let token = not_err!(token.encrypt_refresh_token(&signing_secret, &key));
assert!(token.refresh_token().unwrap().encrypted());
let serialized = not_err!(serde_json::to_string_pretty(&token));
let deserialized: Token<TestClaims> = not_err!(serde_json::from_str(&serialized));
assert_eq!(deserialized, token);
let token = not_err!(token.decode(&signing_secret, Default::default()));
let token = not_err!(token.decrypt_refresh_token(
&signing_secret,
&key,
Default::default(),
jwt::jwa::KeyManagementAlgorithm::A256GCMKW,
jwt::jwa::ContentEncryptionAlgorithm::A256GCM
));
let private = not_err!(token.private_claims());
assert_eq!(*private, Default::default());
let refresh_token = token.refresh_token().unwrap();
let actual_refresh_token_payload: &JsonValue = refresh_token.payload().unwrap();
let map = actual_refresh_token_payload.as_object().unwrap();
assert_eq!(map.get("test").unwrap().as_str().unwrap(), "foobar");
}
#[test]
#[should_panic(expected = "TokenAlreadyEncoded")]
fn panics_when_encoding_encoded() {
let token = make_token(false);
let token = not_err!(token.encode(&jwt::jws::Secret::bytes_from_str("secret")));
let _ = token
.encode(&jwt::jws::Secret::bytes_from_str("secret"))
.unwrap();
}
#[test]
#[should_panic(expected = "TokenAlreadyDecoded")]
fn panics_when_decoding_decoded() {
let token = make_token(false);
let _ = token
.decode(
&jwt::jws::Secret::bytes_from_str("secret"),
Default::default(),
)
.unwrap();
}
#[test]
#[should_panic(expected = "RefreshTokenAlreadyEncrypted")]
fn panics_when_encrypting_encrypted() {
let key = jwt::jwk::JWK::new_octect_key(&[0; 256 / 8], Default::default());
let signing_secret = jwt::jws::Secret::bytes_from_str("secret");
let token = make_token(true);
let token = not_err!(token.encrypt_refresh_token(&signing_secret, &key));
let _ = token.encrypt_refresh_token(&signing_secret, &key).unwrap();
}
#[test]
#[should_panic(expected = "RefreshTokenAlreadyDecrypted")]
fn panics_when_decrypting_decrypted() {
let key = jwt::jwk::JWK::new_octect_key(&[0; 256 / 8], Default::default());
let signing_secret = jwt::jws::Secret::bytes_from_str("secret");
let token = make_token(true);
let _ = token
.decrypt_refresh_token(
&signing_secret,
&key,
Default::default(),
jwt::jwa::KeyManagementAlgorithm::A256GCMKW,
jwt::jwa::ContentEncryptionAlgorithm::A256GCM,
)
.unwrap();
}
#[test]
fn token_serialization_smoke_test() {
let expected_token = make_token(false);
let token = not_err!(expected_token
.clone()
.encode(&jwt::jws::Secret::bytes_from_str("secret")));
let serialized = not_err!(token.serialize());
let deserialized: Token<TestClaims> = not_err!(serde_json::from_str(&serialized));
let actual_token = not_err!(deserialized.decode(
&jwt::jws::Secret::bytes_from_str("secret"),
Default::default()
));
assert_eq!(expected_token, actual_token);
}
#[test]
fn token_response_smoke_test() {
let expected_token = make_token(false);
let token = not_err!(expected_token
.clone()
.encode(&jwt::jws::Secret::bytes_from_str("secret")));
let mut response = not_err!(token.respond());
assert_eq!(response.status(), Status::Ok);
let body_str = not_none!(response.body().and_then(|body| body.into_string()));
let deserialized: Token<TestClaims> = not_err!(serde_json::from_str(&body_str));
let actual_token = not_err!(deserialized.decode(
&jwt::jws::Secret::bytes_from_str("secret"),
Default::default()
));
assert_eq!(expected_token, actual_token);
}
#[test]
fn secrets_are_transformed_for_signing_correctly() {
let none = Secret::None;
assert_matches_non_debug!(not_err!(none.for_signing()), jwt::jws::Secret::None);
let string = Secret::ByteSequence(ByteSequence::String("secret".to_string()));
assert_matches_non_debug!(not_err!(string.for_signing()), jwt::jws::Secret::Bytes(_));
let rsa = Secret::RSAKeyPair {
rsa_private: "test/fixtures/rsa_private_key.der".to_string(),
rsa_public: "test/fixtures/rsa_public_key.der".to_string(),
};
assert_matches_non_debug!(not_err!(rsa.for_signing()), jwt::jws::Secret::RSAKeyPair(_));
}
#[test]
fn secrets_are_transformed_for_verification_correctly() {
let none = Secret::None;
assert_matches_non_debug!(not_err!(none.for_verification()), jwt::jws::Secret::None);
let string = Secret::ByteSequence(ByteSequence::String("secret".to_string()));
assert_matches_non_debug!(
not_err!(string.for_verification()),
jwt::jws::Secret::Bytes(_)
);
let rsa = Secret::RSAKeyPair {
rsa_private: "test/fixtures/rsa_private_key.der".to_string(),
rsa_public: "test/fixtures/rsa_public_key.der".to_string(),
};
assert_matches_non_debug!(
not_err!(rsa.for_verification()),
jwt::jws::Secret::PublicKey(_)
);
}
#[test]
fn token_created_with_refresh_token_disabled() {
let configuration = make_config(false);
let mut map = JsonMap::with_capacity(1);
let _ = map.insert("test".to_string(), From::from("foobar"));
let refresh_token_payload = JsonValue::Object(map);
let now = DateTime::<Utc>::from_utc(NaiveDateTime::from_timestamp(0, 0), Utc);
let expected_expiry = now + chrono::Duration::from_std(Duration::from_secs(120)).unwrap();
let token = not_err!(Token::<TestClaims>::with_configuration_and_time(
&configuration,
"Donald Trump",
"https://www.example.com/",
Default::default(),
Some(&refresh_token_payload),
now
));
let registered = not_err!(token.registered_claims());
assert_eq!(
registered.issuer,
Some(FromStr::from_str("https://www.acme.com").unwrap())
);
assert_eq!(
registered.subject,
Some(FromStr::from_str("Donald Trump").unwrap())
);
assert_eq!(
registered.audience,
Some(jwt::SingleOrMultiple::Single(
FromStr::from_str("https://www.example.com").unwrap()
))
);
assert_eq!(registered.issued_at, Some(now.into()));
assert_eq!(registered.not_before, Some(now.into()));
assert_eq!(registered.expiry, Some(expected_expiry.into()));
let private = not_err!(token.private_claims());
assert_eq!(*private, Default::default());
let header = not_err!(token.header());
assert_eq!(
header.registered.algorithm,
jwt::jwa::SignatureAlgorithm::HS512
);
assert!(token.refresh_token().is_none());
}
#[test]
fn token_created_with_no_refresh_token_payload() {
let configuration = make_config(true);
let now = DateTime::<Utc>::from_utc(NaiveDateTime::from_timestamp(0, 0), Utc);
let expected_expiry = now + chrono::Duration::from_std(Duration::from_secs(120)).unwrap();
let token = not_err!(Token::<TestClaims>::with_configuration_and_time(
&configuration,
"Donald Trump",
"https://www.example.com/",
Default::default(),
None,
now
));
let registered = not_err!(token.registered_claims());
assert_eq!(
registered.issuer,
Some(FromStr::from_str("https://www.acme.com").unwrap())
);
assert_eq!(
registered.subject,
Some(FromStr::from_str("Donald Trump").unwrap())
);
assert_eq!(
registered.audience,
Some(jwt::SingleOrMultiple::Single(
FromStr::from_str("https://www.example.com").unwrap()
))
);
assert_eq!(registered.issued_at, Some(now.into()));
assert_eq!(registered.not_before, Some(now.into()));
assert_eq!(registered.expiry, Some(expected_expiry.into()));
let private = not_err!(token.private_claims());
assert_eq!(*private, Default::default());
let header = not_err!(token.header());
assert_eq!(
header.registered.algorithm,
jwt::jwa::SignatureAlgorithm::HS512
);
assert!(token.refresh_token().is_none());
}
#[test]
fn token_created_with_refresh_token() {
let configuration = make_config(true);
let mut map = JsonMap::with_capacity(1);
let _ = map.insert("test".to_string(), From::from("foobar"));
let refresh_token_payload = JsonValue::Object(map);
let now = DateTime::<Utc>::from_utc(NaiveDateTime::from_timestamp(0, 0), Utc);
let expected_expiry = now + chrono::Duration::from_std(Duration::from_secs(120)).unwrap();
let token = not_err!(Token::<TestClaims>::with_configuration_and_time(
&configuration,
"Donald Trump",
"https://www.example.com/",
Default::default(),
Some(&refresh_token_payload),
now
));
let registered = not_err!(token.registered_claims());
assert_eq!(
registered.issuer,
Some(FromStr::from_str("https://www.acme.com").unwrap())
);
assert_eq!(
registered.subject,
Some(FromStr::from_str("Donald Trump").unwrap())
);
assert_eq!(
registered.audience,
Some(jwt::SingleOrMultiple::Single(
FromStr::from_str("https://www.example.com").unwrap()
))
);
assert_eq!(registered.issued_at, Some(now.into()));
assert_eq!(registered.not_before, Some(now.into()));
assert_eq!(registered.expiry, Some(expected_expiry.into()));
let private = not_err!(token.private_claims());
assert_eq!(*private, Default::default());
let header = not_err!(token.header());
assert_eq!(
header.registered.algorithm,
jwt::jwa::SignatureAlgorithm::HS512
);
let refresh_token = token.refresh_token().unwrap();
let actual_refresh_token_payload: &JsonValue = refresh_token.payload().unwrap();
assert_eq!(*actual_refresh_token_payload, refresh_token_payload);
}
#[test]
#[should_panic(expected = "InvalidService")]
fn validates_service_correctly() {
let configuration = make_config(true);
let now = DateTime::<Utc>::from_utc(NaiveDateTime::from_timestamp(0, 0), Utc);
let _ = Token::<TestClaims>::with_configuration_and_time(
&configuration,
"Donald Trump",
"invalid",
Default::default(),
None,
now,
)
.unwrap();
}
#[test]
fn refresh_token_validates_correctly() {
let configuration = make_config(true);
let refresh_token = make_refresh_token();
not_err!(refresh_token.validate("https://www.example.com/", &configuration, None));
}
#[test]
#[should_panic(expected = "InvalidIssuer")]
fn refresh_token_validates_missing_issuer() {
let configuration = make_config(true);
let refresh_token = make_refresh_token();
let mut jwe = refresh_token.unwrap();
{
let jws = jwe.payload_mut().unwrap();
let claims_set = jws.payload_mut().unwrap();
claims_set.registered.issuer = None;
}
let refresh_token = RefreshToken(jwe);
refresh_token
.validate("https://www.example.com/", &configuration, None)
.unwrap();
}
#[test]
#[should_panic(expected = "InvalidAudience")]
fn refresh_token_validates_missing_audience() {
let configuration = make_config(true);
let refresh_token = make_refresh_token();
let mut jwe = refresh_token.unwrap();
{
let jws = jwe.payload_mut().unwrap();
let claims_set = jws.payload_mut().unwrap();
claims_set.registered.audience = None;
}
let refresh_token = RefreshToken(jwe);
refresh_token
.validate("https://www.example.com/", &configuration, None)
.unwrap();
}
#[test]
#[should_panic(expected = "InvalidService")]
fn refresh_token_validates_invalid_service() {
let configuration = make_config(true);
let refresh_token = make_refresh_token();
refresh_token
.validate("https://www.invalid.com/", &configuration, None)
.unwrap();
}
#[test]
#[should_panic(expected = "InvalidAudience")]
fn refresh_token_validates_mismatch_service_and_audience() {
let mut configuration = make_config(true);
configuration.audience =
jwt::SingleOrMultiple::Single(FromStr::from_str("https://www.invalid.com/").unwrap());
let refresh_token = make_refresh_token();
refresh_token
.validate("https://www.invalid.com/", &configuration, None)
.unwrap();
}
#[test]
#[should_panic(expected = "InvalidAudience")]
fn refresh_token_validates_invalid_audience() {
let configuration = make_config(true);
let refresh_token = make_refresh_token();
let mut jwe = refresh_token.unwrap();
{
let jws = jwe.payload_mut().unwrap();
let claims_set = jws.payload_mut().unwrap();
claims_set.registered.audience = Some(jwt::SingleOrMultiple::Multiple(vec![
FromStr::from_str("https://www.invalid.com/").unwrap(),
FromStr::from_str("https://www.example.com/").unwrap(),
]));
}
let refresh_token = RefreshToken(jwe);
refresh_token
.validate("https://www.example.com/", &configuration, None)
.unwrap();
}
#[test]
#[should_panic(expected = "InvalidIssuer")]
fn refresh_token_validates_invalid_issuer() {
let configuration = make_config(true);
let refresh_token = make_refresh_token();
let mut jwe = refresh_token.unwrap();
{
let jws = jwe.payload_mut().unwrap();
let claims_set = jws.payload_mut().unwrap();
claims_set.registered.issuer =
Some(FromStr::from_str("https://www.invalid.com/").unwrap());
}
let refresh_token = RefreshToken(jwe);
refresh_token
.validate("https://www.example.com/", &configuration, None)
.unwrap();
}
#[test]
#[should_panic(expected = "MissingRequiredClaims([\"iat\"])")]
fn refresh_token_validates_missing_issued_at() {
let configuration = make_config(true);
let refresh_token = make_refresh_token();
let mut jwe = refresh_token.unwrap();
{
let jws = jwe.payload_mut().unwrap();
let claims_set = jws.payload_mut().unwrap();
claims_set.registered.issued_at = None;
}
let refresh_token = RefreshToken(jwe);
refresh_token
.validate("https://www.example.com/", &configuration, None)
.unwrap();
}
#[test]
#[should_panic(expected = "MissingRequiredClaims([\"nbf\"])")]
fn refresh_token_validates_missing_not_before() {
let configuration = make_config(true);
let refresh_token = make_refresh_token();
let mut jwe = refresh_token.unwrap();
{
let jws = jwe.payload_mut().unwrap();
let claims_set = jws.payload_mut().unwrap();
claims_set.registered.not_before = None;
}
let refresh_token = RefreshToken(jwe);
refresh_token
.validate("https://www.example.com/", &configuration, None)
.unwrap();
}
#[test]
#[should_panic(expected = "MissingRequiredClaims([\"exp\"])")]
fn refresh_token_validates_missing_expiry() {
let configuration = make_config(true);
let refresh_token = make_refresh_token();
let mut jwe = refresh_token.unwrap();
{
let jws = jwe.payload_mut().unwrap();
let claims_set = jws.payload_mut().unwrap();
claims_set.registered.expiry = None;
}
let refresh_token = RefreshToken(jwe);
refresh_token
.validate("https://www.example.com/", &configuration, None)
.unwrap();
}
}