use crate::CowStr;
use crate::IntoStatic;
use crate::bos::{BosStr, DefaultStr};
use crate::types::string::{Did, DidService, Nsid};
use alloc::string::String;
use alloc::vec::Vec;
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use serde::{Deserialize, Serialize};
use signature::Verifier;
use smol_str::SmolStr;
use smol_str::format_smolstr;
use thiserror::Error;
#[cfg(feature = "crypto-p256")]
use p256::ecdsa::{Signature as P256Signature, VerifyingKey as P256VerifyingKey};
#[cfg(feature = "crypto-k256")]
use k256::ecdsa::{Signature as K256Signature, VerifyingKey as K256VerifyingKey};
#[derive(Debug, Error, miette::Diagnostic)]
#[non_exhaustive]
pub enum ServiceAuthError {
#[error("malformed JWT: {0}")]
MalformedToken(CowStr<'static>),
#[error("base64 decode error: {0}")]
Base64Decode(#[from] base64::DecodeError),
#[error("JSON parsing error: {0}")]
JsonParse(#[from] serde_json::Error),
#[error("invalid signature")]
InvalidSignature,
#[error("unsupported algorithm: {alg}")]
UnsupportedAlgorithm {
alg: SmolStr,
},
#[error("token expired at {exp} (current time: {now})")]
Expired {
exp: i64,
now: i64,
},
#[error("audience mismatch: expected {expected}, got {actual}")]
AudienceMismatch {
expected: Did,
actual: DidService,
},
#[error("service id mismatch: allowed {allowed:?}, got {actual:?}")]
ServiceIdMismatch {
allowed: Vec<SmolStr>,
actual: Option<SmolStr>,
},
#[error("method mismatch: expected {expected}, got {actual:?}")]
MethodMismatch {
expected: Nsid,
actual: Option<Nsid>,
},
#[error("missing required field: {0}")]
MissingField(&'static str),
#[error("crypto error: {0}")]
Crypto(CowStr<'static>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtHeader<S: BosStr = DefaultStr> {
pub alg: S,
pub typ: S,
}
impl<S> IntoStatic for JwtHeader<S>
where
S: BosStr + IntoStatic,
S::Output: BosStr,
{
type Output = JwtHeader<S::Output>;
fn into_static(self) -> Self::Output {
JwtHeader {
alg: self.alg.into_static(),
typ: self.typ.into_static(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceAuthClaims<S: BosStr = DefaultStr> {
pub iss: Did<S>,
pub aud: DidService<S>,
pub exp: i64,
pub iat: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub jti: Option<S>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lxm: Option<Nsid<S>>,
}
impl<S> IntoStatic for ServiceAuthClaims<S>
where
S: BosStr + IntoStatic,
S::Output: BosStr,
{
type Output = ServiceAuthClaims<S::Output>;
fn into_static(self) -> Self::Output {
ServiceAuthClaims {
iss: self.iss.into_static(),
aud: self.aud.into_static(),
exp: self.exp,
iat: self.iat,
jti: self.jti.map(|j| j.into_static()),
lxm: self.lxm.map(|l| l.into_static()),
}
}
}
impl<S: BosStr> ServiceAuthClaims<S> {
pub fn validate<B, Svc>(
&self,
expected_aud: &Did<B>,
allowed_services: &[Svc],
) -> Result<(), ServiceAuthError>
where
B: BosStr,
Svc: AsRef<str>,
{
if self.aud.audience().as_str() != expected_aud.as_str() {
return Err(ServiceAuthError::AudienceMismatch {
expected: expected_aud.borrow().into_static(),
actual: DidService::new_owned(self.aud.as_str()).unwrap(),
});
}
if !allowed_services.is_empty() {
if let Some(service) = self.aud.service() {
if !allowed_services
.iter()
.any(|allowed| allowed.as_ref() == service)
{
return Err(ServiceAuthError::ServiceIdMismatch {
allowed: allowed_services
.iter()
.map(|allowed| SmolStr::new(allowed.as_ref()))
.collect(),
actual: Some(SmolStr::new(service)),
});
}
}
}
if self.is_expired() {
let now = chrono::Utc::now().timestamp();
return Err(ServiceAuthError::Expired { exp: self.exp, now });
}
Ok(())
}
pub fn is_expired(&self) -> bool {
let now = chrono::Utc::now().timestamp();
self.exp <= now
}
pub fn check_method(&self, nsid: &Nsid<CowStr<'_>>) -> bool {
self.lxm
.as_ref()
.map(|lxm| lxm.as_str() == nsid.as_str())
.unwrap_or(false)
}
pub fn require_method(&self, nsid: &Nsid<CowStr<'_>>) -> Result<(), ServiceAuthError> {
if !self.check_method(nsid) {
return Err(ServiceAuthError::MethodMismatch {
expected: unsafe { Nsid::unchecked(nsid.as_str()).into_static() },
actual: self
.lxm
.as_ref()
.map(|l| Nsid::new_owned(l.as_str()).unwrap()),
});
}
Ok(())
}
}
pub struct ParsedJwt<S: BosStr = DefaultStr> {
header: JwtHeader,
claims: ServiceAuthClaims<S>,
signing_input: String,
signature: Vec<u8>,
}
impl<S: BosStr> ParsedJwt<S> {
pub fn signing_input(&self) -> &[u8] {
self.signing_input.as_bytes()
}
pub fn header(&self) -> &JwtHeader {
&self.header
}
pub fn claims(&self) -> &ServiceAuthClaims<S> {
&self.claims
}
pub fn signature(&self) -> &[u8] {
&self.signature
}
pub fn into_header(self) -> JwtHeader {
self.header
}
pub fn into_claims(self) -> ServiceAuthClaims<S> {
self.claims
}
}
pub fn parse_jwt(token: &str) -> Result<ParsedJwt, ServiceAuthError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(ServiceAuthError::MalformedToken(CowStr::new_static(
"JWT must have exactly 3 parts separated by dots",
)));
}
let header_b64 = parts[0];
let payload_b64 = parts[1];
let signature_b64 = parts[2];
let header_buf = URL_SAFE_NO_PAD.decode(header_b64)?;
let payload_buf = URL_SAFE_NO_PAD.decode(payload_b64)?;
let signature = URL_SAFE_NO_PAD.decode(signature_b64)?;
let header: JwtHeader = serde_json::from_slice(&header_buf)?;
let claims: ServiceAuthClaims = serde_json::from_slice(&payload_buf)?;
let signing_input = format!("{}.{}", header_b64, payload_b64);
Ok(ParsedJwt {
header,
claims,
signing_input,
signature,
})
}
#[derive(Debug, Clone)]
pub enum PublicKey {
#[cfg(feature = "crypto-p256")]
P256(P256VerifyingKey),
#[cfg(feature = "crypto-k256")]
K256(K256VerifyingKey),
}
impl PublicKey {
#[cfg(feature = "crypto-p256")]
pub fn from_p256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> {
let key = P256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| {
ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid P-256 key: {}", e)))
})?;
Ok(PublicKey::P256(key))
}
#[cfg(feature = "crypto-k256")]
pub fn from_k256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> {
let key = K256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| {
ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid K-256 key: {}", e)))
})?;
Ok(PublicKey::K256(key))
}
}
pub fn verify_signature(
parsed: &ParsedJwt,
public_key: &PublicKey,
) -> Result<(), ServiceAuthError> {
let alg = parsed.header().alg.as_str();
let signing_input = parsed.signing_input();
let signature = parsed.signature();
match (alg, public_key) {
#[cfg(feature = "crypto-p256")]
("ES256", PublicKey::P256(key)) => {
let sig = P256Signature::from_slice(signature).map_err(|e| {
ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!(
"invalid ES256 signature: {}",
e
)))
})?;
key.verify(signing_input, &sig)
.map_err(|_| ServiceAuthError::InvalidSignature)?;
Ok(())
}
#[cfg(feature = "crypto-k256")]
("ES256K", PublicKey::K256(key)) => {
let sig = K256Signature::from_slice(signature).map_err(|e| {
ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!(
"invalid ES256K signature: {}",
e
)))
})?;
key.verify(signing_input, &sig)
.map_err(|_| ServiceAuthError::InvalidSignature)?;
Ok(())
}
_ => Err(ServiceAuthError::UnsupportedAlgorithm {
alg: SmolStr::new(alg),
}),
}
}
pub fn verify_service_jwt(
token: &str,
public_key: &PublicKey,
) -> Result<ServiceAuthClaims, ServiceAuthError> {
let parsed = parse_jwt(token)?;
verify_signature(&parsed, public_key)?;
Ok(parsed.into_claims())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_jwt_invalid_format() {
let result = parse_jwt("not.a.valid.jwt.with.too.many.parts");
assert!(matches!(result, Err(ServiceAuthError::MalformedToken(_))));
}
#[test]
fn test_claims_expiration() {
let now = chrono::Utc::now().timestamp();
let expired_claims: ServiceAuthClaims = ServiceAuthClaims {
iss: Did::new_static("did:plc:test").unwrap(),
aud: DidService::new_static("did:web:example.com").unwrap(),
exp: now - 100,
iat: now - 200,
jti: None,
lxm: None,
};
assert!(expired_claims.is_expired());
let valid_claims: ServiceAuthClaims = ServiceAuthClaims {
iss: Did::new_static("did:plc:test").unwrap(),
aud: DidService::new_static("did:web:example.com").unwrap(),
exp: now + 100,
iat: now,
jti: None,
lxm: None,
};
assert!(!valid_claims.is_expired());
}
fn claims_with_aud(aud: &str) -> ServiceAuthClaims {
ServiceAuthClaims {
iss: Did::new_static("did:plc:test").unwrap(),
aud: DidService::new_owned(aud).unwrap(),
exp: chrono::Utc::now().timestamp() + 100,
iat: chrono::Utc::now().timestamp(),
jti: None,
lxm: None,
}
}
#[test]
fn test_audience_validation() {
let expected_aud = Did::new("did:web:example.com").unwrap();
assert!(
claims_with_aud("did:web:example.com")
.validate(&expected_aud, &[] as &[&str])
.is_ok()
);
assert!(
claims_with_aud("did:web:example.com#bsky_appview")
.validate(&expected_aud, &[] as &[&str])
.is_ok()
);
assert!(
claims_with_aud("did:web:example.com")
.validate(&expected_aud, &["bsky_appview"])
.is_ok()
);
assert!(
claims_with_aud("did:web:example.com#bsky_appview")
.validate(&expected_aud, &["bsky_appview"])
.is_ok()
);
assert!(matches!(
claims_with_aud("did:web:example.com#other").validate(&expected_aud, &["bsky_appview"]),
Err(ServiceAuthError::ServiceIdMismatch { .. })
));
let wrong_aud = Did::new("did:web:wrong.com").unwrap();
assert!(matches!(
claims_with_aud("did:web:example.com#bsky_appview")
.validate(&wrong_aud, &["bsky_appview"]),
Err(ServiceAuthError::AudienceMismatch { .. })
));
}
#[test]
fn test_method_check() {
let claims: ServiceAuthClaims = ServiceAuthClaims {
iss: Did::new_static("did:plc:test").unwrap(),
aud: DidService::new_static("did:web:example.com").unwrap(),
exp: chrono::Utc::now().timestamp() + 100,
iat: chrono::Utc::now().timestamp(),
jti: None,
lxm: Some(Nsid::new_static("app.bsky.feed.getFeedSkeleton").unwrap()),
};
let expected = Nsid::new_static("app.bsky.feed.getFeedSkeleton").unwrap();
assert!(claims.check_method(&expected));
let wrong = Nsid::new_static("app.bsky.feed.getTimeline".into()).unwrap();
assert!(!claims.check_method(&wrong));
}
}