Skip to main content

multistore_sts/
route_handler.rs

1//! Route handler for STS `AssumeRoleWithWebIdentity` requests.
2//!
3//! Intercepts STS queries before they reach the proxy dispatch pipeline
4//! and delegates to [`try_handle_sts`].
5
6use crate::{try_handle_sts, JwksCache, TokenKey};
7use multistore::registry::CredentialRegistry;
8use multistore::route_handler::{ProxyResult, RequestInfo, RouteHandler, RouteHandlerFuture};
9use multistore::router::Router;
10
11/// Handler that intercepts `AssumeRoleWithWebIdentity` STS requests.
12struct StsHandler<C> {
13    config: C,
14    cache: JwksCache,
15    key: Option<TokenKey>,
16}
17
18impl<C: CredentialRegistry> RouteHandler for StsHandler<C> {
19    fn handle<'a>(&'a self, req: &'a RequestInfo<'a>) -> RouteHandlerFuture<'a> {
20        Box::pin(async move {
21            let (status, xml) =
22                try_handle_sts(req.query, &self.config, &self.cache, self.key.as_ref()).await?;
23            Some(ProxyResult::xml(status, xml))
24        })
25    }
26}
27
28/// Extension trait for registering STS routes on a [`Router`].
29pub trait StsRouterExt {
30    /// Register the STS handler on the given `path`.
31    ///
32    /// STS requests are identified by query parameters
33    /// (`Action=AssumeRoleWithWebIdentity`), not by path, so any path
34    /// can be used (e.g. `"/"` or `"/.sts"`).
35    fn with_sts<C: CredentialRegistry + 'static>(
36        self,
37        path: &str,
38        config: C,
39        cache: JwksCache,
40        key: Option<TokenKey>,
41    ) -> Self;
42}
43
44impl StsRouterExt for Router {
45    fn with_sts<C: CredentialRegistry + 'static>(
46        self,
47        path: &str,
48        config: C,
49        cache: JwksCache,
50        key: Option<TokenKey>,
51    ) -> Self {
52        self.route(path, StsHandler { config, cache, key })
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use multistore::error::ProxyError;
60    use multistore::types::{RoleConfig, StoredCredential};
61
62    /// Minimal stub that satisfies `CredentialRegistry` without real data.
63    #[derive(Clone)]
64    struct EmptyRegistry;
65
66    impl CredentialRegistry for EmptyRegistry {
67        async fn get_credential(
68            &self,
69            _access_key_id: &str,
70        ) -> Result<Option<StoredCredential>, ProxyError> {
71            Ok(None)
72        }
73        async fn get_role(&self, _role_id: &str) -> Result<Option<RoleConfig>, ProxyError> {
74            Ok(None)
75        }
76    }
77
78    fn test_router() -> Router {
79        let cache = JwksCache::new(reqwest::Client::new(), std::time::Duration::from_secs(60));
80        Router::new().with_sts("/", EmptyRegistry, cache, None)
81    }
82
83    #[tokio::test]
84    async fn sts_query_on_root_path_is_handled() {
85        let router = test_router();
86        let headers = http::HeaderMap::new();
87        let req = RequestInfo::new(
88            &http::Method::GET,
89            "/",
90            Some("Action=AssumeRoleWithWebIdentity&RoleArn=test&WebIdentityToken=tok"),
91            &headers,
92            None,
93        );
94        assert!(
95            router.dispatch(&req).await.is_some(),
96            "STS request to / must be intercepted by the router"
97        );
98    }
99
100    #[tokio::test]
101    async fn non_sts_query_on_root_path_falls_through() {
102        let router = test_router();
103        let headers = http::HeaderMap::new();
104        let req = RequestInfo::new(&http::Method::GET, "/", Some("prefix=foo/"), &headers, None);
105        assert!(
106            router.dispatch(&req).await.is_none(),
107            "non-STS request to / must fall through"
108        );
109    }
110}