Skip to main content

rust_mcp_actix/
runtime.rs

1use crate::server::ActixServer;
2use actix_web::dev::ServerHandle;
3use rust_mcp_sdk::session_store::SessionStore;
4use rust_mcp_sdk::task_store::{ClientTaskStore, ServerTaskStore, TaskStatusPoller};
5use rust_mcp_sdk::{
6    error::SdkResult,
7    mcp_http::McpAppState,
8    schema::{
9        schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient},
10        CreateMessageRequestParams, CreateMessageResult, ElicitRequestParams, ElicitResult,
11        GenericResult, GetTaskParams, GetTaskResult, InitializeRequestParams, ListRootsResult,
12        LoggingMessageNotificationParams, NotificationParams, RequestParams,
13        ResourceUpdatedNotificationParams,
14    },
15    McpServer,
16};
17use rust_mcp_sdk::{
18    schema::{
19        schema_utils::{ClientTaskResult, CustomNotification, CustomRequest},
20        CancelTaskParams, CancelTaskResult, CancelledNotificationParams, CreateTaskResult,
21        ElicitCompleteParams, GetTaskPayloadParams, ProgressNotificationParams, RpcError,
22        TaskStatusNotificationParams,
23    },
24    SessionId,
25};
26use std::io;
27use std::sync::Arc;
28use std::time::Duration;
29use tokio::task::JoinHandle;
30
31/// Runtime handle for a running Actix MCP server.
32///
33/// Provides session management, graceful shutdown, and per-session request/notification
34/// methods. Implements [`McpHttpServer`] for framework-agnostic usage.
35pub struct ActixRuntime {
36    pub(crate) state: Arc<McpAppState>,
37    pub(crate) server_task: JoinHandle<io::Result<()>>,
38    pub(crate) server_handle: ServerHandle,
39}
40
41impl ActixRuntime {
42    /// Creates and starts a new runtime from an `ActixServer`.
43    pub async fn create(server: ActixServer) -> SdkResult<Self> {
44        let addr = server
45            .options()
46            .resolve_server_address()
47            .map_err(|e| rust_mcp_sdk::error::McpSdkError::Internal { description: e })?;
48
49        let state = server.state();
50        let info = server.server_info(Some(addr)).unwrap_or_default();
51        tracing::info!("{}", info);
52
53        let state_clone = state.clone();
54        let handler = server.handler.clone();
55        let mount_options = server.options().resolve_mount_options();
56
57        let srv = actix_web::HttpServer::new(move || {
58            actix_web::App::new().service(crate::mcp_scope(
59                state_clone.clone(),
60                handler.clone(),
61                &mount_options,
62            ))
63        });
64
65        #[cfg(feature = "ssl")]
66        let srv = if server.options().enable_ssl {
67            let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
68            let config = load_rustls_config(
69                server
70                    .options()
71                    .ssl_cert_path
72                    .as_deref()
73                    .unwrap_or_default(),
74                server.options().ssl_key_path.as_deref().unwrap_or_default(),
75            )
76            .map_err(|e| rust_mcp_sdk::error::McpSdkError::Internal {
77                description: e.to_string(),
78            })?;
79            srv.bind_rustls_0_23(addr, config).map_err(|e| {
80                rust_mcp_sdk::error::McpSdkError::Internal {
81                    description: e.to_string(),
82                }
83            })?
84        } else {
85            srv.bind(addr)
86                .map_err(|e| rust_mcp_sdk::error::McpSdkError::Internal {
87                    description: e.to_string(),
88                })?
89        };
90
91        #[cfg(not(feature = "ssl"))]
92        let srv = srv
93            .bind(addr)
94            .map_err(|e| rust_mcp_sdk::error::McpSdkError::Internal {
95                description: e.to_string(),
96            })?;
97
98        let srv = srv.run();
99
100        let server_handle = srv.handle();
101        let server_task = tokio::spawn(srv);
102
103        // Task store notification forwarding
104        use futures::StreamExt;
105        if let Some(task_store) = state.task_store.clone() {
106            if let Some(mut stream) = task_store.subscribe() {
107                let state_clone = state.clone();
108                tokio::spawn(async move {
109                    while let Some((params, session_id_opt)) = stream.next().await {
110                        if let Some(session_id) = session_id_opt.as_ref() {
111                            if let Some(transport) = state_clone.session_store.get(session_id).await
112                            {
113                                let _ = transport.notify_task_status(params).await;
114                            }
115                        }
116                    }
117                });
118            }
119        }
120
121        // Task polling for server-initiated tasks
122        if let Some(client_task_store) = state.client_task_store.clone() {
123            let session_store = state.session_store.clone();
124            let callback = task_poller_callback(Arc::clone(&client_task_store), session_store);
125            let _ = client_task_store.start_task_polling(callback);
126        }
127
128        Ok(Self {
129            state,
130            server_task,
131            server_handle,
132        })
133    }
134
135    /// Gracefully stops the server.
136    pub fn graceful_shutdown(&self, _timeout: Option<Duration>) {
137        let handle = self.server_handle.clone();
138        tokio::spawn(async move {
139            let _ = handle.stop(true).await;
140        });
141    }
142
143    /// Awaits server completion (typically until shutdown).
144    pub async fn await_server(self) -> SdkResult<()> {
145        self.server_task
146            .await
147            .map_err(|e| rust_mcp_sdk::error::McpSdkError::Internal {
148                description: e.to_string(),
149            })?
150            .map_err(|e| rust_mcp_sdk::error::McpSdkError::Internal {
151                description: e.to_string(),
152            })
153    }
154
155    /// Returns all active session IDs.
156    pub async fn sessions(&self) -> Vec<String> {
157        self.state.session_store.keys().await
158    }
159
160    /// Returns the runtime for a given session.
161    pub async fn runtime_by_session(
162        &self,
163        session_id: &SessionId,
164    ) -> Result<Arc<ServerRuntime>, rust_mcp_sdk::error::McpSdkError> {
165        self.state
166            .session_store
167            .get(session_id)
168            .await
169            .ok_or_else(|| rust_mcp_sdk::error::McpSdkError::Internal {
170                description: format!("Session not found: {}", session_id),
171            })
172    }
173
174    // --- Request methods ---
175
176    pub async fn send_request(
177        &self,
178        session_id: &SessionId,
179        request: RequestFromServer,
180        timeout: Option<Duration>,
181    ) -> SdkResult<ResultFromClient> {
182        let runtime = self.runtime_by_session(session_id).await?;
183        runtime.request(request, timeout).await
184    }
185
186    pub async fn send_notification(
187        &self,
188        session_id: &SessionId,
189        notification: NotificationFromServer,
190    ) -> SdkResult<()> {
191        let runtime = self.runtime_by_session(session_id).await?;
192        runtime.send_notification(notification).await
193    }
194
195    pub async fn client_info(
196        &self,
197        session_id: &SessionId,
198    ) -> SdkResult<Option<InitializeRequestParams>> {
199        let runtime = self.runtime_by_session(session_id).await?;
200        Ok(runtime.client_info())
201    }
202
203    pub async fn request_elicitation(
204        &self,
205        session_id: &SessionId,
206        params: ElicitRequestParams,
207    ) -> SdkResult<ElicitResult> {
208        self.runtime_by_session(session_id)
209            .await?
210            .request_elicitation(params)
211            .await
212    }
213
214    pub async fn request_root_list(
215        &self,
216        session_id: &SessionId,
217        params: Option<RequestParams>,
218    ) -> SdkResult<ListRootsResult> {
219        self.runtime_by_session(session_id)
220            .await?
221            .request_root_list(params)
222            .await
223    }
224
225    pub async fn ping(
226        &self,
227        session_id: &SessionId,
228        params: Option<RequestParams>,
229        timeout: Option<Duration>,
230    ) -> SdkResult<rust_mcp_sdk::schema::Result> {
231        self.runtime_by_session(session_id)
232            .await?
233            .ping(params, timeout)
234            .await
235    }
236
237    pub async fn request_message_creation(
238        &self,
239        session_id: &SessionId,
240        params: CreateMessageRequestParams,
241    ) -> SdkResult<CreateMessageResult> {
242        self.runtime_by_session(session_id)
243            .await?
244            .request_message_creation(params)
245            .await
246    }
247
248    pub async fn request_get_task(
249        &self,
250        session_id: &SessionId,
251        params: GetTaskParams,
252    ) -> SdkResult<GetTaskResult> {
253        self.runtime_by_session(session_id)
254            .await?
255            .request_get_task(params)
256            .await
257    }
258
259    pub async fn request_custom(
260        &self,
261        session_id: &SessionId,
262        params: CustomRequest,
263    ) -> SdkResult<GenericResult> {
264        self.runtime_by_session(session_id)
265            .await?
266            .request_custom(params)
267            .await
268    }
269
270    // --- Notification methods ---
271
272    pub async fn notify_log_message(
273        &self,
274        session_id: &SessionId,
275        params: LoggingMessageNotificationParams,
276    ) -> SdkResult<()> {
277        self.runtime_by_session(session_id)
278            .await?
279            .notify_log_message(params)
280            .await
281    }
282
283    pub async fn notify_tool_list_changed(
284        &self,
285        session_id: &SessionId,
286        params: Option<NotificationParams>,
287    ) -> SdkResult<()> {
288        self.runtime_by_session(session_id)
289            .await?
290            .notify_tool_list_changed(params)
291            .await
292    }
293
294    pub async fn notify_resource_updated(
295        &self,
296        session_id: &SessionId,
297        params: ResourceUpdatedNotificationParams,
298    ) -> SdkResult<()> {
299        self.runtime_by_session(session_id)
300            .await?
301            .notify_resource_updated(params)
302            .await
303    }
304
305    pub async fn notify_resource_list_changed(
306        &self,
307        session_id: &SessionId,
308        params: Option<NotificationParams>,
309    ) -> SdkResult<()> {
310        self.runtime_by_session(session_id)
311            .await?
312            .notify_resource_list_changed(params)
313            .await
314    }
315
316    pub async fn notify_prompt_list_changed(
317        &self,
318        session_id: &SessionId,
319        params: Option<NotificationParams>,
320    ) -> SdkResult<()> {
321        self.runtime_by_session(session_id)
322            .await?
323            .notify_prompt_list_changed(params)
324            .await
325    }
326
327    pub async fn notify_task_status(
328        &self,
329        session_id: &SessionId,
330        params: TaskStatusNotificationParams,
331    ) -> SdkResult<()> {
332        self.runtime_by_session(session_id)
333            .await?
334            .notify_task_status(params)
335            .await
336    }
337
338    pub async fn notify_cancellation(
339        &self,
340        session_id: &SessionId,
341        params: CancelledNotificationParams,
342    ) -> SdkResult<()> {
343        self.runtime_by_session(session_id)
344            .await?
345            .notify_cancellation(params)
346            .await
347    }
348
349    pub async fn notify_progress(
350        &self,
351        session_id: &SessionId,
352        params: ProgressNotificationParams,
353    ) -> SdkResult<()> {
354        self.runtime_by_session(session_id)
355            .await?
356            .notify_progress(params)
357            .await
358    }
359
360    pub async fn notify_elicitation_completed(
361        &self,
362        session_id: &SessionId,
363        params: ElicitCompleteParams,
364    ) -> SdkResult<()> {
365        self.runtime_by_session(session_id)
366            .await?
367            .notify_elicitation_completed(params)
368            .await
369    }
370
371    pub async fn notify_custom(
372        &self,
373        session_id: &SessionId,
374        params: CustomNotification,
375    ) -> SdkResult<()> {
376        self.runtime_by_session(session_id)
377            .await?
378            .notify_custom(params)
379            .await
380    }
381
382    // --- Additional request methods (parity with AxumRuntime) ---
383
384    pub async fn request_elicitation_task(
385        &self,
386        session_id: &SessionId,
387        params: ElicitRequestParams,
388    ) -> SdkResult<CreateTaskResult> {
389        self.runtime_by_session(session_id)
390            .await?
391            .request_elicitation_task(params)
392            .await
393    }
394
395    pub async fn request_get_task_payload(
396        &self,
397        session_id: &SessionId,
398        params: GetTaskPayloadParams,
399    ) -> SdkResult<ClientTaskResult> {
400        self.runtime_by_session(session_id)
401            .await?
402            .request_get_task_payload(params)
403            .await
404    }
405
406    pub async fn request_task_cancellation(
407        &self,
408        session_id: &SessionId,
409        params: CancelTaskParams,
410    ) -> SdkResult<CancelTaskResult> {
411        self.runtime_by_session(session_id)
412            .await?
413            .request_task_cancellation(params)
414            .await
415    }
416
417    // --- Getters ---
418
419    pub fn task_store(&self) -> Option<Arc<ServerTaskStore>> {
420        self.state.task_store.clone()
421    }
422
423    pub fn client_task_store(&self) -> Option<Arc<ClientTaskStore>> {
424        self.state.client_task_store.clone()
425    }
426
427    // --- Deprecated aliases ---
428
429    #[deprecated(since = "0.8.0", note = "Use `request_root_list()` instead.")]
430    pub async fn list_roots(
431        &self,
432        session_id: &SessionId,
433        params: Option<RequestParams>,
434    ) -> SdkResult<ListRootsResult> {
435        self.request_root_list(session_id, params).await
436    }
437
438    #[deprecated(since = "0.8.0", note = "Use `request_elicitation()` instead.")]
439    pub async fn elicit_input(
440        &self,
441        session_id: &SessionId,
442        params: ElicitRequestParams,
443    ) -> SdkResult<ElicitResult> {
444        self.request_elicitation(session_id, params).await
445    }
446
447    #[deprecated(since = "0.8.0", note = "Use `request_message_creation()` instead.")]
448    pub async fn create_message(
449        &self,
450        session_id: &SessionId,
451        params: CreateMessageRequestParams,
452    ) -> SdkResult<CreateMessageResult> {
453        self.request_message_creation(session_id, params).await
454    }
455
456    #[deprecated(since = "0.8.0", note = "Use `notify_tool_list_changed()` instead.")]
457    pub async fn send_tool_list_changed(
458        &self,
459        session_id: &SessionId,
460        params: Option<NotificationParams>,
461    ) -> SdkResult<()> {
462        self.notify_tool_list_changed(session_id, params).await
463    }
464
465    #[deprecated(since = "0.8.0", note = "Use `notify_resource_updated()` instead.")]
466    pub async fn send_resource_updated(
467        &self,
468        session_id: &SessionId,
469        params: ResourceUpdatedNotificationParams,
470    ) -> SdkResult<()> {
471        self.notify_resource_updated(session_id, params).await
472    }
473
474    #[deprecated(
475        since = "0.8.0",
476        note = "Use `notify_resource_list_changed()` instead."
477    )]
478    pub async fn send_resource_list_changed(
479        &self,
480        session_id: &SessionId,
481        params: Option<NotificationParams>,
482    ) -> SdkResult<()> {
483        self.notify_resource_list_changed(session_id, params).await
484    }
485
486    #[deprecated(since = "0.8.0", note = "Use `notify_prompt_list_changed()` instead.")]
487    pub async fn send_prompt_list_changed(
488        &self,
489        session_id: &SessionId,
490        params: Option<NotificationParams>,
491    ) -> SdkResult<()> {
492        self.notify_prompt_list_changed(session_id, params).await
493    }
494
495    #[deprecated(since = "0.8.0", note = "Use `notify_log_message()` instead.")]
496    pub async fn send_logging_message(
497        &self,
498        session_id: &SessionId,
499        params: LoggingMessageNotificationParams,
500    ) -> SdkResult<()> {
501        self.notify_log_message(session_id, params).await
502    }
503}
504
505fn task_poller_callback(
506    client_task_store: Arc<ClientTaskStore>,
507    session_store: Arc<dyn SessionStore>,
508) -> TaskStatusPoller {
509    let session_store = session_store.clone();
510    let task_store_clone = client_task_store.clone();
511
512    let callback: TaskStatusPoller = Box::new(move |task_id, session_id| {
513        let session_store_clone = session_store.clone();
514        let task_store_clone = task_store_clone.clone();
515        Box::pin(async move {
516            let Some(session) = session_id.as_ref() else {
517                return Err(RpcError::invalid_request()
518                    .with_message("No session id provided!".to_string())
519                    .into());
520            };
521
522            let Some(runtime) = session_store_clone.get(session).await else {
523                return Err(RpcError::invalid_request()
524                    .with_message("Invalid or broken session!".to_string())
525                    .into());
526            };
527
528            runtime
529                .poll_task_status(task_id, session_id, task_store_clone)
530                .await
531        })
532    });
533    callback
534}
535
536use async_trait::async_trait;
537use rust_mcp_sdk::mcp_server::ServerRuntime;
538use rust_mcp_sdk::McpHttpServer;
539
540#[async_trait]
541impl McpHttpServer for ActixRuntime {
542    async fn graceful_shutdown(&self) {
543        self.graceful_shutdown(None);
544    }
545
546    async fn sessions(&self) -> Vec<SessionId> {
547        ActixRuntime::sessions(self).await
548    }
549
550    async fn runtime_by_session(&self, id: &SessionId) -> SdkResult<Arc<ServerRuntime>> {
551        ActixRuntime::runtime_by_session(self, id).await
552    }
553}
554
555#[cfg(feature = "ssl")]
556fn load_rustls_config(cert_path: &str, key_path: &str) -> std::io::Result<rustls::ServerConfig> {
557    use std::fs::File;
558    use std::io::BufReader;
559
560    let cert_file = File::open(cert_path)?;
561    let mut cert_reader = BufReader::new(cert_file);
562    let certs: Vec<rustls::pki_types::CertificateDer> = rustls_pemfile::certs(&mut cert_reader)
563        .collect::<Result<Vec<_>, _>>()
564        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
565
566    let key_file = File::open(key_path)?;
567    let mut key_reader = BufReader::new(key_file);
568    let key = rustls_pemfile::private_key(&mut key_reader)?.ok_or_else(|| {
569        std::io::Error::new(std::io::ErrorKind::InvalidInput, "no private key found")
570    })?;
571
572    let config = rustls::ServerConfig::builder()
573        .with_no_client_auth()
574        .with_single_cert(certs, key)
575        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
576
577    Ok(config)
578}
579
580#[cfg(all(test, feature = "ssl"))]
581mod ssl_tests {
582    #[test]
583    fn install_crypto_provider_idempotent() {
584        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
585        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
586    }
587}