1use super::auth::McpAuth;
7use adk_core::{AdkError, Result};
8use std::collections::HashMap;
9use std::time::Duration;
10
11#[derive(Clone)]
38pub struct McpHttpClientBuilder {
39 endpoint: String,
41 auth: McpAuth,
43 timeout: Duration,
45 headers: HashMap<String, String>,
47}
48
49impl McpHttpClientBuilder {
50 pub fn new(endpoint: impl Into<String>) -> Self {
56 Self {
57 endpoint: endpoint.into(),
58 auth: McpAuth::None,
59 timeout: Duration::from_secs(30),
60 headers: HashMap::new(),
61 }
62 }
63
64 pub fn with_auth(mut self, auth: McpAuth) -> Self {
73 self.auth = auth;
74 self
75 }
76
77 pub fn timeout(mut self, timeout: Duration) -> Self {
81 self.timeout = timeout;
82 self
83 }
84
85 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
87 self.headers.insert(key.into(), value.into());
88 self
89 }
90
91 pub fn endpoint(&self) -> &str {
93 &self.endpoint
94 }
95
96 pub fn get_timeout(&self) -> Duration {
98 self.timeout
99 }
100
101 pub fn get_auth(&self) -> &McpAuth {
103 &self.auth
104 }
105
106 #[cfg(feature = "http-transport")]
118 pub async fn connect(
119 self,
120 ) -> Result<super::McpToolset<impl rmcp::service::Service<rmcp::RoleClient>>> {
121 use rmcp::ServiceExt;
122 use rmcp::transport::streamable_http_client::{
123 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
124 };
125
126 let token = match &self.auth {
129 McpAuth::Bearer(token) => Some(token.clone()),
130 McpAuth::OAuth2(config) => {
131 let token = config
133 .get_or_refresh_token()
134 .await
135 .map_err(|e| AdkError::Tool(format!("OAuth2 authentication failed: {}", e)))?;
136 Some(token)
137 }
138 McpAuth::ApiKey { .. } => {
139 None
142 }
143 McpAuth::None => None,
144 };
145
146 let mut config = StreamableHttpClientTransportConfig::with_uri(self.endpoint.as_str());
148
149 if let Some(token) = token {
151 config = config.auth_header(token);
152 }
153
154 let transport = StreamableHttpClientTransport::from_config(config);
156
157 let client = ()
159 .serve(transport)
160 .await
161 .map_err(|e| AdkError::Tool(format!("Failed to connect to MCP server: {}", e)))?;
162
163 Ok(super::McpToolset::new(client))
164 }
165
166 #[cfg(not(feature = "http-transport"))]
168 pub async fn connect(self) -> Result<()> {
169 Err(AdkError::Tool(
170 "HTTP transport requires the 'http-transport' feature. \
171 Add `adk-tool = { features = [\"http-transport\"] }` to your Cargo.toml"
172 .to_string(),
173 ))
174 }
175}
176
177impl std::fmt::Debug for McpHttpClientBuilder {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("McpHttpClientBuilder")
180 .field("endpoint", &self.endpoint)
181 .field("auth", &self.auth)
182 .field("timeout", &self.timeout)
183 .field("headers", &self.headers.keys().collect::<Vec<_>>())
184 .finish()
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn test_builder_new() {
194 let builder = McpHttpClientBuilder::new("https://mcp.example.com");
195 assert_eq!(builder.endpoint(), "https://mcp.example.com");
196 assert_eq!(builder.get_timeout(), Duration::from_secs(30));
197 }
198
199 #[test]
200 fn test_builder_with_auth() {
201 let builder = McpHttpClientBuilder::new("https://mcp.example.com")
202 .with_auth(McpAuth::bearer("test-token"));
203 assert!(builder.get_auth().is_configured());
204 }
205
206 #[test]
207 fn test_builder_timeout() {
208 let builder =
209 McpHttpClientBuilder::new("https://mcp.example.com").timeout(Duration::from_secs(60));
210 assert_eq!(builder.get_timeout(), Duration::from_secs(60));
211 }
212
213 #[test]
214 fn test_builder_headers() {
215 let builder =
216 McpHttpClientBuilder::new("https://mcp.example.com").header("X-Custom", "value");
217 assert!(builder.headers.contains_key("X-Custom"));
218 }
219}