use std::fmt;
use std::pin::Pin;
use std::time::{Duration, SystemTime, SystemTimeError, UNIX_EPOCH};
use async_trait::async_trait;
use base64::prelude::*;
use ed25519_dalek::{SignatureError, Signer, Verifier};
use futures::Future;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
pub use ed25519_dalek::{Signature, SigningKey, VerifyingKey};
pub use rand::rngs::OsRng;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ErrorKind {
Auth,
Base64,
Fetch,
Format,
Json,
Time,
}
#[derive(Debug)]
pub struct Error {
kind: ErrorKind,
message: String,
}
impl Error {
pub fn new(kind: ErrorKind, message: String) -> Self {
Self { kind, message }
}
pub fn kind(&self) -> ErrorKind {
self.kind
}
pub fn into_inner(self) -> (ErrorKind, String) {
(self.kind, self.message)
}
pub fn auth<M: fmt::Display>(message: M) -> Self {
Self::new(ErrorKind::Auth, message.to_string())
}
pub fn format<M: fmt::Display>(cause: M) -> Self {
Self::new(ErrorKind::Format, cause.to_string())
}
pub fn fetch<Info: fmt::Debug>(info: Info) -> Self {
Self::new(ErrorKind::Fetch, format!("{info:?}"))
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}: {}", self.kind, self.message)
}
}
impl std::error::Error for Error {}
impl From<base64::DecodeError> for Error {
fn from(cause: base64::DecodeError) -> Self {
Self::new(ErrorKind::Base64, cause.to_string())
}
}
impl From<serde_json::Error> for Error {
fn from(cause: serde_json::Error) -> Self {
Self::new(ErrorKind::Json, cause.to_string())
}
}
impl From<SignatureError> for Error {
fn from(cause: SignatureError) -> Self {
Self::new(ErrorKind::Auth, cause.to_string())
}
}
impl From<SystemTimeError> for Error {
fn from(cause: SystemTimeError) -> Self {
Self::new(ErrorKind::Time, cause.to_string())
}
}
#[async_trait]
pub trait Resolve: Send + Sync {
type HostId: Serialize + DeserializeOwned + fmt::Debug + Send + Sync;
type ActorId: Serialize + DeserializeOwned + fmt::Debug + Send + Sync;
type Claims: Serialize + DeserializeOwned + Send + Sync;
async fn resolve(
&self,
host: &Self::HostId,
actor_id: &Self::ActorId,
) -> Result<Actor<Self::ActorId>, Error>;
async fn verify(
&self,
encoded: String,
now: SystemTime,
) -> Result<SignedToken<Self::HostId, Self::ActorId, Self::Claims>, Error>
where
Self::ActorId: PartialEq,
{
let claims = verify_claims(self, &encoded, now).await?;
Ok(SignedToken::new(claims, encoded))
}
}
async fn decode_and_verify_token<R: Resolve + ?Sized>(
resolver: &R,
encoded: &str,
now: SystemTime,
) -> Result<Token<R::HostId, R::ActorId, R::Claims>, Error>
where
R::ActorId: PartialEq,
{
let (message, signature) = token_signature(encoded)?;
let token: Token<R::HostId, R::ActorId, R::Claims> = decode_token(message)?;
if token.is_expired(now) {
return Err(Error::new(ErrorKind::Time, "token is expired".into()));
}
let actor = resolver.resolve(&token.iss, &token.actor_id).await?;
if actor.id != token.actor_id {
return Err(Error::auth(
"attempted to use a bearer token for a different actor",
));
}
if let Err(cause) = actor.public_key().verify(message.as_bytes(), &signature) {
Err(Error::auth(format!("invalid bearer token: {cause}")))
} else {
Ok(token)
}
}
type Verification<'a, H, A, C> =
Pin<Box<dyn Future<Output = Result<Claims<H, A, C>, Error>> + Send + 'a>>;
fn verify_claims<'a, R>(
resolver: &'a R,
encoded: &'a str,
now: SystemTime,
) -> Verification<'a, R::HostId, R::ActorId, R::Claims>
where
R: Resolve + ?Sized,
R::ActorId: PartialEq,
{
Box::pin(async move {
let token = decode_and_verify_token(resolver, encoded, now).await?;
if let Some(parent) = token.inherit {
let parent_claims = verify_claims(resolver, &parent, now).await?;
if token.exp <= parent_claims.exp {
parent_claims.consume(token.iss, token.actor_id, token.custom)
} else {
Err(Error::new(
ErrorKind::Time,
"cannot extend the expiration time of a recursive token".into(),
))
}
} else {
Ok(Claims::new(
token.exp,
token.iss,
token.actor_id,
token.custom,
))
}
})
}
enum Key {
Public(VerifyingKey),
Private(SigningKey),
}
impl Key {
fn has_private_key(&self) -> bool {
match &self {
Self::Public(_) => false,
Self::Private(_) => true,
}
}
}
pub struct Actor<A> {
id: A,
key: Key,
}
impl<A> Actor<A> {
pub fn new(id: A) -> Self {
Self::with_keypair(id, SigningKey::generate(&mut OsRng))
}
pub fn with_keypair(id: A, keypair: SigningKey) -> Self {
Self {
id,
key: Key::Private(keypair),
}
}
pub fn with_public_key(id: A, public_key: VerifyingKey) -> Self {
Self {
id,
key: Key::Public(public_key),
}
}
pub fn id(&self) -> &A {
&self.id
}
pub fn has_private_key(&self) -> bool {
self.key.has_private_key()
}
pub fn public_key(&self) -> VerifyingKey {
match &self.key {
Key::Public(public_key) => *public_key,
Key::Private(keypair) => keypair.verifying_key(),
}
}
fn sign_token_inner<H, C>(&self, token: &Token<H, A, C>) -> Result<String, Error>
where
H: Serialize,
A: Serialize,
C: Serialize,
{
let keypair = match &self.key {
Key::Private(keypair) => Ok(keypair),
Key::Public(_) => Err(Error::auth("cannot sign a token without a private key")),
}?;
let header = BASE64_STANDARD.encode(serde_json::to_string(&TokenHeader::default())?);
let claims = BASE64_STANDARD.encode(serde_json::to_string(&token)?);
let signature = keypair.try_sign(format!("{header}.{claims}").as_bytes())?;
let signature = BASE64_STANDARD.encode(signature.to_bytes());
Ok(format!("{header}.{claims}.{signature}"))
}
pub fn sign_token<H, C>(&self, token: Token<H, A, C>) -> Result<SignedToken<H, A, C>, Error>
where
H: Serialize,
A: Serialize,
C: Serialize,
{
let jwt = self.sign_token_inner(&token)?;
let claims = Claims {
exp: token.exp,
host: token.iss,
actor_id: token.actor_id,
claims: token.custom,
inherit: None,
};
Ok(SignedToken::new(claims, jwt))
}
pub fn consume_and_sign<H, C>(
&self,
token: SignedToken<H, A, C>,
host_id: H,
claims: C,
now: SystemTime,
) -> Result<SignedToken<H, A, C>, Error>
where
H: Serialize + Clone,
A: Serialize + Clone,
C: Serialize + Clone,
{
let (token, claims) = Token::consume(token, now, host_id.clone(), self.id.clone(), claims)?;
let token = self.sign_token_inner(&token)?;
Ok(SignedToken::new(claims, token))
}
}
impl<A: Clone> Clone for Actor<A> {
fn clone(&self) -> Self {
Actor {
id: self.id.clone(),
key: match &self.key {
Key::Public(public_key) => Key::Public(*public_key),
Key::Private(keypair) => Key::Public(keypair.verifying_key()),
},
}
}
}
impl<A: fmt::Debug> fmt::Debug for Actor<A> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "actor {:?}", self.id)
}
}
#[derive(Eq, PartialEq, Debug, Deserialize, Serialize)]
struct TokenHeader {
alg: String,
typ: String,
}
impl Default for TokenHeader {
fn default() -> TokenHeader {
TokenHeader {
alg: "ES256".into(),
typ: "JWT".into(),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Claims<H, A, C> {
exp: u64,
host: H,
actor_id: A,
claims: C,
inherit: Option<Box<Claims<H, A, C>>>,
}
impl<H, A, C> Claims<H, A, C> {
fn new(exp: u64, host: H, actor_id: A, claims: C) -> Self {
Self {
exp,
host,
actor_id,
claims,
inherit: None,
}
}
fn consume(self, host: H, actor_id: A, claims: C) -> Result<Self, Error> {
let exp = self.expires().duration_since(UNIX_EPOCH)?;
Ok(Self {
exp: exp.as_secs(),
host,
actor_id,
claims,
inherit: Some(Box::new(self)),
})
}
fn expires(&self) -> SystemTime {
UNIX_EPOCH + Duration::from_secs(self.exp)
}
}
pub struct Iter<'a, H, A, C> {
claims: Option<&'a Claims<H, A, C>>,
}
impl<'a, H: 'a, A: 'a, C: 'a> Iterator for Iter<'a, H, A, C> {
type Item = (&'a H, &'a A, &'a C);
fn next(&mut self) -> Option<Self::Item> {
let claims = self.claims?;
let item = (&claims.host, &claims.actor_id, &claims.claims);
self.claims = claims.inherit.as_ref().map(|claims| &**claims);
Some(item)
}
}
impl<H, A, C> Claims<H, A, C> {
pub fn iter(&self) -> Iter<H, A, C> {
Iter { claims: Some(self) }
}
}
impl<H: PartialEq, A: PartialEq, C> Claims<H, A, C> {
pub fn get(&self, host: &H, actor_id: &A) -> Option<&C> {
self.iter()
.filter_map(|(h, a, c)| {
if h == host && a == actor_id {
Some(c)
} else {
None
}
})
.next()
}
}
impl<'a, H, A, C> IntoIterator for &'a Claims<H, A, C> {
type Item = (&'a H, &'a A, &'a C);
type IntoIter = Iter<'a, H, A, C>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[derive(Clone, Eq, PartialEq, Deserialize, Serialize)]
pub struct Token<H, A, C> {
iss: H,
iat: u64,
exp: u64,
actor_id: A,
custom: C,
inherit: Option<String>,
}
impl<H, A, C> Token<H, A, C> {
pub fn new(iss: H, iat: SystemTime, ttl: Duration, actor_id: A, claims: C) -> Self {
let iat = iat.duration_since(UNIX_EPOCH).expect("duration");
let exp = iat + ttl;
Self {
iss,
iat: iat.as_secs(),
exp: exp.as_secs(),
actor_id,
custom: claims,
inherit: None,
}
}
fn consume(
parent: SignedToken<H, A, C>,
iat: SystemTime,
host_id: H,
actor_id: A,
claims: C,
) -> Result<(Self, Claims<H, A, C>), Error>
where
H: Clone,
A: Clone,
C: Clone,
{
let iat = iat.duration_since(UNIX_EPOCH)?;
let exp = parent.expires().duration_since(UNIX_EPOCH)?;
let token = Self {
iss: host_id.clone(),
iat: iat.as_secs(),
exp: exp.as_secs(),
actor_id: actor_id.clone(),
custom: claims.clone(),
inherit: Some(parent.jwt),
};
let claims = parent.claims.consume(host_id, actor_id, claims)?;
Ok((token, claims))
}
pub fn issuer(&self) -> &H {
&self.iss
}
pub fn actor_id(&self) -> &A {
&self.actor_id
}
pub fn is_expired(&self, now: SystemTime) -> bool {
let iat = UNIX_EPOCH + Duration::from_secs(self.iat);
let exp = UNIX_EPOCH + Duration::from_secs(self.exp);
now < iat || now >= exp
}
}
impl<H: fmt::Display, A: fmt::Display, C> fmt::Debug for Token<H, A, C> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"JWT token claiming to authenticate actor {} at host {}",
self.actor_id, self.iss
)
}
}
#[derive(Clone, Eq, PartialEq)]
pub struct SignedToken<H, A, C> {
claims: Claims<H, A, C>,
jwt: String,
}
impl<H, A, C> SignedToken<H, A, C> {
fn new(data: Claims<H, A, C>, jwt: String) -> Self {
Self { claims: data, jwt }
}
pub fn claims(&self) -> &Claims<H, A, C> {
&self.claims
}
pub fn expires(&self) -> SystemTime {
self.claims.expires()
}
pub fn jwt(&self) -> &str {
&self.jwt
}
pub fn into_jwt(self) -> String {
self.jwt
}
}
impl<H: fmt::Debug, A: fmt::Debug, C: fmt::Debug> fmt::Debug for SignedToken<H, A, C> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "JWT {} which claims {:?}", self.jwt, self.claims)
}
}
fn token_signature(encoded: &str) -> Result<(&str, Signature), Error> {
if encoded.ends_with('.') {
return Err(Error::format("encoded token cannot end with ."));
}
let i = encoded
.rfind('.')
.ok_or_else(|| Error::format(format!("invalid token: {}", encoded)))?;
let message = &encoded[..i];
let signature = BASE64_STANDARD
.decode(&encoded[(i + 1)..])
.map_err(|e| Error::new(ErrorKind::Base64, e.to_string()))?;
let signature = Signature::try_from(&signature[..])?;
Ok((message, signature))
}
fn decode_token<H, A, C>(encoded: &str) -> Result<Token<H, A, C>, Error>
where
H: DeserializeOwned,
A: DeserializeOwned,
C: DeserializeOwned,
{
let i = encoded
.find('.')
.ok_or_else(|| Error::format(format!("invalid token: {}", encoded)))?;
let header = BASE64_STANDARD.decode(&encoded[..i])?;
let header: TokenHeader = serde_json::from_slice(&header)?;
if header != TokenHeader::default() {
return Err(Error::format(format!(
"unsupported bearer token type: {header:?}"
)));
}
let token = BASE64_STANDARD.decode(&encoded[(i + 1)..])?;
let token = serde_json::from_slice(&token)?;
Ok(token)
}
#[cfg(test)]
mod tests {
use super::*;
const SIZE_LIMIT: usize = 8000;
#[test]
fn test_format() {
let actor = Actor::new("actor".to_string());
let token = Token::new(
"example.com".to_string(),
SystemTime::now(),
Duration::from_secs(30),
actor.id().to_string(),
(),
);
let signed = actor.sign_token(token).unwrap();
let (message, _) = token_signature(signed.jwt()).unwrap();
assert!(signed.jwt().starts_with(message));
assert!(signed.jwt().len() < SIZE_LIMIT);
}
}