1use crate::server::ActixServer;
2use actix_web::dev::ServerHandle;
3use rust_mcp_sdk::session_store::SessionStore;
4use rust_mcp_sdk::task_store::{ClientTaskStore, ServerTaskStore, TaskStatusPoller};
5use rust_mcp_sdk::{
6 error::SdkResult,
7 mcp_http::McpAppState,
8 schema::{
9 schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient},
10 CreateMessageRequestParams, CreateMessageResult, ElicitRequestParams, ElicitResult,
11 GenericResult, GetTaskParams, GetTaskResult, InitializeRequestParams, ListRootsResult,
12 LoggingMessageNotificationParams, NotificationParams, RequestParams,
13 ResourceUpdatedNotificationParams,
14 },
15 McpServer,
16};
17use rust_mcp_sdk::{
18 schema::{
19 schema_utils::{ClientTaskResult, CustomNotification, CustomRequest},
20 CancelTaskParams, CancelTaskResult, CancelledNotificationParams, CreateTaskResult,
21 ElicitCompleteParams, GetTaskPayloadParams, ProgressNotificationParams, RpcError,
22 TaskStatusNotificationParams,
23 },
24 SessionId,
25};
26use std::io;
27use std::sync::Arc;
28use std::time::Duration;
29use tokio::task::JoinHandle;
30
31pub struct ActixRuntime {
36 pub(crate) state: Arc<McpAppState>,
37 pub(crate) server_task: JoinHandle<io::Result<()>>,
38 pub(crate) server_handle: ServerHandle,
39}
40
41impl ActixRuntime {
42 pub async fn create(server: ActixServer) -> SdkResult<Self> {
44 let addr = server
45 .options()
46 .resolve_server_address()
47 .map_err(|e| rust_mcp_sdk::error::McpSdkError::Internal { description: e })?;
48
49 let state = server.state();
50 let info = server.server_info(Some(addr)).unwrap_or_default();
51 tracing::info!("{}", info);
52
53 let state_clone = state.clone();
54 let handler = server.handler.clone();
55 let mount_options = server.options().resolve_mount_options();
56
57 let srv = actix_web::HttpServer::new(move || {
58 actix_web::App::new().service(crate::mcp_scope(
59 state_clone.clone(),
60 handler.clone(),
61 &mount_options,
62 ))
63 });
64
65 #[cfg(feature = "ssl")]
66 let srv = if server.options().enable_ssl {
67 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
68 let config = load_rustls_config(
69 server
70 .options()
71 .ssl_cert_path
72 .as_deref()
73 .unwrap_or_default(),
74 server.options().ssl_key_path.as_deref().unwrap_or_default(),
75 )
76 .map_err(|e| rust_mcp_sdk::error::McpSdkError::Internal {
77 description: e.to_string(),
78 })?;
79 srv.bind_rustls_0_23(addr, config).map_err(|e| {
80 rust_mcp_sdk::error::McpSdkError::Internal {
81 description: e.to_string(),
82 }
83 })?
84 } else {
85 srv.bind(addr)
86 .map_err(|e| rust_mcp_sdk::error::McpSdkError::Internal {
87 description: e.to_string(),
88 })?
89 };
90
91 #[cfg(not(feature = "ssl"))]
92 let srv = srv
93 .bind(addr)
94 .map_err(|e| rust_mcp_sdk::error::McpSdkError::Internal {
95 description: e.to_string(),
96 })?;
97
98 let srv = srv.run();
99
100 let server_handle = srv.handle();
101 let server_task = tokio::spawn(srv);
102
103 use futures::StreamExt;
105 if let Some(task_store) = state.task_store.clone() {
106 if let Some(mut stream) = task_store.subscribe() {
107 let state_clone = state.clone();
108 tokio::spawn(async move {
109 while let Some((params, session_id_opt)) = stream.next().await {
110 if let Some(session_id) = session_id_opt.as_ref() {
111 if let Some(transport) = state_clone.session_store.get(session_id).await
112 {
113 let _ = transport.notify_task_status(params).await;
114 }
115 }
116 }
117 });
118 }
119 }
120
121 if let Some(client_task_store) = state.client_task_store.clone() {
123 let session_store = state.session_store.clone();
124 let callback = task_poller_callback(Arc::clone(&client_task_store), session_store);
125 let _ = client_task_store.start_task_polling(callback);
126 }
127
128 Ok(Self {
129 state,
130 server_task,
131 server_handle,
132 })
133 }
134
135 pub fn graceful_shutdown(&self, _timeout: Option<Duration>) {
137 let handle = self.server_handle.clone();
138 tokio::spawn(async move {
139 let _ = handle.stop(true).await;
140 });
141 }
142
143 pub async fn await_server(self) -> SdkResult<()> {
145 self.server_task
146 .await
147 .map_err(|e| rust_mcp_sdk::error::McpSdkError::Internal {
148 description: e.to_string(),
149 })?
150 .map_err(|e| rust_mcp_sdk::error::McpSdkError::Internal {
151 description: e.to_string(),
152 })
153 }
154
155 pub async fn sessions(&self) -> Vec<String> {
157 self.state.session_store.keys().await
158 }
159
160 pub async fn runtime_by_session(
162 &self,
163 session_id: &SessionId,
164 ) -> Result<Arc<ServerRuntime>, rust_mcp_sdk::error::McpSdkError> {
165 self.state
166 .session_store
167 .get(session_id)
168 .await
169 .ok_or_else(|| rust_mcp_sdk::error::McpSdkError::Internal {
170 description: format!("Session not found: {}", session_id),
171 })
172 }
173
174 pub async fn send_request(
177 &self,
178 session_id: &SessionId,
179 request: RequestFromServer,
180 timeout: Option<Duration>,
181 ) -> SdkResult<ResultFromClient> {
182 let runtime = self.runtime_by_session(session_id).await?;
183 runtime.request(request, timeout).await
184 }
185
186 pub async fn send_notification(
187 &self,
188 session_id: &SessionId,
189 notification: NotificationFromServer,
190 ) -> SdkResult<()> {
191 let runtime = self.runtime_by_session(session_id).await?;
192 runtime.send_notification(notification).await
193 }
194
195 pub async fn client_info(
196 &self,
197 session_id: &SessionId,
198 ) -> SdkResult<Option<InitializeRequestParams>> {
199 let runtime = self.runtime_by_session(session_id).await?;
200 Ok(runtime.client_info())
201 }
202
203 pub async fn request_elicitation(
204 &self,
205 session_id: &SessionId,
206 params: ElicitRequestParams,
207 ) -> SdkResult<ElicitResult> {
208 self.runtime_by_session(session_id)
209 .await?
210 .request_elicitation(params)
211 .await
212 }
213
214 pub async fn request_root_list(
215 &self,
216 session_id: &SessionId,
217 params: Option<RequestParams>,
218 ) -> SdkResult<ListRootsResult> {
219 self.runtime_by_session(session_id)
220 .await?
221 .request_root_list(params)
222 .await
223 }
224
225 pub async fn ping(
226 &self,
227 session_id: &SessionId,
228 params: Option<RequestParams>,
229 timeout: Option<Duration>,
230 ) -> SdkResult<rust_mcp_sdk::schema::Result> {
231 self.runtime_by_session(session_id)
232 .await?
233 .ping(params, timeout)
234 .await
235 }
236
237 pub async fn request_message_creation(
238 &self,
239 session_id: &SessionId,
240 params: CreateMessageRequestParams,
241 ) -> SdkResult<CreateMessageResult> {
242 self.runtime_by_session(session_id)
243 .await?
244 .request_message_creation(params)
245 .await
246 }
247
248 pub async fn request_get_task(
249 &self,
250 session_id: &SessionId,
251 params: GetTaskParams,
252 ) -> SdkResult<GetTaskResult> {
253 self.runtime_by_session(session_id)
254 .await?
255 .request_get_task(params)
256 .await
257 }
258
259 pub async fn request_custom(
260 &self,
261 session_id: &SessionId,
262 params: CustomRequest,
263 ) -> SdkResult<GenericResult> {
264 self.runtime_by_session(session_id)
265 .await?
266 .request_custom(params)
267 .await
268 }
269
270 pub async fn notify_log_message(
273 &self,
274 session_id: &SessionId,
275 params: LoggingMessageNotificationParams,
276 ) -> SdkResult<()> {
277 self.runtime_by_session(session_id)
278 .await?
279 .notify_log_message(params)
280 .await
281 }
282
283 pub async fn notify_tool_list_changed(
284 &self,
285 session_id: &SessionId,
286 params: Option<NotificationParams>,
287 ) -> SdkResult<()> {
288 self.runtime_by_session(session_id)
289 .await?
290 .notify_tool_list_changed(params)
291 .await
292 }
293
294 pub async fn notify_resource_updated(
295 &self,
296 session_id: &SessionId,
297 params: ResourceUpdatedNotificationParams,
298 ) -> SdkResult<()> {
299 self.runtime_by_session(session_id)
300 .await?
301 .notify_resource_updated(params)
302 .await
303 }
304
305 pub async fn notify_resource_list_changed(
306 &self,
307 session_id: &SessionId,
308 params: Option<NotificationParams>,
309 ) -> SdkResult<()> {
310 self.runtime_by_session(session_id)
311 .await?
312 .notify_resource_list_changed(params)
313 .await
314 }
315
316 pub async fn notify_prompt_list_changed(
317 &self,
318 session_id: &SessionId,
319 params: Option<NotificationParams>,
320 ) -> SdkResult<()> {
321 self.runtime_by_session(session_id)
322 .await?
323 .notify_prompt_list_changed(params)
324 .await
325 }
326
327 pub async fn notify_task_status(
328 &self,
329 session_id: &SessionId,
330 params: TaskStatusNotificationParams,
331 ) -> SdkResult<()> {
332 self.runtime_by_session(session_id)
333 .await?
334 .notify_task_status(params)
335 .await
336 }
337
338 pub async fn notify_cancellation(
339 &self,
340 session_id: &SessionId,
341 params: CancelledNotificationParams,
342 ) -> SdkResult<()> {
343 self.runtime_by_session(session_id)
344 .await?
345 .notify_cancellation(params)
346 .await
347 }
348
349 pub async fn notify_progress(
350 &self,
351 session_id: &SessionId,
352 params: ProgressNotificationParams,
353 ) -> SdkResult<()> {
354 self.runtime_by_session(session_id)
355 .await?
356 .notify_progress(params)
357 .await
358 }
359
360 pub async fn notify_elicitation_completed(
361 &self,
362 session_id: &SessionId,
363 params: ElicitCompleteParams,
364 ) -> SdkResult<()> {
365 self.runtime_by_session(session_id)
366 .await?
367 .notify_elicitation_completed(params)
368 .await
369 }
370
371 pub async fn notify_custom(
372 &self,
373 session_id: &SessionId,
374 params: CustomNotification,
375 ) -> SdkResult<()> {
376 self.runtime_by_session(session_id)
377 .await?
378 .notify_custom(params)
379 .await
380 }
381
382 pub async fn request_elicitation_task(
385 &self,
386 session_id: &SessionId,
387 params: ElicitRequestParams,
388 ) -> SdkResult<CreateTaskResult> {
389 self.runtime_by_session(session_id)
390 .await?
391 .request_elicitation_task(params)
392 .await
393 }
394
395 pub async fn request_get_task_payload(
396 &self,
397 session_id: &SessionId,
398 params: GetTaskPayloadParams,
399 ) -> SdkResult<ClientTaskResult> {
400 self.runtime_by_session(session_id)
401 .await?
402 .request_get_task_payload(params)
403 .await
404 }
405
406 pub async fn request_task_cancellation(
407 &self,
408 session_id: &SessionId,
409 params: CancelTaskParams,
410 ) -> SdkResult<CancelTaskResult> {
411 self.runtime_by_session(session_id)
412 .await?
413 .request_task_cancellation(params)
414 .await
415 }
416
417 pub fn task_store(&self) -> Option<Arc<ServerTaskStore>> {
420 self.state.task_store.clone()
421 }
422
423 pub fn client_task_store(&self) -> Option<Arc<ClientTaskStore>> {
424 self.state.client_task_store.clone()
425 }
426
427 #[deprecated(since = "0.8.0", note = "Use `request_root_list()` instead.")]
430 pub async fn list_roots(
431 &self,
432 session_id: &SessionId,
433 params: Option<RequestParams>,
434 ) -> SdkResult<ListRootsResult> {
435 self.request_root_list(session_id, params).await
436 }
437
438 #[deprecated(since = "0.8.0", note = "Use `request_elicitation()` instead.")]
439 pub async fn elicit_input(
440 &self,
441 session_id: &SessionId,
442 params: ElicitRequestParams,
443 ) -> SdkResult<ElicitResult> {
444 self.request_elicitation(session_id, params).await
445 }
446
447 #[deprecated(since = "0.8.0", note = "Use `request_message_creation()` instead.")]
448 pub async fn create_message(
449 &self,
450 session_id: &SessionId,
451 params: CreateMessageRequestParams,
452 ) -> SdkResult<CreateMessageResult> {
453 self.request_message_creation(session_id, params).await
454 }
455
456 #[deprecated(since = "0.8.0", note = "Use `notify_tool_list_changed()` instead.")]
457 pub async fn send_tool_list_changed(
458 &self,
459 session_id: &SessionId,
460 params: Option<NotificationParams>,
461 ) -> SdkResult<()> {
462 self.notify_tool_list_changed(session_id, params).await
463 }
464
465 #[deprecated(since = "0.8.0", note = "Use `notify_resource_updated()` instead.")]
466 pub async fn send_resource_updated(
467 &self,
468 session_id: &SessionId,
469 params: ResourceUpdatedNotificationParams,
470 ) -> SdkResult<()> {
471 self.notify_resource_updated(session_id, params).await
472 }
473
474 #[deprecated(
475 since = "0.8.0",
476 note = "Use `notify_resource_list_changed()` instead."
477 )]
478 pub async fn send_resource_list_changed(
479 &self,
480 session_id: &SessionId,
481 params: Option<NotificationParams>,
482 ) -> SdkResult<()> {
483 self.notify_resource_list_changed(session_id, params).await
484 }
485
486 #[deprecated(since = "0.8.0", note = "Use `notify_prompt_list_changed()` instead.")]
487 pub async fn send_prompt_list_changed(
488 &self,
489 session_id: &SessionId,
490 params: Option<NotificationParams>,
491 ) -> SdkResult<()> {
492 self.notify_prompt_list_changed(session_id, params).await
493 }
494
495 #[deprecated(since = "0.8.0", note = "Use `notify_log_message()` instead.")]
496 pub async fn send_logging_message(
497 &self,
498 session_id: &SessionId,
499 params: LoggingMessageNotificationParams,
500 ) -> SdkResult<()> {
501 self.notify_log_message(session_id, params).await
502 }
503}
504
505fn task_poller_callback(
506 client_task_store: Arc<ClientTaskStore>,
507 session_store: Arc<dyn SessionStore>,
508) -> TaskStatusPoller {
509 let session_store = session_store.clone();
510 let task_store_clone = client_task_store.clone();
511
512 let callback: TaskStatusPoller = Box::new(move |task_id, session_id| {
513 let session_store_clone = session_store.clone();
514 let task_store_clone = task_store_clone.clone();
515 Box::pin(async move {
516 let Some(session) = session_id.as_ref() else {
517 return Err(RpcError::invalid_request()
518 .with_message("No session id provided!".to_string())
519 .into());
520 };
521
522 let Some(runtime) = session_store_clone.get(session).await else {
523 return Err(RpcError::invalid_request()
524 .with_message("Invalid or broken session!".to_string())
525 .into());
526 };
527
528 runtime
529 .poll_task_status(task_id, session_id, task_store_clone)
530 .await
531 })
532 });
533 callback
534}
535
536use async_trait::async_trait;
537use rust_mcp_sdk::mcp_server::ServerRuntime;
538use rust_mcp_sdk::McpHttpServer;
539
540#[async_trait]
541impl McpHttpServer for ActixRuntime {
542 async fn graceful_shutdown(&self) {
543 self.graceful_shutdown(None);
544 }
545
546 async fn sessions(&self) -> Vec<SessionId> {
547 ActixRuntime::sessions(self).await
548 }
549
550 async fn runtime_by_session(&self, id: &SessionId) -> SdkResult<Arc<ServerRuntime>> {
551 ActixRuntime::runtime_by_session(self, id).await
552 }
553}
554
555#[cfg(feature = "ssl")]
556fn load_rustls_config(cert_path: &str, key_path: &str) -> std::io::Result<rustls::ServerConfig> {
557 use std::fs::File;
558 use std::io::BufReader;
559
560 let cert_file = File::open(cert_path)?;
561 let mut cert_reader = BufReader::new(cert_file);
562 let certs: Vec<rustls::pki_types::CertificateDer> = rustls_pemfile::certs(&mut cert_reader)
563 .collect::<Result<Vec<_>, _>>()
564 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
565
566 let key_file = File::open(key_path)?;
567 let mut key_reader = BufReader::new(key_file);
568 let key = rustls_pemfile::private_key(&mut key_reader)?.ok_or_else(|| {
569 std::io::Error::new(std::io::ErrorKind::InvalidInput, "no private key found")
570 })?;
571
572 let config = rustls::ServerConfig::builder()
573 .with_no_client_auth()
574 .with_single_cert(certs, key)
575 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
576
577 Ok(config)
578}
579
580#[cfg(all(test, feature = "ssl"))]
581mod ssl_tests {
582 #[test]
583 fn install_crypto_provider_idempotent() {
584 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
585 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
586 }
587}