use std::{fmt, sync::Arc};
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct OidcEndpoints {
pub authorization_endpoint: String,
pub token_endpoint: String,
}
#[derive(Debug, Deserialize)]
pub struct OidcTokenResponse {
pub access_token: String,
pub id_token: Option<String>,
pub expires_in: Option<u64>,
pub refresh_token: Option<String>,
}
pub struct OidcServerClient {
client_id: String,
client_secret: String,
server_redirect_uri: String,
authorization_endpoint: String,
token_endpoint: String,
}
#[allow(clippy::missing_fields_in_debug)] impl fmt::Debug for OidcServerClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OidcServerClient")
.field("client_id", &self.client_id)
.field("client_secret", &"[REDACTED]")
.field("authorization_endpoint", &self.authorization_endpoint)
.finish_non_exhaustive()
}
}
impl OidcServerClient {
const MAX_OIDC_RESPONSE_BYTES: usize = 1024 * 1024;
pub fn new(
client_id: impl Into<String>,
client_secret: impl Into<String>,
server_redirect_uri: impl Into<String>,
authorization_endpoint: impl Into<String>,
token_endpoint: impl Into<String>,
) -> Self {
Self {
client_id: client_id.into(),
client_secret: client_secret.into(),
server_redirect_uri: server_redirect_uri.into(),
authorization_endpoint: authorization_endpoint.into(),
token_endpoint: token_endpoint.into(),
}
}
pub fn from_compiled_schema(schema_json: &serde_json::Value) -> Option<Arc<Self>> {
#[derive(Deserialize)]
struct AuthCfg {
client_id: String,
client_secret_env: String,
server_redirect_uri: String,
}
let auth_cfg: AuthCfg =
schema_json.get("auth").and_then(|v| serde_json::from_value(v.clone()).ok())?;
let Ok(client_secret) = std::env::var(&auth_cfg.client_secret_env) else {
tracing::error!(
env_var = %auth_cfg.client_secret_env,
"PKCE init failed: env var for OIDC client secret is not set"
);
return None;
};
let Some(endpoints): Option<OidcEndpoints> = schema_json
.get("auth_endpoints")
.and_then(|v| serde_json::from_value(v.clone()).ok())
else {
tracing::error!(
"PKCE init failed: 'auth_endpoints' not found in compiled schema. \
Re-compile the schema so that the CLI caches the OIDC discovery \
document (authorization_endpoint, token_endpoint)."
);
return None;
};
Some(Arc::new(Self {
client_id: auth_cfg.client_id,
client_secret,
server_redirect_uri: auth_cfg.server_redirect_uri,
authorization_endpoint: endpoints.authorization_endpoint,
token_endpoint: endpoints.token_endpoint,
}))
}
pub fn authorization_url(
&self,
state: &str,
code_challenge: &str,
code_challenge_method: &str,
) -> String {
format!(
"{}?response_type=code\
&client_id={}\
&redirect_uri={}\
&scope=openid%20email%20profile\
&state={}\
&code_challenge={}\
&code_challenge_method={}",
self.authorization_endpoint,
urlencoding::encode(&self.client_id),
urlencoding::encode(&self.server_redirect_uri),
urlencoding::encode(state),
urlencoding::encode(code_challenge),
code_challenge_method,
)
}
pub async fn exchange_code(
&self,
code: &str,
code_verifier: &str,
http: &reqwest::Client,
) -> Result<OidcTokenResponse, anyhow::Error> {
let resp = http
.post(&self.token_endpoint)
.form(&[
("grant_type", "authorization_code"),
("code", code),
("code_verifier", code_verifier),
("redirect_uri", self.server_redirect_uri.as_str()),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
])
.send()
.await?;
let status = resp.status();
let body_bytes = resp
.bytes()
.await
.map_err(|e| anyhow::anyhow!("Failed to read token response: {e}"))?;
anyhow::ensure!(
body_bytes.len() <= Self::MAX_OIDC_RESPONSE_BYTES,
"OIDC token response too large ({} bytes, max {})",
body_bytes.len(),
Self::MAX_OIDC_RESPONSE_BYTES
);
if !status.is_success() {
let body = String::from_utf8_lossy(&body_bytes);
anyhow::bail!("token endpoint returned {status}: {body}");
}
Ok(serde_json::from_slice::<OidcTokenResponse>(&body_bytes)?)
}
}
#[allow(clippy::unwrap_used)] #[cfg(test)]
mod tests {
#[allow(clippy::wildcard_imports)]
use super::*;
fn test_client() -> OidcServerClient {
OidcServerClient::new(
"test-client",
"test-secret",
"https://api.example.com/auth/callback",
"https://provider.example.com/authorize",
"https://provider.example.com/token",
)
}
#[test]
fn test_authorization_url_contains_required_pkce_params() {
let client = test_client();
let url = client.authorization_url("my_state", "my_challenge", "S256");
assert!(url.contains("response_type=code"), "missing response_type");
assert!(url.contains("client_id=test-client"), "missing client_id");
assert!(url.contains("code_challenge=my_challenge"), "missing code_challenge");
assert!(url.contains("code_challenge_method=S256"), "missing method");
assert!(url.contains("state="), "missing state");
assert!(url.contains("redirect_uri="), "missing redirect_uri");
}
#[test]
fn oidc_response_cap_constant_is_reasonable() {
assert_eq!(OidcServerClient::MAX_OIDC_RESPONSE_BYTES, 1024 * 1024);
}
#[test]
fn oidc_response_cap_covers_error_path() {
const { assert!(OidcServerClient::MAX_OIDC_RESPONSE_BYTES >= 64 * 1024) }
const { assert!(OidcServerClient::MAX_OIDC_RESPONSE_BYTES <= 100 * 1024 * 1024) }
}
#[tokio::test]
async fn oidc_oversized_error_response_is_rejected() {
use wiremock::{
Mock, MockServer, ResponseTemplate,
matchers::{method, path},
};
let mock_server = MockServer::start().await;
let oversized = vec![b'e'; OidcServerClient::MAX_OIDC_RESPONSE_BYTES + 1];
Mock::given(method("POST"))
.and(path("/token"))
.respond_with(ResponseTemplate::new(400).set_body_bytes(oversized))
.mount(&mock_server)
.await;
let client = OidcServerClient::new(
"client_id",
"client_secret",
"https://example.com/callback",
"https://example.com/auth",
format!("{}/token", mock_server.uri()),
);
let http = reqwest::Client::new();
let result = client.exchange_code("code", "verifier", &http).await;
assert!(result.is_err(), "oversized error response must be rejected");
let msg = result.unwrap_err().to_string();
assert!(msg.contains("too large"), "error must mention size limit, got: {msg}");
}
#[tokio::test]
async fn oidc_oversized_success_response_is_rejected() {
use wiremock::{
Mock, MockServer, ResponseTemplate,
matchers::{method, path},
};
let mock_server = MockServer::start().await;
let oversized = vec![b'x'; OidcServerClient::MAX_OIDC_RESPONSE_BYTES + 1];
Mock::given(method("POST"))
.and(path("/token"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(oversized))
.mount(&mock_server)
.await;
let client = OidcServerClient::new(
"client_id",
"client_secret",
"https://example.com/callback",
"https://example.com/auth",
format!("{}/token", mock_server.uri()),
);
let http = reqwest::Client::new();
let result = client.exchange_code("code", "verifier", &http).await;
assert!(result.is_err(), "oversized success response must be rejected, got: {result:?}");
let msg = result.unwrap_err().to_string();
assert!(msg.contains("too large"), "error must mention size limit, got: {msg}");
}
#[test]
fn test_authorization_url_includes_openid_scope() {
let client = test_client();
let url = client.authorization_url("s", "c", "S256");
assert!(url.contains("openid"), "authorization URL must request the openid scope: {url}");
}
#[test]
fn test_authorization_url_state_is_percent_encoded() {
let client = test_client();
let state_with_spaces = "hello world+test";
let url = client.authorization_url(state_with_spaces, "challenge", "S256");
let state_segment = url.split("state=").nth(1).unwrap().split('&').next().unwrap();
assert!(!state_segment.contains(' '), "space in state must be percent-encoded");
assert!(!state_segment.contains('+'), "plus in state must be percent-encoded");
}
#[test]
fn test_from_compiled_schema_absent_auth_returns_none() {
let json = serde_json::json!({});
assert!(OidcServerClient::from_compiled_schema(&json).is_none());
}
#[test]
fn test_from_compiled_schema_missing_env_var_returns_none() {
let json = serde_json::json!({
"auth": {
"discovery_url": "https://example.com",
"client_id": "x",
"client_secret_env": "__FRAISEQL_TEST_DEFINITELY_UNSET_42XYZ__",
"server_redirect_uri": "https://api.example.com/auth/callback"
},
"auth_endpoints": {
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token"
}
});
let _ = OidcServerClient::from_compiled_schema(&json);
}
#[test]
fn test_from_compiled_schema_missing_endpoints_returns_none() {
let json = serde_json::json!({
"auth": {
"discovery_url": "https://example.com",
"client_id": "x",
"client_secret_env": "PATH",
"server_redirect_uri": "https://api.example.com/auth/callback"
}
});
assert!(
OidcServerClient::from_compiled_schema(&json).is_none(),
"missing auth_endpoints must return None"
);
}
#[test]
fn test_debug_redacts_client_secret() {
let client = test_client();
let debug_str = format!("{client:?}");
assert!(
!debug_str.contains("test-secret"),
"Debug output must not expose the client secret: {debug_str}"
);
assert!(debug_str.contains("[REDACTED]"));
}
}