1use base64::Engine;
2use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3
4use crate::Result;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct JwtValidationOptions {
8 pub enabled: bool,
9 pub strict: bool,
10 pub leeway_seconds: i64,
11}
12
13impl Default for JwtValidationOptions {
14 fn default() -> Self {
15 Self {
16 enabled: true,
17 strict: false,
18 leeway_seconds: 0,
19 }
20 }
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum JwtCheck {
25 Ok,
26 Warn(String),
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30enum JwtErr {
31 NotJwt,
32 InvalidJwt,
33 ExpInvalid,
34 NbfInvalid,
35 Expired { exp: i64, now: i64 },
36 NotYetValid { nbf: i64, now: i64 },
37}
38
39impl JwtErr {
40 fn script_code(&self) -> String {
41 match self {
42 Self::NotJwt => "not_jwt".to_string(),
43 Self::InvalidJwt => "invalid_jwt".to_string(),
44 Self::ExpInvalid => "exp_invalid".to_string(),
45 Self::NbfInvalid => "nbf_invalid".to_string(),
46 Self::Expired { exp, now } => format!("expired exp={exp} now={now}"),
47 Self::NotYetValid { nbf, now } => format!("nbf_in_future nbf={nbf} now={now}"),
48 }
49 }
50}
51
52fn unix_now_seconds() -> Result<i64> {
53 let now = std::time::SystemTime::now()
54 .duration_since(std::time::UNIX_EPOCH)?
55 .as_secs();
56 Ok(i64::try_from(now).unwrap_or(i64::MAX))
57}
58
59fn parse_numeric(value: &serde_json::Value) -> Option<i64> {
60 match value {
61 serde_json::Value::Bool(_) => None,
62 serde_json::Value::Number(n) => {
63 if let Some(i) = n.as_i64() {
64 return Some(i);
65 }
66 if let Some(u) = n.as_u64() {
67 return i64::try_from(u).ok();
68 }
69 n.as_f64().map(|f| f as i64)
70 }
71 serde_json::Value::String(s) => {
72 let trimmed = s.trim();
73 if trimmed.chars().all(|c| c.is_ascii_digit()) {
74 return trimmed.parse::<i64>().ok();
75 }
76 None
77 }
78 _ => None,
79 }
80}
81
82fn decode_json_segment(segment: &str) -> Result<serde_json::Value> {
83 let decoded = URL_SAFE_NO_PAD
84 .decode(segment.as_bytes())
85 .map_err(|_| anyhow::anyhow!("base64url decode failed"))?;
86 let v: serde_json::Value =
87 serde_json::from_slice(&decoded).map_err(|_| anyhow::anyhow!("json decode failed"))?;
88 Ok(v)
89}
90
91fn validate_jwt_at(token: &str, leeway_seconds: i64, now: i64) -> std::result::Result<(), JwtErr> {
92 let parts: Vec<&str> = token.trim().split('.').collect();
93 if parts.len() != 3 {
94 return Err(JwtErr::NotJwt);
95 }
96
97 let payload = match (decode_json_segment(parts[0]), decode_json_segment(parts[1])) {
98 (Ok(_header), Ok(payload)) => payload,
99 _ => return Err(JwtErr::InvalidJwt),
100 };
101
102 if let Some(exp) = payload.get("exp") {
103 let Some(exp_val) = parse_numeric(exp) else {
104 return Err(JwtErr::ExpInvalid);
105 };
106 if exp_val < (now - leeway_seconds) {
107 return Err(JwtErr::Expired { exp: exp_val, now });
108 }
109 }
110
111 if let Some(nbf) = payload.get("nbf") {
112 let Some(nbf_val) = parse_numeric(nbf) else {
113 return Err(JwtErr::NbfInvalid);
114 };
115 if nbf_val > (now + leeway_seconds) {
116 return Err(JwtErr::NotYetValid { nbf: nbf_val, now });
117 }
118 }
119
120 Ok(())
121}
122
123pub fn check_bearer_jwt_at(
124 token: &str,
125 label: &str,
126 opts: JwtValidationOptions,
127 now: i64,
128) -> Result<JwtCheck> {
129 let token = token.trim();
130 if token.is_empty() || !opts.enabled {
131 return Ok(JwtCheck::Ok);
132 }
133
134 let leeway_seconds = opts.leeway_seconds.max(0);
135 match validate_jwt_at(token, leeway_seconds, now) {
136 Ok(()) => Ok(JwtCheck::Ok),
137 Err(JwtErr::Expired { exp, now }) => {
138 let code = JwtErr::Expired { exp, now }.script_code();
139 anyhow::bail!("JWT expired for {label} ({code})");
140 }
141 Err(JwtErr::NotYetValid { nbf, now }) => {
142 let code = JwtErr::NotYetValid { nbf, now }.script_code();
143 anyhow::bail!("JWT not yet valid for {label} ({code})");
144 }
145 Err(
146 err @ (JwtErr::NotJwt | JwtErr::InvalidJwt | JwtErr::ExpInvalid | JwtErr::NbfInvalid),
147 ) => {
148 let code = err.script_code();
149 if opts.strict {
150 anyhow::bail!("invalid JWT for {label} ({code})");
151 }
152 Ok(JwtCheck::Warn(format!(
153 "token for {label} is not a valid JWT ({code}); skipping format validation"
154 )))
155 }
156 }
157}
158
159pub fn check_bearer_jwt(token: &str, label: &str, opts: JwtValidationOptions) -> Result<JwtCheck> {
160 let now = unix_now_seconds()?;
161 check_bearer_jwt_at(token, label, opts, now)
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use pretty_assertions::assert_eq;
168
169 fn b64url_json(value: &serde_json::Value) -> String {
170 let bytes = serde_json::to_vec(value).expect("json");
171 URL_SAFE_NO_PAD.encode(bytes)
172 }
173
174 fn make_jwt(payload: serde_json::Value) -> String {
175 let header = serde_json::json!({"alg":"none","typ":"JWT"});
176 format!("{}.{}.sig", b64url_json(&header), b64url_json(&payload))
177 }
178
179 #[test]
180 fn jwt_ok_when_claims_are_valid() {
181 let now = 1_700_000_000;
182 let token = make_jwt(serde_json::json!({"exp": now + 10, "nbf": now - 10}));
183 let out = check_bearer_jwt_at(&token, "t", JwtValidationOptions::default(), now).unwrap();
184 assert_eq!(out, JwtCheck::Ok);
185 }
186
187 #[test]
188 fn jwt_expired_is_hard_error_even_when_non_strict() {
189 let now = 1_700_000_000;
190 let token = make_jwt(serde_json::json!({"exp": now - 1}));
191 let err =
192 check_bearer_jwt_at(&token, "t", JwtValidationOptions::default(), now).unwrap_err();
193 assert!(format!("{err:#}").contains("JWT expired"));
194 }
195
196 #[test]
197 fn jwt_nbf_in_future_is_hard_error_even_when_non_strict() {
198 let now = 1_700_000_000;
199 let token = make_jwt(serde_json::json!({"nbf": now + 1}));
200 let err =
201 check_bearer_jwt_at(&token, "t", JwtValidationOptions::default(), now).unwrap_err();
202 assert!(format!("{err:#}").contains("JWT not yet valid"));
203 }
204
205 #[test]
206 fn jwt_format_errors_warn_when_non_strict() {
207 let now = 1_700_000_000;
208 let opts = JwtValidationOptions {
209 strict: false,
210 ..JwtValidationOptions::default()
211 };
212
213 let out = check_bearer_jwt_at("not.a.jwt", "t", opts, now).unwrap();
214 match out {
215 JwtCheck::Warn(msg) => assert!(msg.contains("skipping format validation")),
216 other => panic!("expected warn, got {other:?}"),
217 }
218 }
219
220 #[test]
221 fn jwt_format_errors_fail_when_strict() {
222 let now = 1_700_000_000;
223 let opts = JwtValidationOptions {
224 strict: true,
225 ..JwtValidationOptions::default()
226 };
227
228 let err = check_bearer_jwt_at("not.a.jwt", "t", opts, now).unwrap_err();
229 assert!(format!("{err:#}").contains("invalid JWT"));
230 }
231
232 #[test]
233 fn jwt_leeway_applies_to_exp_and_nbf() {
234 let now = 1_700_000_000;
235 let token = make_jwt(serde_json::json!({"exp": now - 5, "nbf": now + 5}));
236 let opts = JwtValidationOptions {
237 leeway_seconds: 10,
238 ..JwtValidationOptions::default()
239 };
240
241 let out = check_bearer_jwt_at(&token, "t", opts, now).unwrap();
242 assert_eq!(out, JwtCheck::Ok);
243 }
244
245 #[test]
246 fn jwt_validation_can_be_disabled() {
247 let now = 1_700_000_000;
248 let opts = JwtValidationOptions {
249 enabled: false,
250 ..JwtValidationOptions::default()
251 };
252
253 let out = check_bearer_jwt_at("not.a.jwt", "t", opts, now).unwrap();
254 assert_eq!(out, JwtCheck::Ok);
255 }
256}