Skip to main content

sts_cat/
trust_policy.rs

1use crate::error::Error;
2
3#[derive(Debug, serde::Deserialize)]
4pub struct TrustPolicy {
5    pub issuer: Option<String>,
6    pub issuer_pattern: Option<String>,
7    pub subject: Option<String>,
8    pub subject_pattern: Option<String>,
9    pub audience: Option<String>,
10    pub audience_pattern: Option<String>,
11    pub claim_pattern: Option<std::collections::HashMap<String, String>>,
12    pub permissions: Permissions,
13    pub repositories: Option<Vec<String>>,
14}
15
16#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub struct Permissions {
18    #[serde(flatten)]
19    pub inner: std::collections::HashMap<String, String>,
20}
21
22pub struct CompiledTrustPolicy {
23    issuer: StringMatcher,
24    subject: StringMatcher,
25    audience: AudienceMatch,
26    claim_patterns: Vec<(String, regex::Regex)>,
27    pub permissions: Permissions,
28    pub repositories: Option<Vec<String>>,
29}
30
31enum StringMatcher {
32    Exact(String),
33    Pattern(regex::Regex),
34}
35
36enum AudienceMatch {
37    Exact(String),
38    Pattern(regex::Regex),
39    Identifier,
40}
41
42pub struct Actor {
43    pub issuer: String,
44    pub subject: String,
45    pub matched_claims: Vec<(String, String)>,
46}
47
48impl StringMatcher {
49    fn compile(
50        exact: Option<String>,
51        pattern: Option<String>,
52        field_name: &str,
53    ) -> Result<Self, Error> {
54        use serde::de::Error as _;
55        match (exact, pattern) {
56            (Some(exact), None) => Ok(StringMatcher::Exact(exact)),
57            (None, Some(pattern)) => {
58                let re = compile_anchored_regex(&pattern)?;
59                Ok(StringMatcher::Pattern(re))
60            }
61            (Some(_), Some(_)) => Err(Error::PolicyParse(toml::de::Error::custom(format!(
62                "cannot specify both {field_name} and {field_name}_pattern"
63            )))),
64            (None, None) => Err(Error::PolicyParse(toml::de::Error::custom(format!(
65                "must specify either {field_name} or {field_name}_pattern"
66            )))),
67        }
68    }
69
70    fn check(&self, value: &str, field_name: &str) -> Result<(), Error> {
71        match self {
72            StringMatcher::Exact(expected) => {
73                if value != expected {
74                    return Err(Error::PermissionDenied(format!(
75                        "{field_name} did not match"
76                    )));
77                }
78            }
79            StringMatcher::Pattern(re) => {
80                if !re.is_match(value) {
81                    return Err(Error::PermissionDenied(format!(
82                        "{field_name} did not match pattern"
83                    )));
84                }
85            }
86        }
87        Ok(())
88    }
89}
90
91impl AudienceMatch {
92    fn compile(exact: Option<String>, pattern: Option<String>) -> Result<Self, Error> {
93        use serde::de::Error as _;
94        match (exact, pattern) {
95            (Some(exact), None) => Ok(AudienceMatch::Exact(exact)),
96            (None, Some(pattern)) => {
97                let re = compile_anchored_regex(&pattern)?;
98                Ok(AudienceMatch::Pattern(re))
99            }
100            (None, None) => Ok(AudienceMatch::Identifier),
101            (Some(_), Some(_)) => Err(Error::PolicyParse(toml::de::Error::custom(
102                "cannot specify both audience and audience_pattern",
103            ))),
104        }
105    }
106
107    fn check<'a>(
108        &self,
109        audiences: impl Iterator<Item = &'a str>,
110        identifier: &str,
111    ) -> Result<(), Error> {
112        match self {
113            AudienceMatch::Exact(expected) => {
114                if !audiences.into_iter().any(|a| a == expected) {
115                    return Err(Error::PermissionDenied("audience did not match".into()));
116                }
117            }
118            AudienceMatch::Pattern(re) => {
119                if !audiences.into_iter().any(|a| re.is_match(a)) {
120                    return Err(Error::PermissionDenied(
121                        "audience did not match pattern".into(),
122                    ));
123                }
124            }
125            AudienceMatch::Identifier => {
126                if !audiences.into_iter().any(|a| a == identifier) {
127                    return Err(Error::PermissionDenied(
128                        "audience did not match identifier".into(),
129                    ));
130                }
131            }
132        }
133        Ok(())
134    }
135}
136
137fn compile_anchored_regex(pattern: &str) -> Result<regex::Regex, Error> {
138    Ok(regex::Regex::new(&format!("^(?:{pattern})$"))?)
139}
140
141fn compile_claim_patterns(
142    patterns: Option<std::collections::HashMap<String, String>>,
143) -> Result<Vec<(String, regex::Regex)>, Error> {
144    let Some(patterns) = patterns else {
145        return Ok(Vec::new());
146    };
147    let mut compiled = Vec::with_capacity(patterns.len());
148    for (name, pattern) in patterns {
149        let re = compile_anchored_regex(&pattern)?;
150        compiled.push((name, re));
151    }
152    Ok(compiled)
153}
154
155impl TrustPolicy {
156    pub fn parse(toml_str: &str) -> Result<Self, Error> {
157        let policy: TrustPolicy = toml::from_str(toml_str)?;
158        Ok(policy)
159    }
160
161    pub fn compile(self, is_org_level: bool) -> Result<CompiledTrustPolicy, Error> {
162        if !is_org_level && self.repositories.is_some() {
163            return Err(Error::PermissionDenied(
164                "repositories field is not allowed in repository-level trust policies".into(),
165            ));
166        }
167
168        Ok(CompiledTrustPolicy {
169            issuer: StringMatcher::compile(self.issuer, self.issuer_pattern, "issuer")?,
170            subject: StringMatcher::compile(self.subject, self.subject_pattern, "subject")?,
171            audience: AudienceMatch::compile(self.audience, self.audience_pattern)?,
172            claim_patterns: compile_claim_patterns(self.claim_pattern)?,
173            permissions: self.permissions,
174            repositories: self.repositories,
175        })
176    }
177}
178
179impl CompiledTrustPolicy {
180    pub fn check_token(
181        &self,
182        claims: &crate::oidc::TokenClaims,
183        identifier: &str,
184    ) -> Result<Actor, Error> {
185        // Defense-in-depth: validate claim format strings before pattern matching
186        crate::oidc::validate_issuer(&claims.iss)?;
187        crate::oidc::validate_subject(&claims.sub)?;
188        for aud in claims.aud.iter() {
189            crate::oidc::validate_audience(aud)?;
190        }
191
192        self.issuer.check(&claims.iss, "issuer")?;
193        self.subject.check(&claims.sub, "subject")?;
194        self.audience.check(claims.aud.iter(), identifier)?;
195        let matched_claims = self.check_claim_patterns(claims)?;
196
197        Ok(Actor {
198            issuer: claims.iss.clone(),
199            subject: claims.sub.clone(),
200            matched_claims,
201        })
202    }
203
204    fn check_claim_patterns(
205        &self,
206        claims: &crate::oidc::TokenClaims,
207    ) -> Result<Vec<(String, String)>, Error> {
208        let mut matched = Vec::new();
209        for (claim_name, pattern) in &self.claim_patterns {
210            let value = claims.extra.get(claim_name).ok_or_else(|| {
211                Error::PermissionDenied(format!("required claim '{claim_name}' not present"))
212            })?;
213
214            let string_value = match value {
215                serde_json::Value::String(s) => s.clone(),
216                serde_json::Value::Bool(b) => b.to_string(),
217                _ => {
218                    return Err(Error::PermissionDenied(format!(
219                        "claim '{claim_name}' is not a string or boolean"
220                    )));
221                }
222            };
223
224            if !pattern.is_match(&string_value) {
225                return Err(Error::PermissionDenied(format!(
226                    "claim '{claim_name}' did not match pattern"
227                )));
228            }
229
230            matched.push((claim_name.clone(), string_value));
231        }
232        Ok(matched)
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_parse_basic_policy() {
242        let toml = r#"
243            issuer = "https://token.actions.githubusercontent.com"
244            subject = "repo:myorg/myrepo:ref:refs/heads/main"
245
246            [permissions]
247            contents = "read"
248        "#;
249        let policy = TrustPolicy::parse(toml).unwrap();
250        assert_eq!(
251            policy.issuer.as_deref(),
252            Some("https://token.actions.githubusercontent.com")
253        );
254        assert!(policy.issuer_pattern.is_none());
255    }
256
257    #[test]
258    fn test_parse_pattern_policy() {
259        let toml = r#"
260            issuer = "https://token.actions.githubusercontent.com"
261            subject_pattern = "repo:myorg/.*:ref:refs/heads/main"
262            repositories = ["repo-a", "repo-b"]
263
264            [permissions]
265            contents = "read"
266        "#;
267        let policy = TrustPolicy::parse(toml).unwrap();
268        assert!(policy.subject_pattern.is_some());
269        assert!(policy.repositories.is_some());
270    }
271
272    #[test]
273    fn test_compile_rejects_both_issuer_and_pattern() {
274        let toml = r#"
275            issuer = "https://example.com"
276            issuer_pattern = "https://.*"
277            subject = "sub"
278
279            [permissions]
280            contents = "read"
281        "#;
282        let policy = TrustPolicy::parse(toml).unwrap();
283        assert!(policy.compile(false).is_err());
284    }
285
286    #[test]
287    fn test_compile_rejects_neither_issuer() {
288        let toml = r#"
289            subject = "sub"
290
291            [permissions]
292            contents = "read"
293        "#;
294        let policy = TrustPolicy::parse(toml).unwrap();
295        assert!(policy.compile(false).is_err());
296    }
297
298    #[test]
299    fn test_compile_rejects_repositories_on_repo_level() {
300        let toml = r#"
301            issuer = "https://example.com"
302            subject = "sub"
303            repositories = ["repo-a"]
304
305            [permissions]
306            contents = "read"
307        "#;
308        let policy = TrustPolicy::parse(toml).unwrap();
309        assert!(policy.compile(false).is_err());
310        let policy2 = TrustPolicy::parse(toml).unwrap();
311        assert!(policy2.compile(true).is_ok());
312    }
313
314    #[test]
315    fn test_compile_audience_fallback_to_identifier() {
316        let toml = r#"
317            issuer = "https://example.com"
318            subject = "sub"
319
320            [permissions]
321            contents = "read"
322        "#;
323        let policy = TrustPolicy::parse(toml).unwrap();
324        let compiled = policy.compile(false).unwrap();
325        assert!(matches!(compiled.audience, AudienceMatch::Identifier));
326    }
327
328    #[test]
329    fn test_check_token_exact_match() {
330        let toml = r#"
331            issuer = "https://token.actions.githubusercontent.com"
332            subject = "repo:myorg/myrepo:ref:refs/heads/main"
333
334            [permissions]
335            contents = "read"
336        "#;
337        let policy = TrustPolicy::parse(toml).unwrap();
338        let compiled = policy.compile(false).unwrap();
339
340        let claims = crate::oidc::TokenClaims {
341            iss: "https://token.actions.githubusercontent.com".into(),
342            sub: "repo:myorg/myrepo:ref:refs/heads/main".into(),
343            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
344            extra: std::collections::HashMap::new(),
345        };
346
347        let actor = compiled.check_token(&claims, "sts.example.com").unwrap();
348        assert_eq!(actor.issuer, "https://token.actions.githubusercontent.com");
349        assert_eq!(actor.subject, "repo:myorg/myrepo:ref:refs/heads/main");
350    }
351
352    #[test]
353    fn test_check_token_pattern_match() {
354        let toml = r#"
355            issuer = "https://token.actions.githubusercontent.com"
356            subject_pattern = "repo:myorg/.*:ref:refs/heads/main"
357
358            [permissions]
359            contents = "read"
360        "#;
361        let policy = TrustPolicy::parse(toml).unwrap();
362        let compiled = policy.compile(true).unwrap();
363
364        let claims = crate::oidc::TokenClaims {
365            iss: "https://token.actions.githubusercontent.com".into(),
366            sub: "repo:myorg/some-repo:ref:refs/heads/main".into(),
367            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
368            extra: std::collections::HashMap::new(),
369        };
370
371        assert!(compiled.check_token(&claims, "sts.example.com").is_ok());
372    }
373
374    #[test]
375    fn test_check_token_subject_mismatch() {
376        let toml = r#"
377            issuer = "https://token.actions.githubusercontent.com"
378            subject = "repo:myorg/myrepo:ref:refs/heads/main"
379
380            [permissions]
381            contents = "read"
382        "#;
383        let policy = TrustPolicy::parse(toml).unwrap();
384        let compiled = policy.compile(false).unwrap();
385
386        let claims = crate::oidc::TokenClaims {
387            iss: "https://token.actions.githubusercontent.com".into(),
388            sub: "repo:myorg/other-repo:ref:refs/heads/main".into(),
389            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
390            extra: std::collections::HashMap::new(),
391        };
392
393        assert!(compiled.check_token(&claims, "sts.example.com").is_err());
394    }
395
396    #[test]
397    fn test_check_token_audience_identifier_fallback() {
398        let toml = r#"
399            issuer = "https://example.com"
400            subject = "sub"
401
402            [permissions]
403            contents = "read"
404        "#;
405        let policy = TrustPolicy::parse(toml).unwrap();
406        let compiled = policy.compile(false).unwrap();
407
408        let claims = crate::oidc::TokenClaims {
409            iss: "https://example.com".into(),
410            sub: "sub".into(),
411            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
412            extra: std::collections::HashMap::new(),
413        };
414        assert!(compiled.check_token(&claims, "sts.example.com").is_ok());
415
416        let claims2 = crate::oidc::TokenClaims {
417            iss: "https://example.com".into(),
418            sub: "sub".into(),
419            aud: crate::oidc::OneOrMany::One("other.example.com".into()),
420            extra: std::collections::HashMap::new(),
421        };
422        assert!(compiled.check_token(&claims2, "sts.example.com").is_err());
423    }
424
425    #[test]
426    fn test_check_token_claim_pattern_bool_coercion() {
427        let toml = r#"
428            issuer = "https://example.com"
429            subject = "sub"
430
431            [permissions]
432            contents = "read"
433
434            [claim_pattern]
435            email_verified = "true"
436        "#;
437        let policy = TrustPolicy::parse(toml).unwrap();
438        let compiled = policy.compile(false).unwrap();
439
440        let mut extra = std::collections::HashMap::new();
441        extra.insert("email_verified".into(), serde_json::Value::Bool(true));
442
443        let claims = crate::oidc::TokenClaims {
444            iss: "https://example.com".into(),
445            sub: "sub".into(),
446            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
447            extra,
448        };
449
450        assert!(compiled.check_token(&claims, "sts.example.com").is_ok());
451    }
452
453    #[test]
454    fn test_check_token_rejects_numeric_claim() {
455        let toml = r#"
456            issuer = "https://example.com"
457            subject = "sub"
458
459            [permissions]
460            contents = "read"
461
462            [claim_pattern]
463            some_number = "42"
464        "#;
465        let policy = TrustPolicy::parse(toml).unwrap();
466        let compiled = policy.compile(false).unwrap();
467
468        let mut extra = std::collections::HashMap::new();
469        extra.insert("some_number".into(), serde_json::Value::Number(42.into()));
470
471        let claims = crate::oidc::TokenClaims {
472            iss: "https://example.com".into(),
473            sub: "sub".into(),
474            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
475            extra,
476        };
477
478        assert!(compiled.check_token(&claims, "sts.example.com").is_err());
479    }
480
481    #[test]
482    fn test_check_token_audience_multi() {
483        let toml = r#"
484            issuer = "https://example.com"
485            subject = "sub"
486            audience = "my-audience"
487
488            [permissions]
489            contents = "read"
490        "#;
491        let policy = TrustPolicy::parse(toml).unwrap();
492        let compiled = policy.compile(false).unwrap();
493
494        let claims = crate::oidc::TokenClaims {
495            iss: "https://example.com".into(),
496            sub: "sub".into(),
497            aud: crate::oidc::OneOrMany::Many(vec!["other-aud".into(), "my-audience".into()]),
498            extra: std::collections::HashMap::new(),
499        };
500
501        assert!(compiled.check_token(&claims, "sts.example.com").is_ok());
502    }
503
504    #[test]
505    fn test_pattern_alternation_fully_anchored() {
506        let toml = r#"
507            issuer = "https://token.actions.githubusercontent.com"
508            subject_pattern = "repo:myorg/a|repo:myorg/b"
509
510            [permissions]
511            contents = "read"
512        "#;
513        let policy = TrustPolicy::parse(toml).unwrap();
514        let compiled = policy.compile(false).unwrap();
515
516        let make_claims = |sub: &str| crate::oidc::TokenClaims {
517            iss: "https://token.actions.githubusercontent.com".into(),
518            sub: sub.into(),
519            aud: crate::oidc::OneOrMany::One("sts.example.com".into()),
520            extra: std::collections::HashMap::new(),
521        };
522
523        assert!(
524            compiled
525                .check_token(&make_claims("repo:myorg/a"), "sts.example.com")
526                .is_ok()
527        );
528        assert!(
529            compiled
530                .check_token(&make_claims("repo:myorg/b"), "sts.example.com")
531                .is_ok()
532        );
533        // Must not match partial strings
534        assert!(
535            compiled
536                .check_token(&make_claims("repo:myorg/a-extra"), "sts.example.com")
537                .is_err()
538        );
539        assert!(
540            compiled
541                .check_token(&make_claims("prefix-repo:myorg/b"), "sts.example.com")
542                .is_err()
543        );
544    }
545}