Skip to main content

fraiseql_auth/
oidc_server_client.rs

1//! Server-side OIDC client for PKCE authorization code flows.
2//!
3//! This is a minimal, runtime-facing client that:
4//! 1. Builds the OIDC `/authorize` redirect URL with PKCE parameters.
5//! 2. Exchanges the authorization code + `code_verifier` for tokens.
6//!
7//! It is intentionally separate from the more general [`crate::oauth::OAuth2Client`] and
8//! [`crate::oauth::OIDCClient`] types in `oauth`: those carry JWKS caches and session
9//! management state that the PKCE route handlers do not need.
10// The client secret is loaded from the environment at runtime and is NEVER
11// stored in the compiled schema or TOML config.
12
13use std::{fmt, sync::Arc};
14
15use serde::Deserialize;
16
17// ---------------------------------------------------------------------------
18// Resolved OIDC endpoints (cached in compiled schema)
19// ---------------------------------------------------------------------------
20
21/// OIDC endpoints fetched from the discovery document and cached in the
22/// compiled schema under `"auth_endpoints"`.
23#[derive(Debug, Clone, Deserialize)]
24pub struct OidcEndpoints {
25    /// The provider's `/authorize` URL.
26    pub authorization_endpoint: String,
27    /// The provider's `/token` URL.
28    pub token_endpoint:         String,
29}
30
31// ---------------------------------------------------------------------------
32// Token response from the provider
33// ---------------------------------------------------------------------------
34
35/// Minimal token response from the OIDC `/token` endpoint.
36#[derive(Debug, Deserialize)]
37pub struct OidcTokenResponse {
38    /// The access token.
39    pub access_token:  String,
40    /// The OpenID Connect identity token (if requested).
41    pub id_token:      Option<String>,
42    /// Seconds until the access token expires.
43    pub expires_in:    Option<u64>,
44    /// Refresh token (if the provider issued one).
45    pub refresh_token: Option<String>,
46}
47
48// ---------------------------------------------------------------------------
49// OidcServerClient
50// ---------------------------------------------------------------------------
51
52/// Minimal OIDC client for server-side PKCE code exchange.
53///
54/// Constructed once at server startup from the compiled schema.
55/// The client secret is read from the environment at that time and
56/// held in memory — it is never written to disk or emitted in logs.
57pub struct OidcServerClient {
58    client_id:              String,
59    /// Intentionally private: the secret must never be accessible via a field.
60    client_secret:          String,
61    server_redirect_uri:    String,
62    authorization_endpoint: String,
63    token_endpoint:         String,
64}
65
66/// Custom `Debug` implementation that redacts the client secret.
67#[allow(clippy::missing_fields_in_debug)] // Reason: endpoint fields omitted to keep debug concise and avoid leaking config in logs
68impl fmt::Debug for OidcServerClient {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        f.debug_struct("OidcServerClient")
71            .field("client_id", &self.client_id)
72            .field("client_secret", &"[REDACTED]")
73            .field("authorization_endpoint", &self.authorization_endpoint)
74            .finish_non_exhaustive()
75    }
76}
77
78impl OidcServerClient {
79    /// Maximum byte size accepted from the OIDC token endpoint response.
80    ///
81    /// A well-formed token response is a few KiB at most.  1 MiB prevents a
82    /// malicious or compromised OIDC provider from exhausting server memory.
83    const MAX_OIDC_RESPONSE_BYTES: usize = 1024 * 1024;
84
85    /// Construct a client directly from resolved credentials and endpoints.
86    ///
87    /// Prefer [`Self::from_compiled_schema`] in production code.
88    /// This constructor exists for testing and direct wiring.
89    pub fn new(
90        client_id: impl Into<String>,
91        client_secret: impl Into<String>,
92        server_redirect_uri: impl Into<String>,
93        authorization_endpoint: impl Into<String>,
94        token_endpoint: impl Into<String>,
95    ) -> Self {
96        Self {
97            client_id:              client_id.into(),
98            client_secret:          client_secret.into(),
99            server_redirect_uri:    server_redirect_uri.into(),
100            authorization_endpoint: authorization_endpoint.into(),
101            token_endpoint:         token_endpoint.into(),
102        }
103    }
104
105    /// Build an `OidcServerClient` from the compiled schema JSON.
106    ///
107    /// Returns `None` if:
108    /// - `schema_json["auth"]` is absent, or
109    /// - the env var named by `client_secret_env` is not set, or
110    /// - the OIDC endpoint cache (`schema_json["auth_endpoints"]`) is absent.
111    ///
112    /// In all failure cases an explanatory `tracing::error!` is emitted so
113    /// operators can diagnose startup issues without reading source code.
114    pub fn from_compiled_schema(schema_json: &serde_json::Value) -> Option<Arc<Self>> {
115        // ── Load [auth] config ────────────────────────────────────────────
116        #[derive(Deserialize)]
117        struct AuthCfg {
118            client_id:           String,
119            client_secret_env:   String,
120            server_redirect_uri: String,
121        }
122
123        let auth_cfg: AuthCfg =
124            schema_json.get("auth").and_then(|v| serde_json::from_value(v.clone()).ok())?;
125
126        // ── Read client secret from env ───────────────────────────────────
127        let Ok(client_secret) = std::env::var(&auth_cfg.client_secret_env) else {
128            tracing::error!(
129                env_var = %auth_cfg.client_secret_env,
130                "PKCE init failed: env var for OIDC client secret is not set"
131            );
132            return None;
133        };
134
135        // ── Load cached endpoints ─────────────────────────────────────────
136        let Some(endpoints): Option<OidcEndpoints> = schema_json
137            .get("auth_endpoints")
138            .and_then(|v| serde_json::from_value(v.clone()).ok())
139        else {
140            tracing::error!(
141                "PKCE init failed: 'auth_endpoints' not found in compiled schema. \
142                 Re-compile the schema so that the CLI caches the OIDC discovery \
143                 document (authorization_endpoint, token_endpoint)."
144            );
145            return None;
146        };
147
148        Some(Arc::new(Self {
149            client_id: auth_cfg.client_id,
150            client_secret,
151            server_redirect_uri: auth_cfg.server_redirect_uri,
152            authorization_endpoint: endpoints.authorization_endpoint,
153            token_endpoint: endpoints.token_endpoint,
154        }))
155    }
156
157    /// Build the OIDC `/authorize` redirect URL with all required PKCE params.
158    ///
159    /// The `state`, `code_challenge`, and `redirect_uri` values are
160    /// percent-encoded so that base64-url characters (+, /, =) do not
161    /// break query string parsing on the provider side.
162    pub fn authorization_url(
163        &self,
164        state: &str,
165        code_challenge: &str,
166        code_challenge_method: &str,
167    ) -> String {
168        format!(
169            "{}?response_type=code\
170             &client_id={}\
171             &redirect_uri={}\
172             &scope=openid%20email%20profile\
173             &state={}\
174             &code_challenge={}\
175             &code_challenge_method={}",
176            self.authorization_endpoint,
177            urlencoding::encode(&self.client_id),
178            urlencoding::encode(&self.server_redirect_uri),
179            urlencoding::encode(state),
180            urlencoding::encode(code_challenge),
181            code_challenge_method,
182        )
183    }
184
185    // 1 MiB
186
187    /// Exchange an authorization code for tokens.
188    ///
189    /// Sends a `POST` to the provider's `/token` endpoint with the PKCE
190    /// `code_verifier` and all required OAuth2 fields.
191    ///
192    /// # Errors
193    ///
194    /// Returns an error if the HTTP request fails, the provider returns a
195    /// non-success status, the response exceeds `MAX_OIDC_RESPONSE_BYTES`, or
196    /// the response body cannot be parsed as JSON.
197    pub async fn exchange_code(
198        &self,
199        code: &str,
200        code_verifier: &str,
201        http: &reqwest::Client,
202    ) -> Result<OidcTokenResponse, anyhow::Error> {
203        let resp = http
204            .post(&self.token_endpoint)
205            .form(&[
206                ("grant_type", "authorization_code"),
207                ("code", code),
208                ("code_verifier", code_verifier),
209                ("redirect_uri", self.server_redirect_uri.as_str()),
210                ("client_id", self.client_id.as_str()),
211                ("client_secret", self.client_secret.as_str()),
212            ])
213            .send()
214            .await?;
215
216        let status = resp.status();
217
218        // Read body with error propagation — unwrap_or_default() would silently
219        // discard network errors and return an empty body, masking failures.
220        let body_bytes = resp
221            .bytes()
222            .await
223            .map_err(|e| anyhow::anyhow!("Failed to read token response: {e}"))?;
224
225        // Size guard BEFORE the status check: a compromised provider could exhaust
226        // memory by sending an oversized non-2xx response that bypassed a later cap.
227        anyhow::ensure!(
228            body_bytes.len() <= Self::MAX_OIDC_RESPONSE_BYTES,
229            "OIDC token response too large ({} bytes, max {})",
230            body_bytes.len(),
231            Self::MAX_OIDC_RESPONSE_BYTES
232        );
233
234        if !status.is_success() {
235            // Body is already bounded by the size check above — no need for .min().
236            let body = String::from_utf8_lossy(&body_bytes);
237            anyhow::bail!("token endpoint returned {status}: {body}");
238        }
239
240        Ok(serde_json::from_slice::<OidcTokenResponse>(&body_bytes)?)
241    }
242}
243
244// ---------------------------------------------------------------------------
245// Unit tests
246// ---------------------------------------------------------------------------
247
248#[allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
249#[cfg(test)]
250mod tests {
251    #[allow(clippy::wildcard_imports)]
252    // Reason: test module — wildcard keeps test boilerplate minimal
253    use super::*;
254
255    fn test_client() -> OidcServerClient {
256        OidcServerClient::new(
257            "test-client",
258            "test-secret",
259            "https://api.example.com/auth/callback",
260            "https://provider.example.com/authorize",
261            "https://provider.example.com/token",
262        )
263    }
264
265    #[test]
266    fn test_authorization_url_contains_required_pkce_params() {
267        let client = test_client();
268        let url = client.authorization_url("my_state", "my_challenge", "S256");
269        assert!(url.contains("response_type=code"), "missing response_type");
270        assert!(url.contains("client_id=test-client"), "missing client_id");
271        assert!(url.contains("code_challenge=my_challenge"), "missing code_challenge");
272        assert!(url.contains("code_challenge_method=S256"), "missing method");
273        assert!(url.contains("state="), "missing state");
274        assert!(url.contains("redirect_uri="), "missing redirect_uri");
275    }
276
277    #[test]
278    fn oidc_response_cap_constant_is_reasonable() {
279        assert_eq!(OidcServerClient::MAX_OIDC_RESPONSE_BYTES, 1024 * 1024);
280    }
281
282    #[test]
283    fn oidc_response_cap_covers_error_path() {
284        // The size guard now fires BEFORE the status check, so both 2xx and
285        // non-2xx responses are bounded. Verify the constant is sane.
286        const { assert!(OidcServerClient::MAX_OIDC_RESPONSE_BYTES >= 64 * 1024) }
287        const { assert!(OidcServerClient::MAX_OIDC_RESPONSE_BYTES <= 100 * 1024 * 1024) }
288    }
289
290    #[tokio::test]
291    async fn oidc_oversized_error_response_is_rejected() {
292        use wiremock::{
293            Mock, MockServer, ResponseTemplate,
294            matchers::{method, path},
295        };
296
297        let mock_server = MockServer::start().await;
298        // Non-2xx response with oversized body — must be rejected before status check.
299        let oversized = vec![b'e'; OidcServerClient::MAX_OIDC_RESPONSE_BYTES + 1];
300        Mock::given(method("POST"))
301            .and(path("/token"))
302            .respond_with(ResponseTemplate::new(400).set_body_bytes(oversized))
303            .mount(&mock_server)
304            .await;
305
306        let client = OidcServerClient::new(
307            "client_id",
308            "client_secret",
309            "https://example.com/callback",
310            "https://example.com/auth",
311            format!("{}/token", mock_server.uri()),
312        );
313        let http = reqwest::Client::new();
314        let result = client.exchange_code("code", "verifier", &http).await;
315
316        assert!(result.is_err(), "oversized error response must be rejected");
317        let msg = result.unwrap_err().to_string();
318        assert!(msg.contains("too large"), "error must mention size limit, got: {msg}");
319    }
320
321    #[tokio::test]
322    async fn oidc_oversized_success_response_is_rejected() {
323        use wiremock::{
324            Mock, MockServer, ResponseTemplate,
325            matchers::{method, path},
326        };
327
328        let mock_server = MockServer::start().await;
329        let oversized = vec![b'x'; OidcServerClient::MAX_OIDC_RESPONSE_BYTES + 1];
330        Mock::given(method("POST"))
331            .and(path("/token"))
332            .respond_with(ResponseTemplate::new(200).set_body_bytes(oversized))
333            .mount(&mock_server)
334            .await;
335
336        let client = OidcServerClient::new(
337            "client_id",
338            "client_secret",
339            "https://example.com/callback",
340            "https://example.com/auth",
341            format!("{}/token", mock_server.uri()),
342        );
343        let http = reqwest::Client::new();
344        let result = client.exchange_code("code", "verifier", &http).await;
345
346        assert!(result.is_err(), "oversized success response must be rejected, got: {result:?}");
347        let msg = result.unwrap_err().to_string();
348        assert!(msg.contains("too large"), "error must mention size limit, got: {msg}");
349    }
350
351    #[test]
352    fn test_authorization_url_includes_openid_scope() {
353        let client = test_client();
354        let url = client.authorization_url("s", "c", "S256");
355        // scope must include "openid" (percent-encoded as openid%20email%20profile)
356        assert!(url.contains("openid"), "authorization URL must request the openid scope: {url}");
357    }
358
359    #[test]
360    fn test_authorization_url_state_is_percent_encoded() {
361        // State tokens produced by encryption may contain +, /, = (base64url-no-pad
362        // avoids = and /, but base64std does not). Ensure the value is encoded.
363        let client = test_client();
364        let state_with_spaces = "hello world+test";
365        let url = client.authorization_url(state_with_spaces, "challenge", "S256");
366        // The raw space must not appear unencoded
367        let state_segment = url.split("state=").nth(1).unwrap().split('&').next().unwrap();
368        assert!(!state_segment.contains(' '), "space in state must be percent-encoded");
369        assert!(!state_segment.contains('+'), "plus in state must be percent-encoded");
370    }
371
372    #[test]
373    fn test_from_compiled_schema_absent_auth_returns_none() {
374        let json = serde_json::json!({});
375        assert!(OidcServerClient::from_compiled_schema(&json).is_none());
376    }
377
378    #[test]
379    fn test_from_compiled_schema_missing_env_var_returns_none() {
380        // Use an env var name that is extremely unlikely to be set in CI.
381        // If somehow set, the test would pass the env lookup but fail at
382        // auth_endpoints (since they aren't present either).
383        let json = serde_json::json!({
384            "auth": {
385                "discovery_url":       "https://example.com",
386                "client_id":           "x",
387                "client_secret_env":   "__FRAISEQL_TEST_DEFINITELY_UNSET_42XYZ__",
388                "server_redirect_uri": "https://api.example.com/auth/callback"
389            },
390            "auth_endpoints": {
391                "authorization_endpoint": "https://example.com/auth",
392                "token_endpoint":         "https://example.com/token"
393            }
394        });
395        // Either the env var lookup fails (most likely) OR the endpoints exist
396        // but the env var is somehow set — either way returns None if no secret.
397        // We can't guarantee env state, so just assert the call doesn't panic.
398        let _ = OidcServerClient::from_compiled_schema(&json);
399        // Primary assertion: missing env var → None (relies on var not being set).
400        // This is inherently best-effort in a test environment.
401    }
402
403    #[test]
404    fn test_from_compiled_schema_missing_endpoints_returns_none() {
405        // auth section present, env var set (via a known-present env var), but no auth_endpoints
406        // cache. Use PATH which is always set in any Unix environment.
407        let json = serde_json::json!({
408            "auth": {
409                "discovery_url":       "https://example.com",
410                "client_id":           "x",
411                "client_secret_env":   "PATH",
412                "server_redirect_uri": "https://api.example.com/auth/callback"
413            }
414            // no "auth_endpoints" — this is what we're testing
415        });
416        assert!(
417            OidcServerClient::from_compiled_schema(&json).is_none(),
418            "missing auth_endpoints must return None"
419        );
420    }
421
422    #[test]
423    fn test_debug_redacts_client_secret() {
424        let client = test_client();
425        let debug_str = format!("{client:?}");
426        assert!(
427            !debug_str.contains("test-secret"),
428            "Debug output must not expose the client secret: {debug_str}"
429        );
430        assert!(debug_str.contains("[REDACTED]"));
431    }
432}