ai_agent/services/mcp/
auth.rs1use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12pub struct AuthConfig {
13 pub enabled: bool,
14 pub auth_type: Option<String>,
15 pub token: Option<String>,
16}
17
18#[derive(Debug, Clone, Default, Serialize, Deserialize)]
20pub struct OAuthConfig {
21 pub client_id: Option<String>,
22 pub client_secret: Option<String>,
23 pub redirect_uri: Option<String>,
24 pub scopes: Vec<String>,
25}
26
27pub fn get_auth_headers(config: &AuthConfig) -> std::collections::HashMap<String, String> {
29 let mut headers = std::collections::HashMap::new();
30
31 if let Some(token) = &config.token {
32 headers.insert("Authorization".to_string(), format!("Bearer {}", token));
33 }
34
35 headers
36}
37
38pub fn is_auth_required(config: &AuthConfig) -> bool {
40 config.enabled && config.auth_type.is_some()
41}
42
43#[derive(Debug, Clone)]
49pub struct McpOAuthResult {
50 pub status: McpOAuthStatus,
52 pub message: String,
54 pub auth_url: Option<String>,
56}
57
58#[derive(Debug, Clone, PartialEq)]
60pub enum McpOAuthStatus {
61 AuthUrl,
63 Authenticated,
65 Unsupported,
67 Error,
69}
70
71pub type McpOAuthCallback = Arc<
74 dyn Fn(
75 String,
76 serde_json::Value,
77 Option<Arc<dyn Fn(String) + Send + Sync>>,
78 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<McpOAuthResult, crate::AgentError>> + Send + Sync>>
79 + Send
80 + Sync,
81>;
82
83static MCP_OAUTH_CALLBACK: once_cell::sync::Lazy<parking_lot::RwLock<Option<McpOAuthCallback>>> =
85 once_cell::sync::Lazy::new(Default::default);
86
87pub fn register_mcp_oauth_callback<F, Fut>(callback: F)
92where
93 F: Fn(
94 String,
95 serde_json::Value,
96 Option<Arc<dyn Fn(String) + Send + Sync>>,
97 ) -> Fut + Send + Sync + 'static,
98 Fut: std::future::Future<Output = Result<McpOAuthResult, crate::AgentError>> + Send + Sync + 'static,
99{
100 let wrapped: McpOAuthCallback = Arc::new(
101 move |server: String, config: serde_json::Value, on_url: Option<Arc<dyn Fn(String) + Send + Sync>>| {
102 Box::pin(callback(server, config, on_url))
103 },
104 );
105 *MCP_OAUTH_CALLBACK.write() = Some(wrapped);
106}
107
108pub async fn perform_mcp_oauth_flow(
112 server_name: String,
113 config: serde_json::Value,
114 on_auth_url: Option<Arc<dyn Fn(String) + Send + Sync>>,
115) -> Result<McpOAuthResult, crate::AgentError> {
116 let callback = MCP_OAUTH_CALLBACK.read().clone();
117 match callback {
118 Some(cb) => cb(server_name, config, on_auth_url).await,
119 None => Err(crate::AgentError::Tool(
120 "No MCP OAuth callback registered. Call register_mcp_oauth_callback() to enable OAuth.".to_string(),
121 )),
122 }
123}
124
125