Skip to main content

multistore_sts/
lib.rs

1//! OIDC/STS authentication for the S3 proxy gateway.
2//!
3//! This crate implements the `AssumeRoleWithWebIdentity` STS API, allowing
4//! workloads like GitHub Actions to exchange OIDC tokens for temporary S3
5//! credentials scoped to specific buckets and prefixes.
6//!
7//! # Integration
8//!
9//! Register STS routes via [`route_handler::StsRouterExt`]:
10//!
11//! ```rust,ignore
12//! use multistore_sts::route_handler::StsRouterExt;
13//!
14//! let router = Router::new()
15//!     .with_sts("/.sts", config, jwks_cache, token_key);
16//! ```
17//!
18//! # Flow
19//!
20//! 1. Client obtains a JWT from their OIDC provider (e.g., GitHub Actions ID token)
21//! 2. Client calls `AssumeRoleWithWebIdentity` with the JWT and desired role
22//! 3. This crate validates the JWT against the OIDC provider's JWKS
23//! 4. Checks trust policy (issuer, audience, subject conditions)
24//! 5. Mints temporary credentials (AccessKeyId/SecretAccessKey/SessionToken)
25//! 6. Returns credentials to the client
26//!
27//! The client then uses these credentials to sign S3 requests normally.
28
29pub mod jwks;
30pub mod request;
31pub mod responses;
32pub mod route_handler;
33pub mod sealed_token;
34pub mod sts;
35
36pub use jwks::JwksCache;
37use multistore::error::ProxyError;
38use multistore::registry::CredentialRegistry;
39use multistore::types::TemporaryCredentials;
40pub use request::try_parse_sts_request;
41use request::StsRequest;
42pub use responses::{build_sts_error_response, build_sts_response};
43pub use sealed_token::TokenKey;
44
45/// Try to handle an STS request. Returns `Some((status, xml))` if the query
46/// contained an STS action, or `None` if it wasn't an STS request.
47///
48/// Requires a `TokenKey` — minted credentials are encrypted into the session
49/// token itself, so no server-side storage is needed. If `token_key` is `None`
50/// and an STS request arrives, an error response is returned.
51pub async fn try_handle_sts<C: CredentialRegistry>(
52    query: Option<&str>,
53    config: &C,
54    jwks_cache: &JwksCache,
55    token_key: Option<&TokenKey>,
56) -> Option<(u16, String)> {
57    let sts_result = try_parse_sts_request(query)?;
58    let (status, xml) = match sts_result {
59        Ok(sts_request) => {
60            let Some(key) = token_key else {
61                tracing::error!("STS request received but SESSION_TOKEN_KEY is not configured");
62                return Some(build_sts_error_response(&ProxyError::ConfigError(
63                    "STS requires SESSION_TOKEN_KEY to be configured".into(),
64                )));
65            };
66            match assume_role_with_web_identity(config, &sts_request, "STSPRXY", jwks_cache, key)
67                .await
68            {
69                Ok(creds) => build_sts_response(&creds),
70                Err(e) => {
71                    tracing::warn!(error = %e, "STS request failed");
72                    build_sts_error_response(&e)
73                }
74            }
75        }
76        Err(e) => build_sts_error_response(&e),
77    };
78    Some((status, xml))
79}
80
81/// Decode JWT header and claims without signature verification.
82fn jwt_decode_unverified(
83    token: &str,
84) -> Result<(serde_json::Value, serde_json::Value), ProxyError> {
85    let mut parts = token.splitn(3, '.');
86    let header_b64 = parts
87        .next()
88        .ok_or_else(|| ProxyError::InvalidOidcToken("malformed JWT".into()))?;
89    let payload_b64 = parts
90        .next()
91        .ok_or_else(|| ProxyError::InvalidOidcToken("malformed JWT".into()))?;
92
93    Ok((
94        jwks::decode_jwt_segment(header_b64)?,
95        jwks::decode_jwt_segment(payload_b64)?,
96    ))
97}
98
99/// Validate an OIDC token and mint temporary credentials.
100///
101/// Credentials are encrypted into a self-contained session token via `token_key`.
102/// No server-side credential storage is needed.
103pub async fn assume_role_with_web_identity<C: CredentialRegistry>(
104    config: &C,
105    sts_request: &StsRequest,
106    key_prefix: &str,
107    jwks_cache: &JwksCache,
108    token_key: &TokenKey,
109) -> Result<TemporaryCredentials, ProxyError> {
110    // Look up the role
111    let role = config
112        .get_role(&sts_request.role_arn)
113        .await?
114        .ok_or_else(|| ProxyError::RoleNotFound(sts_request.role_arn.to_string()))?;
115
116    // Decode the JWT header and claims without verification to extract issuer and kid
117    let (header, insecure_claims) = jwt_decode_unverified(&sts_request.web_identity_token)?;
118
119    let issuer = insecure_claims
120        .get("iss")
121        .and_then(|v| v.as_str())
122        .ok_or_else(|| ProxyError::InvalidOidcToken("missing iss claim".into()))?;
123
124    // Verify the issuer is trusted
125    if !role.trusted_oidc_issuers.iter().any(|i| i == issuer) {
126        return Err(ProxyError::InvalidOidcToken(format!(
127            "untrusted issuer: {}",
128            issuer
129        )));
130    }
131
132    // Fail fast on unsupported algorithms before making any network requests
133    let alg = header.get("alg").and_then(|v| v.as_str()).unwrap_or("");
134    if alg != "RS256" {
135        return Err(ProxyError::InvalidOidcToken(format!(
136            "unsupported JWT algorithm: {}",
137            alg
138        )));
139    }
140
141    // Fetch JWKS (using cache) and verify the token
142    let jwks = jwks_cache.get_or_fetch(issuer).await?;
143    let kid = header
144        .get("kid")
145        .and_then(|v| v.as_str())
146        .ok_or_else(|| ProxyError::InvalidOidcToken("JWT missing kid".into()))?;
147
148    let key = jwks::find_key(&jwks, kid)?;
149    let claims = jwks::verify_token(&sts_request.web_identity_token, key, issuer, &role)?;
150
151    // Check subject conditions
152    let subject = claims.get("sub").and_then(|v| v.as_str()).unwrap_or("");
153
154    if !role.subject_conditions.is_empty() {
155        let matches = role
156            .subject_conditions
157            .iter()
158            .any(|pattern| subject_matches(subject, pattern));
159        if !matches {
160            return Err(ProxyError::InvalidOidcToken(format!(
161                "subject '{}' does not match any conditions",
162                subject
163            )));
164        }
165    }
166
167    // Mint temporary credentials (AWS enforces 900s minimum)
168    const MIN_SESSION_DURATION_SECS: u64 = 900;
169    let duration = sts_request
170        .duration_seconds
171        .unwrap_or(3600)
172        .clamp(MIN_SESSION_DURATION_SECS, role.max_session_duration_secs);
173
174    let mut creds = sts::mint_temporary_credentials(&role, subject, duration, key_prefix, &claims);
175
176    // Encrypt the full credentials into the session token — stateless, no storage needed
177    creds.session_token = token_key.seal(&creds)?;
178
179    Ok(creds)
180}
181
182/// Simple glob-style matching for subject conditions.
183/// Supports `*` as a wildcard for any sequence of characters.
184fn subject_matches(subject: &str, pattern: &str) -> bool {
185    if pattern == "*" {
186        return true;
187    }
188
189    let parts: Vec<&str> = pattern.split('*').collect();
190    if parts.len() == 1 {
191        return subject == pattern;
192    }
193
194    let mut remaining = subject;
195
196    // First part must be a prefix
197    if !parts[0].is_empty() {
198        if !remaining.starts_with(parts[0]) {
199            return false;
200        }
201        remaining = &remaining[parts[0].len()..];
202    }
203
204    // Middle parts must appear in order
205    for part in &parts[1..parts.len() - 1] {
206        if part.is_empty() {
207            continue;
208        }
209        match remaining.find(part) {
210            Some(idx) => remaining = &remaining[idx + part.len()..],
211            None => return false,
212        }
213    }
214
215    // Last part must be a suffix
216    let last = parts.last().unwrap();
217    if !last.is_empty() {
218        return remaining.ends_with(last);
219    }
220
221    true
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_subject_matching() {
230        // Trailing wildcard
231        assert!(subject_matches(
232            "repo:org/repo:ref:refs/heads/main",
233            "repo:org/repo:*"
234        ));
235
236        // Match-all wildcard
237        assert!(subject_matches("repo:org/repo:ref:refs/heads/main", "*"));
238
239        // Exact match (no wildcards)
240        assert!(subject_matches(
241            "repo:org/repo:ref:refs/heads/main",
242            "repo:org/repo:ref:refs/heads/main"
243        ));
244
245        // Wrong prefix
246        assert!(!subject_matches(
247            "repo:org/repo:ref:refs/heads/main",
248            "repo:other/*"
249        ));
250
251        // Multiple wildcards
252        assert!(subject_matches(
253            "repo:org/repo:ref:refs/heads/main",
254            "repo:org/*:ref:refs/heads/*"
255        ));
256    }
257
258    #[test]
259    fn test_subject_matching_exact() {
260        assert!(subject_matches("abc", "abc"));
261        assert!(!subject_matches("abc", "abcd"));
262        assert!(!subject_matches("abcd", "abc"));
263        assert!(!subject_matches("", "abc"));
264        assert!(subject_matches("", ""));
265    }
266
267    #[test]
268    fn test_subject_matching_leading_wildcard() {
269        assert!(subject_matches("anything", "*"));
270        assert!(subject_matches("", "*"));
271        assert!(subject_matches("foo", "*foo"));
272        assert!(subject_matches("xfoo", "*foo"));
273        assert!(!subject_matches("foox", "*foo"));
274    }
275
276    #[test]
277    fn test_subject_matching_trailing_wildcard() {
278        assert!(subject_matches("foo", "foo*"));
279        assert!(subject_matches("foobar", "foo*"));
280        assert!(!subject_matches("xfoo", "foo*"));
281    }
282
283    #[test]
284    fn test_subject_matching_middle_wildcard() {
285        assert!(subject_matches("foobar", "foo*bar"));
286        assert!(subject_matches("fooXbar", "foo*bar"));
287        assert!(subject_matches("fooXYZbar", "foo*bar"));
288        assert!(!subject_matches("fooXbaz", "foo*bar"));
289        assert!(!subject_matches("xfoobar", "foo*bar"));
290    }
291
292    #[test]
293    fn test_subject_matching_multiple_wildcards() {
294        // Two wildcards with repeated literal
295        assert!(subject_matches("axbb", "a*b*b"));
296        assert!(!subject_matches("axb", "a*b*b"));
297
298        // Wildcard must not overlap with suffix
299        assert!(!subject_matches("abc", "a*bc*c"));
300        assert!(subject_matches("abcc", "a*bc*c"));
301
302        // Multiple wildcards requiring non-greedy left-to-right match
303        assert!(subject_matches("aab", "*a*ab"));
304        assert!(!subject_matches("xab", "*a*ab"));
305
306        // Repeated pattern in subject
307        assert!(subject_matches("xababab", "*ab*ab"));
308        assert!(!subject_matches("xab", "*ab*ab"));
309    }
310
311    #[test]
312    fn test_subject_matching_double_wildcard() {
313        assert!(subject_matches("anything", "**"));
314        assert!(subject_matches("", "**"));
315    }
316
317    #[test]
318    fn test_subject_matching_empty_subject() {
319        assert!(subject_matches("", "*"));
320        assert!(!subject_matches("", "a"));
321        assert!(subject_matches("", ""));
322    }
323}