use crate::error::{Error, ErrorCode, Result};
use crate::server::auth::oauth2::OidcDiscoveryMetadata;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug)]
pub struct OidcDiscoveryClient {
client: reqwest::Client,
max_retries: usize,
retry_delay: Duration,
}
impl Default for OidcDiscoveryClient {
fn default() -> Self {
Self::new()
}
}
impl OidcDiscoveryClient {
pub fn new() -> Self {
Self {
client: reqwest::Client::new(),
max_retries: 3,
retry_delay: Duration::from_millis(500),
}
}
pub fn with_settings(max_retries: usize, retry_delay: Duration) -> Self {
Self {
client: reqwest::Client::new(),
max_retries,
retry_delay,
}
}
pub async fn discover(&self, issuer_url: &str) -> Result<OidcDiscoveryMetadata> {
let discovery_url = format!(
"{}/.well-known/openid-configuration",
issuer_url.trim_end_matches('/')
);
let mut attempts = 0;
let mut last_error = None;
while attempts < self.max_retries {
match self.fetch_discovery(&discovery_url).await {
Ok(metadata) => return Ok(metadata),
Err(e) => {
if self.should_retry(&e) && attempts + 1 < self.max_retries {
attempts += 1;
tokio::time::sleep(self.retry_delay).await;
continue;
}
last_error = Some(e);
break;
},
}
}
Err(last_error.unwrap_or_else(|| {
Error::protocol(
ErrorCode::INTERNAL_ERROR,
"Failed to discover OIDC configuration",
)
}))
}
async fn fetch_discovery(&self, url: &str) -> Result<OidcDiscoveryMetadata> {
let response = self
.client
.get(url)
.header("Accept", "application/json")
.send()
.await
.map_err(|e| {
Error::protocol(
ErrorCode::INTERNAL_ERROR,
format!("Failed to fetch discovery document: {}", e),
)
})?;
if !response.status().is_success() {
return Err(Error::protocol(
ErrorCode::INTERNAL_ERROR,
format!("Discovery endpoint returned status: {}", response.status()),
));
}
response.json::<OidcDiscoveryMetadata>().await.map_err(|e| {
Error::protocol(
ErrorCode::PARSE_ERROR,
format!("Failed to parse discovery document: {}", e),
)
})
}
fn should_retry(&self, error: &Error) -> bool {
let _ = self.max_retries;
if matches!(error, Error::Timeout(_)) {
return true;
}
let error_str = error.to_string();
error_str.contains("CORS")
|| error_str.contains("network")
|| error_str.contains("timeout")
|| error_str.contains("connection")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_in: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
}
#[derive(Debug)]
pub struct TokenExchangeClient {
client: reqwest::Client,
}
impl Default for TokenExchangeClient {
fn default() -> Self {
Self::new()
}
}
impl TokenExchangeClient {
pub fn new() -> Self {
Self {
client: reqwest::Client::new(),
}
}
pub async fn exchange_code(
&self,
token_endpoint: &str,
code: &str,
client_id: &str,
client_secret: Option<&str>,
redirect_uri: &str,
code_verifier: Option<&str>,
) -> Result<TokenResponse> {
let mut params = vec![
("grant_type", "authorization_code"),
("code", code),
("client_id", client_id),
("redirect_uri", redirect_uri),
];
if let Some(verifier) = code_verifier {
params.push(("code_verifier", verifier));
}
let mut request = self.client
.post(token_endpoint)
.header("Accept", "application/json") .form(¶ms);
if let Some(secret) = client_secret {
request = request.basic_auth(client_id, Some(secret));
}
let response = request.send().await.map_err(|e| {
Error::protocol(
ErrorCode::INTERNAL_ERROR,
format!("Failed to exchange authorization code: {}", e),
)
})?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(Error::protocol(
ErrorCode::INVALID_REQUEST,
format!("Token exchange failed: {}", error_text),
));
}
response.json::<TokenResponse>().await.map_err(|e| {
Error::protocol(
ErrorCode::PARSE_ERROR,
format!("Failed to parse token response: {}", e),
)
})
}
pub async fn refresh_token(
&self,
token_endpoint: &str,
refresh_token: &str,
client_id: &str,
client_secret: Option<&str>,
scope: Option<&str>,
) -> Result<TokenResponse> {
let mut params = vec![
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
("client_id", client_id),
];
if let Some(s) = scope {
params.push(("scope", s));
}
let mut request = self.client
.post(token_endpoint)
.header("Accept", "application/json") .form(¶ms);
if let Some(secret) = client_secret {
request = request.basic_auth(client_id, Some(secret));
}
let response = request.send().await.map_err(|e| {
Error::protocol(
ErrorCode::INTERNAL_ERROR,
format!("Failed to refresh token: {}", e),
)
})?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(Error::protocol(
ErrorCode::INVALID_REQUEST,
format!("Token refresh failed: {}", error_text),
));
}
response.json::<TokenResponse>().await.map_err(|e| {
Error::protocol(
ErrorCode::PARSE_ERROR,
format!("Failed to parse token response: {}", e),
)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_discovery_url_construction() {
let _client = OidcDiscoveryClient::new();
let test_cases = vec![
(
"https://example.com",
"https://example.com/.well-known/openid-configuration",
),
(
"https://example.com/",
"https://example.com/.well-known/openid-configuration",
),
(
"https://auth.example.com/oauth",
"https://auth.example.com/oauth/.well-known/openid-configuration",
),
];
for (issuer, expected) in test_cases {
let url = format!(
"{}/.well-known/openid-configuration",
issuer.trim_end_matches('/')
);
assert_eq!(url, expected);
}
}
#[test]
fn test_should_retry_logic() {
let client = OidcDiscoveryClient::new();
let cors_error =
Error::protocol(ErrorCode::INTERNAL_ERROR, "CORS policy blocked the request");
assert!(client.should_retry(&cors_error));
let network_error = Error::protocol(ErrorCode::INTERNAL_ERROR, "network connection failed");
assert!(client.should_retry(&network_error));
let timeout_error = Error::Timeout(5000);
assert!(client.should_retry(&timeout_error));
let parse_error = Error::protocol(ErrorCode::PARSE_ERROR, "Invalid JSON");
assert!(!client.should_retry(&parse_error));
}
#[test]
fn test_discovery_client_with_settings() {
let client = OidcDiscoveryClient::with_settings(5, Duration::from_secs(2));
assert_eq!(client.max_retries, 5);
assert_eq!(client.retry_delay, Duration::from_secs(2));
}
#[test]
fn test_token_response_serialization() {
let token_response = TokenResponse {
access_token: "test_token".to_string(),
token_type: "Bearer".to_string(),
expires_in: Some(3600),
refresh_token: Some("refresh_token".to_string()),
scope: Some("openid profile".to_string()),
};
let json = serde_json::to_string(&token_response).unwrap();
assert!(json.contains("test_token"));
assert!(json.contains("Bearer"));
let deserialized: TokenResponse = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.access_token, "test_token");
assert_eq!(deserialized.expires_in, Some(3600));
}
#[test]
fn test_oidc_discovery_metadata_defaults() {
let json = r#"{
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code"],
"scopes_supported": ["openid", "profile"],
"token_endpoint_auth_methods_supported": ["client_secret_basic"],
"code_challenge_methods_supported": ["S256"]
}"#;
let metadata: OidcDiscoveryMetadata = serde_json::from_str(json).unwrap();
assert_eq!(metadata.issuer, "https://auth.example.com");
assert_eq!(metadata.jwks_uri, None);
assert_eq!(metadata.userinfo_endpoint, None);
}
}