1use crate::error::Error;
2
3const MAX_RESPONSE_SIZE: usize = 100 * 1024; static PATH_CHAR_RE: std::sync::LazyLock<regex::Regex> =
6 std::sync::LazyLock::new(|| regex::Regex::new(r"^[a-zA-Z0-9\-._~/]+$").unwrap());
7
8pub fn validate_issuer(issuer: &str) -> Result<(), Error> {
9 if issuer.is_empty() || issuer.chars().count() > 255 {
10 return Err(Error::Unauthenticated(
11 "issuer empty or exceeds 255 characters".into(),
12 ));
13 }
14
15 let parsed = url::Url::parse(issuer)
16 .map_err(|_| Error::Unauthenticated("issuer is not a valid URL".into()))?;
17
18 match parsed.scheme() {
19 "https" => {}
20 "http" => match parsed.host() {
21 Some(url::Host::Domain("localhost")) => {}
22 Some(url::Host::Ipv4(ip)) if ip == std::net::Ipv4Addr::LOCALHOST => {}
23 Some(url::Host::Ipv6(ip)) if ip == std::net::Ipv6Addr::LOCALHOST => {}
24 _ => {
25 return Err(Error::Unauthenticated("issuer must use HTTPS".into()));
26 }
27 },
28 _ => {
29 return Err(Error::Unauthenticated("issuer must use HTTPS".into()));
30 }
31 }
32
33 if parsed.query().is_some() || parsed.fragment().is_some() {
35 return Err(Error::Unauthenticated(
36 "issuer must not contain query or fragment".into(),
37 ));
38 }
39 if issuer.contains('?') || issuer.contains('#') {
40 return Err(Error::Unauthenticated(
41 "issuer must not contain query or fragment".into(),
42 ));
43 }
44
45 if parsed.host_str().is_none() || parsed.host_str() == Some("") {
46 return Err(Error::Unauthenticated("issuer must have a host".into()));
47 }
48
49 if !parsed.username().is_empty() || parsed.password().is_some() {
50 return Err(Error::Unauthenticated(
51 "issuer must not contain userinfo".into(),
52 ));
53 }
54
55 let raw_host = {
58 let after_scheme = issuer
59 .strip_prefix(parsed.scheme())
60 .and_then(|s| s.strip_prefix("://"))
61 .unwrap_or("");
62 let host_part = if let Some(pos) = after_scheme.find('/') {
63 &after_scheme[..pos]
64 } else {
65 after_scheme
66 };
67 if host_part.starts_with('[') {
68 host_part.to_owned()
70 } else if let Some(pos) = host_part.rfind(':') {
71 host_part[..pos].to_owned()
72 } else {
73 host_part.to_owned()
74 }
75 };
76 for ch in raw_host.chars() {
77 if ch as u32 > 127 {
78 return Err(Error::Unauthenticated(
79 "issuer hostname must be ASCII-only".into(),
80 ));
81 }
82 if ch.is_control() || ch.is_whitespace() {
83 return Err(Error::Unauthenticated(
84 "issuer hostname contains invalid characters".into(),
85 ));
86 }
87 }
88
89 let raw_path = issuer
92 .strip_prefix(parsed.scheme())
93 .and_then(|s| s.strip_prefix("://"))
94 .and_then(|s| s.find('/').map(|pos| &s[pos..]))
95 .unwrap_or("");
96 let path = if raw_path.is_empty() {
97 parsed.path()
98 } else {
99 raw_path
100 };
101 if !path.is_empty() && path != "/" {
102 if !path.starts_with('/') {
103 return Err(Error::Unauthenticated(
104 "issuer path must start with /".into(),
105 ));
106 }
107 if path.contains("..") {
108 return Err(Error::Unauthenticated(
109 "issuer path must not contain ..".into(),
110 ));
111 }
112 if path.contains("//") {
113 return Err(Error::Unauthenticated(
114 "issuer path must not contain //".into(),
115 ));
116 }
117 if path.contains("~~") {
118 return Err(Error::Unauthenticated(
119 "issuer path must not contain ~~".into(),
120 ));
121 }
122 if path.ends_with('~') {
123 return Err(Error::Unauthenticated(
124 "issuer path must not end with ~".into(),
125 ));
126 }
127
128 if !PATH_CHAR_RE.is_match(path) {
129 return Err(Error::Unauthenticated(
130 "issuer path contains invalid characters".into(),
131 ));
132 }
133
134 for segment in path.split('/') {
135 if segment.is_empty() {
136 continue;
137 }
138 if segment == "." || segment == ".." || segment == "~" {
139 return Err(Error::Unauthenticated(
140 "issuer path contains invalid segment".into(),
141 ));
142 }
143 if segment.len() > 150 {
144 return Err(Error::Unauthenticated(
145 "issuer path segment exceeds 150 characters".into(),
146 ));
147 }
148 }
149 }
150
151 Ok(())
152}
153
154const SUBJECT_REJECT_CHARS: &str = "\"'`\\<>;&$(){}[]";
155const AUDIENCE_REJECT_CHARS: &str = "\"'`\\<>;|&$(){}[]@";
156
157pub fn validate_subject(value: &str) -> Result<(), Error> {
158 validate_claim_string(value, SUBJECT_REJECT_CHARS, "subject")
159}
160
161pub fn validate_audience(value: &str) -> Result<(), Error> {
162 validate_claim_string(value, AUDIENCE_REJECT_CHARS, "audience")
163}
164
165fn validate_claim_string(value: &str, reject_chars: &str, field: &str) -> Result<(), Error> {
166 if value.is_empty() {
167 return Err(Error::Unauthenticated(format!("{field} must not be empty")));
168 }
169 if value.chars().count() > 255 {
170 return Err(Error::Unauthenticated(format!(
171 "{field} exceeds 255 characters"
172 )));
173 }
174 for ch in value.chars() {
175 if (ch as u32) <= 0x1f {
176 return Err(Error::Unauthenticated(format!(
177 "{field} contains control characters"
178 )));
179 }
180 if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
181 return Err(Error::Unauthenticated(format!(
182 "{field} contains whitespace"
183 )));
184 }
185 if reject_chars.contains(ch) {
186 return Err(Error::Unauthenticated(format!(
187 "{field} contains invalid character"
188 )));
189 }
190 if !ch.is_alphanumeric() && !ch.is_ascii_punctuation() && ch as u32 > 127 {
191 if !is_printable(ch) {
193 return Err(Error::Unauthenticated(format!(
194 "{field} contains non-printable character"
195 )));
196 }
197 }
198 }
199 Ok(())
200}
201
202fn is_printable(ch: char) -> bool {
203 !ch.is_control() && ch as u32 != 0xFFFD
204}
205
206#[derive(Debug, serde::Deserialize)]
207pub(crate) struct OidcDiscoveryDocument {
208 pub(crate) issuer: String,
209 pub(crate) jwks_uri: String,
210}
211
212#[derive(Debug, Clone)]
213pub(crate) struct OidcProvider {
214 pub(crate) jwks: jsonwebtoken::jwk::JwkSet,
215}
216
217pub struct OidcVerifier {
218 http: reqwest::Client,
219 cache: moka::future::Cache<String, std::sync::Arc<OidcProvider>>,
220 allowed_issuers: Option<std::collections::HashSet<String>>,
221}
222
223impl Default for OidcVerifier {
224 fn default() -> Self {
225 Self::new(None)
226 }
227}
228
229#[derive(Debug, serde::Deserialize)]
230pub struct TokenClaims {
231 pub iss: String,
232 pub sub: String,
233 pub aud: OneOrMany,
234 #[serde(flatten)]
235 pub extra: std::collections::HashMap<String, serde_json::Value>,
236}
237
238#[derive(Debug, serde::Deserialize)]
239#[serde(untagged)]
240pub enum OneOrMany {
241 One(String),
242 Many(Vec<String>),
243}
244
245impl OneOrMany {
246 pub fn iter(&self) -> impl Iterator<Item = &str> {
247 let slice: &[String] = match self {
248 OneOrMany::One(s) => std::slice::from_ref(s),
249 OneOrMany::Many(v) => v.as_slice(),
250 };
251 slice.iter().map(|s| s.as_str())
252 }
253}
254
255impl OidcVerifier {
256 pub fn new(allowed_issuer_urls: Option<Vec<String>>) -> Self {
257 let redirect_policy = reqwest::redirect::Policy::custom(|attempt| {
258 let url_str = attempt.url().to_string();
259 if validate_issuer(&url_str).is_err() {
260 attempt.error(std::io::Error::new(
261 std::io::ErrorKind::PermissionDenied,
262 format!("redirect to invalid issuer URL: {url_str}"),
263 ))
264 } else {
265 attempt.follow()
266 }
267 });
268
269 let http = reqwest::Client::builder()
270 .connect_timeout(std::time::Duration::from_secs(10))
271 .timeout(std::time::Duration::from_secs(30))
272 .redirect(redirect_policy)
273 .user_agent(format!("sts-cat/{}", env!("CARGO_PKG_VERSION")))
274 .build()
275 .expect("failed to build OIDC HTTP client");
276
277 let cache = moka::future::Cache::builder()
278 .max_capacity(100)
279 .time_to_live(std::time::Duration::from_secs(900))
280 .build();
281
282 let allowed_issuers = allowed_issuer_urls.map(|urls| {
283 urls.into_iter()
284 .map(|u| u.trim_end_matches('/').to_owned())
285 .collect()
286 });
287
288 Self {
289 http,
290 cache,
291 allowed_issuers,
292 }
293 }
294
295 #[tracing::instrument(skip_all, fields(issuer))]
296 async fn discover(&self, issuer: &str) -> Result<std::sync::Arc<OidcProvider>, Error> {
297 if let Some(provider) = self.cache.get(issuer).await {
298 return Ok(provider);
299 }
300
301 let provider = self.discover_with_retry(issuer).await?;
302 let provider = std::sync::Arc::new(provider);
303 self.cache.insert(issuer.to_owned(), provider.clone()).await;
304 Ok(provider)
305 }
306
307 #[tracing::instrument(skip_all, fields(issuer))]
308 async fn discover_with_retry(&self, issuer: &str) -> Result<OidcProvider, Error> {
309 use backon::Retryable as _;
310
311 let discover_fn = || async { self.discover_once(issuer).await };
312
313 discover_fn
314 .retry(
315 backon::ExponentialBuilder::default()
316 .with_min_delay(std::time::Duration::from_secs(1))
317 .with_max_delay(std::time::Duration::from_secs(30))
318 .with_factor(2.0)
319 .with_jitter()
320 .with_max_times(6),
321 )
322 .when(|e| !is_permanent_error(e))
323 .await
324 }
325
326 #[tracing::instrument(skip_all, fields(issuer))]
327 async fn discover_once(&self, issuer: &str) -> Result<OidcProvider, Error> {
328 let discovery_url = format!(
329 "{}/.well-known/openid-configuration",
330 issuer.trim_end_matches('/')
331 );
332
333 let resp = self
334 .http
335 .get(&discovery_url)
336 .send()
337 .await
338 .map_err(Error::OidcDiscovery)?;
339
340 let status = resp.status();
341 if !status.is_success() {
342 return Err(Error::OidcHttpError(status.as_u16()));
343 }
344
345 let body = read_limited_body(resp, MAX_RESPONSE_SIZE, Error::OidcDiscovery).await?;
346 let doc: OidcDiscoveryDocument =
347 serde_json::from_slice(&body).map_err(|e| Error::Internal(Box::new(e)))?;
348
349 let expected = issuer.trim_end_matches('/');
350 let actual = doc.issuer.trim_end_matches('/');
351 if expected != actual {
352 return Err(Error::Unauthenticated(
353 "OIDC discovery issuer mismatch".into(),
354 ));
355 }
356
357 let jwks_resp = self
358 .http
359 .get(&doc.jwks_uri)
360 .send()
361 .await
362 .map_err(Error::OidcDiscovery)?;
363
364 if !jwks_resp.status().is_success() {
365 return Err(Error::OidcHttpError(jwks_resp.status().as_u16()));
366 }
367
368 let jwks_body =
369 read_limited_body(jwks_resp, MAX_RESPONSE_SIZE, Error::OidcDiscovery).await?;
370 let jwks: jsonwebtoken::jwk::JwkSet =
371 serde_json::from_slice(&jwks_body).map_err(|e| Error::Internal(Box::new(e)))?;
372
373 Ok(OidcProvider { jwks })
374 }
375
376 #[tracing::instrument(skip_all)]
377 pub async fn verify(&self, token: &str) -> Result<TokenClaims, Error> {
378 let header = jsonwebtoken::decode_header(token)?;
379
380 let mut validation = jsonwebtoken::Validation::default();
382 validation.insecure_disable_signature_validation();
383 validation.validate_aud = false;
384 validation.validate_exp = false;
385
386 let unverified: jsonwebtoken::TokenData<TokenClaims> = jsonwebtoken::decode(
387 token,
388 &jsonwebtoken::DecodingKey::from_secret(&[]),
389 &validation,
390 )?;
391
392 let issuer = &unverified.claims.iss;
393
394 validate_issuer(issuer)?;
395
396 if let Some(ref allowed) = self.allowed_issuers {
397 let normalized = issuer.trim_end_matches('/');
398 if !allowed.contains(normalized) {
399 return Err(Error::Unauthenticated("issuer not in allowed list".into()));
400 }
401 }
402
403 let provider = self.discover(issuer).await?;
404
405 let kid = header.kid.as_deref();
406 let decoding_key = find_decoding_key(&provider.jwks, kid, &header.alg)?;
407
408 let mut verification = jsonwebtoken::Validation::new(header.alg);
409 verification.validate_aud = false; verification.set_issuer(&[issuer]);
411
412 let token_data: jsonwebtoken::TokenData<TokenClaims> =
413 jsonwebtoken::decode(token, &decoding_key, &verification)?;
414
415 Ok(token_data.claims)
416 }
417}
418
419fn find_decoding_key(
420 jwks: &jsonwebtoken::jwk::JwkSet,
421 kid: Option<&str>,
422 alg: &jsonwebtoken::Algorithm,
423) -> Result<jsonwebtoken::DecodingKey, Error> {
424 let jwk = if let Some(kid) = kid {
425 jwks.find(kid).ok_or_else(|| {
426 Error::Unauthenticated(format!("no matching key found for kid: {kid}"))
427 })?
428 } else {
429 let alg_str = format!("{alg:?}");
430 jwks.keys
431 .iter()
432 .find(|k| {
433 k.common
434 .key_algorithm
435 .is_some_and(|ka| format!("{ka:?}") == alg_str)
436 })
437 .or_else(|| jwks.keys.first())
438 .ok_or_else(|| Error::Unauthenticated("no keys in JWKS".into()))?
439 };
440
441 jsonwebtoken::DecodingKey::from_jwk(jwk)
442 .map_err(|e| Error::Unauthenticated(format!("invalid JWK: {e}")))
443}
444
445fn is_permanent_error(e: &Error) -> bool {
446 match e {
447 Error::OidcHttpError(code) => matches!(
449 code,
450 400 | 401 | 403 | 404 | 405 | 406 | 410 | 415 | 422 | 501
451 ),
452 Error::OidcDiscovery(_) => false, _ => true, }
455}
456
457pub(crate) async fn read_limited_body(
458 resp: reqwest::Response,
459 limit: usize,
460 map_err: impl Fn(reqwest::Error) -> Error,
461) -> Result<Vec<u8>, Error> {
462 if let Some(len) = resp.content_length()
463 && len as usize > limit
464 {
465 return Err(Error::Unauthenticated(format!(
466 "response too large: {len} bytes (limit: {limit})"
467 )));
468 }
469
470 use futures_util::StreamExt as _;
471 let initial_capacity = resp
472 .content_length()
473 .map_or(4096, |len| (len as usize).min(limit));
474 let mut stream = resp.bytes_stream();
475 let mut buf = Vec::with_capacity(initial_capacity);
476 while let Some(chunk) = stream.next().await {
477 let chunk = chunk.map_err(&map_err)?;
478 if buf.len() + chunk.len() > limit {
479 return Err(Error::Unauthenticated(format!(
480 "response too large (limit: {limit})"
481 )));
482 }
483 buf.extend_from_slice(&chunk);
484 }
485 Ok(buf)
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
493 fn test_validate_issuer_valid() {
494 assert!(validate_issuer("https://accounts.google.com").is_ok());
495 assert!(validate_issuer("https://token.actions.githubusercontent.com").is_ok());
496 assert!(validate_issuer("https://example.com/path/to/issuer").is_ok());
497 assert!(validate_issuer("http://localhost").is_ok());
498 assert!(validate_issuer("http://127.0.0.1").is_ok());
499 assert!(validate_issuer("http://[::1]").is_ok());
500 }
501
502 #[test]
503 fn test_validate_issuer_rejects_http_non_localhost() {
504 assert!(validate_issuer("http://example.com").is_err());
505 }
506
507 #[test]
508 fn test_validate_issuer_rejects_query_fragment() {
509 assert!(validate_issuer("https://example.com?foo=bar").is_err());
510 assert!(validate_issuer("https://example.com#frag").is_err());
511 }
512
513 #[test]
514 fn test_validate_issuer_rejects_userinfo() {
515 assert!(validate_issuer("https://user:pass@example.com").is_err());
516 }
517
518 #[test]
519 fn test_validate_issuer_rejects_path_traversal() {
520 assert!(validate_issuer("https://example.com/..").is_err());
521 assert!(validate_issuer("https://example.com/a/../b").is_err());
522 }
523
524 #[test]
525 fn test_validate_issuer_rejects_double_slash() {
526 assert!(validate_issuer("https://example.com//path").is_err());
527 }
528
529 #[test]
530 fn test_validate_issuer_rejects_tilde_issues() {
531 assert!(validate_issuer("https://example.com/path~").is_err());
532 assert!(validate_issuer("https://example.com/~~path").is_err());
533 assert!(validate_issuer("https://example.com/~").is_err());
534 }
535
536 #[test]
537 fn test_validate_issuer_rejects_dot_segment() {
538 assert!(validate_issuer("https://example.com/.").is_err());
539 }
540
541 #[test]
542 fn test_validate_issuer_rejects_long_segment() {
543 let long_segment = "a".repeat(151);
544 assert!(validate_issuer(&format!("https://example.com/{long_segment}")).is_err());
545 }
546
547 #[test]
548 fn test_validate_issuer_rejects_non_ascii_host() {
549 assert!(validate_issuer("https://exämple.com").is_err());
550 }
551
552 #[test]
553 fn test_validate_subject_valid() {
554 assert!(validate_subject("repo:org/repo:ref:refs/heads/main").is_ok());
555 assert!(validate_subject("user@example.com").is_ok());
556 assert!(validate_subject("simple-subject").is_ok());
557 assert!(validate_subject("pipe|separated").is_ok());
558 }
559
560 #[test]
561 fn test_validate_subject_rejects() {
562 assert!(validate_subject("").is_err());
563 assert!(validate_subject("has space").is_err());
564 assert!(validate_subject("has\"quote").is_err());
565 assert!(validate_subject("has'quote").is_err());
566 assert!(validate_subject("has\\backslash").is_err());
567 assert!(validate_subject("has<bracket").is_err());
568 assert!(validate_subject("has[bracket]").is_err());
569 }
570
571 #[test]
572 fn test_validate_audience_valid() {
573 assert!(validate_audience("https://example.com").is_ok());
574 assert!(validate_audience("my-audience").is_ok());
575 }
576
577 #[test]
578 fn test_validate_audience_more_restrictive_than_subject() {
579 assert!(validate_subject("user@example.com").is_ok());
581 assert!(validate_audience("user@example.com").is_err());
582
583 assert!(validate_subject("pipe|value").is_ok());
584 assert!(validate_audience("pipe|value").is_err());
585
586 assert!(validate_subject("has[bracket]").is_err()); assert!(validate_audience("has[bracket]").is_err());
588 }
589}