1use crate::error::Error;
2
3#[derive(Clone, Default)]
4pub(crate) struct RemoteAddr(String);
5
6#[derive(serde::Deserialize)]
7pub struct ExchangeRequest {
8 pub scope: String,
9 pub identity: String,
10}
11
12#[derive(serde::Serialize)]
13pub struct ExchangeResponse {
14 pub token: GitHubToken,
15}
16
17#[derive(
20 Clone, Debug, serde::Serialize, serde::Deserialize, zeroize::Zeroize, zeroize::ZeroizeOnDrop,
21)]
22pub struct GitHubTokenInner(String);
23impl secrecy::SerializableSecret for GitHubTokenInner {}
24impl secrecy::CloneableSecret for GitHubTokenInner {}
25
26impl GitHubTokenInner {
27 pub fn as_str(&self) -> &str {
28 &self.0
29 }
30}
31
32impl From<String> for GitHubTokenInner {
33 fn from(s: String) -> Self {
34 Self(s)
35 }
36}
37
38pub type GitHubToken = secrecy::SecretBox<GitHubTokenInner>;
39
40pub struct AppState {
41 pub config: crate::config::Config,
42 pub github: crate::github::GitHubClient,
43 pub oidc: crate::oidc::OidcVerifier,
44 org_repos: std::collections::HashMap<String, String>,
45 policy_cache: moka::future::Cache<
46 (String, String, String),
47 std::sync::Arc<crate::trust_policy::CompiledTrustPolicy>,
48 >,
49 installation_cache: moka::future::Cache<String, u64>,
50}
51
52impl AppState {
53 pub async fn build(
54 config: crate::config::Config,
55 ) -> Result<std::sync::Arc<Self>, anyhow::Error> {
56 let signer = config.build_signer().await?;
57 let github =
58 crate::github::GitHubClient::new(&config.github_api_url, &config.github_app_id, signer);
59 let oidc = crate::oidc::OidcVerifier::new(config.allowed_issuer_urls.clone());
60 let org_repos = config.parse_org_repos()?;
61
62 let policy_cache = moka::future::Cache::builder()
63 .max_capacity(200)
64 .time_to_live(std::time::Duration::from_secs(300))
65 .build();
66
67 let installation_cache = moka::future::Cache::builder()
68 .max_capacity(200)
69 .time_to_live(std::time::Duration::from_secs(3600))
70 .build();
71
72 Ok(std::sync::Arc::new(Self {
73 config,
74 github,
75 oidc,
76 org_repos,
77 policy_cache,
78 installation_cache,
79 }))
80 }
81}
82
83async fn handle_exchange(
84 axum::extract::State(state): axum::extract::State<std::sync::Arc<AppState>>,
85 axum::Extension(remote_addr): axum::Extension<RemoteAddr>,
86 headers: axum::http::HeaderMap,
87 bearer: Result<
88 axum_extra::TypedHeader<headers::Authorization<headers::authorization::Bearer>>,
89 axum_extra::typed_header::TypedHeaderRejection,
90 >,
91 axum::Json(req): axum::Json<ExchangeRequest>,
92) -> Result<axum::Json<ExchangeResponse>, Error> {
93 let remote_addr = remote_addr.0;
94 let xff = headers
95 .get("x-forwarded-for")
96 .and_then(|v| v.to_str().ok())
97 .unwrap_or("")
98 .to_owned();
99
100 let axum_extra::TypedHeader(authorization) = bearer
101 .map_err(|_| Error::Unauthenticated("missing or invalid Authorization header".into()))?;
102 let bearer_token = authorization.token();
103
104 if req.scope.is_empty() {
105 return Err(Error::BadRequest("scope must not be empty".into()));
106 }
107 if req.identity.is_empty() {
108 return Err(Error::BadRequest("identity must not be empty".into()));
109 }
110 if !is_valid_name(&req.identity) {
111 return Err(Error::BadRequest("invalid identity format".into()));
112 }
113
114 let (owner, mut repo, is_org_level) = parse_scope(&req.scope)?;
115 let owner = owner.to_ascii_lowercase();
116 if is_org_level && let Some(override_repo) = state.org_repos.get(&owner) {
117 repo = override_repo.clone();
118 }
119 let claims = state.oidc.verify(bearer_token).await?;
120
121 let installation_id = if let Some(id) = state.installation_cache.get(&owner).await {
122 id
123 } else {
124 let id = state.github.get_installation_id(&owner).await?;
125 state.installation_cache.insert(owner.clone(), id).await;
126 id
127 };
128
129 let policy_path = format!(
130 "{}/{}{}",
131 state.config.policy_path_prefix, req.identity, state.config.policy_file_extension
132 );
133
134 let cache_key = (owner.clone(), repo.clone(), req.identity.clone());
135 let compiled = if let Some(cached) = state.policy_cache.get(&cache_key).await {
136 cached
137 } else {
138 let content = state
139 .github
140 .get_trust_policy_content(installation_id, &owner, &repo, &policy_path)
141 .await?;
142 let policy = crate::trust_policy::TrustPolicy::parse(&content)?;
143 let compiled = std::sync::Arc::new(policy.compile(is_org_level)?);
144 state.policy_cache.insert(cache_key, compiled.clone()).await;
145 compiled
146 };
147
148 let actor = match compiled.check_token(&claims, &state.config.identifier) {
149 Ok(actor) => actor,
150 Err(e) => {
151 tracing::warn!(
152 event = "exchange_denied",
153 scope = %req.scope,
154 identity = %req.identity,
155 issuer = %claims.iss,
156 subject = %claims.sub,
157 remote_addr = %remote_addr,
158 xff = %xff,
159 reason = %e,
160 );
161 return Err(e);
162 }
163 };
164
165 tracing::info!(
166 event = "exchange_authorized",
167 scope = %req.scope,
168 identity = %req.identity,
169 issuer = %actor.issuer,
170 subject = %actor.subject,
171 remote_addr = %remote_addr,
172 xff = %xff,
173 installation_id = installation_id,
174 policy_path = %policy_path,
175 );
176
177 let repositories = if is_org_level {
178 compiled.repositories.clone().unwrap_or_default()
179 } else {
180 vec![repo.clone()]
181 };
182
183 let token = state
184 .github
185 .create_installation_token(installation_id, &compiled.permissions, &repositories)
186 .await?;
187
188 use secrecy::ExposeSecret as _;
189 use sha2::Digest as _;
190 let token_hash = hex::encode(sha2::Sha256::digest(
191 token.expose_secret().as_str().as_bytes(),
192 ));
193 tracing::info!(
194 event = "exchange_success",
195 scope = %req.scope,
196 identity = %req.identity,
197 issuer = %actor.issuer,
198 subject = %actor.subject,
199 remote_addr = %remote_addr,
200 xff = %xff,
201 installation_id = installation_id,
202 token_sha256 = %token_hash,
203 );
204
205 Ok(axum::Json(ExchangeResponse { token }))
206}
207
208#[derive(serde::Serialize)]
209pub(crate) struct HealthResponse {
210 ok: bool,
211}
212
213pub(crate) async fn handle_healthz() -> axum::Json<HealthResponse> {
214 axum::Json(HealthResponse { ok: true })
215}
216
217fn is_valid_name(s: &str) -> bool {
218 static RE: std::sync::LazyLock<regex::Regex> =
219 std::sync::LazyLock::new(|| regex::Regex::new(r"^[a-zA-Z0-9._-]+$").unwrap());
220 RE.is_match(s)
221}
222
223fn parse_scope(scope: &str) -> Result<(String, String, bool), Error> {
224 if let Some((owner, repo)) = scope.split_once('/') {
225 if owner.is_empty() || repo.is_empty() {
226 return Err(Error::BadRequest("invalid scope format".into()));
227 }
228 if !is_valid_name(owner) || !is_valid_name(repo) {
229 return Err(Error::BadRequest("invalid scope format".into()));
230 }
231 let is_org_level = repo == ".github";
232 Ok((owner.to_owned(), repo.to_owned(), is_org_level))
233 } else {
234 if scope.is_empty() {
236 return Err(Error::BadRequest("invalid scope format".into()));
237 }
238 if !is_valid_name(scope) {
239 return Err(Error::BadRequest("invalid scope format".into()));
240 }
241 Ok((scope.to_owned(), ".github".to_owned(), true))
242 }
243}
244
245fn make_request_span(req: &axum::extract::Request) -> tracing::Span {
246 let remote_addr = req
247 .extensions()
248 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
249 .map(|ci| ci.0.ip().to_string())
250 .unwrap_or_default();
251 let xff = req
252 .headers()
253 .get("x-forwarded-for")
254 .and_then(|v| v.to_str().ok())
255 .unwrap_or("")
256 .to_owned();
257 tracing::info_span!(
258 "request",
259 method = %req.method(),
260 uri = %req.uri(),
261 version = ?req.version(),
262 remote_addr = %remote_addr,
263 xff = %xff,
264 )
265}
266
267pub fn build_router(state: std::sync::Arc<AppState>) -> axum::Router {
268 let server_header = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
269 axum::Router::new()
270 .route("/token", axum::routing::post(handle_exchange))
271 .route("/healthz", axum::routing::get(handle_healthz))
272 .layer(axum::Extension(RemoteAddr::default()))
273 .layer(axum::middleware::from_fn(
274 |mut req: axum::extract::Request, next: axum::middleware::Next| async move {
275 if let Some(ci) = req
276 .extensions()
277 .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
278 {
279 let addr = RemoteAddr(ci.0.ip().to_string());
280 req.extensions_mut().insert(addr);
281 }
282 next.run(req).await
283 },
284 ))
285 .layer(
286 tower_http::trace::TraceLayer::new_for_http()
287 .make_span_with(make_request_span as fn(&axum::extract::Request) -> tracing::Span)
288 .on_response(
289 tower_http::trace::DefaultOnResponse::new().level(tracing::Level::INFO),
290 ),
291 )
292 .layer(axum::middleware::from_fn(
293 move |req, next: axum::middleware::Next| {
294 let val = server_header.clone();
295 async move {
296 let mut resp = next.run(req).await;
297 resp.headers_mut().insert(
298 axum::http::header::SERVER,
299 axum::http::HeaderValue::from_str(&val).unwrap(),
300 );
301 resp
302 }
303 },
304 ))
305 .with_state(state)
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_parse_scope_repo() {
314 let (owner, repo, is_org) = parse_scope("myorg/myrepo").unwrap();
315 assert_eq!(owner, "myorg");
316 assert_eq!(repo, "myrepo");
317 assert!(!is_org);
318 }
319
320 #[test]
321 fn test_parse_scope_org() {
322 let (owner, repo, is_org) = parse_scope("myorg").unwrap();
323 assert_eq!(owner, "myorg");
324 assert_eq!(repo, ".github");
325 assert!(is_org);
326 }
327
328 #[test]
329 fn test_parse_scope_org_dotgithub() {
330 let (owner, repo, is_org) = parse_scope("myorg/.github").unwrap();
331 assert_eq!(owner, "myorg");
332 assert_eq!(repo, ".github");
333 assert!(is_org);
334 }
335
336 #[test]
337 fn test_parse_scope_empty() {
338 assert!(parse_scope("").is_err());
339 assert!(parse_scope("/repo").is_err());
340 assert!(parse_scope("owner/").is_err());
341 }
342
343 #[test]
344 fn test_parse_scope_rejects_invalid_chars() {
345 assert!(parse_scope("org/../evil").is_err());
346 assert!(parse_scope("org/repo name").is_err());
347 assert!(parse_scope("org/<script>").is_err());
348 assert!(parse_scope("org\0evil").is_err());
349 }
350
351 #[test]
352 fn test_is_valid_name() {
353 assert!(is_valid_name("my-repo"));
354 assert!(is_valid_name("my.repo"));
355 assert!(is_valid_name("my_repo"));
356 assert!(is_valid_name(".github"));
357 assert!(!is_valid_name("../etc/passwd"));
358 assert!(!is_valid_name("repo name"));
359 assert!(!is_valid_name("repo/name"));
360 assert!(!is_valid_name(""));
361 }
362}