Skip to main content

multistore_sts/
lib.rs

1//! OIDC/STS authentication for the S3 proxy gateway.
2//!
3//! This crate implements the `AssumeRoleWithWebIdentity` STS API, allowing
4//! workloads like GitHub Actions to exchange OIDC tokens for temporary S3
5//! credentials scoped to specific buckets and prefixes.
6//!
7//! # Integration
8//!
9//! Register STS routes via [`route_handler::StsRouterExt`]:
10//!
11//! ```rust,ignore
12//! use multistore_sts::route_handler::StsRouterExt;
13//!
14//! let router = Router::new()
15//!     .with_sts(config, jwks_cache, token_key);
16//! ```
17//!
18//! # Flow
19//!
20//! 1. Client obtains a JWT from their OIDC provider (e.g., GitHub Actions ID token)
21//! 2. Client calls `AssumeRoleWithWebIdentity` with the JWT and desired role
22//! 3. This crate validates the JWT against the OIDC provider's JWKS
23//! 4. Checks trust policy (issuer, audience, subject conditions)
24//! 5. Mints temporary credentials (AccessKeyId/SecretAccessKey/SessionToken)
25//! 6. Returns credentials to the client
26//!
27//! The client then uses these credentials to sign S3 requests normally.
28
29pub mod jwks;
30pub mod request;
31pub mod responses;
32pub mod route_handler;
33pub mod sealed_token;
34pub mod sts;
35
36use base64::Engine;
37pub use jwks::JwksCache;
38use multistore::error::ProxyError;
39use multistore::registry::CredentialRegistry;
40use multistore::types::TemporaryCredentials;
41pub use request::try_parse_sts_request;
42use request::StsRequest;
43pub use responses::{build_sts_error_response, build_sts_response};
44pub use sealed_token::TokenKey;
45
46/// Try to handle an STS request. Returns `Some((status, xml))` if the query
47/// contained an STS action, or `None` if it wasn't an STS request.
48///
49/// Requires a `TokenKey` — minted credentials are encrypted into the session
50/// token itself, so no server-side storage is needed. If `token_key` is `None`
51/// and an STS request arrives, an error response is returned.
52pub async fn try_handle_sts<C: CredentialRegistry>(
53    query: Option<&str>,
54    config: &C,
55    jwks_cache: &JwksCache,
56    token_key: Option<&TokenKey>,
57) -> Option<(u16, String)> {
58    let sts_result = try_parse_sts_request(query)?;
59    let (status, xml) = match sts_result {
60        Ok(sts_request) => {
61            let Some(key) = token_key else {
62                tracing::error!("STS request received but SESSION_TOKEN_KEY is not configured");
63                return Some(build_sts_error_response(&ProxyError::ConfigError(
64                    "STS requires SESSION_TOKEN_KEY to be configured".into(),
65                )));
66            };
67            match assume_role_with_web_identity(config, &sts_request, "STSPRXY", jwks_cache, key)
68                .await
69            {
70                Ok(creds) => build_sts_response(&creds),
71                Err(e) => {
72                    tracing::warn!(error = %e, "STS request failed");
73                    build_sts_error_response(&e)
74                }
75            }
76        }
77        Err(e) => build_sts_error_response(&e),
78    };
79    Some((status, xml))
80}
81
82/// Decode JWT header and claims without signature verification.
83fn jwt_decode_unverified(
84    token: &str,
85) -> Result<(serde_json::Value, serde_json::Value), ProxyError> {
86    let mut parts = token.splitn(3, '.');
87    let header_b64 = parts
88        .next()
89        .ok_or_else(|| ProxyError::InvalidOidcToken("malformed JWT".into()))?;
90    let payload_b64 = parts
91        .next()
92        .ok_or_else(|| ProxyError::InvalidOidcToken("malformed JWT".into()))?;
93
94    let decode = |s: &str| -> Result<serde_json::Value, ProxyError> {
95        let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
96            .decode(s)
97            .map_err(|e| ProxyError::InvalidOidcToken(format!("base64url decode error: {}", e)))?;
98        serde_json::from_slice(&bytes)
99            .map_err(|e| ProxyError::InvalidOidcToken(format!("invalid JWT JSON: {}", e)))
100    };
101
102    Ok((decode(header_b64)?, decode(payload_b64)?))
103}
104
105/// Validate an OIDC token and mint temporary credentials.
106///
107/// Credentials are encrypted into a self-contained session token via `token_key`.
108/// No server-side credential storage is needed.
109pub async fn assume_role_with_web_identity<C: CredentialRegistry>(
110    config: &C,
111    sts_request: &StsRequest,
112    key_prefix: &str,
113    jwks_cache: &JwksCache,
114    token_key: &TokenKey,
115) -> Result<TemporaryCredentials, ProxyError> {
116    // Look up the role
117    let role = config
118        .get_role(&sts_request.role_arn)
119        .await?
120        .ok_or_else(|| ProxyError::RoleNotFound(sts_request.role_arn.to_string()))?;
121
122    // Decode the JWT header and claims without verification to extract issuer and kid
123    let (header, insecure_claims) = jwt_decode_unverified(&sts_request.web_identity_token)?;
124
125    let issuer = insecure_claims
126        .get("iss")
127        .and_then(|v| v.as_str())
128        .ok_or_else(|| ProxyError::InvalidOidcToken("missing iss claim".into()))?;
129
130    // Verify the issuer is trusted
131    if !role.trusted_oidc_issuers.iter().any(|i| i == issuer) {
132        return Err(ProxyError::InvalidOidcToken(format!(
133            "untrusted issuer: {}",
134            issuer
135        )));
136    }
137
138    // Fail fast on unsupported algorithms before making any network requests
139    let alg = header.get("alg").and_then(|v| v.as_str()).unwrap_or("");
140    if alg != "RS256" {
141        return Err(ProxyError::InvalidOidcToken(format!(
142            "unsupported JWT algorithm: {}",
143            alg
144        )));
145    }
146
147    // Fetch JWKS (using cache) and verify the token
148    let jwks = jwks_cache.get_or_fetch(issuer).await?;
149    let kid = header
150        .get("kid")
151        .and_then(|v| v.as_str())
152        .ok_or_else(|| ProxyError::InvalidOidcToken("JWT missing kid".into()))?;
153
154    let key = jwks::find_key(&jwks, kid)?;
155    let claims = jwks::verify_token(&sts_request.web_identity_token, key, issuer, &role)?;
156
157    // Check subject conditions
158    let subject = claims.get("sub").and_then(|v| v.as_str()).unwrap_or("");
159
160    if !role.subject_conditions.is_empty() {
161        let matches = role
162            .subject_conditions
163            .iter()
164            .any(|pattern| subject_matches(subject, pattern));
165        if !matches {
166            return Err(ProxyError::InvalidOidcToken(format!(
167                "subject '{}' does not match any conditions",
168                subject
169            )));
170        }
171    }
172
173    // Mint temporary credentials (AWS enforces 900s minimum)
174    const MIN_SESSION_DURATION_SECS: u64 = 900;
175    let duration = sts_request
176        .duration_seconds
177        .unwrap_or(3600)
178        .clamp(MIN_SESSION_DURATION_SECS, role.max_session_duration_secs);
179
180    let mut creds = sts::mint_temporary_credentials(&role, subject, duration, key_prefix, &claims);
181
182    // Encrypt the full credentials into the session token — stateless, no storage needed
183    creds.session_token = token_key.seal(&creds)?;
184
185    Ok(creds)
186}
187
188/// Simple glob-style matching for subject conditions.
189/// Supports `*` as a wildcard for any sequence of characters.
190fn subject_matches(subject: &str, pattern: &str) -> bool {
191    if pattern == "*" {
192        return true;
193    }
194
195    let parts: Vec<&str> = pattern.split('*').collect();
196    if parts.len() == 1 {
197        return subject == pattern;
198    }
199
200    let mut remaining = subject;
201
202    // First part must be a prefix
203    if !parts[0].is_empty() {
204        if !remaining.starts_with(parts[0]) {
205            return false;
206        }
207        remaining = &remaining[parts[0].len()..];
208    }
209
210    // Middle parts must appear in order
211    for part in &parts[1..parts.len() - 1] {
212        if part.is_empty() {
213            continue;
214        }
215        match remaining.find(part) {
216            Some(idx) => remaining = &remaining[idx + part.len()..],
217            None => return false,
218        }
219    }
220
221    // Last part must be a suffix
222    let last = parts.last().unwrap();
223    if !last.is_empty() {
224        return remaining.ends_with(last);
225    }
226
227    true
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    #[test]
235    fn test_subject_matching() {
236        // Trailing wildcard
237        assert!(subject_matches(
238            "repo:org/repo:ref:refs/heads/main",
239            "repo:org/repo:*"
240        ));
241
242        // Match-all wildcard
243        assert!(subject_matches("repo:org/repo:ref:refs/heads/main", "*"));
244
245        // Exact match (no wildcards)
246        assert!(subject_matches(
247            "repo:org/repo:ref:refs/heads/main",
248            "repo:org/repo:ref:refs/heads/main"
249        ));
250
251        // Wrong prefix
252        assert!(!subject_matches(
253            "repo:org/repo:ref:refs/heads/main",
254            "repo:other/*"
255        ));
256
257        // Multiple wildcards
258        assert!(subject_matches(
259            "repo:org/repo:ref:refs/heads/main",
260            "repo:org/*:ref:refs/heads/*"
261        ));
262    }
263
264    #[test]
265    fn test_subject_matching_exact() {
266        assert!(subject_matches("abc", "abc"));
267        assert!(!subject_matches("abc", "abcd"));
268        assert!(!subject_matches("abcd", "abc"));
269        assert!(!subject_matches("", "abc"));
270        assert!(subject_matches("", ""));
271    }
272
273    #[test]
274    fn test_subject_matching_leading_wildcard() {
275        assert!(subject_matches("anything", "*"));
276        assert!(subject_matches("", "*"));
277        assert!(subject_matches("foo", "*foo"));
278        assert!(subject_matches("xfoo", "*foo"));
279        assert!(!subject_matches("foox", "*foo"));
280    }
281
282    #[test]
283    fn test_subject_matching_trailing_wildcard() {
284        assert!(subject_matches("foo", "foo*"));
285        assert!(subject_matches("foobar", "foo*"));
286        assert!(!subject_matches("xfoo", "foo*"));
287    }
288
289    #[test]
290    fn test_subject_matching_middle_wildcard() {
291        assert!(subject_matches("foobar", "foo*bar"));
292        assert!(subject_matches("fooXbar", "foo*bar"));
293        assert!(subject_matches("fooXYZbar", "foo*bar"));
294        assert!(!subject_matches("fooXbaz", "foo*bar"));
295        assert!(!subject_matches("xfoobar", "foo*bar"));
296    }
297
298    #[test]
299    fn test_subject_matching_multiple_wildcards() {
300        // Two wildcards with repeated literal
301        assert!(subject_matches("axbb", "a*b*b"));
302        assert!(!subject_matches("axb", "a*b*b"));
303
304        // Wildcard must not overlap with suffix
305        assert!(!subject_matches("abc", "a*bc*c"));
306        assert!(subject_matches("abcc", "a*bc*c"));
307
308        // Multiple wildcards requiring non-greedy left-to-right match
309        assert!(subject_matches("aab", "*a*ab"));
310        assert!(!subject_matches("xab", "*a*ab"));
311
312        // Repeated pattern in subject
313        assert!(subject_matches("xababab", "*ab*ab"));
314        assert!(!subject_matches("xab", "*ab*ab"));
315    }
316
317    #[test]
318    fn test_subject_matching_double_wildcard() {
319        assert!(subject_matches("anything", "**"));
320        assert!(subject_matches("", "**"));
321    }
322
323    #[test]
324    fn test_subject_matching_empty_subject() {
325        assert!(subject_matches("", "*"));
326        assert!(!subject_matches("", "a"));
327        assert!(subject_matches("", ""));
328    }
329}