use jsonwebtoken::{Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use super::scope::{ScopeParseError, SmartScopeSet};
#[derive(Debug, Error)]
pub enum TokenError {
#[error("JWT decode/verify failed: {0}")]
Jwt(#[from] jsonwebtoken::errors::Error),
#[error("scope claim malformed: {0}")]
ScopeParse(#[from] ScopeParseError),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessToken {
pub sub: String,
#[serde(default)]
pub iss: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub aud: Option<serde_json::Value>,
pub exp: i64,
#[serde(default)]
pub iat: Option<i64>,
#[serde(default)]
pub nbf: Option<i64>,
pub scope: String,
#[serde(default)]
pub patient: Option<String>,
#[serde(default)]
pub encounter: Option<String>,
#[serde(rename = "fhirUser", default)]
pub fhir_user: Option<String>,
#[serde(flatten)]
pub extras: std::collections::BTreeMap<String, serde_json::Value>,
}
impl AccessToken {
pub fn scope_set(&self) -> Result<SmartScopeSet, ScopeParseError> {
SmartScopeSet::parse(&self.scope)
}
pub fn launch_context(&self) -> super::context::LaunchContext {
super::context::LaunchContext {
patient_id: self.patient.clone(),
encounter_id: self.encounter.clone(),
fhir_user: self.fhir_user.clone(),
}
}
}
pub struct TokenValidator {
key: DecodingKey,
validation: Validation,
}
impl TokenValidator {
pub fn rs256_from_pem(pem: &[u8]) -> Result<Self, TokenError> {
Ok(Self {
key: DecodingKey::from_rsa_pem(pem)?,
validation: Validation::new(Algorithm::RS256),
})
}
pub fn es256_from_pem(pem: &[u8]) -> Result<Self, TokenError> {
Ok(Self {
key: DecodingKey::from_ec_pem(pem)?,
validation: Validation::new(Algorithm::ES256),
})
}
pub fn hs256_from_secret(secret: &[u8]) -> Self {
Self {
key: DecodingKey::from_secret(secret),
validation: Validation::new(Algorithm::HS256),
}
}
pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
self.validation.set_audience(&[audience.into()]);
self
}
pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
self.validation.set_issuer(&[issuer.into()]);
self
}
pub fn decode(&self, jwt: &str) -> Result<AccessToken, TokenError> {
let token = jsonwebtoken::decode::<AccessToken>(jwt, &self.key, &self.validation)?;
Ok(token.claims)
}
}
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{EncodingKey, Header, encode};
fn encode_hs256(claims: &serde_json::Value, secret: &[u8]) -> String {
encode(
&Header::new(Algorithm::HS256),
claims,
&EncodingKey::from_secret(secret),
)
.expect("HS256 sign")
}
#[test]
fn round_trips_smart_claims_via_hs256() {
let secret = b"unit-test-secret";
let now = chrono::Utc::now().timestamp();
let claims = serde_json::json!({
"sub": "alice@example.org",
"iss": "https://auth.example.org",
"aud": "https://fhir.example.org",
"exp": now + 3600,
"iat": now,
"scope": "openid fhirUser patient/Observation.read",
"patient": "alice-001",
"fhirUser": "Practitioner/dr-jones"
});
let jwt = encode_hs256(&claims, secret);
let v = TokenValidator::hs256_from_secret(secret)
.with_audience("https://fhir.example.org")
.with_issuer("https://auth.example.org");
let tok = v.decode(&jwt).unwrap();
assert_eq!(tok.sub, "alice@example.org");
assert_eq!(tok.patient.as_deref(), Some("alice-001"));
assert_eq!(tok.fhir_user.as_deref(), Some("Practitioner/dr-jones"));
let scopes = tok.scope_set().unwrap();
assert_eq!(scopes.0.len(), 3);
}
#[test]
fn rejects_audience_mismatch() {
let secret = b"unit-test-secret";
let now = chrono::Utc::now().timestamp();
let claims = serde_json::json!({
"sub": "x",
"exp": now + 3600,
"aud": "https://fhir.example.org",
"scope": "openid",
});
let jwt = encode_hs256(&claims, secret);
let v = TokenValidator::hs256_from_secret(secret)
.with_audience("https://different-server.example.org");
let err = v.decode(&jwt).unwrap_err();
assert!(matches!(err, TokenError::Jwt(_)));
}
#[test]
fn rejects_expired_token() {
let secret = b"unit-test-secret";
let claims = serde_json::json!({
"sub": "x",
"exp": 1, "scope": "openid",
});
let jwt = encode_hs256(&claims, secret);
let v = TokenValidator::hs256_from_secret(secret);
let err = v.decode(&jwt).unwrap_err();
assert!(matches!(err, TokenError::Jwt(_)));
}
#[test]
fn launch_context_extracted_from_claims() {
let secret = b"unit-test-secret";
let now = chrono::Utc::now().timestamp();
let claims = serde_json::json!({
"sub": "x",
"exp": now + 3600,
"scope": "patient/*.read",
"patient": "alice-001",
"encounter": "enc-9",
});
let jwt = encode_hs256(&claims, secret);
let v = TokenValidator::hs256_from_secret(secret);
let tok = v.decode(&jwt).unwrap();
let ctx = tok.launch_context();
assert_eq!(ctx.patient_id.as_deref(), Some("alice-001"));
assert_eq!(ctx.encounter_id.as_deref(), Some("enc-9"));
}
}