1use std::{fmt, sync::Arc};
14
15use serde::Deserialize;
16
17#[derive(Debug, Clone, Deserialize)]
24pub struct OidcEndpoints {
25 pub authorization_endpoint: String,
27 pub token_endpoint: String,
29}
30
31#[derive(Debug, Deserialize)]
37pub struct OidcTokenResponse {
38 pub access_token: String,
40 pub id_token: Option<String>,
42 pub expires_in: Option<u64>,
44 pub refresh_token: Option<String>,
46}
47
48pub struct OidcServerClient {
58 client_id: String,
59 client_secret: String,
61 server_redirect_uri: String,
62 authorization_endpoint: String,
63 token_endpoint: String,
64}
65
66#[allow(clippy::missing_fields_in_debug)] impl fmt::Debug for OidcServerClient {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 f.debug_struct("OidcServerClient")
71 .field("client_id", &self.client_id)
72 .field("client_secret", &"[REDACTED]")
73 .field("authorization_endpoint", &self.authorization_endpoint)
74 .finish_non_exhaustive()
75 }
76}
77
78impl OidcServerClient {
79 const MAX_OIDC_RESPONSE_BYTES: usize = 1024 * 1024;
84
85 pub fn new(
90 client_id: impl Into<String>,
91 client_secret: impl Into<String>,
92 server_redirect_uri: impl Into<String>,
93 authorization_endpoint: impl Into<String>,
94 token_endpoint: impl Into<String>,
95 ) -> Self {
96 Self {
97 client_id: client_id.into(),
98 client_secret: client_secret.into(),
99 server_redirect_uri: server_redirect_uri.into(),
100 authorization_endpoint: authorization_endpoint.into(),
101 token_endpoint: token_endpoint.into(),
102 }
103 }
104
105 pub fn from_compiled_schema(schema_json: &serde_json::Value) -> Option<Arc<Self>> {
115 #[derive(Deserialize)]
117 struct AuthCfg {
118 client_id: String,
119 client_secret_env: String,
120 server_redirect_uri: String,
121 }
122
123 let auth_cfg: AuthCfg =
124 schema_json.get("auth").and_then(|v| serde_json::from_value(v.clone()).ok())?;
125
126 let Ok(client_secret) = std::env::var(&auth_cfg.client_secret_env) else {
128 tracing::error!(
129 env_var = %auth_cfg.client_secret_env,
130 "PKCE init failed: env var for OIDC client secret is not set"
131 );
132 return None;
133 };
134
135 let Some(endpoints): Option<OidcEndpoints> = schema_json
137 .get("auth_endpoints")
138 .and_then(|v| serde_json::from_value(v.clone()).ok())
139 else {
140 tracing::error!(
141 "PKCE init failed: 'auth_endpoints' not found in compiled schema. \
142 Re-compile the schema so that the CLI caches the OIDC discovery \
143 document (authorization_endpoint, token_endpoint)."
144 );
145 return None;
146 };
147
148 Some(Arc::new(Self {
149 client_id: auth_cfg.client_id,
150 client_secret,
151 server_redirect_uri: auth_cfg.server_redirect_uri,
152 authorization_endpoint: endpoints.authorization_endpoint,
153 token_endpoint: endpoints.token_endpoint,
154 }))
155 }
156
157 pub fn authorization_url(
163 &self,
164 state: &str,
165 code_challenge: &str,
166 code_challenge_method: &str,
167 ) -> String {
168 format!(
169 "{}?response_type=code\
170 &client_id={}\
171 &redirect_uri={}\
172 &scope=openid%20email%20profile\
173 &state={}\
174 &code_challenge={}\
175 &code_challenge_method={}",
176 self.authorization_endpoint,
177 urlencoding::encode(&self.client_id),
178 urlencoding::encode(&self.server_redirect_uri),
179 urlencoding::encode(state),
180 urlencoding::encode(code_challenge),
181 code_challenge_method,
182 )
183 }
184
185 pub async fn exchange_code(
198 &self,
199 code: &str,
200 code_verifier: &str,
201 http: &reqwest::Client,
202 ) -> Result<OidcTokenResponse, anyhow::Error> {
203 let resp = http
204 .post(&self.token_endpoint)
205 .form(&[
206 ("grant_type", "authorization_code"),
207 ("code", code),
208 ("code_verifier", code_verifier),
209 ("redirect_uri", self.server_redirect_uri.as_str()),
210 ("client_id", self.client_id.as_str()),
211 ("client_secret", self.client_secret.as_str()),
212 ])
213 .send()
214 .await?;
215
216 let status = resp.status();
217
218 let body_bytes = resp
221 .bytes()
222 .await
223 .map_err(|e| anyhow::anyhow!("Failed to read token response: {e}"))?;
224
225 anyhow::ensure!(
228 body_bytes.len() <= Self::MAX_OIDC_RESPONSE_BYTES,
229 "OIDC token response too large ({} bytes, max {})",
230 body_bytes.len(),
231 Self::MAX_OIDC_RESPONSE_BYTES
232 );
233
234 if !status.is_success() {
235 let body = String::from_utf8_lossy(&body_bytes);
237 anyhow::bail!("token endpoint returned {status}: {body}");
238 }
239
240 Ok(serde_json::from_slice::<OidcTokenResponse>(&body_bytes)?)
241 }
242}
243
244#[allow(clippy::unwrap_used)] #[cfg(test)]
250mod tests {
251 #[allow(clippy::wildcard_imports)]
252 use super::*;
254
255 fn test_client() -> OidcServerClient {
256 OidcServerClient::new(
257 "test-client",
258 "test-secret",
259 "https://api.example.com/auth/callback",
260 "https://provider.example.com/authorize",
261 "https://provider.example.com/token",
262 )
263 }
264
265 #[test]
266 fn test_authorization_url_contains_required_pkce_params() {
267 let client = test_client();
268 let url = client.authorization_url("my_state", "my_challenge", "S256");
269 assert!(url.contains("response_type=code"), "missing response_type");
270 assert!(url.contains("client_id=test-client"), "missing client_id");
271 assert!(url.contains("code_challenge=my_challenge"), "missing code_challenge");
272 assert!(url.contains("code_challenge_method=S256"), "missing method");
273 assert!(url.contains("state="), "missing state");
274 assert!(url.contains("redirect_uri="), "missing redirect_uri");
275 }
276
277 #[test]
278 fn oidc_response_cap_constant_is_reasonable() {
279 assert_eq!(OidcServerClient::MAX_OIDC_RESPONSE_BYTES, 1024 * 1024);
280 }
281
282 #[test]
283 fn oidc_response_cap_covers_error_path() {
284 const { assert!(OidcServerClient::MAX_OIDC_RESPONSE_BYTES >= 64 * 1024) }
287 const { assert!(OidcServerClient::MAX_OIDC_RESPONSE_BYTES <= 100 * 1024 * 1024) }
288 }
289
290 #[tokio::test]
291 async fn oidc_oversized_error_response_is_rejected() {
292 use wiremock::{
293 Mock, MockServer, ResponseTemplate,
294 matchers::{method, path},
295 };
296
297 let mock_server = MockServer::start().await;
298 let oversized = vec![b'e'; OidcServerClient::MAX_OIDC_RESPONSE_BYTES + 1];
300 Mock::given(method("POST"))
301 .and(path("/token"))
302 .respond_with(ResponseTemplate::new(400).set_body_bytes(oversized))
303 .mount(&mock_server)
304 .await;
305
306 let client = OidcServerClient::new(
307 "client_id",
308 "client_secret",
309 "https://example.com/callback",
310 "https://example.com/auth",
311 format!("{}/token", mock_server.uri()),
312 );
313 let http = reqwest::Client::new();
314 let result = client.exchange_code("code", "verifier", &http).await;
315
316 assert!(result.is_err(), "oversized error response must be rejected");
317 let msg = result.unwrap_err().to_string();
318 assert!(msg.contains("too large"), "error must mention size limit, got: {msg}");
319 }
320
321 #[tokio::test]
322 async fn oidc_oversized_success_response_is_rejected() {
323 use wiremock::{
324 Mock, MockServer, ResponseTemplate,
325 matchers::{method, path},
326 };
327
328 let mock_server = MockServer::start().await;
329 let oversized = vec![b'x'; OidcServerClient::MAX_OIDC_RESPONSE_BYTES + 1];
330 Mock::given(method("POST"))
331 .and(path("/token"))
332 .respond_with(ResponseTemplate::new(200).set_body_bytes(oversized))
333 .mount(&mock_server)
334 .await;
335
336 let client = OidcServerClient::new(
337 "client_id",
338 "client_secret",
339 "https://example.com/callback",
340 "https://example.com/auth",
341 format!("{}/token", mock_server.uri()),
342 );
343 let http = reqwest::Client::new();
344 let result = client.exchange_code("code", "verifier", &http).await;
345
346 assert!(result.is_err(), "oversized success response must be rejected, got: {result:?}");
347 let msg = result.unwrap_err().to_string();
348 assert!(msg.contains("too large"), "error must mention size limit, got: {msg}");
349 }
350
351 #[test]
352 fn test_authorization_url_includes_openid_scope() {
353 let client = test_client();
354 let url = client.authorization_url("s", "c", "S256");
355 assert!(url.contains("openid"), "authorization URL must request the openid scope: {url}");
357 }
358
359 #[test]
360 fn test_authorization_url_state_is_percent_encoded() {
361 let client = test_client();
364 let state_with_spaces = "hello world+test";
365 let url = client.authorization_url(state_with_spaces, "challenge", "S256");
366 let state_segment = url.split("state=").nth(1).unwrap().split('&').next().unwrap();
368 assert!(!state_segment.contains(' '), "space in state must be percent-encoded");
369 assert!(!state_segment.contains('+'), "plus in state must be percent-encoded");
370 }
371
372 #[test]
373 fn test_from_compiled_schema_absent_auth_returns_none() {
374 let json = serde_json::json!({});
375 assert!(OidcServerClient::from_compiled_schema(&json).is_none());
376 }
377
378 #[test]
379 fn test_from_compiled_schema_missing_env_var_returns_none() {
380 let json = serde_json::json!({
384 "auth": {
385 "discovery_url": "https://example.com",
386 "client_id": "x",
387 "client_secret_env": "__FRAISEQL_TEST_DEFINITELY_UNSET_42XYZ__",
388 "server_redirect_uri": "https://api.example.com/auth/callback"
389 },
390 "auth_endpoints": {
391 "authorization_endpoint": "https://example.com/auth",
392 "token_endpoint": "https://example.com/token"
393 }
394 });
395 let _ = OidcServerClient::from_compiled_schema(&json);
399 }
402
403 #[test]
404 fn test_from_compiled_schema_missing_endpoints_returns_none() {
405 let json = serde_json::json!({
408 "auth": {
409 "discovery_url": "https://example.com",
410 "client_id": "x",
411 "client_secret_env": "PATH",
412 "server_redirect_uri": "https://api.example.com/auth/callback"
413 }
414 });
416 assert!(
417 OidcServerClient::from_compiled_schema(&json).is_none(),
418 "missing auth_endpoints must return None"
419 );
420 }
421
422 #[test]
423 fn test_debug_redacts_client_secret() {
424 let client = test_client();
425 let debug_str = format!("{client:?}");
426 assert!(
427 !debug_str.contains("test-secret"),
428 "Debug output must not expose the client secret: {debug_str}"
429 );
430 assert!(debug_str.contains("[REDACTED]"));
431 }
432}