1use http::HeaderName;
2use rmcp::{
6 model::{ClientCapabilities, ClientInfo},
7 transport::{sse::SseTransport, stdio},
8 ServiceExt,
9};
10use std::{collections::HashMap, error::Error as StdError, str::FromStr};
11use tracing::info;
12
13use crate::proxy_handler::ProxyHandler;
14
15pub struct SseClientConfig {
17 pub url: String,
18 pub headers: HashMap<String, String>,
19}
20
21pub async fn run_sse_client(config: SseClientConfig) -> Result<(), Box<dyn StdError>> {
25 info!("Running SSE client with URL: {}", config.url);
26
27 let mut headers = reqwest::header::HeaderMap::new();
29 for (key, value) in config.headers {
30 headers.insert(HeaderName::from_str(&key)?, value.parse()?);
31 }
32
33 let client = reqwest::Client::builder()
35 .default_headers(headers)
36 .build()?;
37
38 let transport = SseTransport::start_with_client(&config.url, client).await?;
40
41 let client_info = ClientInfo {
43 protocol_version: Default::default(),
44 capabilities: ClientCapabilities::builder()
45 .enable_experimental()
46 .enable_roots()
47 .enable_roots_list_changed()
48 .enable_sampling()
49 .build(),
50 ..Default::default()
51 };
52
53 let client = client_info.serve(transport).await?;
55
56 let server_info = client.peer_info();
58 info!("Connected to server: {}", server_info.server_info.name);
59
60 let proxy_handler = ProxyHandler::new(client);
62
63 let stdio_transport = stdio();
65
66 let server = proxy_handler.serve(stdio_transport).await?;
68
69 server.waiting().await?;
71
72 Ok(())
73}