use crate::CowStr;
use crate::IntoStatic;
use crate::types::string::{Did, Nsid};
use alloc::string::String;
use alloc::string::ToString;
use alloc::vec::Vec;
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use ouroboros::self_referencing;
use serde::{Deserialize, Serialize};
use signature::Verifier;
use smol_str::SmolStr;
use smol_str::format_smolstr;
use thiserror::Error;
#[cfg(feature = "crypto-p256")]
use p256::ecdsa::{Signature as P256Signature, VerifyingKey as P256VerifyingKey};
#[cfg(feature = "crypto-k256")]
use k256::ecdsa::{Signature as K256Signature, VerifyingKey as K256VerifyingKey};
#[derive(Debug, Error, miette::Diagnostic)]
#[non_exhaustive]
pub enum ServiceAuthError {
#[error("malformed JWT: {0}")]
MalformedToken(CowStr<'static>),
#[error("base64 decode error: {0}")]
Base64Decode(#[from] base64::DecodeError),
#[error("JSON parsing error: {0}")]
JsonParse(#[from] serde_json::Error),
#[error("invalid signature")]
InvalidSignature,
#[error("unsupported algorithm: {alg}")]
UnsupportedAlgorithm {
alg: SmolStr,
},
#[error("token expired at {exp} (current time: {now})")]
Expired {
exp: i64,
now: i64,
},
#[error("audience mismatch: expected {expected}, got {actual}")]
AudienceMismatch {
expected: Did<'static>,
actual: Did<'static>,
},
#[error("method mismatch: expected {expected}, got {actual:?}")]
MethodMismatch {
expected: Nsid<'static>,
actual: Option<Nsid<'static>>,
},
#[error("missing required field: {0}")]
MissingField(&'static str),
#[error("crypto error: {0}")]
Crypto(CowStr<'static>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtHeader<'a> {
#[serde(borrow)]
pub alg: CowStr<'a>,
#[serde(borrow)]
pub typ: CowStr<'a>,
}
impl IntoStatic for JwtHeader<'_> {
type Output = JwtHeader<'static>;
fn into_static(self) -> Self::Output {
JwtHeader {
alg: self.alg.into_static(),
typ: self.typ.into_static(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceAuthClaims<'a> {
#[serde(borrow)]
pub iss: Did<'a>,
#[serde(borrow)]
pub aud: Did<'a>,
pub exp: i64,
pub iat: i64,
#[serde(borrow, skip_serializing_if = "Option::is_none")]
pub jti: Option<CowStr<'a>>,
#[serde(borrow, skip_serializing_if = "Option::is_none")]
pub lxm: Option<Nsid<'a>>,
}
impl<'a> IntoStatic for ServiceAuthClaims<'a> {
type Output = ServiceAuthClaims<'static>;
fn into_static(self) -> Self::Output {
ServiceAuthClaims {
iss: self.iss.into_static(),
aud: self.aud.into_static(),
exp: self.exp,
iat: self.iat,
jti: self.jti.map(|j| j.into_static()),
lxm: self.lxm.map(|l| l.into_static()),
}
}
}
impl<'a> ServiceAuthClaims<'a> {
pub fn validate(&self, expected_aud: &Did) -> Result<(), ServiceAuthError> {
if self.aud.as_str() != expected_aud.as_str() {
return Err(ServiceAuthError::AudienceMismatch {
expected: expected_aud.clone().into_static(),
actual: self.aud.clone().into_static(),
});
}
if self.is_expired() {
let now = chrono::Utc::now().timestamp();
return Err(ServiceAuthError::Expired { exp: self.exp, now });
}
Ok(())
}
pub fn is_expired(&self) -> bool {
let now = chrono::Utc::now().timestamp();
self.exp <= now
}
pub fn check_method(&self, nsid: &Nsid) -> bool {
self.lxm
.as_ref()
.map(|lxm| lxm.as_str() == nsid.as_str())
.unwrap_or(false)
}
pub fn require_method(&self, nsid: &Nsid) -> Result<(), ServiceAuthError> {
if !self.check_method(nsid) {
return Err(ServiceAuthError::MethodMismatch {
expected: nsid.clone().into_static(),
actual: self.lxm.as_ref().map(|l| l.clone().into_static()),
});
}
Ok(())
}
}
#[self_referencing]
pub struct ParsedJwt {
header_buf: Vec<u8>,
payload_buf: Vec<u8>,
token: String,
signature: Vec<u8>,
#[borrows(header_buf)]
#[covariant]
header: JwtHeader<'this>,
#[borrows(payload_buf)]
#[covariant]
claims: ServiceAuthClaims<'this>,
}
impl ParsedJwt {
pub fn signing_input(&self) -> &[u8] {
self.with_token(|token| {
let dot_pos = token.find('.').unwrap();
let second_dot_pos = token[dot_pos + 1..].find('.').unwrap() + dot_pos + 1;
token[..second_dot_pos].as_bytes()
})
}
pub fn header(&self) -> &JwtHeader<'_> {
self.borrow_header()
}
pub fn claims(&self) -> &ServiceAuthClaims<'_> {
self.borrow_claims()
}
pub fn signature(&self) -> &[u8] {
self.borrow_signature()
}
pub fn into_header(self) -> JwtHeader<'static> {
self.with_header(|header| header.clone().into_static())
}
pub fn into_claims(self) -> ServiceAuthClaims<'static> {
self.with_claims(|claims| claims.clone().into_static())
}
}
pub fn parse_jwt(token: &str) -> Result<ParsedJwt, ServiceAuthError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(ServiceAuthError::MalformedToken(CowStr::new_static(
"JWT must have exactly 3 parts separated by dots",
)));
}
let header_b64 = parts[0];
let payload_b64 = parts[1];
let signature_b64 = parts[2];
let header_buf = URL_SAFE_NO_PAD.decode(header_b64)?;
let payload_buf = URL_SAFE_NO_PAD.decode(payload_b64)?;
let signature = URL_SAFE_NO_PAD.decode(signature_b64)?;
let _header: JwtHeader = serde_json::from_slice(&header_buf)?;
let _claims: ServiceAuthClaims = serde_json::from_slice(&payload_buf)?;
Ok(ParsedJwtBuilder {
header_buf,
payload_buf,
token: token.to_string(),
signature,
header_builder: |buf| {
serde_json::from_slice(buf).expect("header was validated")
},
claims_builder: |buf| {
serde_json::from_slice(buf).expect("claims were validated")
},
}
.build())
}
#[derive(Debug, Clone)]
pub enum PublicKey {
#[cfg(feature = "crypto-p256")]
P256(P256VerifyingKey),
#[cfg(feature = "crypto-k256")]
K256(K256VerifyingKey),
}
impl PublicKey {
#[cfg(feature = "crypto-p256")]
pub fn from_p256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> {
let key = P256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| {
ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid P-256 key: {}", e)))
})?;
Ok(PublicKey::P256(key))
}
#[cfg(feature = "crypto-k256")]
pub fn from_k256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> {
let key = K256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| {
ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid K-256 key: {}", e)))
})?;
Ok(PublicKey::K256(key))
}
}
pub fn verify_signature(
parsed: &ParsedJwt,
public_key: &PublicKey,
) -> Result<(), ServiceAuthError> {
let alg = parsed.header().alg.as_str();
let signing_input = parsed.signing_input();
let signature = parsed.signature();
match (alg, public_key) {
#[cfg(feature = "crypto-p256")]
("ES256", PublicKey::P256(key)) => {
let sig = P256Signature::from_slice(signature).map_err(|e| {
ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!(
"invalid ES256 signature: {}",
e
)))
})?;
key.verify(signing_input, &sig)
.map_err(|_| ServiceAuthError::InvalidSignature)?;
Ok(())
}
#[cfg(feature = "crypto-k256")]
("ES256K", PublicKey::K256(key)) => {
let sig = K256Signature::from_slice(signature).map_err(|e| {
ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!(
"invalid ES256K signature: {}",
e
)))
})?;
key.verify(signing_input, &sig)
.map_err(|_| ServiceAuthError::InvalidSignature)?;
Ok(())
}
_ => Err(ServiceAuthError::UnsupportedAlgorithm {
alg: SmolStr::new(alg),
}),
}
}
pub fn verify_service_jwt(
token: &str,
public_key: &PublicKey,
) -> Result<ServiceAuthClaims<'static>, ServiceAuthError> {
let parsed = parse_jwt(token)?;
verify_signature(&parsed, public_key)?;
Ok(parsed.into_claims())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_jwt_invalid_format() {
let result = parse_jwt("not.a.valid.jwt.with.too.many.parts");
assert!(matches!(result, Err(ServiceAuthError::MalformedToken(_))));
}
#[test]
fn test_claims_expiration() {
let now = chrono::Utc::now().timestamp();
let expired_claims = ServiceAuthClaims {
iss: Did::new("did:plc:test").unwrap(),
aud: Did::new("did:web:example.com").unwrap(),
exp: now - 100,
iat: now - 200,
jti: None,
lxm: None,
};
assert!(expired_claims.is_expired());
let valid_claims = ServiceAuthClaims {
iss: Did::new("did:plc:test").unwrap(),
aud: Did::new("did:web:example.com").unwrap(),
exp: now + 100,
iat: now,
jti: None,
lxm: None,
};
assert!(!valid_claims.is_expired());
}
#[test]
fn test_audience_validation() {
let now = chrono::Utc::now().timestamp();
let claims = ServiceAuthClaims {
iss: Did::new("did:plc:test").unwrap(),
aud: Did::new("did:web:example.com").unwrap(),
exp: now + 100,
iat: now,
jti: None,
lxm: None,
};
let expected_aud = Did::new("did:web:example.com").unwrap();
assert!(claims.validate(&expected_aud).is_ok());
let wrong_aud = Did::new("did:web:wrong.com").unwrap();
assert!(matches!(
claims.validate(&wrong_aud),
Err(ServiceAuthError::AudienceMismatch { .. })
));
}
#[test]
fn test_method_check() {
let claims = ServiceAuthClaims {
iss: Did::new("did:plc:test").unwrap(),
aud: Did::new("did:web:example.com").unwrap(),
exp: chrono::Utc::now().timestamp() + 100,
iat: chrono::Utc::now().timestamp(),
jti: None,
lxm: Some(Nsid::new("app.bsky.feed.getFeedSkeleton").unwrap()),
};
let expected = Nsid::new("app.bsky.feed.getFeedSkeleton").unwrap();
assert!(claims.check_method(&expected));
let wrong = Nsid::new("app.bsky.feed.getTimeline").unwrap();
assert!(!claims.check_method(&wrong));
}
}