1use crate::apple_jwt::base64_url;
25use hmac::{Hmac, Mac};
26use sha2::Sha256;
27use std::time::{SystemTime, UNIX_EPOCH};
28
29type HmacSha256 = Hmac<Sha256>;
30
31#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct JwtClaims {
35 pub sub: String,
37 pub iat: u64,
39 pub exp: u64,
42 pub iss: String,
44 pub tenant_id: Option<String>,
47 pub roles: Vec<String>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum JwtError {
53 Malformed,
55 BadEncoding,
57 UnsupportedAlg,
59 BadSignature,
61 Expired,
63 WrongIssuer,
65}
66
67impl std::fmt::Display for JwtError {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.write_str(match self {
70 Self::Malformed => "JWT malformed",
71 Self::BadEncoding => "JWT base64/JSON decode failed",
72 Self::UnsupportedAlg => "JWT alg not supported (expected HS256)",
73 Self::BadSignature => "JWT signature mismatch",
74 Self::Expired => "JWT expired",
75 Self::WrongIssuer => "JWT issuer mismatch",
76 })
77 }
78}
79
80pub fn mint(secret: &[u8], claims: &JwtClaims) -> String {
89 debug_assert!(
90 claims.exp > claims.iat,
91 "JWT exp ({}) must be > iat ({})",
92 claims.exp,
93 claims.iat
94 );
95 let header = serde_json::json!({"alg": "HS256", "typ": "JWT"});
96 let mut claims_obj = serde_json::Map::new();
97 claims_obj.insert("sub".into(), claims.sub.clone().into());
98 claims_obj.insert("iat".into(), claims.iat.into());
99 claims_obj.insert("exp".into(), claims.exp.into());
100 claims_obj.insert("iss".into(), claims.iss.clone().into());
101 if let Some(t) = &claims.tenant_id {
102 claims_obj.insert("https://pylonsync.com/tenant".into(), t.clone().into());
103 }
104 if !claims.roles.is_empty() {
105 claims_obj.insert(
106 "https://pylonsync.com/roles".into(),
107 serde_json::Value::Array(
108 claims.roles.iter().cloned().map(Into::into).collect(),
109 ),
110 );
111 }
112 let header_b64 = base64_url(serde_json::to_vec(&header).unwrap());
113 let claims_b64 = base64_url(serde_json::to_vec(&claims_obj).unwrap());
114 let signing_input = format!("{header_b64}.{claims_b64}");
115 let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length");
116 mac.update(signing_input.as_bytes());
117 let sig = mac.finalize().into_bytes();
118 let sig_b64 = base64_url(sig);
119 format!("{signing_input}.{sig_b64}")
120}
121
122pub fn verify(token: &str, secret: &[u8], expected_issuer: Option<&str>) -> Result<JwtClaims, JwtError> {
125 let mut parts = token.split('.');
126 let header_b64 = parts.next().ok_or(JwtError::Malformed)?;
127 let claims_b64 = parts.next().ok_or(JwtError::Malformed)?;
128 let sig_b64 = parts.next().ok_or(JwtError::Malformed)?;
129 if parts.next().is_some() {
130 return Err(JwtError::Malformed);
131 }
132
133 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
134 let header_bytes = URL_SAFE_NO_PAD
135 .decode(header_b64)
136 .map_err(|_| JwtError::BadEncoding)?;
137 let header: serde_json::Value =
138 serde_json::from_slice(&header_bytes).map_err(|_| JwtError::BadEncoding)?;
139 if header.get("alg").and_then(|v| v.as_str()) != Some("HS256") {
140 return Err(JwtError::UnsupportedAlg);
141 }
142
143 let signing_input = format!("{header_b64}.{claims_b64}");
144 let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length");
145 mac.update(signing_input.as_bytes());
146 let expected_sig = mac.finalize().into_bytes();
147 let provided_sig = URL_SAFE_NO_PAD
148 .decode(sig_b64)
149 .map_err(|_| JwtError::BadEncoding)?;
150 if !crate::constant_time_eq(&expected_sig, &provided_sig) {
151 return Err(JwtError::BadSignature);
152 }
153
154 let claims_bytes = URL_SAFE_NO_PAD
155 .decode(claims_b64)
156 .map_err(|_| JwtError::BadEncoding)?;
157 let claims: serde_json::Value =
158 serde_json::from_slice(&claims_bytes).map_err(|_| JwtError::BadEncoding)?;
159
160 let now = SystemTime::now()
161 .duration_since(UNIX_EPOCH)
162 .map(|d| d.as_secs())
163 .unwrap_or(0);
164 let exp = claims.get("exp").and_then(|v| v.as_u64()).unwrap_or(0);
165 if exp <= now {
166 return Err(JwtError::Expired);
167 }
168 let iss = claims
169 .get("iss")
170 .and_then(|v| v.as_str())
171 .unwrap_or_default()
172 .to_string();
173 if let Some(want) = expected_issuer {
174 if iss != want {
175 return Err(JwtError::WrongIssuer);
176 }
177 }
178
179 let sub = claims
180 .get("sub")
181 .and_then(|v| v.as_str())
182 .ok_or(JwtError::BadEncoding)?
183 .to_string();
184 let iat = claims.get("iat").and_then(|v| v.as_u64()).unwrap_or(0);
185 let tenant_id = claims
186 .get("https://pylonsync.com/tenant")
187 .and_then(|v| v.as_str())
188 .map(String::from);
189 let roles = claims
190 .get("https://pylonsync.com/roles")
191 .and_then(|v| v.as_array())
192 .map(|arr| {
193 arr.iter()
194 .filter_map(|v| v.as_str().map(String::from))
195 .collect()
196 })
197 .unwrap_or_default();
198
199 Ok(JwtClaims {
200 sub,
201 iat,
202 exp,
203 iss,
204 tenant_id,
205 roles,
206 })
207}
208
209pub fn looks_like_jwt(token: &str) -> bool {
213 let mut parts = token.split('.');
214 let a = parts.next();
215 let b = parts.next();
216 let c = parts.next();
217 let extra = parts.next();
218 matches!((a, b, c, extra), (Some(a), Some(b), Some(c), None) if !a.is_empty() && !b.is_empty() && !c.is_empty())
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 fn fixture_claims(exp_secs_from_now: i64) -> JwtClaims {
226 let now = SystemTime::now()
227 .duration_since(UNIX_EPOCH)
228 .unwrap()
229 .as_secs();
230 JwtClaims {
231 sub: "user-1".into(),
232 iat: now,
233 exp: (now as i64 + exp_secs_from_now) as u64,
234 iss: "pylon-test".into(),
235 tenant_id: None,
236 roles: vec![],
237 }
238 }
239
240 #[test]
241 fn round_trip_minimal_claims() {
242 let secret = b"super-secret-pylon-key";
243 let claims = fixture_claims(3600);
244 let token = mint(secret, &claims);
245 let decoded = verify(&token, secret, Some("pylon-test")).unwrap();
246 assert_eq!(decoded.sub, "user-1");
247 assert_eq!(decoded.iss, "pylon-test");
248 }
249
250 #[test]
251 fn round_trip_with_tenant_and_roles() {
252 let secret = b"k";
253 let mut claims = fixture_claims(3600);
254 claims.tenant_id = Some("acme".into());
255 claims.roles = vec!["admin".into(), "billing".into()];
256 let token = mint(secret, &claims);
257 let decoded = verify(&token, secret, None).unwrap();
258 assert_eq!(decoded.tenant_id.as_deref(), Some("acme"));
259 assert_eq!(decoded.roles, vec!["admin", "billing"]);
260 }
261
262 #[test]
263 fn expired_token_rejected() {
264 let secret = b"k";
265 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
271 let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"HS256","typ":"JWT"}"#);
272 let claims = URL_SAFE_NO_PAD
273 .encode(br#"{"sub":"user-1","iat":1,"exp":2,"iss":"pylon-test"}"#);
274 let signing_input = format!("{header}.{claims}");
275 use hmac::{Hmac, Mac};
276 use sha2::Sha256;
277 let mut mac = Hmac::<Sha256>::new_from_slice(secret).unwrap();
278 mac.update(signing_input.as_bytes());
279 let sig = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes());
280 let token = format!("{signing_input}.{sig}");
281 assert_eq!(verify(&token, secret, None), Err(JwtError::Expired));
282 }
283
284 #[test]
285 #[should_panic(expected = "JWT exp")]
286 #[cfg(debug_assertions)]
287 fn mint_panics_on_exp_le_iat_in_debug() {
288 let secret = b"k";
289 let mut claims = fixture_claims(0);
290 claims.exp = claims.iat;
291 let _ = mint(secret, &claims);
292 }
293
294 #[test]
295 fn wrong_secret_rejected() {
296 let secret = b"k";
297 let claims = fixture_claims(3600);
298 let token = mint(secret, &claims);
299 assert_eq!(
300 verify(&token, b"different-secret", None),
301 Err(JwtError::BadSignature)
302 );
303 }
304
305 #[test]
306 fn wrong_issuer_rejected() {
307 let secret = b"k";
308 let claims = fixture_claims(3600);
309 let token = mint(secret, &claims);
310 assert_eq!(
311 verify(&token, secret, Some("different-issuer")),
312 Err(JwtError::WrongIssuer)
313 );
314 }
315
316 #[test]
317 fn alg_none_rejected() {
318 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
322 let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"none","typ":"JWT"}"#);
323 let claims = URL_SAFE_NO_PAD.encode(br#"{"sub":"attacker","exp":99999999999}"#);
324 let token = format!("{header}.{claims}.");
325 let result = verify(&token, b"any-secret", None);
326 assert_eq!(result, Err(JwtError::UnsupportedAlg));
327 }
328
329 #[test]
330 fn malformed_token_rejected() {
331 assert_eq!(verify("not.a.jwt.too-many-parts", b"k", None), Err(JwtError::Malformed));
332 assert_eq!(verify("only-one-part", b"k", None), Err(JwtError::Malformed));
333 assert_eq!(verify("", b"k", None), Err(JwtError::Malformed));
334 }
335
336 #[test]
337 fn looks_like_jwt_classifies() {
338 assert!(looks_like_jwt("aaa.bbb.ccc"));
339 assert!(!looks_like_jwt("pylon_abcdef"));
340 assert!(!looks_like_jwt("aaa.bbb"));
341 assert!(!looks_like_jwt(""));
342 assert!(!looks_like_jwt("aaa..ccc"));
343 assert!(looks_like_jwt("pk.key_abc.secret"));
349 }
350
351 #[test]
359 fn pk_token_overlaps_jwt_shape_dispatcher_must_check_prefix_first() {
360 let pk_like = "pk.key_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
361 assert!(pk_like.starts_with("pk."));
362 assert!(looks_like_jwt(pk_like));
363 }
364}