multistore_oidc_provider/
lib.rs1pub mod backend_auth;
18pub mod cache;
19pub mod discovery;
20pub mod exchange;
21pub mod jwks;
22pub mod jwt;
23pub mod route_handler;
24
25use std::sync::Arc;
26
27use cache::CredentialCache;
28use exchange::CredentialExchange;
29use jwt::JwtSigner;
30
31#[derive(Debug, Clone)]
33pub struct CloudCredentials {
34 pub access_key_id: String,
36 pub secret_access_key: String,
38 pub session_token: String,
40 pub expires_at: chrono::DateTime<chrono::Utc>,
42}
43
44pub trait HttpExchange:
49 Clone + multistore::maybe_send::MaybeSend + multistore::maybe_send::MaybeSync + 'static
50{
51 fn post_form(
53 &self,
54 url: &str,
55 form: &[(&str, &str)],
56 ) -> impl std::future::Future<Output = Result<String, OidcProviderError>>
57 + multistore::maybe_send::MaybeSend;
58}
59
60pub struct OidcCredentialProvider<H: HttpExchange> {
62 signer: JwtSigner,
63 cache: CredentialCache,
64 http: H,
65 issuer: String,
66 audience: String,
67}
68
69impl<H: HttpExchange> OidcCredentialProvider<H> {
70 pub fn new(signer: JwtSigner, http: H, issuer: String, audience: String) -> Self {
77 Self {
78 signer,
79 cache: CredentialCache::new(),
80 http,
81 issuer,
82 audience,
83 }
84 }
85
86 pub async fn get_credentials<E: CredentialExchange<H>>(
92 &self,
93 cache_key: &str,
94 exchange: &E,
95 subject: &str,
96 extra_claims: &[(&str, &str)],
97 ) -> Result<Arc<CloudCredentials>, OidcProviderError> {
98 if let Some(creds) = self.cache.get(cache_key) {
100 return Ok(creds);
101 }
102
103 let token = self
105 .signer
106 .sign(subject, &self.issuer, &self.audience, extra_claims)?;
107
108 let creds: CloudCredentials = exchange.exchange(&self.http, &token).await?;
110 let creds = Arc::new(creds);
111
112 self.cache.put(cache_key.to_string(), creds.clone());
114
115 Ok(creds)
116 }
117
118 pub fn signer(&self) -> &JwtSigner {
120 &self.signer
121 }
122}
123
124#[derive(Debug, thiserror::Error)]
126pub enum OidcProviderError {
127 #[error("RSA key error: {0}")]
128 KeyError(String),
129
130 #[error("JWT signing error: {0}")]
131 SigningError(String),
132
133 #[error("credential exchange failed: {0}")]
134 ExchangeError(String),
135
136 #[error("HTTP error: {0}")]
137 HttpError(String),
138}
139
140impl From<OidcProviderError> for multistore::error::ProxyError {
141 fn from(e: OidcProviderError) -> Self {
142 multistore::error::ProxyError::Internal(e.to_string())
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use chrono::{Duration, Utc};
150 use std::sync::atomic::{AtomicUsize, Ordering};
151
152 #[derive(Clone)]
154 struct MockHttp {
155 call_count: Arc<AtomicUsize>,
156 }
157
158 impl MockHttp {
159 fn new() -> Self {
160 Self {
161 call_count: Arc::new(AtomicUsize::new(0)),
162 }
163 }
164
165 fn calls(&self) -> usize {
166 self.call_count.load(Ordering::SeqCst)
167 }
168 }
169
170 impl HttpExchange for MockHttp {
171 async fn post_form(
172 &self,
173 _url: &str,
174 _form: &[(&str, &str)],
175 ) -> Result<String, OidcProviderError> {
176 self.call_count.fetch_add(1, Ordering::SeqCst);
177 let exp = (Utc::now() + Duration::hours(1)).to_rfc3339();
178 Ok(format!(
179 r#"<AssumeRoleWithWebIdentityResponse>
180 <AssumeRoleWithWebIdentityResult>
181 <Credentials>
182 <AccessKeyId>AKID_MOCK</AccessKeyId>
183 <SecretAccessKey>secret_mock</SecretAccessKey>
184 <SessionToken>token_mock</SessionToken>
185 <Expiration>{exp}</Expiration>
186 </Credentials>
187 </AssumeRoleWithWebIdentityResult>
188 </AssumeRoleWithWebIdentityResponse>"#
189 ))
190 }
191 }
192
193 fn test_signer() -> JwtSigner {
194 use rsa::pkcs8::EncodePrivateKey;
195 let mut rng = rand::rngs::OsRng;
196 let key = rsa::RsaPrivateKey::new(&mut rng, 2048).unwrap();
197 let pem = key.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF).unwrap();
198 JwtSigner::from_pem(&pem, "test-kid".into(), 300).unwrap()
199 }
200
201 #[tokio::test]
202 async fn get_credentials_returns_fresh_on_first_call() {
203 let http = MockHttp::new();
204 let provider = OidcCredentialProvider::new(
205 test_signer(),
206 http.clone(),
207 "https://issuer.example.com".into(),
208 "sts.amazonaws.com".into(),
209 );
210
211 let exchange = exchange::aws::AwsExchange::new("arn:aws:iam::123:role/Test".into());
212 let creds = provider
213 .get_credentials("role-a", &exchange, "my-sub", &[])
214 .await
215 .unwrap();
216
217 assert_eq!(creds.access_key_id, "AKID_MOCK");
218 assert_eq!(http.calls(), 1);
219 }
220
221 #[tokio::test]
222 async fn get_credentials_uses_cache_on_second_call() {
223 let http = MockHttp::new();
224 let provider = OidcCredentialProvider::new(
225 test_signer(),
226 http.clone(),
227 "https://issuer.example.com".into(),
228 "sts.amazonaws.com".into(),
229 );
230
231 let exchange = exchange::aws::AwsExchange::new("arn:aws:iam::123:role/Test".into());
232
233 let creds1 = provider
235 .get_credentials("role-a", &exchange, "sub", &[])
236 .await
237 .unwrap();
238 assert_eq!(http.calls(), 1);
239
240 let creds2 = provider
242 .get_credentials("role-a", &exchange, "sub", &[])
243 .await
244 .unwrap();
245 assert_eq!(http.calls(), 1);
246 assert_eq!(creds1.access_key_id, creds2.access_key_id);
247 }
248
249 #[tokio::test]
250 async fn different_cache_keys_make_separate_calls() {
251 let http = MockHttp::new();
252 let provider = OidcCredentialProvider::new(
253 test_signer(),
254 http.clone(),
255 "https://issuer.example.com".into(),
256 "sts.amazonaws.com".into(),
257 );
258
259 let exchange = exchange::aws::AwsExchange::new("arn:aws:iam::123:role/Test".into());
260
261 provider
262 .get_credentials("role-a", &exchange, "sub", &[])
263 .await
264 .unwrap();
265 provider
266 .get_credentials("role-b", &exchange, "sub", &[])
267 .await
268 .unwrap();
269
270 assert_eq!(http.calls(), 2);
271 }
272
273 #[test]
274 fn signed_jwt_is_verifiable_via_jwks_public_key() {
275 use base64::Engine;
276 use rsa::pkcs1v15::VerifyingKey;
277 use rsa::signature::Verifier;
278 use rsa::{BigUint, RsaPublicKey};
279
280 let signer = test_signer();
281
282 let token = signer.sign("sub", "iss", "aud", &[]).unwrap();
284
285 let jwks_str = jwks::jwks_json(&[(signer.public_key(), signer.kid())]);
287 let jwks: serde_json::Value = serde_json::from_str(&jwks_str).unwrap();
288
289 let key = &jwks["keys"][0];
291 let b64 = &base64::engine::general_purpose::URL_SAFE_NO_PAD;
292 let n = BigUint::from_bytes_be(&b64.decode(key["n"].as_str().unwrap()).unwrap());
293 let e = BigUint::from_bytes_be(&b64.decode(key["e"].as_str().unwrap()).unwrap());
294 let reconstructed_key = RsaPublicKey::new(n, e).unwrap();
295
296 let parts: Vec<&str> = token.split('.').collect();
298 let signing_input = format!("{}.{}", parts[0], parts[1]);
299 let sig_bytes = b64.decode(parts[2]).unwrap();
300 let signature = rsa::pkcs1v15::Signature::try_from(sig_bytes.as_slice()).unwrap();
301
302 let verifying_key = VerifyingKey::<sha2::Sha256>::new(reconstructed_key);
303 verifying_key
304 .verify(signing_input.as_bytes(), &signature)
305 .expect("JWT signed by JwtSigner should be verifiable via JWKS public key");
306 }
307
308 #[test]
309 fn error_converts_to_proxy_error() {
310 let err = OidcProviderError::ExchangeError("test".into());
311 let proxy_err: multistore::error::ProxyError = err.into();
312 assert!(proxy_err.to_string().contains("test"));
313 assert_eq!(proxy_err.status_code(), 500);
314 }
315}