Skip to main content

sts_cat/
config.rs

1#[derive(Debug, Clone, clap::Parser)]
2pub struct Config {
3    #[arg(long, env = "STS_CAT_GITHUB_APP_ID")]
4    pub github_app_id: String,
5
6    #[arg(
7        long,
8        default_value = "https://api.github.com",
9        env = "STS_CAT_GITHUB_API_URL"
10    )]
11    pub github_api_url: String,
12
13    #[arg(long, env = "STS_CAT_IDENTIFIER")]
14    pub identifier: String,
15
16    #[arg(long, default_value = "0.0.0.0", env = "HOST")]
17    pub host: String,
18
19    #[arg(long, default_value_t = 8080, env = "PORT")]
20    pub port: u16,
21
22    #[arg(long, env = "STS_CAT_LOG_JSON")]
23    pub log_json: bool,
24
25    #[arg(long, env = "STS_CAT_KEY_SOURCE")]
26    pub key_source: KeySource,
27
28    #[arg(long, env = "STS_CAT_KEY_FILE", required_if_eq("key_source", "file"))]
29    pub key_file: Option<std::path::PathBuf>,
30
31    #[arg(long, env = "STS_CAT_KEY_ENV", required_if_eq("key_source", "env"))]
32    pub key_env: Option<String>,
33
34    #[cfg(feature = "aws-kms")]
35    #[arg(
36        long,
37        env = "STS_CAT_AWS_KMS_KEY_ARN",
38        required_if_eq("key_source", "aws-kms")
39    )]
40    pub aws_kms_key_arn: Option<String>,
41
42    #[arg(
43        long,
44        default_value = ".github/sts-cat",
45        env = "STS_CAT_POLICY_PATH_PREFIX"
46    )]
47    pub policy_path_prefix: String,
48
49    #[arg(
50        long,
51        default_value = ".sts.toml",
52        env = "STS_CAT_POLICY_FILE_EXTENSION"
53    )]
54    pub policy_file_extension: String,
55
56    #[arg(long, env = "STS_CAT_ALLOWED_ISSUER_URLS", value_delimiter = ',')]
57    pub allowed_issuer_urls: Option<Vec<String>>,
58
59    #[arg(long, env = "STS_CAT_ORG_REPO", value_delimiter = ',')]
60    pub org_repo: Option<Vec<String>>,
61}
62
63#[derive(Debug, Clone, clap::ValueEnum)]
64pub enum KeySource {
65    File,
66    Env,
67    #[cfg(feature = "aws-kms")]
68    AwsKms,
69}
70
71impl Config {
72    pub fn parse_org_repos(
73        &self,
74    ) -> Result<std::collections::HashMap<String, String>, anyhow::Error> {
75        let mut map = std::collections::HashMap::new();
76        if let Some(ref entries) = self.org_repo {
77            for entry in entries {
78                let (org, repo) = entry.split_once('/').ok_or_else(|| {
79                    anyhow::anyhow!(
80                        "invalid --org-repo value '{entry}': expected format 'org/repo'"
81                    )
82                })?;
83                if org.is_empty() || repo.is_empty() {
84                    anyhow::bail!(
85                        "invalid --org-repo value '{entry}': org and repo must not be empty"
86                    );
87                }
88                map.insert(org.to_ascii_lowercase(), repo.to_owned());
89            }
90        }
91        Ok(map)
92    }
93
94    pub async fn build_signer(
95        &self,
96    ) -> Result<std::sync::Arc<dyn crate::signer::Signer>, anyhow::Error> {
97        match &self.key_source {
98            KeySource::File => {
99                let path = self.key_file.as_ref().unwrap();
100                let pem = std::fs::read(path)?;
101                Ok(std::sync::Arc::new(
102                    crate::signer::raw::RawSigner::from_pem(&pem)?,
103                ))
104            }
105            KeySource::Env => {
106                let env_name = self.key_env.as_ref().unwrap();
107                let pem = std::env::var(env_name)
108                    .map_err(|_| anyhow::anyhow!("env var {env_name} not set"))?;
109                Ok(std::sync::Arc::new(
110                    crate::signer::raw::RawSigner::from_pem(pem.as_bytes())?,
111                ))
112            }
113            #[cfg(feature = "aws-kms")]
114            KeySource::AwsKms => {
115                let arn = self.aws_kms_key_arn.as_ref().unwrap();
116                Ok(std::sync::Arc::new(
117                    crate::signer::aws_kms::AwsKmsSigner::new(arn.clone()).await?,
118                ))
119            }
120        }
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    fn config_with_org_repo(org_repo: Option<Vec<String>>) -> Config {
129        Config {
130            github_app_id: "123".into(),
131            github_api_url: "https://api.github.com".into(),
132            identifier: "example.com".into(),
133            host: "0.0.0.0".into(),
134            port: 8080,
135            log_json: false,
136            key_source: KeySource::File,
137            key_file: Some("/dev/null".into()),
138            key_env: None,
139            #[cfg(feature = "aws-kms")]
140            aws_kms_key_arn: None,
141            policy_path_prefix: ".github/sts-cat".into(),
142            policy_file_extension: ".sts.toml".into(),
143            allowed_issuer_urls: None,
144            org_repo,
145        }
146    }
147
148    #[test]
149    fn test_parse_org_repos_none() {
150        let config = config_with_org_repo(None);
151        let map = config.parse_org_repos().unwrap();
152        assert!(map.is_empty());
153    }
154
155    #[test]
156    fn test_parse_org_repos_single() {
157        let config = config_with_org_repo(Some(vec!["myorg/policies".into()]));
158        let map = config.parse_org_repos().unwrap();
159        assert_eq!(map.get("myorg").unwrap(), "policies");
160    }
161
162    #[test]
163    fn test_parse_org_repos_multiple() {
164        let config =
165            config_with_org_repo(Some(vec!["myorg/policies".into(), "other/infra".into()]));
166        let map = config.parse_org_repos().unwrap();
167        assert_eq!(map.get("myorg").unwrap(), "policies");
168        assert_eq!(map.get("other").unwrap(), "infra");
169    }
170
171    #[test]
172    fn test_parse_org_repos_lowercases_org() {
173        let config = config_with_org_repo(Some(vec!["MyOrg/policies".into()]));
174        let map = config.parse_org_repos().unwrap();
175        assert_eq!(map.get("myorg").unwrap(), "policies");
176        assert!(map.get("MyOrg").is_none());
177    }
178
179    #[test]
180    fn test_parse_org_repos_rejects_no_slash() {
181        let config = config_with_org_repo(Some(vec!["myorg".into()]));
182        assert!(config.parse_org_repos().is_err());
183    }
184
185    #[test]
186    fn test_parse_org_repos_rejects_empty_parts() {
187        let config = config_with_org_repo(Some(vec!["/repo".into()]));
188        assert!(config.parse_org_repos().is_err());
189
190        let config = config_with_org_repo(Some(vec!["org/".into()]));
191        assert!(config.parse_org_repos().is_err());
192    }
193}