Skip to main content

sts_cat/
trust_policy.rs

1use crate::error::Error;
2
3#[derive(Debug, serde::Deserialize)]
4pub struct TrustPolicy {
5    pub issuer: Option<String>,
6    pub issuer_pattern: Option<String>,
7    pub subject: Option<String>,
8    pub subject_pattern: Option<String>,
9    pub audience: Option<String>,
10    pub audience_pattern: Option<String>,
11    pub claim_pattern: Option<std::collections::HashMap<String, String>>,
12    pub max_token_lifetime: Option<u64>,
13    pub permissions: Permissions,
14    pub repositories: Option<Vec<String>>,
15}
16
17#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
18pub struct Permissions {
19    #[serde(flatten)]
20    pub inner: std::collections::HashMap<String, String>,
21}
22
23pub struct CompiledTrustPolicy {
24    issuer: StringMatcher,
25    subject: StringMatcher,
26    audience: AudienceMatch,
27    claim_patterns: Vec<(String, regex::Regex)>,
28    max_token_lifetime: Option<u64>,
29    pub permissions: Permissions,
30    pub repositories: Option<Vec<String>>,
31}
32
33enum StringMatcher {
34    Exact(String),
35    Pattern(regex::Regex),
36}
37
38enum AudienceMatch {
39    Exact(String),
40    Pattern(regex::Regex),
41    Identifier,
42}
43
44pub struct Actor {
45    pub issuer: String,
46    pub subject: String,
47    pub matched_claims: Vec<(String, String)>,
48}
49
50impl StringMatcher {
51    fn compile(
52        exact: Option<String>,
53        pattern: Option<String>,
54        field_name: &str,
55    ) -> Result<Self, Error> {
56        use serde::de::Error as _;
57        match (exact, pattern) {
58            (Some(exact), None) => Ok(StringMatcher::Exact(exact)),
59            (None, Some(pattern)) => {
60                let re = compile_anchored_regex(&pattern)?;
61                Ok(StringMatcher::Pattern(re))
62            }
63            (Some(_), Some(_)) => Err(Error::PolicyParse(toml::de::Error::custom(format!(
64                "cannot specify both {field_name} and {field_name}_pattern"
65            )))),
66            (None, None) => Err(Error::PolicyParse(toml::de::Error::custom(format!(
67                "must specify either {field_name} or {field_name}_pattern"
68            )))),
69        }
70    }
71
72    fn check(&self, value: &str, field_name: &str) -> Result<(), Error> {
73        match self {
74            StringMatcher::Exact(expected) => {
75                if value != expected {
76                    return Err(Error::PermissionDenied(format!(
77                        "{field_name} did not match"
78                    )));
79                }
80            }
81            StringMatcher::Pattern(re) => {
82                if !re.is_match(value) {
83                    return Err(Error::PermissionDenied(format!(
84                        "{field_name} did not match pattern"
85                    )));
86                }
87            }
88        }
89        Ok(())
90    }
91}
92
93impl AudienceMatch {
94    fn compile(exact: Option<String>, pattern: Option<String>) -> Result<Self, Error> {
95        use serde::de::Error as _;
96        match (exact, pattern) {
97            (Some(exact), None) => Ok(AudienceMatch::Exact(exact)),
98            (None, Some(pattern)) => {
99                let re = compile_anchored_regex(&pattern)?;
100                Ok(AudienceMatch::Pattern(re))
101            }
102            (None, None) => Ok(AudienceMatch::Identifier),
103            (Some(_), Some(_)) => Err(Error::PolicyParse(toml::de::Error::custom(
104                "cannot specify both audience and audience_pattern",
105            ))),
106        }
107    }
108
109    fn check<'a>(
110        &self,
111        audiences: impl Iterator<Item = &'a str>,
112        identifier: &str,
113    ) -> Result<(), Error> {
114        match self {
115            AudienceMatch::Exact(expected) => {
116                if !audiences.into_iter().any(|a| a == expected) {
117                    return Err(Error::PermissionDenied("audience did not match".into()));
118                }
119            }
120            AudienceMatch::Pattern(re) => {
121                if !audiences.into_iter().any(|a| re.is_match(a)) {
122                    return Err(Error::PermissionDenied(
123                        "audience did not match pattern".into(),
124                    ));
125                }
126            }
127            AudienceMatch::Identifier => {
128                if !audiences.into_iter().any(|a| a == identifier) {
129                    return Err(Error::PermissionDenied(
130                        "audience did not match identifier".into(),
131                    ));
132                }
133            }
134        }
135        Ok(())
136    }
137}
138
139fn compile_anchored_regex(pattern: &str) -> Result<regex::Regex, Error> {
140    Ok(regex::Regex::new(&format!("^(?:{pattern})$"))?)
141}
142
143fn compile_claim_patterns(
144    patterns: Option<std::collections::HashMap<String, String>>,
145) -> Result<Vec<(String, regex::Regex)>, Error> {
146    let Some(patterns) = patterns else {
147        return Ok(Vec::new());
148    };
149    let mut compiled = Vec::with_capacity(patterns.len());
150    for (name, pattern) in patterns {
151        let re = compile_anchored_regex(&pattern)?;
152        compiled.push((name, re));
153    }
154    Ok(compiled)
155}
156
157impl TrustPolicy {
158    pub fn parse(toml_str: &str) -> Result<Self, Error> {
159        let policy: TrustPolicy = toml::from_str(toml_str)?;
160        Ok(policy)
161    }
162
163    pub fn compile(self, is_org_level: bool) -> Result<CompiledTrustPolicy, Error> {
164        if !is_org_level && self.repositories.is_some() {
165            return Err(Error::PermissionDenied(
166                "repositories field is not allowed in repository-level trust policies".into(),
167            ));
168        }
169
170        Ok(CompiledTrustPolicy {
171            issuer: StringMatcher::compile(self.issuer, self.issuer_pattern, "issuer")?,
172            subject: StringMatcher::compile(self.subject, self.subject_pattern, "subject")?,
173            audience: AudienceMatch::compile(self.audience, self.audience_pattern)?,
174            claim_patterns: compile_claim_patterns(self.claim_pattern)?,
175            max_token_lifetime: self.max_token_lifetime,
176            permissions: self.permissions,
177            repositories: self.repositories,
178        })
179    }
180}
181
182impl CompiledTrustPolicy {
183    pub fn check_token(
184        &self,
185        claims: &crate::oidc::TokenClaims,
186        identifier: &str,
187    ) -> Result<Actor, Error> {
188        // Defense-in-depth: validate claim format strings before pattern matching
189        crate::oidc::validate_issuer(&claims.iss)?;
190        crate::oidc::validate_subject(&claims.sub)?;
191        for aud in claims.aud.iter() {
192            crate::oidc::validate_audience(aud)?;
193        }
194
195        self.issuer.check(&claims.iss, "issuer")?;
196        self.subject.check(&claims.sub, "subject")?;
197        self.audience.check(claims.aud.iter(), identifier)?;
198        self.check_token_lifetime(claims)?;
199        let matched_claims = self.check_claim_patterns(claims)?;
200
201        Ok(Actor {
202            issuer: claims.iss.clone(),
203            subject: claims.sub.clone(),
204            matched_claims,
205        })
206    }
207
208    fn check_token_lifetime(&self, claims: &crate::oidc::TokenClaims) -> Result<(), Error> {
209        let Some(max) = self.max_token_lifetime else {
210            return Ok(());
211        };
212
213        let extra = &claims.extra;
214
215        let exp = extra.get("exp").and_then(|v| v.as_u64()).ok_or_else(|| {
216            Error::PermissionDenied("max_token_lifetime is set but token has no 'exp' claim".into())
217        })?;
218
219        let start = extra
220            .get("nbf")
221            .and_then(|v| v.as_u64())
222            .or_else(|| extra.get("iat").and_then(|v| v.as_u64()))
223            .ok_or_else(|| {
224                Error::PermissionDenied(
225                    "max_token_lifetime is set but token has neither 'nbf' nor 'iat' claim".into(),
226                )
227            })?;
228
229        let lifetime = exp.saturating_sub(start);
230        if lifetime > max {
231            return Err(Error::PermissionDenied(format!(
232                "token lifetime {lifetime}s exceeds max_token_lifetime {max}s"
233            )));
234        }
235
236        Ok(())
237    }
238
239    fn check_claim_patterns(
240        &self,
241        claims: &crate::oidc::TokenClaims,
242    ) -> Result<Vec<(String, String)>, Error> {
243        let mut matched = Vec::new();
244        for (claim_name, pattern) in &self.claim_patterns {
245            let value = claims.extra.get(claim_name).ok_or_else(|| {
246                Error::PermissionDenied(format!("required claim '{claim_name}' not present"))
247            })?;
248
249            let string_value = match value {
250                serde_json::Value::String(s) => s.clone(),
251                serde_json::Value::Bool(b) => b.to_string(),
252                _ => {
253                    return Err(Error::PermissionDenied(format!(
254                        "claim '{claim_name}' is not a string or boolean"
255                    )));
256                }
257            };
258
259            if !pattern.is_match(&string_value) {
260                return Err(Error::PermissionDenied(format!(
261                    "claim '{claim_name}' did not match pattern"
262                )));
263            }
264
265            matched.push((claim_name.clone(), string_value));
266        }
267        Ok(matched)
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_parse_basic_policy() {
277        let toml = r#"
278            issuer = "https://token.actions.githubusercontent.com"
279            subject = "repo:myorg/myrepo:ref:refs/heads/main"
280
281            [permissions]
282            contents = "read"
283        "#;
284        let policy = TrustPolicy::parse(toml).unwrap();
285        assert_eq!(
286            policy.issuer.as_deref(),
287            Some("https://token.actions.githubusercontent.com")
288        );
289        assert!(policy.issuer_pattern.is_none());
290    }
291
292    #[test]
293    fn test_parse_pattern_policy() {
294        let toml = r#"
295            issuer = "https://token.actions.githubusercontent.com"
296            subject_pattern = "repo:myorg/.*:ref:refs/heads/main"
297            repositories = ["repo-a", "repo-b"]
298
299            [permissions]
300            contents = "read"
301        "#;
302        let policy = TrustPolicy::parse(toml).unwrap();
303        assert!(policy.subject_pattern.is_some());
304        assert!(policy.repositories.is_some());
305    }
306
307    #[test]
308    fn test_compile_rejects_both_issuer_and_pattern() {
309        let toml = r#"
310            issuer = "https://example.com"
311            issuer_pattern = "https://.*"
312            subject = "sub"
313
314            [permissions]
315            contents = "read"
316        "#;
317        let policy = TrustPolicy::parse(toml).unwrap();
318        assert!(policy.compile(false).is_err());
319    }
320
321    #[test]
322    fn test_compile_rejects_neither_issuer() {
323        let toml = r#"
324            subject = "sub"
325
326            [permissions]
327            contents = "read"
328        "#;
329        let policy = TrustPolicy::parse(toml).unwrap();
330        assert!(policy.compile(false).is_err());
331    }
332
333    #[test]
334    fn test_compile_rejects_repositories_on_repo_level() {
335        let toml = r#"
336            issuer = "https://example.com"
337            subject = "sub"
338            repositories = ["repo-a"]
339
340            [permissions]
341            contents = "read"
342        "#;
343        let policy = TrustPolicy::parse(toml).unwrap();
344        assert!(policy.compile(false).is_err());
345        let policy2 = TrustPolicy::parse(toml).unwrap();
346        assert!(policy2.compile(true).is_ok());
347    }
348
349    #[test]
350    fn test_compile_audience_fallback_to_identifier() {
351        let toml = r#"
352            issuer = "https://example.com"
353            subject = "sub"
354
355            [permissions]
356            contents = "read"
357        "#;
358        let policy = TrustPolicy::parse(toml).unwrap();
359        let compiled = policy.compile(false).unwrap();
360        assert!(matches!(compiled.audience, AudienceMatch::Identifier));
361    }
362
363    #[test]
364    fn test_check_token_exact_match() {
365        let toml = r#"
366            issuer = "https://token.actions.githubusercontent.com"
367            subject = "repo:myorg/myrepo:ref:refs/heads/main"
368
369            [permissions]
370            contents = "read"
371        "#;
372        let policy = TrustPolicy::parse(toml).unwrap();
373        let compiled = policy.compile(false).unwrap();
374
375        let claims = crate::oidc::TokenClaims {
376            iss: "https://token.actions.githubusercontent.com".into(),
377            sub: "repo:myorg/myrepo:ref:refs/heads/main".into(),
378            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
379            extra: std::collections::HashMap::new(),
380        };
381
382        let actor = compiled.check_token(&claims, "sts.example.com").unwrap();
383        assert_eq!(actor.issuer, "https://token.actions.githubusercontent.com");
384        assert_eq!(actor.subject, "repo:myorg/myrepo:ref:refs/heads/main");
385    }
386
387    #[test]
388    fn test_check_token_pattern_match() {
389        let toml = r#"
390            issuer = "https://token.actions.githubusercontent.com"
391            subject_pattern = "repo:myorg/.*:ref:refs/heads/main"
392
393            [permissions]
394            contents = "read"
395        "#;
396        let policy = TrustPolicy::parse(toml).unwrap();
397        let compiled = policy.compile(true).unwrap();
398
399        let claims = crate::oidc::TokenClaims {
400            iss: "https://token.actions.githubusercontent.com".into(),
401            sub: "repo:myorg/some-repo:ref:refs/heads/main".into(),
402            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
403            extra: std::collections::HashMap::new(),
404        };
405
406        assert!(compiled.check_token(&claims, "sts.example.com").is_ok());
407    }
408
409    #[test]
410    fn test_check_token_subject_mismatch() {
411        let toml = r#"
412            issuer = "https://token.actions.githubusercontent.com"
413            subject = "repo:myorg/myrepo:ref:refs/heads/main"
414
415            [permissions]
416            contents = "read"
417        "#;
418        let policy = TrustPolicy::parse(toml).unwrap();
419        let compiled = policy.compile(false).unwrap();
420
421        let claims = crate::oidc::TokenClaims {
422            iss: "https://token.actions.githubusercontent.com".into(),
423            sub: "repo:myorg/other-repo:ref:refs/heads/main".into(),
424            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
425            extra: std::collections::HashMap::new(),
426        };
427
428        assert!(compiled.check_token(&claims, "sts.example.com").is_err());
429    }
430
431    #[test]
432    fn test_check_token_audience_identifier_fallback() {
433        let toml = r#"
434            issuer = "https://example.com"
435            subject = "sub"
436
437            [permissions]
438            contents = "read"
439        "#;
440        let policy = TrustPolicy::parse(toml).unwrap();
441        let compiled = policy.compile(false).unwrap();
442
443        let claims = crate::oidc::TokenClaims {
444            iss: "https://example.com".into(),
445            sub: "sub".into(),
446            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
447            extra: std::collections::HashMap::new(),
448        };
449        assert!(compiled.check_token(&claims, "sts.example.com").is_ok());
450
451        let claims2 = crate::oidc::TokenClaims {
452            iss: "https://example.com".into(),
453            sub: "sub".into(),
454            aud: crate::oidc::OneOrMany::One("other.example.com".into()),
455            extra: std::collections::HashMap::new(),
456        };
457        assert!(compiled.check_token(&claims2, "sts.example.com").is_err());
458    }
459
460    #[test]
461    fn test_check_token_claim_pattern_bool_coercion() {
462        let toml = r#"
463            issuer = "https://example.com"
464            subject = "sub"
465
466            [permissions]
467            contents = "read"
468
469            [claim_pattern]
470            email_verified = "true"
471        "#;
472        let policy = TrustPolicy::parse(toml).unwrap();
473        let compiled = policy.compile(false).unwrap();
474
475        let mut extra = std::collections::HashMap::new();
476        extra.insert("email_verified".into(), serde_json::Value::Bool(true));
477
478        let claims = crate::oidc::TokenClaims {
479            iss: "https://example.com".into(),
480            sub: "sub".into(),
481            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
482            extra,
483        };
484
485        assert!(compiled.check_token(&claims, "sts.example.com").is_ok());
486    }
487
488    #[test]
489    fn test_check_token_rejects_numeric_claim() {
490        let toml = r#"
491            issuer = "https://example.com"
492            subject = "sub"
493
494            [permissions]
495            contents = "read"
496
497            [claim_pattern]
498            some_number = "42"
499        "#;
500        let policy = TrustPolicy::parse(toml).unwrap();
501        let compiled = policy.compile(false).unwrap();
502
503        let mut extra = std::collections::HashMap::new();
504        extra.insert("some_number".into(), serde_json::Value::Number(42.into()));
505
506        let claims = crate::oidc::TokenClaims {
507            iss: "https://example.com".into(),
508            sub: "sub".into(),
509            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
510            extra,
511        };
512
513        assert!(compiled.check_token(&claims, "sts.example.com").is_err());
514    }
515
516    #[test]
517    fn test_check_token_audience_multi() {
518        let toml = r#"
519            issuer = "https://example.com"
520            subject = "sub"
521            audience = "my-audience"
522
523            [permissions]
524            contents = "read"
525        "#;
526        let policy = TrustPolicy::parse(toml).unwrap();
527        let compiled = policy.compile(false).unwrap();
528
529        let claims = crate::oidc::TokenClaims {
530            iss: "https://example.com".into(),
531            sub: "sub".into(),
532            aud: crate::oidc::OneOrMany::Many(vec!["other-aud".into(), "my-audience".into()]),
533            extra: std::collections::HashMap::new(),
534        };
535
536        assert!(compiled.check_token(&claims, "sts.example.com").is_ok());
537    }
538
539    #[test]
540    fn test_pattern_alternation_fully_anchored() {
541        let toml = r#"
542            issuer = "https://token.actions.githubusercontent.com"
543            subject_pattern = "repo:myorg/a|repo:myorg/b"
544
545            [permissions]
546            contents = "read"
547        "#;
548        let policy = TrustPolicy::parse(toml).unwrap();
549        let compiled = policy.compile(false).unwrap();
550
551        let make_claims = |sub: &str| crate::oidc::TokenClaims {
552            iss: "https://token.actions.githubusercontent.com".into(),
553            sub: sub.into(),
554            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
555            extra: std::collections::HashMap::new(),
556        };
557
558        assert!(
559            compiled
560                .check_token(&make_claims("repo:myorg/a"), "sts.example.com")
561                .is_ok()
562        );
563        assert!(
564            compiled
565                .check_token(&make_claims("repo:myorg/b"), "sts.example.com")
566                .is_ok()
567        );
568        // Must not match partial strings
569        assert!(
570            compiled
571                .check_token(&make_claims("repo:myorg/a-extra"), "sts.example.com")
572                .is_err()
573        );
574        assert!(
575            compiled
576                .check_token(&make_claims("prefix-repo:myorg/b"), "sts.example.com")
577                .is_err()
578        );
579    }
580
581    fn make_claims_with_times(
582        exp: Option<u64>,
583        iat: Option<u64>,
584        nbf: Option<u64>,
585    ) -> crate::oidc::TokenClaims {
586        let mut extra = std::collections::HashMap::new();
587        if let Some(v) = exp {
588            extra.insert("exp".into(), serde_json::Value::from(v));
589        }
590        if let Some(v) = iat {
591            extra.insert("iat".into(), serde_json::Value::from(v));
592        }
593        if let Some(v) = nbf {
594            extra.insert("nbf".into(), serde_json::Value::from(v));
595        }
596        crate::oidc::TokenClaims {
597            iss: "https://example.com".into(),
598            sub: "sub".into(),
599            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
600            extra,
601        }
602    }
603
604    fn compile_lifetime_policy(max_token_lifetime: Option<u64>) -> CompiledTrustPolicy {
605        let toml = format!(
606            r#"
607            issuer = "https://example.com"
608            subject = "sub"
609            {}
610
611            [permissions]
612            contents = "read"
613        "#,
614            max_token_lifetime
615                .map(|v| format!("max_token_lifetime = {v}"))
616                .unwrap_or_default()
617        );
618        TrustPolicy::parse(&toml).unwrap().compile(false).unwrap()
619    }
620
621    #[test]
622    fn test_max_token_lifetime_within_limit() {
623        let compiled = compile_lifetime_policy(Some(600));
624        let claims = make_claims_with_times(Some(1000600), Some(1000000), None);
625        assert!(compiled.check_token(&claims, "sts.example.com").is_ok());
626    }
627
628    #[test]
629    fn test_max_token_lifetime_exceeded() {
630        let compiled = compile_lifetime_policy(Some(600));
631        let claims = make_claims_with_times(Some(1000601), Some(1000000), None);
632        assert!(compiled.check_token(&claims, "sts.example.com").is_err());
633    }
634
635    #[test]
636    fn test_max_token_lifetime_prefers_nbf() {
637        let compiled = compile_lifetime_policy(Some(600));
638        // nbf=1000000, iat=999000, exp=1000500 → lifetime from nbf = 500s (ok), from iat = 1500s (would fail)
639        let claims = make_claims_with_times(Some(1000500), Some(999000), Some(1000000));
640        assert!(compiled.check_token(&claims, "sts.example.com").is_ok());
641    }
642
643    #[test]
644    fn test_max_token_lifetime_not_set() {
645        let compiled = compile_lifetime_policy(None);
646        let claims = make_claims_with_times(Some(2000000), Some(1000000), None);
647        assert!(compiled.check_token(&claims, "sts.example.com").is_ok());
648    }
649
650    #[test]
651    fn test_max_token_lifetime_missing_exp() {
652        let compiled = compile_lifetime_policy(Some(600));
653        let claims = make_claims_with_times(None, Some(1000000), None);
654        assert!(compiled.check_token(&claims, "sts.example.com").is_err());
655    }
656
657    #[test]
658    fn test_max_token_lifetime_missing_iat_and_nbf() {
659        let compiled = compile_lifetime_policy(Some(600));
660        let claims = make_claims_with_times(Some(1000600), None, None);
661        assert!(compiled.check_token(&claims, "sts.example.com").is_err());
662    }
663}