1use crate::apple_jwt::base64_url;
25use hmac::{Hmac, Mac};
26use sha2::Sha256;
27use std::time::{SystemTime, UNIX_EPOCH};
28
29type HmacSha256 = Hmac<Sha256>;
30
31#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct JwtClaims {
35 pub sub: String,
37 pub iat: u64,
39 pub exp: u64,
42 pub iss: String,
44 pub tenant_id: Option<String>,
47 pub roles: Vec<String>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum JwtError {
53 Malformed,
55 BadEncoding,
57 UnsupportedAlg,
59 BadSignature,
61 Expired,
63 WrongIssuer,
65}
66
67impl std::fmt::Display for JwtError {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.write_str(match self {
70 Self::Malformed => "JWT malformed",
71 Self::BadEncoding => "JWT base64/JSON decode failed",
72 Self::UnsupportedAlg => "JWT alg not supported (expected HS256)",
73 Self::BadSignature => "JWT signature mismatch",
74 Self::Expired => "JWT expired",
75 Self::WrongIssuer => "JWT issuer mismatch",
76 })
77 }
78}
79
80pub fn mint(secret: &[u8], claims: &JwtClaims) -> String {
89 debug_assert!(
90 claims.exp > claims.iat,
91 "JWT exp ({}) must be > iat ({})",
92 claims.exp,
93 claims.iat
94 );
95 let header = serde_json::json!({"alg": "HS256", "typ": "JWT"});
96 let mut claims_obj = serde_json::Map::new();
97 claims_obj.insert("sub".into(), claims.sub.clone().into());
98 claims_obj.insert("iat".into(), claims.iat.into());
99 claims_obj.insert("exp".into(), claims.exp.into());
100 claims_obj.insert("iss".into(), claims.iss.clone().into());
101 if let Some(t) = &claims.tenant_id {
102 claims_obj.insert("https://pylonsync.com/tenant".into(), t.clone().into());
103 }
104 if !claims.roles.is_empty() {
105 claims_obj.insert(
106 "https://pylonsync.com/roles".into(),
107 serde_json::Value::Array(claims.roles.iter().cloned().map(Into::into).collect()),
108 );
109 }
110 let header_b64 = base64_url(serde_json::to_vec(&header).unwrap());
111 let claims_b64 = base64_url(serde_json::to_vec(&claims_obj).unwrap());
112 let signing_input = format!("{header_b64}.{claims_b64}");
113 let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length");
114 mac.update(signing_input.as_bytes());
115 let sig = mac.finalize().into_bytes();
116 let sig_b64 = base64_url(sig);
117 format!("{signing_input}.{sig_b64}")
118}
119
120pub fn verify(
123 token: &str,
124 secret: &[u8],
125 expected_issuer: Option<&str>,
126) -> Result<JwtClaims, JwtError> {
127 let mut parts = token.split('.');
128 let header_b64 = parts.next().ok_or(JwtError::Malformed)?;
129 let claims_b64 = parts.next().ok_or(JwtError::Malformed)?;
130 let sig_b64 = parts.next().ok_or(JwtError::Malformed)?;
131 if parts.next().is_some() {
132 return Err(JwtError::Malformed);
133 }
134
135 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
136 let header_bytes = URL_SAFE_NO_PAD
137 .decode(header_b64)
138 .map_err(|_| JwtError::BadEncoding)?;
139 let header: serde_json::Value =
140 serde_json::from_slice(&header_bytes).map_err(|_| JwtError::BadEncoding)?;
141 if header.get("alg").and_then(|v| v.as_str()) != Some("HS256") {
142 return Err(JwtError::UnsupportedAlg);
143 }
144
145 let signing_input = format!("{header_b64}.{claims_b64}");
146 let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length");
147 mac.update(signing_input.as_bytes());
148 let expected_sig = mac.finalize().into_bytes();
149 let provided_sig = URL_SAFE_NO_PAD
150 .decode(sig_b64)
151 .map_err(|_| JwtError::BadEncoding)?;
152 if !crate::constant_time_eq(&expected_sig, &provided_sig) {
153 return Err(JwtError::BadSignature);
154 }
155
156 let claims_bytes = URL_SAFE_NO_PAD
157 .decode(claims_b64)
158 .map_err(|_| JwtError::BadEncoding)?;
159 let claims: serde_json::Value =
160 serde_json::from_slice(&claims_bytes).map_err(|_| JwtError::BadEncoding)?;
161
162 let now = SystemTime::now()
163 .duration_since(UNIX_EPOCH)
164 .map(|d| d.as_secs())
165 .unwrap_or(0);
166 let exp = claims.get("exp").and_then(|v| v.as_u64()).unwrap_or(0);
167 if exp <= now {
168 return Err(JwtError::Expired);
169 }
170 let iss = claims
171 .get("iss")
172 .and_then(|v| v.as_str())
173 .unwrap_or_default()
174 .to_string();
175 if let Some(want) = expected_issuer {
176 if iss != want {
177 return Err(JwtError::WrongIssuer);
178 }
179 }
180
181 let sub = claims
182 .get("sub")
183 .and_then(|v| v.as_str())
184 .ok_or(JwtError::BadEncoding)?
185 .to_string();
186 let iat = claims.get("iat").and_then(|v| v.as_u64()).unwrap_or(0);
187 let tenant_id = claims
188 .get("https://pylonsync.com/tenant")
189 .and_then(|v| v.as_str())
190 .map(String::from);
191 let roles = claims
192 .get("https://pylonsync.com/roles")
193 .and_then(|v| v.as_array())
194 .map(|arr| {
195 arr.iter()
196 .filter_map(|v| v.as_str().map(String::from))
197 .collect()
198 })
199 .unwrap_or_default();
200
201 Ok(JwtClaims {
202 sub,
203 iat,
204 exp,
205 iss,
206 tenant_id,
207 roles,
208 })
209}
210
211pub fn looks_like_jwt(token: &str) -> bool {
215 let mut parts = token.split('.');
216 let a = parts.next();
217 let b = parts.next();
218 let c = parts.next();
219 let extra = parts.next();
220 matches!((a, b, c, extra), (Some(a), Some(b), Some(c), None) if !a.is_empty() && !b.is_empty() && !c.is_empty())
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 fn fixture_claims(exp_secs_from_now: i64) -> JwtClaims {
228 let now = SystemTime::now()
229 .duration_since(UNIX_EPOCH)
230 .unwrap()
231 .as_secs();
232 JwtClaims {
233 sub: "user-1".into(),
234 iat: now,
235 exp: (now as i64 + exp_secs_from_now) as u64,
236 iss: "pylon-test".into(),
237 tenant_id: None,
238 roles: vec![],
239 }
240 }
241
242 #[test]
243 fn round_trip_minimal_claims() {
244 let secret = b"super-secret-pylon-key";
245 let claims = fixture_claims(3600);
246 let token = mint(secret, &claims);
247 let decoded = verify(&token, secret, Some("pylon-test")).unwrap();
248 assert_eq!(decoded.sub, "user-1");
249 assert_eq!(decoded.iss, "pylon-test");
250 }
251
252 #[test]
253 fn round_trip_with_tenant_and_roles() {
254 let secret = b"k";
255 let mut claims = fixture_claims(3600);
256 claims.tenant_id = Some("acme".into());
257 claims.roles = vec!["admin".into(), "billing".into()];
258 let token = mint(secret, &claims);
259 let decoded = verify(&token, secret, None).unwrap();
260 assert_eq!(decoded.tenant_id.as_deref(), Some("acme"));
261 assert_eq!(decoded.roles, vec!["admin", "billing"]);
262 }
263
264 #[test]
265 fn expired_token_rejected() {
266 let secret = b"k";
267 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
273 let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"HS256","typ":"JWT"}"#);
274 let claims =
275 URL_SAFE_NO_PAD.encode(br#"{"sub":"user-1","iat":1,"exp":2,"iss":"pylon-test"}"#);
276 let signing_input = format!("{header}.{claims}");
277 use hmac::{Hmac, Mac};
278 use sha2::Sha256;
279 let mut mac = Hmac::<Sha256>::new_from_slice(secret).unwrap();
280 mac.update(signing_input.as_bytes());
281 let sig = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes());
282 let token = format!("{signing_input}.{sig}");
283 assert_eq!(verify(&token, secret, None), Err(JwtError::Expired));
284 }
285
286 #[test]
287 #[should_panic(expected = "JWT exp")]
288 #[cfg(debug_assertions)]
289 fn mint_panics_on_exp_le_iat_in_debug() {
290 let secret = b"k";
291 let mut claims = fixture_claims(0);
292 claims.exp = claims.iat;
293 let _ = mint(secret, &claims);
294 }
295
296 #[test]
297 fn wrong_secret_rejected() {
298 let secret = b"k";
299 let claims = fixture_claims(3600);
300 let token = mint(secret, &claims);
301 assert_eq!(
302 verify(&token, b"different-secret", None),
303 Err(JwtError::BadSignature)
304 );
305 }
306
307 #[test]
308 fn wrong_issuer_rejected() {
309 let secret = b"k";
310 let claims = fixture_claims(3600);
311 let token = mint(secret, &claims);
312 assert_eq!(
313 verify(&token, secret, Some("different-issuer")),
314 Err(JwtError::WrongIssuer)
315 );
316 }
317
318 #[test]
319 fn alg_none_rejected() {
320 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
324 let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"none","typ":"JWT"}"#);
325 let claims = URL_SAFE_NO_PAD.encode(br#"{"sub":"attacker","exp":99999999999}"#);
326 let token = format!("{header}.{claims}.");
327 let result = verify(&token, b"any-secret", None);
328 assert_eq!(result, Err(JwtError::UnsupportedAlg));
329 }
330
331 #[test]
332 fn malformed_token_rejected() {
333 assert_eq!(
334 verify("not.a.jwt.too-many-parts", b"k", None),
335 Err(JwtError::Malformed)
336 );
337 assert_eq!(
338 verify("only-one-part", b"k", None),
339 Err(JwtError::Malformed)
340 );
341 assert_eq!(verify("", b"k", None), Err(JwtError::Malformed));
342 }
343
344 #[test]
345 fn looks_like_jwt_classifies() {
346 assert!(looks_like_jwt("aaa.bbb.ccc"));
347 assert!(!looks_like_jwt("pylon_abcdef"));
348 assert!(!looks_like_jwt("aaa.bbb"));
349 assert!(!looks_like_jwt(""));
350 assert!(!looks_like_jwt("aaa..ccc"));
351 assert!(looks_like_jwt("pk.key_abc.secret"));
357 }
358
359 #[test]
367 fn pk_token_overlaps_jwt_shape_dispatcher_must_check_prefix_first() {
368 let pk_like =
369 "pk.key_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
370 assert!(pk_like.starts_with("pk."));
371 assert!(looks_like_jwt(pk_like));
372 }
373}