use std::{
collections::HashMap,
str::FromStr,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use anyhow::Context;
use jsonwebtoken::{
Algorithm, EncodingKey, Header, TokenData, dangerous::insecure_decode, encode,
errors::Error as JwtError,
};
use pem::Pem;
use serde::{Deserialize, Serialize};
use snap_tokens::v0::{Pssid, SnapTokenClaims};
use thiserror::Error;
use uuid::Uuid;
use super::{
api::{TokenRequest, TokenResponse},
fake_idp::FakeIdp,
};
use crate::{
authorization_server::fake_idp::FAKE_IDP_ISSUER,
dto::{TokenExchangerConfigDto, TokenExchangerStateDto},
};
pub const NO_ACCESS_TOKEN_TYPE: &str = "N_A";
pub const JWT_TOKEN_TYPE: &str = "urn:ietf:params:oauth:token-type:jwt";
pub const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
pub const ID_TOKEN_TYPE: &str = "urn:ietf:params:oauth:token-type:id_token";
pub const EDGE_APP_CLIENT_ID: &str = "edge_app";
pub trait IdentityProvider {
fn verify_id_token(&self, id_token: &str)
-> Result<TokenData<OpenIdToken>, VerifyIdTokenError>;
}
#[derive(Debug, Error, PartialEq)]
pub enum VerifyIdTokenError {
#[error("JWT error: {0}")]
JwtError(#[from] JwtError),
}
pub trait TokenExchange: Send + Sync {
fn exchange(&mut self, req: TokenRequest) -> Result<TokenResponse, TokenExchangeError>;
}
#[derive(Debug, Error, PartialEq)]
pub enum TokenExchangeError {
#[error("JWT error: {0}")]
JwtError(#[from] JwtError),
#[error("token ID verification error: {0}")]
VerifyIdTokenError(#[from] VerifyIdTokenError),
#[error("invalid grant type: {0}")]
InvalidGrantType(String),
#[error("unsupported subject token type: {0}")]
InvalidSubjectTokenType(String),
#[error("unknown identity provider: {0}")]
UnknownIdentityProvider(String),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OpenIdToken {
pub(crate) iss: String,
pub(crate) sub: String,
pub(crate) aud: String,
pub(crate) exp: i64,
pub(crate) iat: i64,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TokenExchangeConfig {
private_key: Pem,
token_lifetime: Duration,
fake_idp: FakeIdp,
}
impl From<&TokenExchangeConfig> for TokenExchangerConfigDto {
fn from(config: &TokenExchangeConfig) -> Self {
Self {
private_key: config.private_key.to_string(),
token_lifetime: config.token_lifetime,
fake_idp: (&config.fake_idp).into(),
}
}
}
impl TryFrom<TokenExchangerConfigDto> for TokenExchangeConfig {
type Error = anyhow::Error;
fn try_from(value: TokenExchangerConfigDto) -> Result<Self, Self::Error> {
Ok(Self {
private_key: Pem::from_str(&value.private_key)
.context("invalid PEM format for session token issuer key")?,
token_lifetime: value.token_lifetime,
fake_idp: FakeIdp::try_from(value.fake_idp)
.context("invalid fake IDP configuration")?,
})
}
}
impl TokenExchangeConfig {
pub fn new(private_key: Pem, token_lifetime: Duration) -> Self {
Self {
private_key,
token_lifetime,
fake_idp: FakeIdp::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Ssid(pub String);
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TokenExchangeImpl {
config: TokenExchangeConfig,
id_mapping: HashMap<Ssid, Pssid>,
}
impl From<&TokenExchangeImpl> for TokenExchangerStateDto {
fn from(value: &TokenExchangeImpl) -> Self {
Self {
config: (&value.config).into(),
id_mapping: value
.id_mapping
.iter()
.map(|(id, pssid)| (id.0.clone(), pssid.0.to_string()))
.collect(),
}
}
}
impl TryFrom<TokenExchangerStateDto> for TokenExchangeImpl {
type Error = anyhow::Error;
fn try_from(value: TokenExchangerStateDto) -> Result<Self, Self::Error> {
let config = TokenExchangeConfig::try_from(value.config)?;
let id_mapping = value
.id_mapping
.into_iter()
.map(|(id, uuid)| {
Ok((
Ssid(id),
Pssid(Uuid::from_str(&uuid).context("invalid UUID")?),
))
})
.collect::<Result<_, Self::Error>>()?;
Ok(Self { config, id_mapping })
}
}
impl TokenExchangeImpl {
pub fn new(config: TokenExchangeConfig) -> Self {
Self {
config,
id_mapping: HashMap::new(),
}
}
}
impl TokenExchange for TokenExchangeImpl {
fn exchange(&mut self, req: TokenRequest) -> Result<TokenResponse, TokenExchangeError> {
tracing::debug!(request=?req, "Received token exchange request");
if req.grant_type != TOKEN_EXCHANGE_GRANT_TYPE {
tracing::debug!(grant_type=%req.grant_type, "Invalid grant type");
return Err(TokenExchangeError::InvalidGrantType(req.grant_type));
}
if req.subject_token_type != ID_TOKEN_TYPE {
tracing::debug!(subject_token_type=%req.subject_token_type, "Unsupported subject token type");
return Err(TokenExchangeError::InvalidSubjectTokenType(
req.subject_token_type,
));
}
let id_token = &req.subject_token;
let decoded_token = insecure_decode::<OpenIdToken>(id_token)?;
tracing::debug!(token=?decoded_token, "Exchanging token");
let verified_id_token = match decoded_token.claims.iss.as_str() {
FAKE_IDP_ISSUER => self.config.fake_idp.verify_id_token(id_token)?,
_ => {
return Err(TokenExchangeError::UnknownIdentityProvider(
decoded_token.claims.iss,
));
}
};
let pssid = self
.id_mapping
.entry(Ssid(verified_id_token.claims.sub.clone()))
.or_insert_with(|| Pssid(Uuid::new_v4()));
let snap_token_claims = SnapTokenClaims {
pssid: pssid.clone(),
exp: (SystemTime::now() + self.config.token_lifetime)
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
jti: uuid::Uuid::new_v4().to_string(),
};
let snap_token_enc_key =
EncodingKey::from_ed_pem(pem::encode(&self.config.private_key).as_bytes())
.expect("no fail");
let snap_token = encode(
&Header::new(Algorithm::EdDSA),
&snap_token_claims,
&snap_token_enc_key,
)?;
Ok(TokenResponse {
access_token: snap_token,
issued_token_type: JWT_TOKEN_TYPE.to_string(),
token_type: NO_ACCESS_TOKEN_TYPE.to_string(),
expires_in: self.config.token_lifetime.as_secs(),
scope: None,
})
}
}