multistore_sts/
route_handler.rs1use crate::{try_handle_sts, JwksCache, TokenKey};
7use multistore::registry::CredentialRegistry;
8use multistore::route_handler::{ProxyResult, RequestInfo, RouteHandler, RouteHandlerFuture};
9use multistore::router::Router;
10
11struct StsHandler<C> {
13 config: C,
14 cache: JwksCache,
15 key: Option<TokenKey>,
16}
17
18impl<C: CredentialRegistry> RouteHandler for StsHandler<C> {
19 fn handle<'a>(&'a self, req: &'a RequestInfo<'a>) -> RouteHandlerFuture<'a> {
20 Box::pin(async move {
21 let (status, xml) =
22 try_handle_sts(req.query, &self.config, &self.cache, self.key.as_ref()).await?;
23 Some(ProxyResult::xml(status, xml))
24 })
25 }
26}
27
28pub trait StsRouterExt {
30 fn with_sts<C: CredentialRegistry + 'static>(
36 self,
37 path: &str,
38 config: C,
39 cache: JwksCache,
40 key: Option<TokenKey>,
41 ) -> Self;
42}
43
44impl StsRouterExt for Router {
45 fn with_sts<C: CredentialRegistry + 'static>(
46 self,
47 path: &str,
48 config: C,
49 cache: JwksCache,
50 key: Option<TokenKey>,
51 ) -> Self {
52 self.route(path, StsHandler { config, cache, key })
53 }
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59 use multistore::error::ProxyError;
60 use multistore::types::{RoleConfig, StoredCredential};
61
62 #[derive(Clone)]
64 struct EmptyRegistry;
65
66 impl CredentialRegistry for EmptyRegistry {
67 async fn get_credential(
68 &self,
69 _access_key_id: &str,
70 ) -> Result<Option<StoredCredential>, ProxyError> {
71 Ok(None)
72 }
73 async fn get_role(&self, _role_id: &str) -> Result<Option<RoleConfig>, ProxyError> {
74 Ok(None)
75 }
76 }
77
78 fn test_router() -> Router {
79 let cache = JwksCache::new(reqwest::Client::new(), std::time::Duration::from_secs(60));
80 Router::new().with_sts("/", EmptyRegistry, cache, None)
81 }
82
83 #[tokio::test]
84 async fn sts_query_on_root_path_is_handled() {
85 let router = test_router();
86 let headers = http::HeaderMap::new();
87 let req = RequestInfo::new(
88 &http::Method::GET,
89 "/",
90 Some("Action=AssumeRoleWithWebIdentity&RoleArn=test&WebIdentityToken=tok"),
91 &headers,
92 None,
93 );
94 assert!(
95 router.dispatch(&req).await.is_some(),
96 "STS request to / must be intercepted by the router"
97 );
98 }
99
100 #[tokio::test]
101 async fn non_sts_query_on_root_path_falls_through() {
102 let router = test_router();
103 let headers = http::HeaderMap::new();
104 let req = RequestInfo::new(&http::Method::GET, "/", Some("prefix=foo/"), &headers, None);
105 assert!(
106 router.dispatch(&req).await.is_none(),
107 "non-STS request to / must fall through"
108 );
109 }
110}