use std::fmt;
use std::pin::Pin;
use std::str::FromStr;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use futures::Future;
use rand::rngs::OsRng;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use signature::{Signer, Verifier};
pub use ed25519_dalek::{Keypair, PublicKey, Signature};
#[async_trait]
pub trait Resolve: Send + Sync
where
<Self::Host as FromStr>::Err: fmt::Display,
{
type Host: fmt::Display + FromStr + Send + Sync;
type ActorId: DeserializeOwned + PartialEq + Send + Sync;
type Claims: DeserializeOwned + Send + Sync;
fn host(&self) -> Self::Host;
async fn resolve(
&self,
host: &Self::Host,
actor_id: &Self::ActorId,
) -> Result<Actor<Self::ActorId>>;
async fn consume(
&self,
actor_id: Self::ActorId,
claims: Self::Claims,
token: String,
now: SystemTime,
) -> Result<(
Token<Self::ActorId, Self::Claims>,
Claims<Self::ActorId, Self::Claims>,
)> {
let parent_claims = self.validate(&token, now).await?;
let iat = (now.duration_since(UNIX_EPOCH)).map_err(|e| Error::new(ErrorKind::Time, e))?;
let token = Token {
iss: self.host().to_string(),
iat: iat.as_secs(),
exp: parent_claims.exp,
actor_id,
custom: claims,
inherit: Some(token),
};
Ok((token, parent_claims))
}
fn validate<'a>(
&'a self,
encoded: &'a str,
now: SystemTime,
) -> Pin<Box<dyn Future<Output = Result<Claims<Self::ActorId, Self::Claims>>> + Send + 'a>>
{
Box::pin(async move {
let (message, signature) = token_signature(encoded)?;
let token = decode_token(message)?;
if token.is_expired(now)? {
return Err(Error::new(ErrorKind::Time, "token is expired"));
}
let host = token
.iss
.parse()
.map_err(|e| Error::new(ErrorKind::Format, e))?;
let actor = self.resolve(&host, &token.actor_id).await?;
if actor.id != token.actor_id {
return Err(Error::new(
ErrorKind::Auth,
"attempted to use bearer token for different actor",
));
} else if let Err(cause) = actor.public_key().verify(message.as_bytes(), &signature) {
return Err(Error::new(
ErrorKind::Auth,
format!("invalid bearer token: {}", cause),
));
}
if let Some(parent) = token.inherit {
let parent_claims = self.validate(&parent, now).await?;
Ok(Claims::consume(
token.exp,
token.iss,
token.actor_id,
token.custom,
parent_claims,
))
} else {
Ok(Claims::new(
token.exp,
token.iss,
token.actor_id,
token.custom,
))
}
})
}
}
#[derive(Clone, Copy, Eq, PartialEq)]
pub enum ErrorKind {
Base64,
Fetch,
Format,
Json,
Auth,
Time,
}
impl fmt::Display for ErrorKind {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(match self {
Self::Auth => "authentication",
Self::Base64 => "base64 format",
Self::Fetch => "key fetch",
Self::Format => "token format",
Self::Json => "json format",
Self::Time => "time",
})
}
}
pub struct Error {
kind: ErrorKind,
message: String,
}
impl Error {
pub fn new<M: fmt::Display>(kind: ErrorKind, message: M) -> Self {
Self {
kind,
message: message.to_string(),
}
}
pub fn not_found() -> Self {
Self {
kind: ErrorKind::Fetch,
message: "not found".to_string(),
}
}
pub fn kind(&'_ self) -> ErrorKind {
self.kind
}
}
impl std::error::Error for Error {}
impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} error: {}", self.kind, self.message)
}
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Clone, Debug)]
pub struct Claims<I, C> {
exp: u64,
host: String,
actor_id: I,
claims: C,
inherit: Option<Box<Claims<I, C>>>,
}
impl<I: PartialEq, C> Claims<I, C> {
fn new(exp: u64, host: String, actor_id: I, claims: C) -> Self {
Self {
exp,
host,
actor_id,
claims,
inherit: None,
}
}
fn consume(exp: u64, host: String, actor_id: I, claims: C, parent: Self) -> Self {
Self {
exp,
host,
actor_id,
claims,
inherit: Some(Box::new(parent)),
}
}
pub fn get(&self, host: &str, actor_id: &I) -> Option<&C> {
if host == &self.host && actor_id == &self.actor_id {
Some(&self.claims)
} else if let Some(claims) = &self.inherit {
claims.get(host, actor_id)
} else {
None
}
}
}
#[derive(Clone, Deserialize, Serialize)]
pub struct Token<I, C> {
iss: String,
iat: u64,
exp: u64,
actor_id: I,
custom: C,
inherit: Option<String>,
}
impl<I: Eq, C: Eq> Eq for Token<I, C> {}
impl<I: PartialEq, C: PartialEq> PartialEq for Token<I, C> {
fn eq(&self, other: &Self) -> bool {
self.iss == other.iss
&& self.iat == other.iat
&& self.exp == other.exp
&& self.actor_id == other.actor_id
&& self.custom == other.custom
&& self.inherit == other.inherit
}
}
impl<I, C> Token<I, C> {
pub fn new(iss: String, iat: SystemTime, ttl: Duration, actor_id: I, claims: C) -> Self {
let iat = iat.duration_since(UNIX_EPOCH).unwrap();
let exp = iat + ttl;
Self {
iss,
iat: iat.as_secs(),
exp: exp.as_secs(),
actor_id,
custom: claims,
inherit: None,
}
}
pub fn issuer(&'_ self) -> &'_ str {
&self.iss
}
pub fn actor_id(&'_ self) -> &'_ I {
&self.actor_id
}
pub fn is_expired(&self, now: SystemTime) -> Result<bool> {
let iat = UNIX_EPOCH + Duration::from_secs(self.iat);
let exp = UNIX_EPOCH + Duration::from_secs(self.exp);
let ttl = exp
.duration_since(iat)
.map_err(|e| Error::new(ErrorKind::Time, e))?;
match now.duration_since(iat) {
Ok(elapsed) => Ok(elapsed > ttl),
Err(cause) => Err(Error::new(ErrorKind::Time, cause)),
}
}
pub fn claims(&'_ self) -> Claims<I, C>
where
I: Clone,
C: Clone,
{
Claims {
exp: self.exp,
host: self.iss.clone(),
actor_id: self.actor_id.clone(),
claims: self.custom.clone(),
inherit: None,
}
}
}
impl<I: DeserializeOwned, C: DeserializeOwned> FromStr for Token<I, C> {
type Err = Error;
fn from_str(token: &str) -> Result<Self> {
let token: Vec<&str> = token.split('.').collect();
if token.len() != 3 {
return Err(Error::new(
ErrorKind::Format,
"Expected a bearer token in the format '<header>.<claims>.<data>'",
));
}
let token = base64_decode(token[1])?;
json_decode(&token)
}
}
impl<I: fmt::Display, C> fmt::Debug for Token<I, C> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl<I: fmt::Display, C> fmt::Display for Token<I, C> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"JWT token claiming Actor {} at host {}",
self.actor_id, self.iss
)
}
}
enum Key {
Public(PublicKey),
Secret(Keypair),
}
pub struct Actor<I> {
id: I,
key: Key,
}
impl<I> Actor<I> {
pub fn new_keypair() -> Keypair {
let mut rng = OsRng {};
Keypair::generate(&mut rng)
}
pub fn new(id: I) -> Self {
Actor {
id,
key: Key::Secret(Self::new_keypair()),
}
}
pub fn with_keypair(id: I, public_key: &[u8], secret: &[u8]) -> Result<Self> {
let keypair = Keypair::from_bytes(&[secret, public_key].concat())
.map_err(|e| Error::new(ErrorKind::Auth, e))?;
Ok(Self {
id,
key: Key::Secret(keypair),
})
}
pub fn with_public_key(id: I, public_key: &[u8]) -> Result<Self> {
let key = PublicKey::from_bytes(public_key).map_err(|e| Error::new(ErrorKind::Auth, e))?;
Ok(Self {
id,
key: Key::Public(key),
})
}
pub fn id(&'_ self) -> &'_ I {
&self.id
}
pub fn public_key(&'_ self) -> &'_ PublicKey {
match &self.key {
Key::Public(public) => public,
Key::Secret(secret) => &secret.public,
}
}
pub fn sign_token<C: Serialize>(&self, token: &Token<I, C>) -> Result<String>
where
I: Serialize,
{
let keypair = if let Key::Secret(keypair) = &self.key {
keypair
} else {
return Err(Error::new(
ErrorKind::Auth,
"cannot sign a token without a private key",
));
};
let header = base64_json_encode(&TokenHeader::default())?;
let claims = base64_json_encode(&token)?;
let signature = base64::encode(
&keypair
.sign(format!("{}.{}", header, claims).as_bytes())
.to_bytes()[..],
);
Ok(format!("{}.{}.{}", header, claims, signature))
}
}
impl<I: Clone> Clone for Actor<I> {
fn clone(&self) -> Self {
let key = self.public_key().clone();
Actor {
key: Key::Public(key),
id: self.id.clone(),
}
}
}
#[derive(Eq, PartialEq, Deserialize, Serialize)]
struct TokenHeader {
alg: String,
typ: String,
}
impl Default for TokenHeader {
fn default() -> TokenHeader {
TokenHeader {
alg: "ES256".into(),
typ: "JWT".into(),
}
}
}
fn token_signature(encoded: &str) -> Result<(&str, Signature)> {
if encoded.ends_with('.') {
return Err(Error::new(
ErrorKind::Format,
"encoded token cannot end with .",
));
}
let i = encoded
.rfind('.')
.ok_or_else(|| Error::new(ErrorKind::Format, format!("invalid token: {}", encoded)))?;
let message = &encoded[..i];
let signature =
base64::decode(&encoded[(i + 1)..]).map_err(|e| Error::new(ErrorKind::Base64, e))?;
let signature =
signature::Signature::from_bytes(&signature).map_err(|e| Error::new(ErrorKind::Auth, e))?;
Ok((message, signature))
}
fn decode_token<I: DeserializeOwned, C: DeserializeOwned>(encoded: &str) -> Result<Token<I, C>> {
let i = encoded
.find('.')
.ok_or_else(|| Error::new(ErrorKind::Format, format!("invalid token: {}", encoded)))?;
let header = base64_decode(&encoded[..i])?;
let header: TokenHeader =
serde_json::from_slice(&header).map_err(|e| Error::new(ErrorKind::Json, e))?;
if header != TokenHeader::default() {
return Err(Error::new(
ErrorKind::Format,
"Unsupported bearer token type",
));
}
let token = base64_decode(&encoded[(i + 1)..])?;
serde_json::from_slice(&token).map_err(|e| Error::new(ErrorKind::Json, e))
}
fn base64_decode(encoded: &str) -> Result<Vec<u8>> {
base64::decode(encoded).map_err(|e| Error::new(ErrorKind::Base64, e))
}
fn json_decode<'de, T: Deserialize<'de>>(encoded: &'de [u8]) -> Result<T> {
serde_json::from_slice(encoded).map_err(|e| Error::new(ErrorKind::Json, e))
}
fn base64_json_encode<T: Serialize>(data: &T) -> Result<String> {
let as_str = serde_json::to_string(data).map_err(|e| Error::new(ErrorKind::Json, e))?;
Ok(base64::encode(&as_str))
}
#[cfg(test)]
mod tests {
const SIZE_LIMIT: usize = 8000;
use super::*;
#[test]
fn test_format() {
let actor = Actor::new("actor".to_string());
let token = Token::new(
"example.com".to_string(),
SystemTime::now(),
Duration::from_secs(30),
actor.id().to_string(),
(),
);
let encoded = actor.sign_token(&token).unwrap();
let (message, _) = token_signature(&encoded).unwrap();
assert!(encoded.starts_with(message));
println!("length {}", encoded.len());
assert!(encoded.len() < SIZE_LIMIT);
}
}