use http::HeaderName;
use rmcp::{
model::{ClientCapabilities, ClientInfo},
transport::{sse::SseTransport, stdio},
ServiceExt,
};
use std::{collections::HashMap, error::Error as StdError, str::FromStr};
use tracing::info;
use crate::proxy_handler::ProxyHandler;
pub struct SseClientConfig {
pub url: String,
pub headers: HashMap<String, String>,
}
pub async fn run_sse_client(config: SseClientConfig) -> Result<(), Box<dyn StdError>> {
info!("Running SSE client with URL: {}", config.url);
let mut headers = reqwest::header::HeaderMap::new();
for (key, value) in config.headers {
headers.insert(HeaderName::from_str(&key)?, value.parse()?);
}
let client = reqwest::Client::builder()
.default_headers(headers)
.build()?;
let transport = SseTransport::start_with_client(&config.url, client).await?;
let client_info = ClientInfo {
protocol_version: Default::default(),
capabilities: ClientCapabilities::builder()
.enable_experimental()
.enable_roots()
.enable_roots_list_changed()
.enable_sampling()
.build(),
..Default::default()
};
let client = client_info.serve(transport).await?;
let server_info = client.peer_info();
info!("Connected to server: {}", server_info.server_info.name);
let proxy_handler = ProxyHandler::new(client);
let stdio_transport = stdio();
let server = proxy_handler.serve(stdio_transport).await?;
server.waiting().await?;
Ok(())
}