pub mod callback_server;
pub mod cimd;
pub mod discovery;
pub mod flow;
pub mod token_store;
pub use callback_server::{start_callback_server, OAuthCallbackResult};
pub use cimd::{resolve_client_id, stop_cimd_server};
pub use discovery::{
clear_all_discovered_oauth_cache, clear_discovered_oauth_cache, discover_oauth_from_mcp_server,
};
pub use flow::{
build_authorization_url, exchange_code_for_token, generate_pkce_pair, generate_state,
is_token_expired, refresh_access_token, OAuthTokenResponse, PkcePair,
};
pub use token_store::{
clear_token, get_valid_token, load_token, save_token, TokenMetadata, TokenResult,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use url::Url;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct OAuthConfig {
pub client_id: String,
#[serde(default)]
pub client_secret: String,
pub authorization_url: String,
pub token_url: String,
pub callback_url: String,
#[serde(default)]
pub scopes: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub state: Option<String>,
#[serde(default = "default_refresh_buffer")]
pub refresh_buffer_seconds: u64,
}
fn default_refresh_buffer() -> u64 {
300
}
impl OAuthConfig {
pub fn new(
client_id: String,
client_secret: String,
authorization_url: String,
token_url: String,
callback_url: String,
scopes: Vec<String>,
) -> Self {
Self {
client_id,
client_secret,
authorization_url,
token_url,
callback_url,
scopes,
state: None,
refresh_buffer_seconds: default_refresh_buffer(),
}
}
pub fn validate(&self) -> Result<(), String> {
if self.client_id.trim().is_empty() {
return Err("OAuth client_id cannot be empty".to_string());
}
let auth_url = Url::parse(&self.authorization_url).map_err(|e| {
format!(
"OAuth authorization_url is invalid: {}. Must be a valid URL",
e
)
})?;
if auth_url.scheme() != "https"
&& auth_url.host_str() != Some("localhost")
&& auth_url.host_str() != Some("127.0.0.1")
{
return Err(
"OAuth authorization_url must use HTTPS or http://localhost/http://127.0.0.1"
.to_string(),
);
}
let token_url = Url::parse(&self.token_url)
.map_err(|e| format!("OAuth token_url is invalid: {}. Must be a valid URL", e))?;
if token_url.scheme() != "https"
&& token_url.host_str() != Some("localhost")
&& token_url.host_str() != Some("127.0.0.1")
{
return Err(
"OAuth token_url must use HTTPS or http://localhost/http://127.0.0.1".to_string(),
);
}
let callback_url = Url::parse(&self.callback_url)
.map_err(|e| format!("OAuth callback_url is invalid: {}. Must be a valid URL", e))?;
if callback_url.scheme() != "http" && callback_url.scheme() != "https" {
return Err("OAuth callback_url must use HTTP or HTTPS".to_string());
}
for scope in &self.scopes {
if scope.trim().is_empty() {
return Err("OAuth scopes cannot contain empty strings".to_string());
}
}
if self.refresh_buffer_seconds < 60 {
return Err("OAuth refresh_buffer_seconds must be at least 60 seconds".to_string());
}
Ok(())
}
}
lazy_static::lazy_static! {
static ref OAUTH_LOCKS: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>> = Arc::new(Mutex::new(HashMap::new()));
}
pub async fn start_oauth_flow(config: &OAuthConfig) -> Result<OAuthCallbackResult, anyhow::Error> {
config
.validate()
.map_err(|e| anyhow::anyhow!("OAuth config validation failed: {}", e))?;
let pkce = generate_pkce_pair();
let state = generate_state();
let result = start_callback_server(config, state, pkce.code_verifier).await?;
Ok(result)
}
pub async fn get_access_token(
config: &OAuthConfig,
server_name: &str,
force_refresh: bool,
) -> Result<Option<String>, anyhow::Error> {
let lock = {
let mut locks = OAUTH_LOCKS.lock().await;
locks
.entry(server_name.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
};
let _guard = lock.lock().await;
if !force_refresh {
crate::log_debug!(
"🔍 GET_ACCESS_TOKEN: Checking for existing valid token for server_name='{}'",
server_name
);
let valid_token = get_valid_token(server_name, config.refresh_buffer_seconds)
.await
.map_err(|e| anyhow::anyhow!("Failed to check token: {}", e))?;
if let Some(metadata) = valid_token {
crate::log_debug!(
"✅ GET_ACCESS_TOKEN: Found valid token, token_prefix='{}...'",
metadata.access_token.chars().take(10).collect::<String>()
);
return Ok(Some(metadata.access_token));
} else {
crate::log_debug!("⚠️ GET_ACCESS_TOKEN: No valid token found, starting OAuth flow");
}
}
let result = start_oauth_flow(config).await?;
match result {
OAuthCallbackResult::Success { access_token, .. } => Ok(Some(access_token)),
OAuthCallbackResult::Error { error, description } => Err(anyhow::anyhow!(
"OAuth failed: {} - {}",
error,
description.unwrap_or_default()
)),
OAuthCallbackResult::Cancelled => Ok(None),
OAuthCallbackResult::Timeout => Err(anyhow::anyhow!("OAuth timed out")),
}
}
pub async fn is_authenticated(server_name: &str, refresh_buffer_seconds: u64) -> bool {
get_valid_token(server_name, refresh_buffer_seconds)
.await
.map(|m| m.is_some())
.unwrap_or(false)
}