1use crate::action_required_manager::ActionRequiredManager;
2use crate::agents::types::SharedProvider;
3use crate::session_context::SESSION_ID_HEADER;
4use rmcp::model::{
5 Content, CreateElicitationRequestParam, CreateElicitationResult, ElicitationAction, ErrorCode,
6 JsonObject,
7};
8use rmcp::{
10 model::{
11 CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification,
12 CancelledNotificationMethod, CancelledNotificationParam, ClientCapabilities, ClientInfo,
13 ClientRequest, CreateMessageRequestParam, CreateMessageResult, GetPromptRequest,
14 GetPromptRequestParam, GetPromptResult, Implementation, InitializeResult,
15 ListPromptsRequest, ListPromptsResult, ListResourcesRequest, ListResourcesResult,
16 ListToolsRequest, ListToolsResult, LoggingMessageNotification,
17 LoggingMessageNotificationMethod, PaginatedRequestParam, ProgressNotification,
18 ProgressNotificationMethod, ProtocolVersion, ReadResourceRequest, ReadResourceRequestParam,
19 ReadResourceResult, RequestId, Role, SamplingMessage, ServerNotification, ServerResult,
20 },
21 service::{
22 ClientInitializeError, PeerRequestOptions, RequestContext, RequestHandle, RunningService,
23 ServiceRole,
24 },
25 transport::IntoTransport,
26 ClientHandler, ErrorData, Peer, RoleClient, ServiceError, ServiceExt,
27};
28use serde_json::Value;
29use std::{sync::Arc, time::Duration};
30use tokio::sync::{
31 mpsc::{self, Sender},
32 Mutex,
33};
34use tokio_util::sync::CancellationToken;
35
36pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
37
38pub type Error = rmcp::ServiceError;
39
40#[async_trait::async_trait]
41pub trait McpClientTrait: Send + Sync {
42 async fn list_resources(
43 &self,
44 next_cursor: Option<String>,
45 cancel_token: CancellationToken,
46 ) -> Result<ListResourcesResult, Error>;
47
48 async fn read_resource(
49 &self,
50 uri: &str,
51 cancel_token: CancellationToken,
52 ) -> Result<ReadResourceResult, Error>;
53
54 async fn list_tools(
55 &self,
56 next_cursor: Option<String>,
57 cancel_token: CancellationToken,
58 ) -> Result<ListToolsResult, Error>;
59
60 async fn call_tool(
61 &self,
62 name: &str,
63 arguments: Option<JsonObject>,
64 cancel_token: CancellationToken,
65 ) -> Result<CallToolResult, Error>;
66
67 async fn list_prompts(
68 &self,
69 next_cursor: Option<String>,
70 cancel_token: CancellationToken,
71 ) -> Result<ListPromptsResult, Error>;
72
73 async fn get_prompt(
74 &self,
75 name: &str,
76 arguments: Value,
77 cancel_token: CancellationToken,
78 ) -> Result<GetPromptResult, Error>;
79
80 async fn subscribe(&self) -> mpsc::Receiver<ServerNotification>;
81
82 fn get_info(&self) -> Option<&InitializeResult>;
83
84 async fn get_moim(&self) -> Option<String> {
85 None
86 }
87}
88
89pub struct AsterClient {
90 notification_handlers: Arc<Mutex<Vec<Sender<ServerNotification>>>>,
91 provider: SharedProvider,
92}
93
94impl AsterClient {
95 pub fn new(
96 handlers: Arc<Mutex<Vec<Sender<ServerNotification>>>>,
97 provider: SharedProvider,
98 ) -> Self {
99 AsterClient {
100 notification_handlers: handlers,
101 provider,
102 }
103 }
104}
105
106impl ClientHandler for AsterClient {
107 async fn on_progress(
108 &self,
109 params: rmcp::model::ProgressNotificationParam,
110 context: rmcp::service::NotificationContext<rmcp::RoleClient>,
111 ) {
112 self.notification_handlers
113 .lock()
114 .await
115 .iter()
116 .for_each(|handler| {
117 let _ = handler.try_send(ServerNotification::ProgressNotification(
118 ProgressNotification {
119 params: params.clone(),
120 method: ProgressNotificationMethod,
121 extensions: context.extensions.clone(),
122 },
123 ));
124 });
125 }
126
127 async fn on_logging_message(
128 &self,
129 params: rmcp::model::LoggingMessageNotificationParam,
130 context: rmcp::service::NotificationContext<rmcp::RoleClient>,
131 ) {
132 self.notification_handlers
133 .lock()
134 .await
135 .iter()
136 .for_each(|handler| {
137 let _ = handler.try_send(ServerNotification::LoggingMessageNotification(
138 LoggingMessageNotification {
139 params: params.clone(),
140 method: LoggingMessageNotificationMethod,
141 extensions: context.extensions.clone(),
142 },
143 ));
144 });
145 }
146
147 async fn create_message(
148 &self,
149 params: CreateMessageRequestParam,
150 _context: RequestContext<RoleClient>,
151 ) -> Result<CreateMessageResult, ErrorData> {
152 let provider = self
153 .provider
154 .lock()
155 .await
156 .as_ref()
157 .ok_or(ErrorData::new(
158 ErrorCode::INTERNAL_ERROR,
159 "Could not use provider",
160 None,
161 ))?
162 .clone();
163
164 let provider_ready_messages: Vec<crate::conversation::message::Message> = params
165 .messages
166 .iter()
167 .map(|msg| {
168 let base = match msg.role {
169 Role::User => crate::conversation::message::Message::user(),
170 Role::Assistant => crate::conversation::message::Message::assistant(),
171 };
172
173 match msg.content.as_text() {
174 Some(text) => base.with_text(&text.text),
175 None => base.with_content(msg.content.clone().into()),
176 }
177 })
178 .collect();
179
180 let system_prompt = params
181 .system_prompt
182 .as_deref()
183 .unwrap_or("You are a general-purpose AI agent called aster");
184
185 let mut model_config = provider.get_model_config();
187
188 if let Some(prefs) = ¶ms.model_preferences {
191 if let Some(hints) = &prefs.hints {
193 for hint in hints {
194 if let Some(name) = &hint.name {
195 if !name.is_empty() {
198 model_config.model_name = name.clone();
199 break;
200 }
201 }
202 }
203 }
204 }
205
206 model_config = model_config.with_max_tokens(Some(params.max_tokens as i32));
208
209 if let Some(temperature) = params.temperature {
211 model_config = model_config.with_temperature(Some(temperature));
212 }
213
214 let (response, usage) = provider
216 .complete_with_model(&model_config, system_prompt, &provider_ready_messages, &[])
217 .await
218 .map_err(|e| {
219 ErrorData::new(
220 ErrorCode::INTERNAL_ERROR,
221 "Unexpected error while completing the prompt",
222 Some(Value::from(e.to_string())),
223 )
224 })?;
225
226 Ok(CreateMessageResult {
227 model: usage.model,
228 stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()),
229 message: SamplingMessage {
230 role: Role::Assistant,
231 content: if let Some(content) = response.content.first() {
238 match content {
239 crate::conversation::message::MessageContent::Text(text) => {
240 Content::text(&text.text)
241 }
242 crate::conversation::message::MessageContent::Image(img) => {
243 Content::image(&img.data, &img.mime_type)
244 }
245 _ => Content::text(""),
247 }
248 } else {
249 Content::text("")
250 },
251 },
252 })
253 }
254
255 async fn create_elicitation(
256 &self,
257 request: CreateElicitationRequestParam,
258 _context: RequestContext<RoleClient>,
259 ) -> Result<CreateElicitationResult, ErrorData> {
260 let schema_value = serde_json::to_value(&request.requested_schema).map_err(|e| {
261 ErrorData::new(
262 ErrorCode::INTERNAL_ERROR,
263 format!("Failed to serialize elicitation schema: {}", e),
264 None,
265 )
266 })?;
267
268 ActionRequiredManager::global()
269 .request_and_wait(
270 request.message.clone(),
271 schema_value,
272 Duration::from_secs(300),
273 )
274 .await
275 .map(|user_data| CreateElicitationResult {
276 action: ElicitationAction::Accept,
277 content: Some(user_data),
278 })
279 .map_err(|e| {
280 ErrorData::new(
281 ErrorCode::INTERNAL_ERROR,
282 format!("Elicitation request timed out or failed: {}", e),
283 None,
284 )
285 })
286 }
287
288 fn get_info(&self) -> ClientInfo {
289 ClientInfo {
290 protocol_version: ProtocolVersion::V_2025_03_26,
291 capabilities: ClientCapabilities::builder()
292 .enable_sampling()
293 .enable_elicitation()
294 .build(),
295 client_info: Implementation {
296 name: "aster".to_string(),
297 version: std::env::var("ASTER_MCP_CLIENT_VERSION")
298 .unwrap_or(env!("CARGO_PKG_VERSION").to_owned()),
299 icons: None,
300 title: None,
301 website_url: None,
302 },
303 }
304 }
305}
306
307pub struct McpClient {
309 client: Mutex<RunningService<RoleClient, AsterClient>>,
310 notification_subscribers: Arc<Mutex<Vec<mpsc::Sender<ServerNotification>>>>,
311 server_info: Option<InitializeResult>,
312 timeout: std::time::Duration,
313}
314
315impl McpClient {
316 pub async fn connect<T, E, A>(
317 transport: T,
318 timeout: std::time::Duration,
319 provider: SharedProvider,
320 ) -> Result<Self, ClientInitializeError>
321 where
322 T: IntoTransport<RoleClient, E, A>,
323 E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
324 {
325 let notification_subscribers =
326 Arc::new(Mutex::new(Vec::<mpsc::Sender<ServerNotification>>::new()));
327
328 let client = AsterClient::new(notification_subscribers.clone(), provider);
329 let client: rmcp::service::RunningService<rmcp::RoleClient, AsterClient> =
330 client.serve(transport).await?;
331 let server_info = client.peer_info().cloned();
332
333 Ok(Self {
334 client: Mutex::new(client),
335 notification_subscribers,
336 server_info,
337 timeout,
338 })
339 }
340
341 async fn send_request(
342 &self,
343 request: ClientRequest,
344 cancel_token: CancellationToken,
345 ) -> Result<ServerResult, Error> {
346 let handle = self
347 .client
348 .lock()
349 .await
350 .send_cancellable_request(request, PeerRequestOptions::no_options())
351 .await?;
352
353 await_response(handle, self.timeout, &cancel_token).await
354 }
355}
356
357async fn await_response(
358 handle: RequestHandle<RoleClient>,
359 timeout: Duration,
360 cancel_token: &CancellationToken,
361) -> Result<<RoleClient as ServiceRole>::PeerResp, ServiceError> {
362 let receiver = handle.rx;
363 let peer = handle.peer;
364 let request_id = handle.id;
365 tokio::select! {
366 result = receiver => {
367 result.map_err(|_e| ServiceError::TransportClosed)?
368 }
369 _ = tokio::time::sleep(timeout) => {
370 send_cancel_message(&peer, request_id, Some("timed out".to_owned())).await?;
371 Err(ServiceError::Timeout{timeout})
372 }
373 _ = cancel_token.cancelled() => {
374 send_cancel_message(&peer, request_id, Some("operation cancelled".to_owned())).await?;
375 Err(ServiceError::Cancelled { reason: None })
376 }
377 }
378}
379
380async fn send_cancel_message(
381 peer: &Peer<RoleClient>,
382 request_id: RequestId,
383 reason: Option<String>,
384) -> Result<(), ServiceError> {
385 peer.send_notification(
386 CancelledNotification {
387 params: CancelledNotificationParam { request_id, reason },
388 method: CancelledNotificationMethod,
389 extensions: Default::default(),
390 }
391 .into(),
392 )
393 .await
394}
395
396#[async_trait::async_trait]
397impl McpClientTrait for McpClient {
398 fn get_info(&self) -> Option<&InitializeResult> {
399 self.server_info.as_ref()
400 }
401
402 async fn list_resources(
403 &self,
404 cursor: Option<String>,
405 cancel_token: CancellationToken,
406 ) -> Result<ListResourcesResult, Error> {
407 let res = self
408 .send_request(
409 ClientRequest::ListResourcesRequest(ListResourcesRequest {
410 params: Some(PaginatedRequestParam { cursor }),
411 method: Default::default(),
412 extensions: inject_session_into_extensions(Default::default()),
413 }),
414 cancel_token,
415 )
416 .await?;
417
418 match res {
419 ServerResult::ListResourcesResult(result) => Ok(result),
420 _ => Err(ServiceError::UnexpectedResponse),
421 }
422 }
423
424 async fn read_resource(
425 &self,
426 uri: &str,
427 cancel_token: CancellationToken,
428 ) -> Result<ReadResourceResult, Error> {
429 let res = self
430 .send_request(
431 ClientRequest::ReadResourceRequest(ReadResourceRequest {
432 params: ReadResourceRequestParam {
433 uri: uri.to_string(),
434 },
435 method: Default::default(),
436 extensions: inject_session_into_extensions(Default::default()),
437 }),
438 cancel_token,
439 )
440 .await?;
441
442 match res {
443 ServerResult::ReadResourceResult(result) => Ok(result),
444 _ => Err(ServiceError::UnexpectedResponse),
445 }
446 }
447
448 async fn list_tools(
449 &self,
450 cursor: Option<String>,
451 cancel_token: CancellationToken,
452 ) -> Result<ListToolsResult, Error> {
453 let res = self
454 .send_request(
455 ClientRequest::ListToolsRequest(ListToolsRequest {
456 params: Some(PaginatedRequestParam { cursor }),
457 method: Default::default(),
458 extensions: inject_session_into_extensions(Default::default()),
459 }),
460 cancel_token,
461 )
462 .await?;
463
464 match res {
465 ServerResult::ListToolsResult(result) => Ok(result),
466 _ => Err(ServiceError::UnexpectedResponse),
467 }
468 }
469
470 async fn call_tool(
471 &self,
472 name: &str,
473 arguments: Option<JsonObject>,
474 cancel_token: CancellationToken,
475 ) -> Result<CallToolResult, Error> {
476 let res = self
477 .send_request(
478 ClientRequest::CallToolRequest(CallToolRequest {
479 params: CallToolRequestParam {
480 name: name.to_string().into(),
481 arguments,
482 },
483 method: Default::default(),
484 extensions: inject_session_into_extensions(Default::default()),
485 }),
486 cancel_token,
487 )
488 .await?;
489
490 match res {
491 ServerResult::CallToolResult(result) => Ok(result),
492 _ => Err(ServiceError::UnexpectedResponse),
493 }
494 }
495
496 async fn list_prompts(
497 &self,
498 cursor: Option<String>,
499 cancel_token: CancellationToken,
500 ) -> Result<ListPromptsResult, Error> {
501 let res = self
502 .send_request(
503 ClientRequest::ListPromptsRequest(ListPromptsRequest {
504 params: Some(PaginatedRequestParam { cursor }),
505 method: Default::default(),
506 extensions: inject_session_into_extensions(Default::default()),
507 }),
508 cancel_token,
509 )
510 .await?;
511
512 match res {
513 ServerResult::ListPromptsResult(result) => Ok(result),
514 _ => Err(ServiceError::UnexpectedResponse),
515 }
516 }
517
518 async fn get_prompt(
519 &self,
520 name: &str,
521 arguments: Value,
522 cancel_token: CancellationToken,
523 ) -> Result<GetPromptResult, Error> {
524 let arguments = match arguments {
525 Value::Object(map) => Some(map),
526 _ => None,
527 };
528 let res = self
529 .send_request(
530 ClientRequest::GetPromptRequest(GetPromptRequest {
531 params: GetPromptRequestParam {
532 name: name.to_string(),
533 arguments,
534 },
535 method: Default::default(),
536 extensions: inject_session_into_extensions(Default::default()),
537 }),
538 cancel_token,
539 )
540 .await?;
541
542 match res {
543 ServerResult::GetPromptResult(result) => Ok(result),
544 _ => Err(ServiceError::UnexpectedResponse),
545 }
546 }
547
548 async fn subscribe(&self) -> mpsc::Receiver<ServerNotification> {
549 let (tx, rx) = mpsc::channel(16);
550 self.notification_subscribers.lock().await.push(tx);
551 rx
552 }
553}
554
555fn inject_session_into_extensions(
557 mut extensions: rmcp::model::Extensions,
558) -> rmcp::model::Extensions {
559 use rmcp::model::Meta;
560
561 if let Some(session_id) = crate::session_context::current_session_id() {
562 let mut meta_map = extensions
563 .get::<Meta>()
564 .map(|meta| meta.0.clone())
565 .unwrap_or_default();
566
567 meta_map.retain(|k, _| !k.eq_ignore_ascii_case(SESSION_ID_HEADER));
569
570 meta_map.insert(SESSION_ID_HEADER.to_string(), Value::String(session_id));
571
572 extensions.insert(Meta(meta_map));
573 }
574
575 extensions
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use rmcp::model::Meta;
582
583 #[tokio::test]
584 async fn test_session_id_in_mcp_meta() {
585 use serde_json::json;
586
587 let session_id = "test-session-789";
588 crate::session_context::with_session_id(Some(session_id.to_string()), async {
589 let extensions = inject_session_into_extensions(Default::default());
590 let meta = extensions.get::<Meta>().unwrap();
591
592 assert_eq!(
593 &meta.0,
594 json!({
595 SESSION_ID_HEADER: session_id
596 })
597 .as_object()
598 .unwrap()
599 );
600 })
601 .await;
602 }
603
604 #[tokio::test]
605 async fn test_no_session_id_in_mcp_when_absent() {
606 let extensions = inject_session_into_extensions(Default::default());
607 let meta = extensions.get::<Meta>();
608
609 assert!(meta.is_none());
610 }
611
612 #[tokio::test]
613 async fn test_all_mcp_operations_include_session() {
614 use serde_json::json;
615
616 let session_id = "consistent-session-id";
617 crate::session_context::with_session_id(Some(session_id.to_string()), async {
618 let ext1 = inject_session_into_extensions(Default::default());
619 let ext2 = inject_session_into_extensions(Default::default());
620 let ext3 = inject_session_into_extensions(Default::default());
621
622 for ext in [&ext1, &ext2, &ext3] {
623 assert_eq!(
624 &ext.get::<Meta>().unwrap().0,
625 json!({
626 SESSION_ID_HEADER: session_id
627 })
628 .as_object()
629 .unwrap()
630 );
631 }
632 })
633 .await;
634 }
635
636 #[tokio::test]
637 async fn test_session_id_case_insensitive_replacement() {
638 use rmcp::model::{Extensions, Meta};
639 use serde_json::{from_value, json};
640
641 let session_id = "new-session-id";
642 crate::session_context::with_session_id(Some(session_id.to_string()), async {
643 let mut extensions = Extensions::new();
644 extensions.insert(
645 from_value::<Meta>(json!({
646 "ASTER-SESSION-ID": "old-session-1",
647 "Aster-Session-Id": "old-session-2",
648 "other-key": "preserve-me"
649 }))
650 .unwrap(),
651 );
652
653 let extensions = inject_session_into_extensions(extensions);
654 let meta = extensions.get::<Meta>().unwrap();
655
656 assert_eq!(
657 &meta.0,
658 json!({
659 SESSION_ID_HEADER: session_id,
660 "other-key": "preserve-me"
661 })
662 .as_object()
663 .unwrap()
664 );
665 })
666 .await;
667 }
668}