1pub mod jwks;
30pub mod request;
31pub mod responses;
32pub mod route_handler;
33pub mod sealed_token;
34pub mod sts;
35
36use base64::Engine;
37pub use jwks::JwksCache;
38use multistore::error::ProxyError;
39use multistore::registry::CredentialRegistry;
40use multistore::types::TemporaryCredentials;
41pub use request::try_parse_sts_request;
42use request::StsRequest;
43pub use responses::{build_sts_error_response, build_sts_response};
44pub use sealed_token::TokenKey;
45
46pub async fn try_handle_sts<C: CredentialRegistry>(
53 query: Option<&str>,
54 config: &C,
55 jwks_cache: &JwksCache,
56 token_key: Option<&TokenKey>,
57) -> Option<(u16, String)> {
58 let sts_result = try_parse_sts_request(query)?;
59 let (status, xml) = match sts_result {
60 Ok(sts_request) => {
61 let Some(key) = token_key else {
62 tracing::error!("STS request received but SESSION_TOKEN_KEY is not configured");
63 return Some(build_sts_error_response(&ProxyError::ConfigError(
64 "STS requires SESSION_TOKEN_KEY to be configured".into(),
65 )));
66 };
67 match assume_role_with_web_identity(config, &sts_request, "STSPRXY", jwks_cache, key)
68 .await
69 {
70 Ok(creds) => build_sts_response(&creds),
71 Err(e) => {
72 tracing::warn!(error = %e, "STS request failed");
73 build_sts_error_response(&e)
74 }
75 }
76 }
77 Err(e) => build_sts_error_response(&e),
78 };
79 Some((status, xml))
80}
81
82fn jwt_decode_unverified(
84 token: &str,
85) -> Result<(serde_json::Value, serde_json::Value), ProxyError> {
86 let mut parts = token.splitn(3, '.');
87 let header_b64 = parts
88 .next()
89 .ok_or_else(|| ProxyError::InvalidOidcToken("malformed JWT".into()))?;
90 let payload_b64 = parts
91 .next()
92 .ok_or_else(|| ProxyError::InvalidOidcToken("malformed JWT".into()))?;
93
94 let decode = |s: &str| -> Result<serde_json::Value, ProxyError> {
95 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
96 .decode(s)
97 .map_err(|e| ProxyError::InvalidOidcToken(format!("base64url decode error: {}", e)))?;
98 serde_json::from_slice(&bytes)
99 .map_err(|e| ProxyError::InvalidOidcToken(format!("invalid JWT JSON: {}", e)))
100 };
101
102 Ok((decode(header_b64)?, decode(payload_b64)?))
103}
104
105pub async fn assume_role_with_web_identity<C: CredentialRegistry>(
110 config: &C,
111 sts_request: &StsRequest,
112 key_prefix: &str,
113 jwks_cache: &JwksCache,
114 token_key: &TokenKey,
115) -> Result<TemporaryCredentials, ProxyError> {
116 let role = config
118 .get_role(&sts_request.role_arn)
119 .await?
120 .ok_or_else(|| ProxyError::RoleNotFound(sts_request.role_arn.to_string()))?;
121
122 let (header, insecure_claims) = jwt_decode_unverified(&sts_request.web_identity_token)?;
124
125 let issuer = insecure_claims
126 .get("iss")
127 .and_then(|v| v.as_str())
128 .ok_or_else(|| ProxyError::InvalidOidcToken("missing iss claim".into()))?;
129
130 if !role.trusted_oidc_issuers.iter().any(|i| i == issuer) {
132 return Err(ProxyError::InvalidOidcToken(format!(
133 "untrusted issuer: {}",
134 issuer
135 )));
136 }
137
138 let alg = header.get("alg").and_then(|v| v.as_str()).unwrap_or("");
140 if alg != "RS256" {
141 return Err(ProxyError::InvalidOidcToken(format!(
142 "unsupported JWT algorithm: {}",
143 alg
144 )));
145 }
146
147 let jwks = jwks_cache.get_or_fetch(issuer).await?;
149 let kid = header
150 .get("kid")
151 .and_then(|v| v.as_str())
152 .ok_or_else(|| ProxyError::InvalidOidcToken("JWT missing kid".into()))?;
153
154 let key = jwks::find_key(&jwks, kid)?;
155 let claims = jwks::verify_token(&sts_request.web_identity_token, key, issuer, &role)?;
156
157 let subject = claims.get("sub").and_then(|v| v.as_str()).unwrap_or("");
159
160 if !role.subject_conditions.is_empty() {
161 let matches = role
162 .subject_conditions
163 .iter()
164 .any(|pattern| subject_matches(subject, pattern));
165 if !matches {
166 return Err(ProxyError::InvalidOidcToken(format!(
167 "subject '{}' does not match any conditions",
168 subject
169 )));
170 }
171 }
172
173 const MIN_SESSION_DURATION_SECS: u64 = 900;
175 let duration = sts_request
176 .duration_seconds
177 .unwrap_or(3600)
178 .clamp(MIN_SESSION_DURATION_SECS, role.max_session_duration_secs);
179
180 let mut creds = sts::mint_temporary_credentials(&role, subject, duration, key_prefix, &claims);
181
182 creds.session_token = token_key.seal(&creds)?;
184
185 Ok(creds)
186}
187
188fn subject_matches(subject: &str, pattern: &str) -> bool {
191 if pattern == "*" {
192 return true;
193 }
194
195 let parts: Vec<&str> = pattern.split('*').collect();
196 if parts.len() == 1 {
197 return subject == pattern;
198 }
199
200 let mut remaining = subject;
201
202 if !parts[0].is_empty() {
204 if !remaining.starts_with(parts[0]) {
205 return false;
206 }
207 remaining = &remaining[parts[0].len()..];
208 }
209
210 for part in &parts[1..parts.len() - 1] {
212 if part.is_empty() {
213 continue;
214 }
215 match remaining.find(part) {
216 Some(idx) => remaining = &remaining[idx + part.len()..],
217 None => return false,
218 }
219 }
220
221 let last = parts.last().unwrap();
223 if !last.is_empty() {
224 return remaining.ends_with(last);
225 }
226
227 true
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[test]
235 fn test_subject_matching() {
236 assert!(subject_matches(
238 "repo:org/repo:ref:refs/heads/main",
239 "repo:org/repo:*"
240 ));
241
242 assert!(subject_matches("repo:org/repo:ref:refs/heads/main", "*"));
244
245 assert!(subject_matches(
247 "repo:org/repo:ref:refs/heads/main",
248 "repo:org/repo:ref:refs/heads/main"
249 ));
250
251 assert!(!subject_matches(
253 "repo:org/repo:ref:refs/heads/main",
254 "repo:other/*"
255 ));
256
257 assert!(subject_matches(
259 "repo:org/repo:ref:refs/heads/main",
260 "repo:org/*:ref:refs/heads/*"
261 ));
262 }
263
264 #[test]
265 fn test_subject_matching_exact() {
266 assert!(subject_matches("abc", "abc"));
267 assert!(!subject_matches("abc", "abcd"));
268 assert!(!subject_matches("abcd", "abc"));
269 assert!(!subject_matches("", "abc"));
270 assert!(subject_matches("", ""));
271 }
272
273 #[test]
274 fn test_subject_matching_leading_wildcard() {
275 assert!(subject_matches("anything", "*"));
276 assert!(subject_matches("", "*"));
277 assert!(subject_matches("foo", "*foo"));
278 assert!(subject_matches("xfoo", "*foo"));
279 assert!(!subject_matches("foox", "*foo"));
280 }
281
282 #[test]
283 fn test_subject_matching_trailing_wildcard() {
284 assert!(subject_matches("foo", "foo*"));
285 assert!(subject_matches("foobar", "foo*"));
286 assert!(!subject_matches("xfoo", "foo*"));
287 }
288
289 #[test]
290 fn test_subject_matching_middle_wildcard() {
291 assert!(subject_matches("foobar", "foo*bar"));
292 assert!(subject_matches("fooXbar", "foo*bar"));
293 assert!(subject_matches("fooXYZbar", "foo*bar"));
294 assert!(!subject_matches("fooXbaz", "foo*bar"));
295 assert!(!subject_matches("xfoobar", "foo*bar"));
296 }
297
298 #[test]
299 fn test_subject_matching_multiple_wildcards() {
300 assert!(subject_matches("axbb", "a*b*b"));
302 assert!(!subject_matches("axb", "a*b*b"));
303
304 assert!(!subject_matches("abc", "a*bc*c"));
306 assert!(subject_matches("abcc", "a*bc*c"));
307
308 assert!(subject_matches("aab", "*a*ab"));
310 assert!(!subject_matches("xab", "*a*ab"));
311
312 assert!(subject_matches("xababab", "*ab*ab"));
314 assert!(!subject_matches("xab", "*ab*ab"));
315 }
316
317 #[test]
318 fn test_subject_matching_double_wildcard() {
319 assert!(subject_matches("anything", "**"));
320 assert!(subject_matches("", "**"));
321 }
322
323 #[test]
324 fn test_subject_matching_empty_subject() {
325 assert!(subject_matches("", "*"));
326 assert!(!subject_matches("", "a"));
327 assert!(subject_matches("", ""));
328 }
329}