use super::auth::McpAuth;
use super::elicitation::ElicitationHandler;
use adk_core::{AdkError, Result};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone)]
pub struct McpHttpClientBuilder {
endpoint: String,
auth: McpAuth,
timeout: Duration,
headers: HashMap<String, String>,
elicitation_handler: Option<Arc<dyn ElicitationHandler>>,
}
impl McpHttpClientBuilder {
pub fn new(endpoint: impl Into<String>) -> Self {
Self {
endpoint: endpoint.into(),
auth: McpAuth::None,
timeout: Duration::from_secs(30),
headers: HashMap::new(),
elicitation_handler: None,
}
}
pub fn with_auth(mut self, auth: McpAuth) -> Self {
self.auth = auth;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub fn with_elicitation_handler(mut self, handler: Arc<dyn ElicitationHandler>) -> Self {
self.elicitation_handler = Some(handler);
self
}
pub fn endpoint(&self) -> &str {
&self.endpoint
}
pub fn get_timeout(&self) -> Duration {
self.timeout
}
pub fn get_auth(&self) -> &McpAuth {
&self.auth
}
#[cfg(feature = "http-transport")]
pub async fn connect(
self,
) -> Result<super::McpToolset<impl rmcp::service::Service<rmcp::RoleClient>>> {
use adk_core::{ErrorCategory, ErrorComponent};
use rmcp::ServiceExt;
use rmcp::transport::streamable_http_client::{
StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
};
let token = match &self.auth {
McpAuth::Bearer(token) => Some(token.clone()),
McpAuth::OAuth2(config) => {
let token = config.get_or_refresh_token().await.map_err(|e| {
AdkError::new(
ErrorComponent::Tool,
ErrorCategory::Unauthorized,
"mcp.oauth.token_fetch",
format!("OAuth2 authentication failed: {e}"),
)
})?;
Some(token)
}
McpAuth::ApiKey { .. } => {
None
}
McpAuth::None => None,
};
let mut config = StreamableHttpClientTransportConfig::with_uri(self.endpoint.as_str());
if let Some(token) = token {
config = config.auth_header(token);
}
let transport = StreamableHttpClientTransport::from_config(config);
let client = ()
.serve(transport)
.await
.map_err(|e| AdkError::tool(format!("Failed to connect to MCP server: {e}")))?;
Ok(super::McpToolset::new(client))
}
#[cfg(not(feature = "http-transport"))]
pub async fn connect(self) -> Result<()> {
Err(AdkError::tool(
"HTTP transport requires the 'http-transport' feature. \
Add `adk-tool = { features = [\"http-transport\"] }` to your Cargo.toml",
))
}
#[cfg(feature = "http-transport")]
pub async fn connect_with_elicitation(
self,
) -> Result<super::McpToolset<impl rmcp::service::Service<rmcp::RoleClient>>> {
use adk_core::{ErrorCategory, ErrorComponent};
use rmcp::ServiceExt;
use rmcp::transport::streamable_http_client::{
StreamableHttpClientTransport, StreamableHttpClientTransportConfig,
};
let handler = self.elicitation_handler.ok_or_else(|| {
AdkError::tool(
"connect_with_elicitation requires with_elicitation_handler to be called first",
)
})?;
let token = match &self.auth {
McpAuth::Bearer(token) => Some(token.clone()),
McpAuth::OAuth2(config) => {
let token = config.get_or_refresh_token().await.map_err(|e| {
AdkError::new(
ErrorComponent::Tool,
ErrorCategory::Unauthorized,
"mcp.oauth.token_fetch",
format!("OAuth2 authentication failed: {e}"),
)
})?;
Some(token)
}
McpAuth::ApiKey { .. } => None,
McpAuth::None => None,
};
let mut config = StreamableHttpClientTransportConfig::with_uri(self.endpoint.as_str());
if let Some(token) = token {
config = config.auth_header(token);
}
let transport = StreamableHttpClientTransport::from_config(config);
let adk_handler = super::elicitation::AdkClientHandler::new(handler);
let client = adk_handler
.serve(transport)
.await
.map_err(|e| AdkError::tool(format!("failed to connect to MCP server: {e}")))?;
Ok(super::McpToolset::new(client))
}
#[cfg(not(feature = "http-transport"))]
pub async fn connect_with_elicitation(self) -> Result<()> {
Err(AdkError::tool(
"HTTP transport requires the 'http-transport' feature. \
Add `adk-tool = { features = [\"http-transport\"] }` to your Cargo.toml",
))
}
}
impl std::fmt::Debug for McpHttpClientBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpHttpClientBuilder")
.field("endpoint", &self.endpoint)
.field("auth", &self.auth)
.field("timeout", &self.timeout)
.field("headers", &self.headers.keys().collect::<Vec<_>>())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_new() {
let builder = McpHttpClientBuilder::new("https://mcp.example.com");
assert_eq!(builder.endpoint(), "https://mcp.example.com");
assert_eq!(builder.get_timeout(), Duration::from_secs(30));
}
#[test]
fn test_builder_with_auth() {
let builder = McpHttpClientBuilder::new("https://mcp.example.com")
.with_auth(McpAuth::bearer("test-token"));
assert!(builder.get_auth().is_configured());
}
#[test]
fn test_builder_timeout() {
let builder =
McpHttpClientBuilder::new("https://mcp.example.com").timeout(Duration::from_secs(60));
assert_eq!(builder.get_timeout(), Duration::from_secs(60));
}
#[test]
fn test_builder_headers() {
let builder =
McpHttpClientBuilder::new("https://mcp.example.com").header("X-Custom", "value");
assert!(builder.headers.contains_key("X-Custom"));
}
}