1use super::auth::McpAuth;
7use super::elicitation::ElicitationHandler;
8use adk_core::{AdkError, Result};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::Duration;
12
13#[derive(Clone)]
40pub struct McpHttpClientBuilder {
41 endpoint: String,
43 auth: McpAuth,
45 timeout: Duration,
47 headers: HashMap<String, String>,
49 elicitation_handler: Option<Arc<dyn ElicitationHandler>>,
51}
52
53impl McpHttpClientBuilder {
54 pub fn new(endpoint: impl Into<String>) -> Self {
60 Self {
61 endpoint: endpoint.into(),
62 auth: McpAuth::None,
63 timeout: Duration::from_secs(30),
64 headers: HashMap::new(),
65 elicitation_handler: None,
66 }
67 }
68
69 pub fn with_auth(mut self, auth: McpAuth) -> Self {
78 self.auth = auth;
79 self
80 }
81
82 pub fn timeout(mut self, timeout: Duration) -> Self {
86 self.timeout = timeout;
87 self
88 }
89
90 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
92 self.headers.insert(key.into(), value.into());
93 self
94 }
95
96 pub fn with_elicitation_handler(mut self, handler: Arc<dyn ElicitationHandler>) -> Self {
101 self.elicitation_handler = Some(handler);
102 self
103 }
104
105 pub fn endpoint(&self) -> &str {
107 &self.endpoint
108 }
109
110 pub fn get_timeout(&self) -> Duration {
112 self.timeout
113 }
114
115 pub fn get_auth(&self) -> &McpAuth {
117 &self.auth
118 }
119
120 #[cfg(feature = "http-transport")]
132 pub async fn connect(
133 self,
134 ) -> Result<super::McpToolset<impl rmcp::service::Service<rmcp::RoleClient>>> {
135 use adk_core::{ErrorCategory, ErrorComponent};
136 use rmcp::ServiceExt;
137 use rmcp::transport::streamable_http_client::{
138 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
139 };
140
141 let token = match &self.auth {
144 McpAuth::Bearer(token) => Some(token.clone()),
145 McpAuth::OAuth2(config) => {
146 let token = config.get_or_refresh_token().await.map_err(|e| {
148 AdkError::new(
149 ErrorComponent::Tool,
150 ErrorCategory::Unauthorized,
151 "mcp.oauth.token_fetch",
152 format!("OAuth2 authentication failed: {e}"),
153 )
154 })?;
155 Some(token)
156 }
157 McpAuth::ApiKey { .. } => {
158 None
161 }
162 McpAuth::None => None,
163 };
164
165 let mut config = StreamableHttpClientTransportConfig::with_uri(self.endpoint.as_str());
167
168 if let Some(token) = token {
170 config = config.auth_header(token);
171 }
172
173 let transport = StreamableHttpClientTransport::from_config(config);
175
176 let client = ()
178 .serve(transport)
179 .await
180 .map_err(|e| AdkError::tool(format!("Failed to connect to MCP server: {e}")))?;
181
182 Ok(super::McpToolset::new(client))
183 }
184
185 #[cfg(not(feature = "http-transport"))]
187 pub async fn connect(self) -> Result<()> {
188 Err(AdkError::tool(
189 "HTTP transport requires the 'http-transport' feature. \
190 Add `adk-tool = { features = [\"http-transport\"] }` to your Cargo.toml",
191 ))
192 }
193
194 #[cfg(feature = "http-transport")]
215 pub async fn connect_with_elicitation(
216 self,
217 ) -> Result<super::McpToolset<impl rmcp::service::Service<rmcp::RoleClient>>> {
218 use adk_core::{ErrorCategory, ErrorComponent};
219 use rmcp::ServiceExt;
220 use rmcp::transport::streamable_http_client::{
221 StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
222 };
223
224 let handler = self.elicitation_handler.ok_or_else(|| {
225 AdkError::tool(
226 "connect_with_elicitation requires with_elicitation_handler to be called first",
227 )
228 })?;
229
230 let token = match &self.auth {
232 McpAuth::Bearer(token) => Some(token.clone()),
233 McpAuth::OAuth2(config) => {
234 let token = config.get_or_refresh_token().await.map_err(|e| {
235 AdkError::new(
236 ErrorComponent::Tool,
237 ErrorCategory::Unauthorized,
238 "mcp.oauth.token_fetch",
239 format!("OAuth2 authentication failed: {e}"),
240 )
241 })?;
242 Some(token)
243 }
244 McpAuth::ApiKey { .. } => None,
245 McpAuth::None => None,
246 };
247
248 let mut config = StreamableHttpClientTransportConfig::with_uri(self.endpoint.as_str());
249 if let Some(token) = token {
250 config = config.auth_header(token);
251 }
252
253 let transport = StreamableHttpClientTransport::from_config(config);
254 let adk_handler = super::elicitation::AdkClientHandler::new(handler);
255 let client = adk_handler
256 .serve(transport)
257 .await
258 .map_err(|e| AdkError::tool(format!("failed to connect to MCP server: {e}")))?;
259
260 Ok(super::McpToolset::new(client))
261 }
262
263 #[cfg(not(feature = "http-transport"))]
265 pub async fn connect_with_elicitation(self) -> Result<()> {
266 Err(AdkError::tool(
267 "HTTP transport requires the 'http-transport' feature. \
268 Add `adk-tool = { features = [\"http-transport\"] }` to your Cargo.toml",
269 ))
270 }
271}
272
273impl std::fmt::Debug for McpHttpClientBuilder {
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 f.debug_struct("McpHttpClientBuilder")
276 .field("endpoint", &self.endpoint)
277 .field("auth", &self.auth)
278 .field("timeout", &self.timeout)
279 .field("headers", &self.headers.keys().collect::<Vec<_>>())
280 .finish()
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_builder_new() {
290 let builder = McpHttpClientBuilder::new("https://mcp.example.com");
291 assert_eq!(builder.endpoint(), "https://mcp.example.com");
292 assert_eq!(builder.get_timeout(), Duration::from_secs(30));
293 }
294
295 #[test]
296 fn test_builder_with_auth() {
297 let builder = McpHttpClientBuilder::new("https://mcp.example.com")
298 .with_auth(McpAuth::bearer("test-token"));
299 assert!(builder.get_auth().is_configured());
300 }
301
302 #[test]
303 fn test_builder_timeout() {
304 let builder =
305 McpHttpClientBuilder::new("https://mcp.example.com").timeout(Duration::from_secs(60));
306 assert_eq!(builder.get_timeout(), Duration::from_secs(60));
307 }
308
309 #[test]
310 fn test_builder_headers() {
311 let builder =
312 McpHttpClientBuilder::new("https://mcp.example.com").header("X-Custom", "value");
313 assert!(builder.headers.contains_key("X-Custom"));
314 }
315}