adk_rust_mcp_common/
server.rs1use crate::transport::Transport;
20use rmcp::{ServerHandler, ServiceExt};
21use thiserror::Error;
22use tokio::sync::oneshot;
23
24#[derive(Debug, Error)]
26pub enum ServerError {
27 #[error("Failed to bind to port {port}: {message}")]
29 BindFailed { port: u16, message: String },
30
31 #[error("Transport error: {0}")]
33 Transport(String),
34
35 #[error("Server shutdown")]
37 Shutdown,
38
39 #[error("IO error: {0}")]
41 Io(#[from] std::io::Error),
42}
43
44pub struct McpServerBuilder<H> {
49 handler: H,
50 transport: Transport,
51 shutdown_rx: Option<oneshot::Receiver<()>>,
52}
53
54impl<H> McpServerBuilder<H>
55where
56 H: ServerHandler + Clone + Send + Sync + 'static,
57{
58 pub fn new(handler: H) -> Self {
60 Self {
61 handler,
62 transport: Transport::default(),
63 shutdown_rx: None,
64 }
65 }
66
67 pub fn with_transport(mut self, transport: Transport) -> Self {
69 self.transport = transport;
70 self
71 }
72
73 pub fn with_shutdown(mut self, shutdown_rx: oneshot::Receiver<()>) -> Self {
78 self.shutdown_rx = Some(shutdown_rx);
79 self
80 }
81
82 pub async fn run(self) -> Result<(), ServerError> {
86 tracing::info!(transport = %self.transport, "Starting MCP server");
87
88 match self.transport {
89 Transport::Stdio => self.run_stdio().await,
90 Transport::Http { port } => self.run_http(port).await,
91 Transport::Sse { port } => self.run_sse(port).await,
92 }
93 }
94
95 async fn run_stdio(self) -> Result<(), ServerError> {
97 use rmcp::transport::io::stdio;
98
99 let transport = stdio();
100
101 let shutdown_future = async {
103 if let Some(rx) = self.shutdown_rx {
104 let _ = rx.await;
105 } else {
106 wait_for_shutdown_signal().await;
108 }
109 };
110
111 let service = self
113 .handler
114 .serve(transport)
115 .await
116 .map_err(|e| ServerError::Transport(e.to_string()))?;
117
118 tokio::select! {
119 result = service.waiting() => {
120 result.map_err(|e| ServerError::Transport(e.to_string()))?;
121 Ok(())
122 }
123 _ = shutdown_future => {
124 tracing::info!("Received shutdown signal, stopping server");
125 Ok(())
126 }
127 }
128 }
129
130 async fn run_http(self, port: u16) -> Result<(), ServerError> {
132 use rmcp::transport::streamable_http_server::{
133 session::local::LocalSessionManager, StreamableHttpService,
134 };
135
136 let handler = self.handler.clone();
137 let service = StreamableHttpService::new(
138 move || Ok(handler.clone()),
139 LocalSessionManager::default().into(),
140 Default::default(),
141 );
142
143 let router = axum::Router::new().nest_service("/mcp", service);
144
145 let bind_addr = format!("0.0.0.0:{}", port);
146 let tcp_listener = tokio::net::TcpListener::bind(&bind_addr)
147 .await
148 .map_err(|e| ServerError::BindFailed {
149 port,
150 message: e.to_string(),
151 })?;
152
153 tracing::info!(port, "HTTP server listening");
154
155 let shutdown_future = async {
157 if let Some(rx) = self.shutdown_rx {
158 let _ = rx.await;
159 } else {
160 wait_for_shutdown_signal().await;
161 }
162 };
163
164 axum::serve(tcp_listener, router)
165 .with_graceful_shutdown(shutdown_future)
166 .await
167 .map_err(|e| ServerError::Transport(e.to_string()))?;
168
169 tracing::info!("HTTP server stopped");
170 Ok(())
171 }
172
173 async fn run_sse(self, port: u16) -> Result<(), ServerError> {
178 self.run_http(port).await
181 }
182}
183
184async fn wait_for_shutdown_signal() {
186 #[cfg(unix)]
187 {
188 use tokio::signal::unix::{SignalKind, signal};
189
190 let mut sigterm =
191 signal(SignalKind::terminate()).expect("Failed to register SIGTERM handler");
192 let mut sigint =
193 signal(SignalKind::interrupt()).expect("Failed to register SIGINT handler");
194
195 tokio::select! {
196 _ = sigterm.recv() => {
197 tracing::info!("Received SIGTERM");
198 }
199 _ = sigint.recv() => {
200 tracing::info!("Received SIGINT");
201 }
202 }
203 }
204
205 #[cfg(not(unix))]
206 {
207 tokio::signal::ctrl_c()
208 .await
209 .expect("Failed to register Ctrl+C handler");
210 tracing::info!("Received Ctrl+C");
211 }
212}
213
214pub fn shutdown_channel() -> (oneshot::Sender<()>, oneshot::Receiver<()>) {
219 oneshot::channel()
220}