Skip to main content

meld_core/
auth.rs

1use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
2use serde::{Deserialize, Serialize};
3use thiserror::Error;
4
5#[derive(Debug, Clone)]
6pub struct JwtValidationConfig {
7    pub secret: String,
8    pub expected_issuer: Option<String>,
9    pub expected_audience: Option<String>,
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct AuthPrincipal {
14    pub subject: String,
15    pub issuer: Option<String>,
16    pub audience: Vec<String>,
17    pub scopes: Vec<String>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct JwtClaims {
22    pub sub: String,
23    pub exp: usize,
24    #[serde(default)]
25    pub iss: Option<String>,
26    #[serde(default)]
27    pub aud: Option<AudienceClaim>,
28    #[serde(default)]
29    pub scope: Option<String>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33#[serde(untagged)]
34pub enum AudienceClaim {
35    One(String),
36    Many(Vec<String>),
37}
38
39impl AudienceClaim {
40    fn into_vec(self) -> Vec<String> {
41        match self {
42            Self::One(value) => vec![value],
43            Self::Many(values) => values,
44        }
45    }
46}
47
48#[derive(Debug, Error)]
49pub enum AuthError {
50    #[error("invalid token: {0}")]
51    InvalidToken(String),
52    #[error("issuer mismatch")]
53    IssuerMismatch,
54    #[error("audience mismatch")]
55    AudienceMismatch,
56}
57
58pub fn validate_bearer_jwt(
59    token: &str,
60    cfg: &JwtValidationConfig,
61) -> Result<AuthPrincipal, AuthError> {
62    let mut validation = Validation::new(Algorithm::HS256);
63    validation.validate_exp = true;
64    validation.validate_aud = false;
65    validation
66        .required_spec_claims
67        .extend(["sub".to_string(), "exp".to_string()]);
68
69    let token_data = decode::<JwtClaims>(
70        token,
71        &DecodingKey::from_secret(cfg.secret.as_bytes()),
72        &validation,
73    )
74    .map_err(|err| AuthError::InvalidToken(err.to_string()))?;
75
76    let claims = token_data.claims;
77    if let Some(expected) = cfg.expected_issuer.as_deref() {
78        if claims.iss.as_deref() != Some(expected) {
79            return Err(AuthError::IssuerMismatch);
80        }
81    }
82
83    let audience = claims.aud.map(AudienceClaim::into_vec).unwrap_or_default();
84    if let Some(expected) = cfg.expected_audience.as_deref() {
85        if !audience.iter().any(|value| value == expected) {
86            return Err(AuthError::AudienceMismatch);
87        }
88    }
89
90    let scopes = claims
91        .scope
92        .unwrap_or_default()
93        .split_whitespace()
94        .map(str::to_string)
95        .collect::<Vec<_>>();
96
97    Ok(AuthPrincipal {
98        subject: claims.sub,
99        issuer: claims.iss,
100        audience,
101        scopes,
102    })
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use jsonwebtoken::{encode, EncodingKey, Header};
109
110    fn issue_token(secret: &str, claims: &JwtClaims) -> String {
111        encode(
112            &Header::new(Algorithm::HS256),
113            claims,
114            &EncodingKey::from_secret(secret.as_bytes()),
115        )
116        .expect("token should encode")
117    }
118
119    fn base_claims() -> JwtClaims {
120        JwtClaims {
121            sub: "user-1".to_string(),
122            exp: 4_102_444_800,
123            iss: Some("https://issuer.local".to_string()),
124            aud: Some(AudienceClaim::One("meld-api".to_string())),
125            scope: Some("read:notes write:notes".to_string()),
126        }
127    }
128
129    #[test]
130    fn validates_token_and_maps_principal() {
131        let secret = "dev-secret";
132        let token = issue_token(secret, &base_claims());
133        let cfg = JwtValidationConfig {
134            secret: secret.to_string(),
135            expected_issuer: Some("https://issuer.local".to_string()),
136            expected_audience: Some("meld-api".to_string()),
137        };
138
139        let principal = validate_bearer_jwt(&token, &cfg).expect("token should validate");
140        assert_eq!(principal.subject, "user-1");
141        assert_eq!(principal.issuer.as_deref(), Some("https://issuer.local"));
142        assert!(principal.audience.iter().any(|aud| aud == "meld-api"));
143        assert!(principal.scopes.iter().any(|scope| scope == "read:notes"));
144    }
145
146    #[test]
147    fn rejects_issuer_mismatch() {
148        let secret = "dev-secret";
149        let token = issue_token(secret, &base_claims());
150        let cfg = JwtValidationConfig {
151            secret: secret.to_string(),
152            expected_issuer: Some("https://other-issuer.local".to_string()),
153            expected_audience: None,
154        };
155
156        let err = validate_bearer_jwt(&token, &cfg).expect_err("issuer mismatch should fail");
157        assert!(matches!(err, AuthError::IssuerMismatch));
158    }
159
160    #[test]
161    fn rejects_audience_mismatch() {
162        let secret = "dev-secret";
163        let token = issue_token(secret, &base_claims());
164        let cfg = JwtValidationConfig {
165            secret: secret.to_string(),
166            expected_issuer: None,
167            expected_audience: Some("other-aud".to_string()),
168        };
169
170        let err = validate_bearer_jwt(&token, &cfg).expect_err("audience mismatch should fail");
171        assert!(matches!(err, AuthError::AudienceMismatch));
172    }
173}