1pub mod jwks;
30pub mod request;
31pub mod responses;
32pub mod route_handler;
33pub mod sealed_token;
34pub mod sts;
35
36pub use jwks::JwksCache;
37use multistore::error::ProxyError;
38use multistore::registry::CredentialRegistry;
39use multistore::types::TemporaryCredentials;
40pub use request::try_parse_sts_request;
41use request::StsRequest;
42pub use responses::{build_sts_error_response, build_sts_response};
43pub use sealed_token::TokenKey;
44
45pub async fn try_handle_sts<C: CredentialRegistry>(
52 query: Option<&str>,
53 config: &C,
54 jwks_cache: &JwksCache,
55 token_key: Option<&TokenKey>,
56) -> Option<(u16, String)> {
57 let sts_result = try_parse_sts_request(query)?;
58 let (status, xml) = match sts_result {
59 Ok(sts_request) => {
60 let Some(key) = token_key else {
61 tracing::error!("STS request received but SESSION_TOKEN_KEY is not configured");
62 return Some(build_sts_error_response(&ProxyError::ConfigError(
63 "STS requires SESSION_TOKEN_KEY to be configured".into(),
64 )));
65 };
66 match assume_role_with_web_identity(config, &sts_request, "STSPRXY", jwks_cache, key)
67 .await
68 {
69 Ok(creds) => build_sts_response(&creds),
70 Err(e) => {
71 tracing::warn!(error = %e, "STS request failed");
72 build_sts_error_response(&e)
73 }
74 }
75 }
76 Err(e) => build_sts_error_response(&e),
77 };
78 Some((status, xml))
79}
80
81fn jwt_decode_unverified(
83 token: &str,
84) -> Result<(serde_json::Value, serde_json::Value), ProxyError> {
85 let mut parts = token.splitn(3, '.');
86 let header_b64 = parts
87 .next()
88 .ok_or_else(|| ProxyError::InvalidOidcToken("malformed JWT".into()))?;
89 let payload_b64 = parts
90 .next()
91 .ok_or_else(|| ProxyError::InvalidOidcToken("malformed JWT".into()))?;
92
93 Ok((
94 jwks::decode_jwt_segment(header_b64)?,
95 jwks::decode_jwt_segment(payload_b64)?,
96 ))
97}
98
99pub async fn assume_role_with_web_identity<C: CredentialRegistry>(
104 config: &C,
105 sts_request: &StsRequest,
106 key_prefix: &str,
107 jwks_cache: &JwksCache,
108 token_key: &TokenKey,
109) -> Result<TemporaryCredentials, ProxyError> {
110 let role = config
112 .get_role(&sts_request.role_arn)
113 .await?
114 .ok_or_else(|| ProxyError::RoleNotFound(sts_request.role_arn.to_string()))?;
115
116 let (header, insecure_claims) = jwt_decode_unverified(&sts_request.web_identity_token)?;
118
119 let issuer = insecure_claims
120 .get("iss")
121 .and_then(|v| v.as_str())
122 .ok_or_else(|| ProxyError::InvalidOidcToken("missing iss claim".into()))?;
123
124 if !role.trusted_oidc_issuers.iter().any(|i| i == issuer) {
126 return Err(ProxyError::InvalidOidcToken(format!(
127 "untrusted issuer: {}",
128 issuer
129 )));
130 }
131
132 let alg = header.get("alg").and_then(|v| v.as_str()).unwrap_or("");
134 if alg != "RS256" {
135 return Err(ProxyError::InvalidOidcToken(format!(
136 "unsupported JWT algorithm: {}",
137 alg
138 )));
139 }
140
141 let jwks = jwks_cache.get_or_fetch(issuer).await?;
143 let kid = header
144 .get("kid")
145 .and_then(|v| v.as_str())
146 .ok_or_else(|| ProxyError::InvalidOidcToken("JWT missing kid".into()))?;
147
148 let key = jwks::find_key(&jwks, kid)?;
149 let claims = jwks::verify_token(&sts_request.web_identity_token, key, issuer, &role)?;
150
151 let subject = claims.get("sub").and_then(|v| v.as_str()).unwrap_or("");
153
154 if !role.subject_conditions.is_empty() {
155 let matches = role
156 .subject_conditions
157 .iter()
158 .any(|pattern| subject_matches(subject, pattern));
159 if !matches {
160 return Err(ProxyError::InvalidOidcToken(format!(
161 "subject '{}' does not match any conditions",
162 subject
163 )));
164 }
165 }
166
167 const MIN_SESSION_DURATION_SECS: u64 = 900;
169 let duration = sts_request
170 .duration_seconds
171 .unwrap_or(3600)
172 .clamp(MIN_SESSION_DURATION_SECS, role.max_session_duration_secs);
173
174 let mut creds = sts::mint_temporary_credentials(&role, subject, duration, key_prefix, &claims);
175
176 creds.session_token = token_key.seal(&creds)?;
178
179 Ok(creds)
180}
181
182fn subject_matches(subject: &str, pattern: &str) -> bool {
185 if pattern == "*" {
186 return true;
187 }
188
189 let parts: Vec<&str> = pattern.split('*').collect();
190 if parts.len() == 1 {
191 return subject == pattern;
192 }
193
194 let mut remaining = subject;
195
196 if !parts[0].is_empty() {
198 if !remaining.starts_with(parts[0]) {
199 return false;
200 }
201 remaining = &remaining[parts[0].len()..];
202 }
203
204 for part in &parts[1..parts.len() - 1] {
206 if part.is_empty() {
207 continue;
208 }
209 match remaining.find(part) {
210 Some(idx) => remaining = &remaining[idx + part.len()..],
211 None => return false,
212 }
213 }
214
215 let last = parts.last().unwrap();
217 if !last.is_empty() {
218 return remaining.ends_with(last);
219 }
220
221 true
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_subject_matching() {
230 assert!(subject_matches(
232 "repo:org/repo:ref:refs/heads/main",
233 "repo:org/repo:*"
234 ));
235
236 assert!(subject_matches("repo:org/repo:ref:refs/heads/main", "*"));
238
239 assert!(subject_matches(
241 "repo:org/repo:ref:refs/heads/main",
242 "repo:org/repo:ref:refs/heads/main"
243 ));
244
245 assert!(!subject_matches(
247 "repo:org/repo:ref:refs/heads/main",
248 "repo:other/*"
249 ));
250
251 assert!(subject_matches(
253 "repo:org/repo:ref:refs/heads/main",
254 "repo:org/*:ref:refs/heads/*"
255 ));
256 }
257
258 #[test]
259 fn test_subject_matching_exact() {
260 assert!(subject_matches("abc", "abc"));
261 assert!(!subject_matches("abc", "abcd"));
262 assert!(!subject_matches("abcd", "abc"));
263 assert!(!subject_matches("", "abc"));
264 assert!(subject_matches("", ""));
265 }
266
267 #[test]
268 fn test_subject_matching_leading_wildcard() {
269 assert!(subject_matches("anything", "*"));
270 assert!(subject_matches("", "*"));
271 assert!(subject_matches("foo", "*foo"));
272 assert!(subject_matches("xfoo", "*foo"));
273 assert!(!subject_matches("foox", "*foo"));
274 }
275
276 #[test]
277 fn test_subject_matching_trailing_wildcard() {
278 assert!(subject_matches("foo", "foo*"));
279 assert!(subject_matches("foobar", "foo*"));
280 assert!(!subject_matches("xfoo", "foo*"));
281 }
282
283 #[test]
284 fn test_subject_matching_middle_wildcard() {
285 assert!(subject_matches("foobar", "foo*bar"));
286 assert!(subject_matches("fooXbar", "foo*bar"));
287 assert!(subject_matches("fooXYZbar", "foo*bar"));
288 assert!(!subject_matches("fooXbaz", "foo*bar"));
289 assert!(!subject_matches("xfoobar", "foo*bar"));
290 }
291
292 #[test]
293 fn test_subject_matching_multiple_wildcards() {
294 assert!(subject_matches("axbb", "a*b*b"));
296 assert!(!subject_matches("axb", "a*b*b"));
297
298 assert!(!subject_matches("abc", "a*bc*c"));
300 assert!(subject_matches("abcc", "a*bc*c"));
301
302 assert!(subject_matches("aab", "*a*ab"));
304 assert!(!subject_matches("xab", "*a*ab"));
305
306 assert!(subject_matches("xababab", "*ab*ab"));
308 assert!(!subject_matches("xab", "*ab*ab"));
309 }
310
311 #[test]
312 fn test_subject_matching_double_wildcard() {
313 assert!(subject_matches("anything", "**"));
314 assert!(subject_matches("", "**"));
315 }
316
317 #[test]
318 fn test_subject_matching_empty_subject() {
319 assert!(subject_matches("", "*"));
320 assert!(!subject_matches("", "a"));
321 assert!(subject_matches("", ""));
322 }
323}