Skip to main content

sts_cat/
exchange.rs

1use crate::error::Error;
2
3#[derive(Clone, Default)]
4pub(crate) struct RemoteAddr(String);
5
6#[derive(serde::Deserialize)]
7pub struct ExchangeRequest {
8    pub scope: String,
9    pub identity: String,
10}
11
12#[derive(serde::Serialize)]
13pub struct ExchangeResponse {
14    pub token: GitHubToken,
15}
16
17/// Inner type for GitHub installation access tokens.
18/// Implements SerializableSecret intentionally — sts-cat is a token vending service.
19#[derive(
20    Clone, Debug, serde::Serialize, serde::Deserialize, zeroize::Zeroize, zeroize::ZeroizeOnDrop,
21)]
22pub struct GitHubTokenInner(String);
23impl secrecy::SerializableSecret for GitHubTokenInner {}
24impl secrecy::CloneableSecret for GitHubTokenInner {}
25
26impl GitHubTokenInner {
27    pub fn as_str(&self) -> &str {
28        &self.0
29    }
30}
31
32impl From<String> for GitHubTokenInner {
33    fn from(s: String) -> Self {
34        Self(s)
35    }
36}
37
38pub type GitHubToken = secrecy::SecretBox<GitHubTokenInner>;
39
40pub struct AppState {
41    pub config: crate::config::Config,
42    pub github: crate::github::GitHubClient,
43    pub oidc: crate::oidc::OidcVerifier,
44    org_repos: std::collections::HashMap<String, String>,
45    policy_cache: moka::future::Cache<
46        (String, String, String),
47        std::sync::Arc<crate::trust_policy::CompiledTrustPolicy>,
48    >,
49    installation_cache: moka::future::Cache<String, u64>,
50}
51
52impl AppState {
53    pub async fn build(
54        config: crate::config::Config,
55    ) -> Result<std::sync::Arc<Self>, anyhow::Error> {
56        let signer = config.build_signer().await?;
57        let github =
58            crate::github::GitHubClient::new(&config.github_api_url, &config.github_app_id, signer);
59        let oidc = crate::oidc::OidcVerifier::new(config.allowed_issuer_urls.clone());
60        let org_repos = config.parse_org_repos()?;
61
62        let policy_cache = moka::future::Cache::builder()
63            .max_capacity(200)
64            .time_to_live(std::time::Duration::from_secs(300))
65            .build();
66
67        let installation_cache = moka::future::Cache::builder()
68            .max_capacity(200)
69            .time_to_live(std::time::Duration::from_secs(3600))
70            .build();
71
72        Ok(std::sync::Arc::new(Self {
73            config,
74            github,
75            oidc,
76            org_repos,
77            policy_cache,
78            installation_cache,
79        }))
80    }
81}
82
83async fn handle_exchange(
84    axum::extract::State(state): axum::extract::State<std::sync::Arc<AppState>>,
85    axum::Extension(remote_addr): axum::Extension<RemoteAddr>,
86    headers: axum::http::HeaderMap,
87    bearer: Result<
88        axum_extra::TypedHeader<headers::Authorization<headers::authorization::Bearer>>,
89        axum_extra::typed_header::TypedHeaderRejection,
90    >,
91    axum::Json(req): axum::Json<ExchangeRequest>,
92) -> Result<axum::Json<ExchangeResponse>, Error> {
93    let remote_addr = remote_addr.0;
94    let xff = headers
95        .get("x-forwarded-for")
96        .and_then(|v| v.to_str().ok())
97        .unwrap_or("")
98        .to_owned();
99
100    let axum_extra::TypedHeader(authorization) = bearer
101        .map_err(|_| Error::Unauthenticated("missing or invalid Authorization header".into()))?;
102    let bearer_token = authorization.token();
103
104    if req.scope.is_empty() {
105        return Err(Error::BadRequest("scope must not be empty".into()));
106    }
107    if req.identity.is_empty() {
108        return Err(Error::BadRequest("identity must not be empty".into()));
109    }
110    if !is_valid_name(&req.identity) {
111        return Err(Error::BadRequest("invalid identity format".into()));
112    }
113
114    let (owner, mut repo, is_org_level) = parse_scope(&req.scope)?;
115    let owner = owner.to_ascii_lowercase();
116    if is_org_level && let Some(override_repo) = state.org_repos.get(&owner) {
117        repo = override_repo.clone();
118    }
119    let claims = state.oidc.verify(bearer_token).await?;
120
121    let installation_id = if let Some(id) = state.installation_cache.get(&owner).await {
122        id
123    } else {
124        let id = state.github.get_installation_id(&owner).await?;
125        state.installation_cache.insert(owner.clone(), id).await;
126        id
127    };
128
129    let policy_path = format!(
130        "{}/{}{}",
131        state.config.policy_path_prefix, req.identity, state.config.policy_file_extension
132    );
133
134    let cache_key = (owner.clone(), repo.clone(), req.identity.clone());
135    let compiled = if let Some(cached) = state.policy_cache.get(&cache_key).await {
136        cached
137    } else {
138        let content = state
139            .github
140            .get_trust_policy_content(installation_id, &owner, &repo, &policy_path)
141            .await?;
142        let policy = crate::trust_policy::TrustPolicy::parse(&content)?;
143        let compiled = std::sync::Arc::new(policy.compile(is_org_level)?);
144        state.policy_cache.insert(cache_key, compiled.clone()).await;
145        compiled
146    };
147
148    let actor = match compiled.check_token(&claims, &state.config.identifier) {
149        Ok(actor) => actor,
150        Err(e) => {
151            tracing::warn!(
152                event = "exchange_denied",
153                scope = %req.scope,
154                identity = %req.identity,
155                issuer = %claims.iss,
156                subject = %claims.sub,
157                remote_addr = %remote_addr,
158                xff = %xff,
159                reason = %e,
160            );
161            return Err(e);
162        }
163    };
164
165    tracing::info!(
166        event = "exchange_authorized",
167        scope = %req.scope,
168        identity = %req.identity,
169        issuer = %actor.issuer,
170        subject = %actor.subject,
171        remote_addr = %remote_addr,
172        xff = %xff,
173        installation_id = installation_id,
174        policy_path = %policy_path,
175    );
176
177    let repositories = if is_org_level {
178        compiled.repositories.clone().unwrap_or_default()
179    } else {
180        vec![repo.clone()]
181    };
182
183    let token = state
184        .github
185        .create_installation_token(installation_id, &compiled.permissions, &repositories)
186        .await?;
187
188    use secrecy::ExposeSecret as _;
189    use sha2::Digest as _;
190    let token_hash = hex::encode(sha2::Sha256::digest(
191        token.expose_secret().as_str().as_bytes(),
192    ));
193    tracing::info!(
194        event = "exchange_success",
195        scope = %req.scope,
196        identity = %req.identity,
197        issuer = %actor.issuer,
198        subject = %actor.subject,
199        remote_addr = %remote_addr,
200        xff = %xff,
201        installation_id = installation_id,
202        token_sha256 = %token_hash,
203    );
204
205    Ok(axum::Json(ExchangeResponse { token }))
206}
207
208#[derive(serde::Serialize)]
209pub(crate) struct HealthResponse {
210    ok: bool,
211}
212
213pub(crate) async fn handle_healthz() -> axum::Json<HealthResponse> {
214    axum::Json(HealthResponse { ok: true })
215}
216
217fn is_valid_name(s: &str) -> bool {
218    static RE: std::sync::LazyLock<regex::Regex> =
219        std::sync::LazyLock::new(|| regex::Regex::new(r"^[a-zA-Z0-9._-]+$").unwrap());
220    RE.is_match(s)
221}
222
223fn parse_scope(scope: &str) -> Result<(String, String, bool), Error> {
224    if let Some((owner, repo)) = scope.split_once('/') {
225        if owner.is_empty() || repo.is_empty() {
226            return Err(Error::BadRequest("invalid scope format".into()));
227        }
228        if !is_valid_name(owner) || !is_valid_name(repo) {
229            return Err(Error::BadRequest("invalid scope format".into()));
230        }
231        let is_org_level = repo == ".github";
232        Ok((owner.to_owned(), repo.to_owned(), is_org_level))
233    } else {
234        // Org-level scope: "org" → reads from ".github" repo
235        if scope.is_empty() {
236            return Err(Error::BadRequest("invalid scope format".into()));
237        }
238        if !is_valid_name(scope) {
239            return Err(Error::BadRequest("invalid scope format".into()));
240        }
241        Ok((scope.to_owned(), ".github".to_owned(), true))
242    }
243}
244
245fn make_request_span(req: &axum::extract::Request) -> tracing::Span {
246    let remote_addr = req
247        .extensions()
248        .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
249        .map(|ci| ci.0.ip().to_string())
250        .unwrap_or_default();
251    let xff = req
252        .headers()
253        .get("x-forwarded-for")
254        .and_then(|v| v.to_str().ok())
255        .unwrap_or("")
256        .to_owned();
257    tracing::info_span!(
258        "request",
259        method = %req.method(),
260        uri = %req.uri(),
261        version = ?req.version(),
262        remote_addr = %remote_addr,
263        xff = %xff,
264    )
265}
266
267pub fn build_router(state: std::sync::Arc<AppState>) -> axum::Router {
268    let server_header = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
269    axum::Router::new()
270        .route("/token", axum::routing::post(handle_exchange))
271        .route("/healthz", axum::routing::get(handle_healthz))
272        .layer(axum::Extension(RemoteAddr::default()))
273        .layer(axum::middleware::from_fn(
274            |mut req: axum::extract::Request, next: axum::middleware::Next| async move {
275                if let Some(ci) = req
276                    .extensions()
277                    .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
278                {
279                    let addr = RemoteAddr(ci.0.ip().to_string());
280                    req.extensions_mut().insert(addr);
281                }
282                next.run(req).await
283            },
284        ))
285        .layer(
286            tower_http::trace::TraceLayer::new_for_http()
287                .make_span_with(make_request_span as fn(&axum::extract::Request) -> tracing::Span)
288                .on_response(
289                    tower_http::trace::DefaultOnResponse::new().level(tracing::Level::INFO),
290                ),
291        )
292        .layer(axum::middleware::from_fn(
293            move |req, next: axum::middleware::Next| {
294                let val = server_header.clone();
295                async move {
296                    let mut resp = next.run(req).await;
297                    resp.headers_mut().insert(
298                        axum::http::header::SERVER,
299                        axum::http::HeaderValue::from_str(&val).unwrap(),
300                    );
301                    resp
302                }
303            },
304        ))
305        .with_state(state)
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_parse_scope_repo() {
314        let (owner, repo, is_org) = parse_scope("myorg/myrepo").unwrap();
315        assert_eq!(owner, "myorg");
316        assert_eq!(repo, "myrepo");
317        assert!(!is_org);
318    }
319
320    #[test]
321    fn test_parse_scope_org() {
322        let (owner, repo, is_org) = parse_scope("myorg").unwrap();
323        assert_eq!(owner, "myorg");
324        assert_eq!(repo, ".github");
325        assert!(is_org);
326    }
327
328    #[test]
329    fn test_parse_scope_org_dotgithub() {
330        let (owner, repo, is_org) = parse_scope("myorg/.github").unwrap();
331        assert_eq!(owner, "myorg");
332        assert_eq!(repo, ".github");
333        assert!(is_org);
334    }
335
336    #[test]
337    fn test_parse_scope_empty() {
338        assert!(parse_scope("").is_err());
339        assert!(parse_scope("/repo").is_err());
340        assert!(parse_scope("owner/").is_err());
341    }
342
343    #[test]
344    fn test_parse_scope_rejects_invalid_chars() {
345        assert!(parse_scope("org/../evil").is_err());
346        assert!(parse_scope("org/repo name").is_err());
347        assert!(parse_scope("org/<script>").is_err());
348        assert!(parse_scope("org\0evil").is_err());
349    }
350
351    #[test]
352    fn test_is_valid_name() {
353        assert!(is_valid_name("my-repo"));
354        assert!(is_valid_name("my.repo"));
355        assert!(is_valid_name("my_repo"));
356        assert!(is_valid_name(".github"));
357        assert!(!is_valid_name("../etc/passwd"));
358        assert!(!is_valid_name("repo name"));
359        assert!(!is_valid_name("repo/name"));
360        assert!(!is_valid_name(""));
361    }
362}