Skip to main content

api_testing_core/
jwt.rs

1use base64::Engine;
2use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3
4use crate::Result;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct JwtValidationOptions {
8    pub enabled: bool,
9    pub strict: bool,
10    pub leeway_seconds: i64,
11}
12
13impl Default for JwtValidationOptions {
14    fn default() -> Self {
15        Self {
16            enabled: true,
17            strict: false,
18            leeway_seconds: 0,
19        }
20    }
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum JwtCheck {
25    Ok,
26    Warn(String),
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30enum JwtErr {
31    NotJwt,
32    InvalidJwt,
33    ExpInvalid,
34    NbfInvalid,
35    Expired { exp: i64, now: i64 },
36    NotYetValid { nbf: i64, now: i64 },
37}
38
39impl JwtErr {
40    fn script_code(&self) -> String {
41        match self {
42            Self::NotJwt => "not_jwt".to_string(),
43            Self::InvalidJwt => "invalid_jwt".to_string(),
44            Self::ExpInvalid => "exp_invalid".to_string(),
45            Self::NbfInvalid => "nbf_invalid".to_string(),
46            Self::Expired { exp, now } => format!("expired exp={exp} now={now}"),
47            Self::NotYetValid { nbf, now } => format!("nbf_in_future nbf={nbf} now={now}"),
48        }
49    }
50}
51
52fn unix_now_seconds() -> Result<i64> {
53    let now = std::time::SystemTime::now()
54        .duration_since(std::time::UNIX_EPOCH)?
55        .as_secs();
56    Ok(i64::try_from(now).unwrap_or(i64::MAX))
57}
58
59fn parse_numeric(value: &serde_json::Value) -> Option<i64> {
60    match value {
61        serde_json::Value::Bool(_) => None,
62        serde_json::Value::Number(n) => {
63            if let Some(i) = n.as_i64() {
64                return Some(i);
65            }
66            if let Some(u) = n.as_u64() {
67                return i64::try_from(u).ok();
68            }
69            n.as_f64().map(|f| f as i64)
70        }
71        serde_json::Value::String(s) => {
72            let trimmed = s.trim();
73            if trimmed.chars().all(|c| c.is_ascii_digit()) {
74                return trimmed.parse::<i64>().ok();
75            }
76            None
77        }
78        _ => None,
79    }
80}
81
82fn decode_json_segment(segment: &str) -> Result<serde_json::Value> {
83    let decoded = URL_SAFE_NO_PAD
84        .decode(segment.as_bytes())
85        .map_err(|_| anyhow::anyhow!("base64url decode failed"))?;
86    let v: serde_json::Value =
87        serde_json::from_slice(&decoded).map_err(|_| anyhow::anyhow!("json decode failed"))?;
88    Ok(v)
89}
90
91fn validate_jwt_at(token: &str, leeway_seconds: i64, now: i64) -> std::result::Result<(), JwtErr> {
92    let parts: Vec<&str> = token.trim().split('.').collect();
93    if parts.len() != 3 {
94        return Err(JwtErr::NotJwt);
95    }
96
97    let payload = match (decode_json_segment(parts[0]), decode_json_segment(parts[1])) {
98        (Ok(_header), Ok(payload)) => payload,
99        _ => return Err(JwtErr::InvalidJwt),
100    };
101
102    if let Some(exp) = payload.get("exp") {
103        let Some(exp_val) = parse_numeric(exp) else {
104            return Err(JwtErr::ExpInvalid);
105        };
106        if exp_val < (now - leeway_seconds) {
107            return Err(JwtErr::Expired { exp: exp_val, now });
108        }
109    }
110
111    if let Some(nbf) = payload.get("nbf") {
112        let Some(nbf_val) = parse_numeric(nbf) else {
113            return Err(JwtErr::NbfInvalid);
114        };
115        if nbf_val > (now + leeway_seconds) {
116            return Err(JwtErr::NotYetValid { nbf: nbf_val, now });
117        }
118    }
119
120    Ok(())
121}
122
123pub fn check_bearer_jwt_at(
124    token: &str,
125    label: &str,
126    opts: JwtValidationOptions,
127    now: i64,
128) -> Result<JwtCheck> {
129    let token = token.trim();
130    if token.is_empty() || !opts.enabled {
131        return Ok(JwtCheck::Ok);
132    }
133
134    let leeway_seconds = opts.leeway_seconds.max(0);
135    match validate_jwt_at(token, leeway_seconds, now) {
136        Ok(()) => Ok(JwtCheck::Ok),
137        Err(JwtErr::Expired { exp, now }) => {
138            let code = JwtErr::Expired { exp, now }.script_code();
139            anyhow::bail!("JWT expired for {label} ({code})");
140        }
141        Err(JwtErr::NotYetValid { nbf, now }) => {
142            let code = JwtErr::NotYetValid { nbf, now }.script_code();
143            anyhow::bail!("JWT not yet valid for {label} ({code})");
144        }
145        Err(
146            err @ (JwtErr::NotJwt | JwtErr::InvalidJwt | JwtErr::ExpInvalid | JwtErr::NbfInvalid),
147        ) => {
148            let code = err.script_code();
149            if opts.strict {
150                anyhow::bail!("invalid JWT for {label} ({code})");
151            }
152            Ok(JwtCheck::Warn(format!(
153                "token for {label} is not a valid JWT ({code}); skipping format validation"
154            )))
155        }
156    }
157}
158
159pub fn check_bearer_jwt(token: &str, label: &str, opts: JwtValidationOptions) -> Result<JwtCheck> {
160    let now = unix_now_seconds()?;
161    check_bearer_jwt_at(token, label, opts, now)
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use pretty_assertions::assert_eq;
168
169    fn b64url_json(value: &serde_json::Value) -> String {
170        let bytes = serde_json::to_vec(value).expect("json");
171        URL_SAFE_NO_PAD.encode(bytes)
172    }
173
174    fn make_jwt(payload: serde_json::Value) -> String {
175        let header = serde_json::json!({"alg":"none","typ":"JWT"});
176        format!("{}.{}.sig", b64url_json(&header), b64url_json(&payload))
177    }
178
179    #[test]
180    fn jwt_ok_when_claims_are_valid() {
181        let now = 1_700_000_000;
182        let token = make_jwt(serde_json::json!({"exp": now + 10, "nbf": now - 10}));
183        let out = check_bearer_jwt_at(&token, "t", JwtValidationOptions::default(), now).unwrap();
184        assert_eq!(out, JwtCheck::Ok);
185    }
186
187    #[test]
188    fn jwt_expired_is_hard_error_even_when_non_strict() {
189        let now = 1_700_000_000;
190        let token = make_jwt(serde_json::json!({"exp": now - 1}));
191        let err =
192            check_bearer_jwt_at(&token, "t", JwtValidationOptions::default(), now).unwrap_err();
193        assert!(format!("{err:#}").contains("JWT expired"));
194    }
195
196    #[test]
197    fn jwt_nbf_in_future_is_hard_error_even_when_non_strict() {
198        let now = 1_700_000_000;
199        let token = make_jwt(serde_json::json!({"nbf": now + 1}));
200        let err =
201            check_bearer_jwt_at(&token, "t", JwtValidationOptions::default(), now).unwrap_err();
202        assert!(format!("{err:#}").contains("JWT not yet valid"));
203    }
204
205    #[test]
206    fn jwt_format_errors_warn_when_non_strict() {
207        let now = 1_700_000_000;
208        let opts = JwtValidationOptions {
209            strict: false,
210            ..JwtValidationOptions::default()
211        };
212
213        let out = check_bearer_jwt_at("not.a.jwt", "t", opts, now).unwrap();
214        match out {
215            JwtCheck::Warn(msg) => assert!(msg.contains("skipping format validation")),
216            other => panic!("expected warn, got {other:?}"),
217        }
218    }
219
220    #[test]
221    fn jwt_format_errors_fail_when_strict() {
222        let now = 1_700_000_000;
223        let opts = JwtValidationOptions {
224            strict: true,
225            ..JwtValidationOptions::default()
226        };
227
228        let err = check_bearer_jwt_at("not.a.jwt", "t", opts, now).unwrap_err();
229        assert!(format!("{err:#}").contains("invalid JWT"));
230    }
231
232    #[test]
233    fn jwt_leeway_applies_to_exp_and_nbf() {
234        let now = 1_700_000_000;
235        let token = make_jwt(serde_json::json!({"exp": now - 5, "nbf": now + 5}));
236        let opts = JwtValidationOptions {
237            leeway_seconds: 10,
238            ..JwtValidationOptions::default()
239        };
240
241        let out = check_bearer_jwt_at(&token, "t", opts, now).unwrap();
242        assert_eq!(out, JwtCheck::Ok);
243    }
244
245    #[test]
246    fn jwt_validation_can_be_disabled() {
247        let now = 1_700_000_000;
248        let opts = JwtValidationOptions {
249            enabled: false,
250            ..JwtValidationOptions::default()
251        };
252
253        let out = check_bearer_jwt_at("not.a.jwt", "t", opts, now).unwrap();
254        assert_eq!(out, JwtCheck::Ok);
255    }
256}