Skip to main content

aster/agents/
mcp_client.rs

1use crate::action_required_manager::ActionRequiredManager;
2use crate::agents::types::SharedProvider;
3use crate::session_context::SESSION_ID_HEADER;
4use rmcp::model::{
5    Content, CreateElicitationRequestParam, CreateElicitationResult, ElicitationAction, ErrorCode,
6    JsonObject,
7};
8/// MCP client implementation for Aster
9use rmcp::{
10    model::{
11        CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification,
12        CancelledNotificationMethod, CancelledNotificationParam, ClientCapabilities, ClientInfo,
13        ClientRequest, CreateMessageRequestParam, CreateMessageResult, GetPromptRequest,
14        GetPromptRequestParam, GetPromptResult, Implementation, InitializeResult,
15        ListPromptsRequest, ListPromptsResult, ListResourcesRequest, ListResourcesResult,
16        ListToolsRequest, ListToolsResult, LoggingMessageNotification,
17        LoggingMessageNotificationMethod, PaginatedRequestParam, ProgressNotification,
18        ProgressNotificationMethod, ProtocolVersion, ReadResourceRequest, ReadResourceRequestParam,
19        ReadResourceResult, RequestId, Role, SamplingMessage, ServerNotification, ServerResult,
20    },
21    service::{
22        ClientInitializeError, PeerRequestOptions, RequestContext, RequestHandle, RunningService,
23        ServiceRole,
24    },
25    transport::IntoTransport,
26    ClientHandler, ErrorData, Peer, RoleClient, ServiceError, ServiceExt,
27};
28use serde_json::Value;
29use std::{sync::Arc, time::Duration};
30use tokio::sync::{
31    mpsc::{self, Sender},
32    Mutex,
33};
34use tokio_util::sync::CancellationToken;
35
36pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
37
38pub type Error = rmcp::ServiceError;
39
40#[async_trait::async_trait]
41pub trait McpClientTrait: Send + Sync {
42    async fn list_resources(
43        &self,
44        next_cursor: Option<String>,
45        cancel_token: CancellationToken,
46    ) -> Result<ListResourcesResult, Error>;
47
48    async fn read_resource(
49        &self,
50        uri: &str,
51        cancel_token: CancellationToken,
52    ) -> Result<ReadResourceResult, Error>;
53
54    async fn list_tools(
55        &self,
56        next_cursor: Option<String>,
57        cancel_token: CancellationToken,
58    ) -> Result<ListToolsResult, Error>;
59
60    async fn call_tool(
61        &self,
62        name: &str,
63        arguments: Option<JsonObject>,
64        cancel_token: CancellationToken,
65    ) -> Result<CallToolResult, Error>;
66
67    async fn list_prompts(
68        &self,
69        next_cursor: Option<String>,
70        cancel_token: CancellationToken,
71    ) -> Result<ListPromptsResult, Error>;
72
73    async fn get_prompt(
74        &self,
75        name: &str,
76        arguments: Value,
77        cancel_token: CancellationToken,
78    ) -> Result<GetPromptResult, Error>;
79
80    async fn subscribe(&self) -> mpsc::Receiver<ServerNotification>;
81
82    fn get_info(&self) -> Option<&InitializeResult>;
83
84    async fn get_moim(&self) -> Option<String> {
85        None
86    }
87}
88
89pub struct AsterClient {
90    notification_handlers: Arc<Mutex<Vec<Sender<ServerNotification>>>>,
91    provider: SharedProvider,
92}
93
94impl AsterClient {
95    pub fn new(
96        handlers: Arc<Mutex<Vec<Sender<ServerNotification>>>>,
97        provider: SharedProvider,
98    ) -> Self {
99        AsterClient {
100            notification_handlers: handlers,
101            provider,
102        }
103    }
104}
105
106impl ClientHandler for AsterClient {
107    async fn on_progress(
108        &self,
109        params: rmcp::model::ProgressNotificationParam,
110        context: rmcp::service::NotificationContext<rmcp::RoleClient>,
111    ) {
112        self.notification_handlers
113            .lock()
114            .await
115            .iter()
116            .for_each(|handler| {
117                let _ = handler.try_send(ServerNotification::ProgressNotification(
118                    ProgressNotification {
119                        params: params.clone(),
120                        method: ProgressNotificationMethod,
121                        extensions: context.extensions.clone(),
122                    },
123                ));
124            });
125    }
126
127    async fn on_logging_message(
128        &self,
129        params: rmcp::model::LoggingMessageNotificationParam,
130        context: rmcp::service::NotificationContext<rmcp::RoleClient>,
131    ) {
132        self.notification_handlers
133            .lock()
134            .await
135            .iter()
136            .for_each(|handler| {
137                let _ = handler.try_send(ServerNotification::LoggingMessageNotification(
138                    LoggingMessageNotification {
139                        params: params.clone(),
140                        method: LoggingMessageNotificationMethod,
141                        extensions: context.extensions.clone(),
142                    },
143                ));
144            });
145    }
146
147    async fn create_message(
148        &self,
149        params: CreateMessageRequestParam,
150        _context: RequestContext<RoleClient>,
151    ) -> Result<CreateMessageResult, ErrorData> {
152        let provider = self
153            .provider
154            .lock()
155            .await
156            .as_ref()
157            .ok_or(ErrorData::new(
158                ErrorCode::INTERNAL_ERROR,
159                "Could not use provider",
160                None,
161            ))?
162            .clone();
163
164        let provider_ready_messages: Vec<crate::conversation::message::Message> = params
165            .messages
166            .iter()
167            .map(|msg| {
168                let base = match msg.role {
169                    Role::User => crate::conversation::message::Message::user(),
170                    Role::Assistant => crate::conversation::message::Message::assistant(),
171                };
172
173                match msg.content.as_text() {
174                    Some(text) => base.with_text(&text.text),
175                    None => base.with_content(msg.content.clone().into()),
176                }
177            })
178            .collect();
179
180        let system_prompt = params
181            .system_prompt
182            .as_deref()
183            .unwrap_or("You are a general-purpose AI agent called aster");
184
185        // Build model config with sampling parameters
186        let mut model_config = provider.get_model_config();
187
188        // Apply model preferences if provided
189        // MCP model preferences include hints (model name patterns) and priority scores
190        if let Some(prefs) = &params.model_preferences {
191            // Try to find a matching model from hints
192            if let Some(hints) = &prefs.hints {
193                for hint in hints {
194                    if let Some(name) = &hint.name {
195                        // Use the hint name as the model name if it looks like a valid model
196                        // The hint name can be a full model name or a pattern
197                        if !name.is_empty() {
198                            model_config.model_name = name.clone();
199                            break;
200                        }
201                    }
202                }
203            }
204        }
205
206        // Apply maxTokens from the request (required field in MCP sampling)
207        model_config = model_config.with_max_tokens(Some(params.max_tokens as i32));
208
209        // Apply temperature if provided in the request
210        if let Some(temperature) = params.temperature {
211            model_config = model_config.with_temperature(Some(temperature));
212        }
213
214        // Use complete_with_model to apply the custom model config
215        let (response, usage) = provider
216            .complete_with_model(&model_config, system_prompt, &provider_ready_messages, &[])
217            .await
218            .map_err(|e| {
219                ErrorData::new(
220                    ErrorCode::INTERNAL_ERROR,
221                    "Unexpected error while completing the prompt",
222                    Some(Value::from(e.to_string())),
223                )
224            })?;
225
226        Ok(CreateMessageResult {
227            model: usage.model,
228            stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()),
229            message: SamplingMessage {
230                role: Role::Assistant,
231                // TODO(alexhancock): MCP sampling currently only supports one content on each SamplingMessage
232                // https://modelcontextprotocol.io/specification/draft/client/sampling#messages
233                // This doesn't mesh well with aster's approach which has Vec<MessageContent>
234                // There is a proposal to MCP which is agreed to go in the next version to have SamplingMessages support multiple content parts
235                // https://github.com/modelcontextprotocol/modelcontextprotocol/pull/198
236                // Until that is formalized, we can take the first message content from the provider and use it
237                content: if let Some(content) = response.content.first() {
238                    match content {
239                        crate::conversation::message::MessageContent::Text(text) => {
240                            Content::text(&text.text)
241                        }
242                        crate::conversation::message::MessageContent::Image(img) => {
243                            Content::image(&img.data, &img.mime_type)
244                        }
245                        // TODO(alexhancock) - Content::Audio? aster's messages don't currently have it
246                        _ => Content::text(""),
247                    }
248                } else {
249                    Content::text("")
250                },
251            },
252        })
253    }
254
255    async fn create_elicitation(
256        &self,
257        request: CreateElicitationRequestParam,
258        _context: RequestContext<RoleClient>,
259    ) -> Result<CreateElicitationResult, ErrorData> {
260        let schema_value = serde_json::to_value(&request.requested_schema).map_err(|e| {
261            ErrorData::new(
262                ErrorCode::INTERNAL_ERROR,
263                format!("Failed to serialize elicitation schema: {}", e),
264                None,
265            )
266        })?;
267
268        ActionRequiredManager::global()
269            .request_and_wait(
270                request.message.clone(),
271                schema_value,
272                Duration::from_secs(300),
273            )
274            .await
275            .map(|user_data| CreateElicitationResult {
276                action: ElicitationAction::Accept,
277                content: Some(user_data),
278            })
279            .map_err(|e| {
280                ErrorData::new(
281                    ErrorCode::INTERNAL_ERROR,
282                    format!("Elicitation request timed out or failed: {}", e),
283                    None,
284                )
285            })
286    }
287
288    fn get_info(&self) -> ClientInfo {
289        ClientInfo {
290            protocol_version: ProtocolVersion::V_2025_03_26,
291            capabilities: ClientCapabilities::builder()
292                .enable_sampling()
293                .enable_elicitation()
294                .build(),
295            client_info: Implementation {
296                name: "aster".to_string(),
297                version: std::env::var("ASTER_MCP_CLIENT_VERSION")
298                    .unwrap_or(env!("CARGO_PKG_VERSION").to_owned()),
299                icons: None,
300                title: None,
301                website_url: None,
302            },
303        }
304    }
305}
306
307/// The MCP client is the interface for MCP operations.
308pub struct McpClient {
309    client: Mutex<RunningService<RoleClient, AsterClient>>,
310    notification_subscribers: Arc<Mutex<Vec<mpsc::Sender<ServerNotification>>>>,
311    server_info: Option<InitializeResult>,
312    timeout: std::time::Duration,
313}
314
315impl McpClient {
316    pub async fn connect<T, E, A>(
317        transport: T,
318        timeout: std::time::Duration,
319        provider: SharedProvider,
320    ) -> Result<Self, ClientInitializeError>
321    where
322        T: IntoTransport<RoleClient, E, A>,
323        E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
324    {
325        let notification_subscribers =
326            Arc::new(Mutex::new(Vec::<mpsc::Sender<ServerNotification>>::new()));
327
328        let client = AsterClient::new(notification_subscribers.clone(), provider);
329        let client: rmcp::service::RunningService<rmcp::RoleClient, AsterClient> =
330            client.serve(transport).await?;
331        let server_info = client.peer_info().cloned();
332
333        Ok(Self {
334            client: Mutex::new(client),
335            notification_subscribers,
336            server_info,
337            timeout,
338        })
339    }
340
341    async fn send_request(
342        &self,
343        request: ClientRequest,
344        cancel_token: CancellationToken,
345    ) -> Result<ServerResult, Error> {
346        let handle = self
347            .client
348            .lock()
349            .await
350            .send_cancellable_request(request, PeerRequestOptions::no_options())
351            .await?;
352
353        await_response(handle, self.timeout, &cancel_token).await
354    }
355}
356
357async fn await_response(
358    handle: RequestHandle<RoleClient>,
359    timeout: Duration,
360    cancel_token: &CancellationToken,
361) -> Result<<RoleClient as ServiceRole>::PeerResp, ServiceError> {
362    let receiver = handle.rx;
363    let peer = handle.peer;
364    let request_id = handle.id;
365    tokio::select! {
366        result = receiver => {
367            result.map_err(|_e| ServiceError::TransportClosed)?
368        }
369        _ = tokio::time::sleep(timeout) => {
370            send_cancel_message(&peer, request_id, Some("timed out".to_owned())).await?;
371            Err(ServiceError::Timeout{timeout})
372        }
373        _ = cancel_token.cancelled() => {
374            send_cancel_message(&peer, request_id, Some("operation cancelled".to_owned())).await?;
375            Err(ServiceError::Cancelled { reason: None })
376        }
377    }
378}
379
380async fn send_cancel_message(
381    peer: &Peer<RoleClient>,
382    request_id: RequestId,
383    reason: Option<String>,
384) -> Result<(), ServiceError> {
385    peer.send_notification(
386        CancelledNotification {
387            params: CancelledNotificationParam { request_id, reason },
388            method: CancelledNotificationMethod,
389            extensions: Default::default(),
390        }
391        .into(),
392    )
393    .await
394}
395
396#[async_trait::async_trait]
397impl McpClientTrait for McpClient {
398    fn get_info(&self) -> Option<&InitializeResult> {
399        self.server_info.as_ref()
400    }
401
402    async fn list_resources(
403        &self,
404        cursor: Option<String>,
405        cancel_token: CancellationToken,
406    ) -> Result<ListResourcesResult, Error> {
407        let res = self
408            .send_request(
409                ClientRequest::ListResourcesRequest(ListResourcesRequest {
410                    params: Some(PaginatedRequestParam { cursor }),
411                    method: Default::default(),
412                    extensions: inject_session_into_extensions(Default::default()),
413                }),
414                cancel_token,
415            )
416            .await?;
417
418        match res {
419            ServerResult::ListResourcesResult(result) => Ok(result),
420            _ => Err(ServiceError::UnexpectedResponse),
421        }
422    }
423
424    async fn read_resource(
425        &self,
426        uri: &str,
427        cancel_token: CancellationToken,
428    ) -> Result<ReadResourceResult, Error> {
429        let res = self
430            .send_request(
431                ClientRequest::ReadResourceRequest(ReadResourceRequest {
432                    params: ReadResourceRequestParam {
433                        uri: uri.to_string(),
434                    },
435                    method: Default::default(),
436                    extensions: inject_session_into_extensions(Default::default()),
437                }),
438                cancel_token,
439            )
440            .await?;
441
442        match res {
443            ServerResult::ReadResourceResult(result) => Ok(result),
444            _ => Err(ServiceError::UnexpectedResponse),
445        }
446    }
447
448    async fn list_tools(
449        &self,
450        cursor: Option<String>,
451        cancel_token: CancellationToken,
452    ) -> Result<ListToolsResult, Error> {
453        let res = self
454            .send_request(
455                ClientRequest::ListToolsRequest(ListToolsRequest {
456                    params: Some(PaginatedRequestParam { cursor }),
457                    method: Default::default(),
458                    extensions: inject_session_into_extensions(Default::default()),
459                }),
460                cancel_token,
461            )
462            .await?;
463
464        match res {
465            ServerResult::ListToolsResult(result) => Ok(result),
466            _ => Err(ServiceError::UnexpectedResponse),
467        }
468    }
469
470    async fn call_tool(
471        &self,
472        name: &str,
473        arguments: Option<JsonObject>,
474        cancel_token: CancellationToken,
475    ) -> Result<CallToolResult, Error> {
476        let res = self
477            .send_request(
478                ClientRequest::CallToolRequest(CallToolRequest {
479                    params: CallToolRequestParam {
480                        name: name.to_string().into(),
481                        arguments,
482                    },
483                    method: Default::default(),
484                    extensions: inject_session_into_extensions(Default::default()),
485                }),
486                cancel_token,
487            )
488            .await?;
489
490        match res {
491            ServerResult::CallToolResult(result) => Ok(result),
492            _ => Err(ServiceError::UnexpectedResponse),
493        }
494    }
495
496    async fn list_prompts(
497        &self,
498        cursor: Option<String>,
499        cancel_token: CancellationToken,
500    ) -> Result<ListPromptsResult, Error> {
501        let res = self
502            .send_request(
503                ClientRequest::ListPromptsRequest(ListPromptsRequest {
504                    params: Some(PaginatedRequestParam { cursor }),
505                    method: Default::default(),
506                    extensions: inject_session_into_extensions(Default::default()),
507                }),
508                cancel_token,
509            )
510            .await?;
511
512        match res {
513            ServerResult::ListPromptsResult(result) => Ok(result),
514            _ => Err(ServiceError::UnexpectedResponse),
515        }
516    }
517
518    async fn get_prompt(
519        &self,
520        name: &str,
521        arguments: Value,
522        cancel_token: CancellationToken,
523    ) -> Result<GetPromptResult, Error> {
524        let arguments = match arguments {
525            Value::Object(map) => Some(map),
526            _ => None,
527        };
528        let res = self
529            .send_request(
530                ClientRequest::GetPromptRequest(GetPromptRequest {
531                    params: GetPromptRequestParam {
532                        name: name.to_string(),
533                        arguments,
534                    },
535                    method: Default::default(),
536                    extensions: inject_session_into_extensions(Default::default()),
537                }),
538                cancel_token,
539            )
540            .await?;
541
542        match res {
543            ServerResult::GetPromptResult(result) => Ok(result),
544            _ => Err(ServiceError::UnexpectedResponse),
545        }
546    }
547
548    async fn subscribe(&self) -> mpsc::Receiver<ServerNotification> {
549        let (tx, rx) = mpsc::channel(16);
550        self.notification_subscribers.lock().await.push(tx);
551        rx
552    }
553}
554
555/// Replaces session ID, case-insensitively, in Extensions._meta.
556fn inject_session_into_extensions(
557    mut extensions: rmcp::model::Extensions,
558) -> rmcp::model::Extensions {
559    use rmcp::model::Meta;
560
561    if let Some(session_id) = crate::session_context::current_session_id() {
562        let mut meta_map = extensions
563            .get::<Meta>()
564            .map(|meta| meta.0.clone())
565            .unwrap_or_default();
566
567        // JsonObject is case-sensitive, so we use retain for case-insensitive removal
568        meta_map.retain(|k, _| !k.eq_ignore_ascii_case(SESSION_ID_HEADER));
569
570        meta_map.insert(SESSION_ID_HEADER.to_string(), Value::String(session_id));
571
572        extensions.insert(Meta(meta_map));
573    }
574
575    extensions
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581    use rmcp::model::Meta;
582
583    #[tokio::test]
584    async fn test_session_id_in_mcp_meta() {
585        use serde_json::json;
586
587        let session_id = "test-session-789";
588        crate::session_context::with_session_id(Some(session_id.to_string()), async {
589            let extensions = inject_session_into_extensions(Default::default());
590            let meta = extensions.get::<Meta>().unwrap();
591
592            assert_eq!(
593                &meta.0,
594                json!({
595                    SESSION_ID_HEADER: session_id
596                })
597                .as_object()
598                .unwrap()
599            );
600        })
601        .await;
602    }
603
604    #[tokio::test]
605    async fn test_no_session_id_in_mcp_when_absent() {
606        let extensions = inject_session_into_extensions(Default::default());
607        let meta = extensions.get::<Meta>();
608
609        assert!(meta.is_none());
610    }
611
612    #[tokio::test]
613    async fn test_all_mcp_operations_include_session() {
614        use serde_json::json;
615
616        let session_id = "consistent-session-id";
617        crate::session_context::with_session_id(Some(session_id.to_string()), async {
618            let ext1 = inject_session_into_extensions(Default::default());
619            let ext2 = inject_session_into_extensions(Default::default());
620            let ext3 = inject_session_into_extensions(Default::default());
621
622            for ext in [&ext1, &ext2, &ext3] {
623                assert_eq!(
624                    &ext.get::<Meta>().unwrap().0,
625                    json!({
626                        SESSION_ID_HEADER: session_id
627                    })
628                    .as_object()
629                    .unwrap()
630                );
631            }
632        })
633        .await;
634    }
635
636    #[tokio::test]
637    async fn test_session_id_case_insensitive_replacement() {
638        use rmcp::model::{Extensions, Meta};
639        use serde_json::{from_value, json};
640
641        let session_id = "new-session-id";
642        crate::session_context::with_session_id(Some(session_id.to_string()), async {
643            let mut extensions = Extensions::new();
644            extensions.insert(
645                from_value::<Meta>(json!({
646                    "ASTER-SESSION-ID": "old-session-1",
647                    "Aster-Session-Id": "old-session-2",
648                    "other-key": "preserve-me"
649                }))
650                .unwrap(),
651            );
652
653            let extensions = inject_session_into_extensions(extensions);
654            let meta = extensions.get::<Meta>().unwrap();
655
656            assert_eq!(
657                &meta.0,
658                json!({
659                    SESSION_ID_HEADER: session_id,
660                    "other-key": "preserve-me"
661                })
662                .as_object()
663                .unwrap()
664            );
665        })
666        .await;
667    }
668}