Skip to main content

sts_cat/
oidc.rs

1use crate::error::Error;
2
3const MAX_RESPONSE_SIZE: usize = 100 * 1024; // 100 KiB
4
5static PATH_CHAR_RE: std::sync::LazyLock<regex::Regex> =
6    std::sync::LazyLock::new(|| regex::Regex::new(r"^[a-zA-Z0-9\-._~/]+$").unwrap());
7
8pub fn validate_issuer(issuer: &str) -> Result<(), Error> {
9    if issuer.is_empty() || issuer.chars().count() > 255 {
10        return Err(Error::Unauthenticated(
11            "issuer empty or exceeds 255 characters".into(),
12        ));
13    }
14
15    let parsed = url::Url::parse(issuer)
16        .map_err(|_| Error::Unauthenticated("issuer is not a valid URL".into()))?;
17
18    match parsed.scheme() {
19        "https" => {}
20        "http" => match parsed.host() {
21            Some(url::Host::Domain("localhost")) => {}
22            Some(url::Host::Ipv4(ip)) if ip == std::net::Ipv4Addr::LOCALHOST => {}
23            Some(url::Host::Ipv6(ip)) if ip == std::net::Ipv6Addr::LOCALHOST => {}
24            _ => {
25                return Err(Error::Unauthenticated("issuer must use HTTPS".into()));
26            }
27        },
28        _ => {
29            return Err(Error::Unauthenticated("issuer must use HTTPS".into()));
30        }
31    }
32
33    // Check both parsed and raw: url::Url may normalize away certain encodings
34    if parsed.query().is_some() || parsed.fragment().is_some() {
35        return Err(Error::Unauthenticated(
36            "issuer must not contain query or fragment".into(),
37        ));
38    }
39    if issuer.contains('?') || issuer.contains('#') {
40        return Err(Error::Unauthenticated(
41            "issuer must not contain query or fragment".into(),
42        ));
43    }
44
45    if parsed.host_str().is_none() || parsed.host_str() == Some("") {
46        return Err(Error::Unauthenticated("issuer must have a host".into()));
47    }
48
49    if !parsed.username().is_empty() || parsed.password().is_some() {
50        return Err(Error::Unauthenticated(
51            "issuer must not contain userinfo".into(),
52        ));
53    }
54
55    // ASCII-only hostname — check the raw input string because url::Url
56    // converts IDN to punycode (e.g. exämple.com → xn--exmple-cua.com)
57    let raw_host = {
58        let after_scheme = issuer
59            .strip_prefix(parsed.scheme())
60            .and_then(|s| s.strip_prefix("://"))
61            .unwrap_or("");
62        let host_part = if let Some(pos) = after_scheme.find('/') {
63            &after_scheme[..pos]
64        } else {
65            after_scheme
66        };
67        if host_part.starts_with('[') {
68            // IPv6: take everything including brackets
69            host_part.to_owned()
70        } else if let Some(pos) = host_part.rfind(':') {
71            host_part[..pos].to_owned()
72        } else {
73            host_part.to_owned()
74        }
75    };
76    for ch in raw_host.chars() {
77        if ch as u32 > 127 {
78            return Err(Error::Unauthenticated(
79                "issuer hostname must be ASCII-only".into(),
80            ));
81        }
82        if ch.is_control() || ch.is_whitespace() {
83            return Err(Error::Unauthenticated(
84                "issuer hostname contains invalid characters".into(),
85            ));
86        }
87    }
88
89    // Path validation — use the raw issuer string to extract the path,
90    // since url::Url normalizes away `.` and `..` segments.
91    let raw_path = issuer
92        .strip_prefix(parsed.scheme())
93        .and_then(|s| s.strip_prefix("://"))
94        .and_then(|s| s.find('/').map(|pos| &s[pos..]))
95        .unwrap_or("");
96    let path = if raw_path.is_empty() {
97        parsed.path()
98    } else {
99        raw_path
100    };
101    if !path.is_empty() && path != "/" {
102        if !path.starts_with('/') {
103            return Err(Error::Unauthenticated(
104                "issuer path must start with /".into(),
105            ));
106        }
107        if path.contains("..") {
108            return Err(Error::Unauthenticated(
109                "issuer path must not contain ..".into(),
110            ));
111        }
112        if path.contains("//") {
113            return Err(Error::Unauthenticated(
114                "issuer path must not contain //".into(),
115            ));
116        }
117        if path.contains("~~") {
118            return Err(Error::Unauthenticated(
119                "issuer path must not contain ~~".into(),
120            ));
121        }
122        if path.ends_with('~') {
123            return Err(Error::Unauthenticated(
124                "issuer path must not end with ~".into(),
125            ));
126        }
127
128        if !PATH_CHAR_RE.is_match(path) {
129            return Err(Error::Unauthenticated(
130                "issuer path contains invalid characters".into(),
131            ));
132        }
133
134        for segment in path.split('/') {
135            if segment.is_empty() {
136                continue;
137            }
138            if segment == "." || segment == ".." || segment == "~" {
139                return Err(Error::Unauthenticated(
140                    "issuer path contains invalid segment".into(),
141                ));
142            }
143            if segment.len() > 150 {
144                return Err(Error::Unauthenticated(
145                    "issuer path segment exceeds 150 characters".into(),
146                ));
147            }
148        }
149    }
150
151    Ok(())
152}
153
154const SUBJECT_REJECT_CHARS: &str = "\"'`\\<>;&$(){}[]";
155const AUDIENCE_REJECT_CHARS: &str = "\"'`\\<>;|&$(){}[]@";
156
157pub fn validate_subject(value: &str) -> Result<(), Error> {
158    validate_claim_string(value, SUBJECT_REJECT_CHARS, "subject")
159}
160
161pub fn validate_audience(value: &str) -> Result<(), Error> {
162    validate_claim_string(value, AUDIENCE_REJECT_CHARS, "audience")
163}
164
165fn validate_claim_string(value: &str, reject_chars: &str, field: &str) -> Result<(), Error> {
166    if value.is_empty() {
167        return Err(Error::Unauthenticated(format!("{field} must not be empty")));
168    }
169    if value.chars().count() > 255 {
170        return Err(Error::Unauthenticated(format!(
171            "{field} exceeds 255 characters"
172        )));
173    }
174    for ch in value.chars() {
175        if (ch as u32) <= 0x1f {
176            return Err(Error::Unauthenticated(format!(
177                "{field} contains control characters"
178            )));
179        }
180        if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
181            return Err(Error::Unauthenticated(format!(
182                "{field} contains whitespace"
183            )));
184        }
185        if reject_chars.contains(ch) {
186            return Err(Error::Unauthenticated(format!(
187                "{field} contains invalid character"
188            )));
189        }
190        if !ch.is_alphanumeric() && !ch.is_ascii_punctuation() && ch as u32 > 127 {
191            // Approximate Go's unicode.IsPrint (categories L, M, N, P, S, Zs)
192            if !is_printable(ch) {
193                return Err(Error::Unauthenticated(format!(
194                    "{field} contains non-printable character"
195                )));
196            }
197        }
198    }
199    Ok(())
200}
201
202fn is_printable(ch: char) -> bool {
203    !ch.is_control() && ch as u32 != 0xFFFD
204}
205
206#[derive(Debug, serde::Deserialize)]
207pub(crate) struct OidcDiscoveryDocument {
208    pub(crate) issuer: String,
209    pub(crate) jwks_uri: String,
210}
211
212#[derive(Debug, Clone)]
213pub(crate) struct OidcProvider {
214    pub(crate) jwks: jsonwebtoken::jwk::JwkSet,
215}
216
217pub struct OidcVerifier {
218    http: reqwest::Client,
219    cache: moka::future::Cache<String, std::sync::Arc<OidcProvider>>,
220    allowed_issuers: Option<std::collections::HashSet<String>>,
221}
222
223impl Default for OidcVerifier {
224    fn default() -> Self {
225        Self::new(None)
226    }
227}
228
229#[derive(Debug, serde::Deserialize)]
230pub struct TokenClaims {
231    pub iss: String,
232    pub sub: String,
233    pub aud: OneOrMany,
234    #[serde(flatten)]
235    pub extra: std::collections::HashMap<String, serde_json::Value>,
236}
237
238#[derive(Debug, serde::Deserialize)]
239#[serde(untagged)]
240pub enum OneOrMany {
241    One(String),
242    Many(Vec<String>),
243}
244
245impl OneOrMany {
246    pub fn iter(&self) -> impl Iterator<Item = &str> {
247        let slice: &[String] = match self {
248            OneOrMany::One(s) => std::slice::from_ref(s),
249            OneOrMany::Many(v) => v.as_slice(),
250        };
251        slice.iter().map(|s| s.as_str())
252    }
253}
254
255impl OidcVerifier {
256    pub fn new(allowed_issuer_urls: Option<Vec<String>>) -> Self {
257        let redirect_policy = reqwest::redirect::Policy::custom(|attempt| {
258            let url_str = attempt.url().to_string();
259            if validate_issuer(&url_str).is_err() {
260                attempt.error(std::io::Error::new(
261                    std::io::ErrorKind::PermissionDenied,
262                    format!("redirect to invalid issuer URL: {url_str}"),
263                ))
264            } else {
265                attempt.follow()
266            }
267        });
268
269        let http = reqwest::Client::builder()
270            .connect_timeout(std::time::Duration::from_secs(10))
271            .timeout(std::time::Duration::from_secs(30))
272            .redirect(redirect_policy)
273            .user_agent(format!("sts-cat/{}", env!("CARGO_PKG_VERSION")))
274            .build()
275            .expect("failed to build OIDC HTTP client");
276
277        let cache = moka::future::Cache::builder()
278            .max_capacity(100)
279            .time_to_live(std::time::Duration::from_secs(900))
280            .build();
281
282        let allowed_issuers = allowed_issuer_urls.map(|urls| {
283            urls.into_iter()
284                .map(|u| u.trim_end_matches('/').to_owned())
285                .collect()
286        });
287
288        Self {
289            http,
290            cache,
291            allowed_issuers,
292        }
293    }
294
295    #[tracing::instrument(skip_all, fields(issuer))]
296    async fn discover(&self, issuer: &str) -> Result<std::sync::Arc<OidcProvider>, Error> {
297        if let Some(provider) = self.cache.get(issuer).await {
298            return Ok(provider);
299        }
300
301        let provider = self.discover_with_retry(issuer).await?;
302        let provider = std::sync::Arc::new(provider);
303        self.cache.insert(issuer.to_owned(), provider.clone()).await;
304        Ok(provider)
305    }
306
307    #[tracing::instrument(skip_all, fields(issuer))]
308    async fn discover_with_retry(&self, issuer: &str) -> Result<OidcProvider, Error> {
309        use backon::Retryable as _;
310
311        let discover_fn = || async { self.discover_once(issuer).await };
312
313        discover_fn
314            .retry(
315                backon::ExponentialBuilder::default()
316                    .with_min_delay(std::time::Duration::from_secs(1))
317                    .with_max_delay(std::time::Duration::from_secs(30))
318                    .with_factor(2.0)
319                    .with_jitter()
320                    .with_max_times(6),
321            )
322            .when(|e| !is_permanent_error(e))
323            .await
324    }
325
326    #[tracing::instrument(skip_all, fields(issuer))]
327    async fn discover_once(&self, issuer: &str) -> Result<OidcProvider, Error> {
328        let discovery_url = format!(
329            "{}/.well-known/openid-configuration",
330            issuer.trim_end_matches('/')
331        );
332
333        let resp = self
334            .http
335            .get(&discovery_url)
336            .send()
337            .await
338            .map_err(Error::OidcDiscovery)?;
339
340        let status = resp.status();
341        if !status.is_success() {
342            return Err(Error::OidcHttpError(status.as_u16()));
343        }
344
345        let body = read_limited_body(resp, MAX_RESPONSE_SIZE, Error::OidcDiscovery).await?;
346        let doc: OidcDiscoveryDocument =
347            serde_json::from_slice(&body).map_err(|e| Error::Internal(Box::new(e)))?;
348
349        let expected = issuer.trim_end_matches('/');
350        let actual = doc.issuer.trim_end_matches('/');
351        if expected != actual {
352            return Err(Error::Unauthenticated(
353                "OIDC discovery issuer mismatch".into(),
354            ));
355        }
356
357        let jwks_resp = self
358            .http
359            .get(&doc.jwks_uri)
360            .send()
361            .await
362            .map_err(Error::OidcDiscovery)?;
363
364        if !jwks_resp.status().is_success() {
365            return Err(Error::OidcHttpError(jwks_resp.status().as_u16()));
366        }
367
368        let jwks_body =
369            read_limited_body(jwks_resp, MAX_RESPONSE_SIZE, Error::OidcDiscovery).await?;
370        let jwks: jsonwebtoken::jwk::JwkSet =
371            serde_json::from_slice(&jwks_body).map_err(|e| Error::Internal(Box::new(e)))?;
372
373        Ok(OidcProvider { jwks })
374    }
375
376    #[tracing::instrument(skip_all)]
377    pub async fn verify(&self, token: &str) -> Result<TokenClaims, Error> {
378        let header = jsonwebtoken::decode_header(token)?;
379
380        // Extract issuer without signature verification to discover the OIDC provider
381        let mut validation = jsonwebtoken::Validation::default();
382        validation.insecure_disable_signature_validation();
383        validation.validate_aud = false;
384        validation.validate_exp = false;
385
386        let unverified: jsonwebtoken::TokenData<TokenClaims> = jsonwebtoken::decode(
387            token,
388            &jsonwebtoken::DecodingKey::from_secret(&[]),
389            &validation,
390        )?;
391
392        let issuer = &unverified.claims.iss;
393
394        validate_issuer(issuer)?;
395
396        if let Some(ref allowed) = self.allowed_issuers {
397            let normalized = issuer.trim_end_matches('/');
398            if !allowed.contains(normalized) {
399                return Err(Error::Unauthenticated("issuer not in allowed list".into()));
400            }
401        }
402
403        let provider = self.discover(issuer).await?;
404
405        let kid = header.kid.as_deref();
406        let decoding_key = find_decoding_key(&provider.jwks, kid, &header.alg)?;
407
408        let mut verification = jsonwebtoken::Validation::new(header.alg);
409        verification.validate_aud = false; // Audience checked later by trust policy
410        verification.set_issuer(&[issuer]);
411
412        let token_data: jsonwebtoken::TokenData<TokenClaims> =
413            jsonwebtoken::decode(token, &decoding_key, &verification)?;
414
415        Ok(token_data.claims)
416    }
417}
418
419fn find_decoding_key(
420    jwks: &jsonwebtoken::jwk::JwkSet,
421    kid: Option<&str>,
422    alg: &jsonwebtoken::Algorithm,
423) -> Result<jsonwebtoken::DecodingKey, Error> {
424    let jwk = if let Some(kid) = kid {
425        jwks.find(kid).ok_or_else(|| {
426            Error::Unauthenticated(format!("no matching key found for kid: {kid}"))
427        })?
428    } else {
429        let alg_str = format!("{alg:?}");
430        jwks.keys
431            .iter()
432            .find(|k| {
433                k.common
434                    .key_algorithm
435                    .is_some_and(|ka| format!("{ka:?}") == alg_str)
436            })
437            .or_else(|| jwks.keys.first())
438            .ok_or_else(|| Error::Unauthenticated("no keys in JWKS".into()))?
439    };
440
441    jsonwebtoken::DecodingKey::from_jwk(jwk)
442        .map_err(|e| Error::Unauthenticated(format!("invalid JWK: {e}")))
443}
444
445fn is_permanent_error(e: &Error) -> bool {
446    match e {
447        // HTTP 4xx (except 408, 429) and 501 are permanent
448        Error::OidcHttpError(code) => matches!(
449            code,
450            400 | 401 | 403 | 404 | 405 | 406 | 410 | 415 | 422 | 501
451        ),
452        Error::OidcDiscovery(_) => false, // Network errors are transient
453        _ => true,                        // Parse errors etc. are permanent
454    }
455}
456
457pub(crate) async fn read_limited_body(
458    resp: reqwest::Response,
459    limit: usize,
460    map_err: impl Fn(reqwest::Error) -> Error,
461) -> Result<Vec<u8>, Error> {
462    if let Some(len) = resp.content_length()
463        && len as usize > limit
464    {
465        return Err(Error::Unauthenticated(format!(
466            "response too large: {len} bytes (limit: {limit})"
467        )));
468    }
469
470    use futures_util::StreamExt as _;
471    let initial_capacity = resp
472        .content_length()
473        .map_or(4096, |len| (len as usize).min(limit));
474    let mut stream = resp.bytes_stream();
475    let mut buf = Vec::with_capacity(initial_capacity);
476    while let Some(chunk) = stream.next().await {
477        let chunk = chunk.map_err(&map_err)?;
478        if buf.len() + chunk.len() > limit {
479            return Err(Error::Unauthenticated(format!(
480                "response too large (limit: {limit})"
481            )));
482        }
483        buf.extend_from_slice(&chunk);
484    }
485    Ok(buf)
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[test]
493    fn test_validate_issuer_valid() {
494        assert!(validate_issuer("https://accounts.google.com").is_ok());
495        assert!(validate_issuer("https://token.actions.githubusercontent.com").is_ok());
496        assert!(validate_issuer("https://example.com/path/to/issuer").is_ok());
497        assert!(validate_issuer("http://localhost").is_ok());
498        assert!(validate_issuer("http://127.0.0.1").is_ok());
499        assert!(validate_issuer("http://[::1]").is_ok());
500    }
501
502    #[test]
503    fn test_validate_issuer_rejects_http_non_localhost() {
504        assert!(validate_issuer("http://example.com").is_err());
505    }
506
507    #[test]
508    fn test_validate_issuer_rejects_query_fragment() {
509        assert!(validate_issuer("https://example.com?foo=bar").is_err());
510        assert!(validate_issuer("https://example.com#frag").is_err());
511    }
512
513    #[test]
514    fn test_validate_issuer_rejects_userinfo() {
515        assert!(validate_issuer("https://user:pass@example.com").is_err());
516    }
517
518    #[test]
519    fn test_validate_issuer_rejects_path_traversal() {
520        assert!(validate_issuer("https://example.com/..").is_err());
521        assert!(validate_issuer("https://example.com/a/../b").is_err());
522    }
523
524    #[test]
525    fn test_validate_issuer_rejects_double_slash() {
526        assert!(validate_issuer("https://example.com//path").is_err());
527    }
528
529    #[test]
530    fn test_validate_issuer_rejects_tilde_issues() {
531        assert!(validate_issuer("https://example.com/path~").is_err());
532        assert!(validate_issuer("https://example.com/~~path").is_err());
533        assert!(validate_issuer("https://example.com/~").is_err());
534    }
535
536    #[test]
537    fn test_validate_issuer_rejects_dot_segment() {
538        assert!(validate_issuer("https://example.com/.").is_err());
539    }
540
541    #[test]
542    fn test_validate_issuer_rejects_long_segment() {
543        let long_segment = "a".repeat(151);
544        assert!(validate_issuer(&format!("https://example.com/{long_segment}")).is_err());
545    }
546
547    #[test]
548    fn test_validate_issuer_rejects_non_ascii_host() {
549        assert!(validate_issuer("https://exämple.com").is_err());
550    }
551
552    #[test]
553    fn test_validate_subject_valid() {
554        assert!(validate_subject("repo:org/repo:ref:refs/heads/main").is_ok());
555        assert!(validate_subject("user@example.com").is_ok());
556        assert!(validate_subject("simple-subject").is_ok());
557        assert!(validate_subject("pipe|separated").is_ok());
558    }
559
560    #[test]
561    fn test_validate_subject_rejects() {
562        assert!(validate_subject("").is_err());
563        assert!(validate_subject("has space").is_err());
564        assert!(validate_subject("has\"quote").is_err());
565        assert!(validate_subject("has'quote").is_err());
566        assert!(validate_subject("has\\backslash").is_err());
567        assert!(validate_subject("has<bracket").is_err());
568        assert!(validate_subject("has[bracket]").is_err());
569    }
570
571    #[test]
572    fn test_validate_audience_valid() {
573        assert!(validate_audience("https://example.com").is_ok());
574        assert!(validate_audience("my-audience").is_ok());
575    }
576
577    #[test]
578    fn test_validate_audience_more_restrictive_than_subject() {
579        // Subject allows these, audience rejects them
580        assert!(validate_subject("user@example.com").is_ok());
581        assert!(validate_audience("user@example.com").is_err());
582
583        assert!(validate_subject("pipe|value").is_ok());
584        assert!(validate_audience("pipe|value").is_err());
585
586        assert!(validate_subject("has[bracket]").is_err()); // subject also rejects []
587        assert!(validate_audience("has[bracket]").is_err());
588    }
589}