use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AuthConfig {
pub enabled: bool,
pub auth_type: Option<String>,
pub token: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct OAuthConfig {
pub client_id: Option<String>,
pub client_secret: Option<String>,
pub redirect_uri: Option<String>,
pub scopes: Vec<String>,
}
pub fn get_auth_headers(config: &AuthConfig) -> std::collections::HashMap<String, String> {
let mut headers = std::collections::HashMap::new();
if let Some(token) = &config.token {
headers.insert("Authorization".to_string(), format!("Bearer {}", token));
}
headers
}
pub fn is_auth_required(config: &AuthConfig) -> bool {
config.enabled && config.auth_type.is_some()
}
#[derive(Debug, Clone)]
pub struct McpOAuthResult {
pub status: McpOAuthStatus,
pub message: String,
pub auth_url: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum McpOAuthStatus {
AuthUrl,
Authenticated,
Unsupported,
Error,
}
pub type McpOAuthCallback = Arc<
dyn Fn(
String,
serde_json::Value,
Option<Arc<dyn Fn(String) + Send + Sync>>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<McpOAuthResult, crate::AgentError>> + Send + Sync>>
+ Send
+ Sync,
>;
static MCP_OAUTH_CALLBACK: once_cell::sync::Lazy<parking_lot::RwLock<Option<McpOAuthCallback>>> =
once_cell::sync::Lazy::new(Default::default);
pub fn register_mcp_oauth_callback<F, Fut>(callback: F)
where
F: Fn(
String,
serde_json::Value,
Option<Arc<dyn Fn(String) + Send + Sync>>,
) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<McpOAuthResult, crate::AgentError>> + Send + Sync + 'static,
{
let wrapped: McpOAuthCallback = Arc::new(
move |server: String, config: serde_json::Value, on_url: Option<Arc<dyn Fn(String) + Send + Sync>>| {
Box::pin(callback(server, config, on_url))
},
);
*MCP_OAUTH_CALLBACK.write() = Some(wrapped);
}
pub async fn perform_mcp_oauth_flow(
server_name: String,
config: serde_json::Value,
on_auth_url: Option<Arc<dyn Fn(String) + Send + Sync>>,
) -> Result<McpOAuthResult, crate::AgentError> {
let callback = MCP_OAUTH_CALLBACK.read().clone();
match callback {
Some(cb) => cb(server_name, config, on_auth_url).await,
None => Err(crate::AgentError::Tool(
"No MCP OAuth callback registered. Call register_mcp_oauth_callback() to enable OAuth.".to_string(),
)),
}
}