Skip to main content

multistore_oidc_provider/
lib.rs

1//! OIDC provider for outbound authentication.
2//!
3//! This crate enables the proxy to act as its own OIDC identity provider:
4//!
5//! 1. **JWT signing** — mint JWTs signed with the proxy's RSA private key
6//! 2. **JWKS serving** — expose the corresponding public key as a JWK set
7//! 3. **OIDC discovery** — generate `.well-known/openid-configuration` responses
8//! 4. **Credential exchange** — trade self-signed JWTs for cloud provider
9//!    credentials (AWS STS, Azure AD, GCP STS)
10//! 5. **Route handler** — [`route_handler::OidcRouterExt`] registers
11//!    `.well-known` endpoint closures on a [`Router`](multistore::router::Router)
12//!
13//! The crate is runtime-agnostic: HTTP calls are abstracted behind an
14//! [`HttpExchange`] trait so that each runtime (reqwest, Fetch API, etc.)
15//! can provide its own implementation.
16
17pub mod backend_auth;
18pub mod cache;
19pub mod discovery;
20pub mod exchange;
21pub mod jwks;
22pub mod jwt;
23pub mod route_handler;
24
25use std::sync::Arc;
26
27use cache::CredentialCache;
28use exchange::CredentialExchange;
29use jwt::JwtSigner;
30
31/// Temporary cloud credentials obtained via token exchange.
32#[derive(Debug, Clone)]
33pub struct CloudCredentials {
34    /// AWS access key ID. Empty string for Azure/GCP (bearer-token-only providers).
35    pub access_key_id: String,
36    /// AWS secret access key. Empty string for Azure/GCP (bearer-token-only providers).
37    pub secret_access_key: String,
38    /// Session or bearer token. For Azure/GCP this is the sole credential.
39    pub session_token: String,
40    /// When these credentials expire.
41    pub expires_at: chrono::DateTime<chrono::Utc>,
42}
43
44/// HTTP client abstraction for outbound requests (STS token exchange).
45///
46/// Each runtime provides its own implementation — `reqwest` on native,
47/// `Fetch` on Cloudflare Workers.
48pub trait HttpExchange:
49    Clone + multistore::maybe_send::MaybeSend + multistore::maybe_send::MaybeSync + 'static
50{
51    /// Send a `POST` request with form-encoded body and return the response text.
52    fn post_form(
53        &self,
54        url: &str,
55        form: &[(&str, &str)],
56    ) -> impl std::future::Future<Output = Result<String, OidcProviderError>>
57           + multistore::maybe_send::MaybeSend;
58}
59
60/// Top-level provider that combines signing, exchange, and caching.
61pub struct OidcCredentialProvider<H: HttpExchange> {
62    signer: JwtSigner,
63    cache: CredentialCache,
64    http: H,
65    issuer: String,
66    audience: String,
67}
68
69impl<H: HttpExchange> OidcCredentialProvider<H> {
70    /// Create a new provider.
71    ///
72    /// * `signer`   — RSA JWT signer used to mint self-signed tokens.
73    /// * `http`     — runtime-specific HTTP client for outbound STS calls.
74    /// * `issuer`   — `iss` claim written into minted JWTs (must match OIDC discovery).
75    /// * `audience` — `aud` claim written into minted JWTs (must match the cloud provider's expected audience).
76    pub fn new(signer: JwtSigner, http: H, issuer: String, audience: String) -> Self {
77        Self {
78            signer,
79            cache: CredentialCache::new(),
80            http,
81            issuer,
82            audience,
83        }
84    }
85
86    /// Get credentials for a backend, using cached values when available.
87    ///
88    /// `exchange` describes how to trade the self-signed JWT for cloud
89    /// credentials (AWS, Azure, GCP). `cache_key` identifies the backend
90    /// for caching purposes (e.g. the role ARN).
91    pub async fn get_credentials<E: CredentialExchange<H>>(
92        &self,
93        cache_key: &str,
94        exchange: &E,
95        subject: &str,
96        extra_claims: &[(&str, &str)],
97    ) -> Result<Arc<CloudCredentials>, OidcProviderError> {
98        // Check cache first
99        if let Some(creds) = self.cache.get(cache_key) {
100            return Ok(creds);
101        }
102
103        // Mint a JWT
104        let token = self
105            .signer
106            .sign(subject, &self.issuer, &self.audience, extra_claims)?;
107
108        // Exchange it for cloud credentials
109        let creds: CloudCredentials = exchange.exchange(&self.http, &token).await?;
110        let creds = Arc::new(creds);
111
112        // Cache
113        self.cache.put(cache_key.to_string(), creds.clone());
114
115        Ok(creds)
116    }
117
118    /// Access the underlying signer (e.g. for JWKS generation).
119    pub fn signer(&self) -> &JwtSigner {
120        &self.signer
121    }
122}
123
124/// Errors produced by this crate.
125#[derive(Debug, thiserror::Error)]
126pub enum OidcProviderError {
127    #[error("RSA key error: {0}")]
128    KeyError(String),
129
130    #[error("JWT signing error: {0}")]
131    SigningError(String),
132
133    #[error("credential exchange failed: {0}")]
134    ExchangeError(String),
135
136    #[error("HTTP error: {0}")]
137    HttpError(String),
138}
139
140impl From<OidcProviderError> for multistore::error::ProxyError {
141    fn from(e: OidcProviderError) -> Self {
142        multistore::error::ProxyError::Internal(e.to_string())
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use chrono::{Duration, Utc};
150    use std::sync::atomic::{AtomicUsize, Ordering};
151
152    /// Mock HTTP client that records calls and returns a preset AWS STS response.
153    #[derive(Clone)]
154    struct MockHttp {
155        call_count: Arc<AtomicUsize>,
156    }
157
158    impl MockHttp {
159        fn new() -> Self {
160            Self {
161                call_count: Arc::new(AtomicUsize::new(0)),
162            }
163        }
164
165        fn calls(&self) -> usize {
166            self.call_count.load(Ordering::SeqCst)
167        }
168    }
169
170    impl HttpExchange for MockHttp {
171        async fn post_form(
172            &self,
173            _url: &str,
174            _form: &[(&str, &str)],
175        ) -> Result<String, OidcProviderError> {
176            self.call_count.fetch_add(1, Ordering::SeqCst);
177            let exp = (Utc::now() + Duration::hours(1)).to_rfc3339();
178            Ok(format!(
179                r#"<AssumeRoleWithWebIdentityResponse>
180                    <AssumeRoleWithWebIdentityResult>
181                        <Credentials>
182                            <AccessKeyId>AKID_MOCK</AccessKeyId>
183                            <SecretAccessKey>secret_mock</SecretAccessKey>
184                            <SessionToken>token_mock</SessionToken>
185                            <Expiration>{exp}</Expiration>
186                        </Credentials>
187                    </AssumeRoleWithWebIdentityResult>
188                </AssumeRoleWithWebIdentityResponse>"#
189            ))
190        }
191    }
192
193    fn test_signer() -> JwtSigner {
194        use rsa::pkcs8::EncodePrivateKey;
195        let mut rng = rand::rngs::OsRng;
196        let key = rsa::RsaPrivateKey::new(&mut rng, 2048).unwrap();
197        let pem = key.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF).unwrap();
198        JwtSigner::from_pem(&pem, "test-kid".into(), 300).unwrap()
199    }
200
201    #[tokio::test]
202    async fn get_credentials_returns_fresh_on_first_call() {
203        let http = MockHttp::new();
204        let provider = OidcCredentialProvider::new(
205            test_signer(),
206            http.clone(),
207            "https://issuer.example.com".into(),
208            "sts.amazonaws.com".into(),
209        );
210
211        let exchange = exchange::aws::AwsExchange::new("arn:aws:iam::123:role/Test".into());
212        let creds = provider
213            .get_credentials("role-a", &exchange, "my-sub", &[])
214            .await
215            .unwrap();
216
217        assert_eq!(creds.access_key_id, "AKID_MOCK");
218        assert_eq!(http.calls(), 1);
219    }
220
221    #[tokio::test]
222    async fn get_credentials_uses_cache_on_second_call() {
223        let http = MockHttp::new();
224        let provider = OidcCredentialProvider::new(
225            test_signer(),
226            http.clone(),
227            "https://issuer.example.com".into(),
228            "sts.amazonaws.com".into(),
229        );
230
231        let exchange = exchange::aws::AwsExchange::new("arn:aws:iam::123:role/Test".into());
232
233        // First call — hits mock HTTP
234        let creds1 = provider
235            .get_credentials("role-a", &exchange, "sub", &[])
236            .await
237            .unwrap();
238        assert_eq!(http.calls(), 1);
239
240        // Second call — should use cache, no additional HTTP call
241        let creds2 = provider
242            .get_credentials("role-a", &exchange, "sub", &[])
243            .await
244            .unwrap();
245        assert_eq!(http.calls(), 1);
246        assert_eq!(creds1.access_key_id, creds2.access_key_id);
247    }
248
249    #[tokio::test]
250    async fn different_cache_keys_make_separate_calls() {
251        let http = MockHttp::new();
252        let provider = OidcCredentialProvider::new(
253            test_signer(),
254            http.clone(),
255            "https://issuer.example.com".into(),
256            "sts.amazonaws.com".into(),
257        );
258
259        let exchange = exchange::aws::AwsExchange::new("arn:aws:iam::123:role/Test".into());
260
261        provider
262            .get_credentials("role-a", &exchange, "sub", &[])
263            .await
264            .unwrap();
265        provider
266            .get_credentials("role-b", &exchange, "sub", &[])
267            .await
268            .unwrap();
269
270        assert_eq!(http.calls(), 2);
271    }
272
273    #[test]
274    fn signed_jwt_is_verifiable_via_jwks_public_key() {
275        use base64::Engine;
276        use rsa::pkcs1v15::VerifyingKey;
277        use rsa::signature::Verifier;
278        use rsa::{BigUint, RsaPublicKey};
279
280        let signer = test_signer();
281
282        // Sign a JWT
283        let token = signer.sign("sub", "iss", "aud", &[]).unwrap();
284
285        // Generate JWKS from the same signer
286        let jwks_str = jwks::jwks_json(&[(signer.public_key(), signer.kid())]);
287        let jwks: serde_json::Value = serde_json::from_str(&jwks_str).unwrap();
288
289        // Extract public key from JWKS
290        let key = &jwks["keys"][0];
291        let b64 = &base64::engine::general_purpose::URL_SAFE_NO_PAD;
292        let n = BigUint::from_bytes_be(&b64.decode(key["n"].as_str().unwrap()).unwrap());
293        let e = BigUint::from_bytes_be(&b64.decode(key["e"].as_str().unwrap()).unwrap());
294        let reconstructed_key = RsaPublicKey::new(n, e).unwrap();
295
296        // Verify signature using the JWKS-derived key
297        let parts: Vec<&str> = token.split('.').collect();
298        let signing_input = format!("{}.{}", parts[0], parts[1]);
299        let sig_bytes = b64.decode(parts[2]).unwrap();
300        let signature = rsa::pkcs1v15::Signature::try_from(sig_bytes.as_slice()).unwrap();
301
302        let verifying_key = VerifyingKey::<sha2::Sha256>::new(reconstructed_key);
303        verifying_key
304            .verify(signing_input.as_bytes(), &signature)
305            .expect("JWT signed by JwtSigner should be verifiable via JWKS public key");
306    }
307
308    #[test]
309    fn error_converts_to_proxy_error() {
310        let err = OidcProviderError::ExchangeError("test".into());
311        let proxy_err: multistore::error::ProxyError = err.into();
312        assert!(proxy_err.to_string().contains("test"));
313        assert_eq!(proxy_err.status_code(), 500);
314    }
315}