1use std::future::Future;
7use std::pin::Pin;
8use std::sync::{Arc, Mutex};
9
10use asupersync::time::wall_now;
11use asupersync::types::CancelReason;
12use asupersync::{Budget, CancelKind, Cx, Outcome, RegionId, TaskId};
13
14use crate::{AuthContext, SessionState};
15
16pub trait NotificationSender: Send + Sync {
25 fn send_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>);
33}
34
35pub trait SamplingSender: Send + Sync {
45 fn create_message(
56 &self,
57 request: SamplingRequest,
58 ) -> std::pin::Pin<
59 Box<dyn std::future::Future<Output = crate::McpResult<SamplingResponse>> + Send + '_>,
60 >;
61}
62
63#[derive(Debug, Clone)]
65pub struct SamplingRequest {
66 pub messages: Vec<SamplingRequestMessage>,
68 pub max_tokens: u32,
70 pub system_prompt: Option<String>,
72 pub temperature: Option<f64>,
74 pub stop_sequences: Vec<String>,
76 pub model_hints: Vec<String>,
78}
79
80impl SamplingRequest {
81 #[must_use]
83 pub fn new(messages: Vec<SamplingRequestMessage>, max_tokens: u32) -> Self {
84 Self {
85 messages,
86 max_tokens,
87 system_prompt: None,
88 temperature: None,
89 stop_sequences: Vec::new(),
90 model_hints: Vec::new(),
91 }
92 }
93
94 #[must_use]
96 pub fn prompt(text: impl Into<String>, max_tokens: u32) -> Self {
97 Self::new(vec![SamplingRequestMessage::user(text)], max_tokens)
98 }
99
100 #[must_use]
102 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
103 self.system_prompt = Some(prompt.into());
104 self
105 }
106
107 #[must_use]
109 pub fn with_temperature(mut self, temp: f64) -> Self {
110 self.temperature = Some(temp);
111 self
112 }
113
114 #[must_use]
116 pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
117 self.stop_sequences = sequences;
118 self
119 }
120
121 #[must_use]
123 pub fn with_model_hints(mut self, hints: Vec<String>) -> Self {
124 self.model_hints = hints;
125 self
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct SamplingRequestMessage {
132 pub role: SamplingRole,
134 pub text: String,
136}
137
138impl SamplingRequestMessage {
139 #[must_use]
141 pub fn user(text: impl Into<String>) -> Self {
142 Self {
143 role: SamplingRole::User,
144 text: text.into(),
145 }
146 }
147
148 #[must_use]
150 pub fn assistant(text: impl Into<String>) -> Self {
151 Self {
152 role: SamplingRole::Assistant,
153 text: text.into(),
154 }
155 }
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum SamplingRole {
161 User,
163 Assistant,
165}
166
167#[derive(Debug, Clone)]
169pub struct SamplingResponse {
170 pub text: String,
172 pub model: String,
174 pub stop_reason: SamplingStopReason,
176}
177
178impl SamplingResponse {
179 #[must_use]
181 pub fn new(text: impl Into<String>, model: impl Into<String>) -> Self {
182 Self {
183 text: text.into(),
184 model: model.into(),
185 stop_reason: SamplingStopReason::EndTurn,
186 }
187 }
188}
189
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
192pub enum SamplingStopReason {
193 #[default]
195 EndTurn,
196 StopSequence,
198 MaxTokens,
200}
201
202#[derive(Debug, Clone, Copy, Default)]
206pub struct NoOpSamplingSender;
207
208impl SamplingSender for NoOpSamplingSender {
209 fn create_message(
210 &self,
211 _request: SamplingRequest,
212 ) -> std::pin::Pin<
213 Box<dyn std::future::Future<Output = crate::McpResult<SamplingResponse>> + Send + '_>,
214 > {
215 Box::pin(async {
216 Err(crate::McpError::new(
217 crate::McpErrorCode::InvalidRequest,
218 "Sampling not supported: client does not have sampling capability",
219 ))
220 })
221 }
222}
223
224pub trait ElicitationSender: Send + Sync {
234 fn elicit(
245 &self,
246 request: ElicitationRequest,
247 ) -> std::pin::Pin<
248 Box<dyn std::future::Future<Output = crate::McpResult<ElicitationResponse>> + Send + '_>,
249 >;
250}
251
252#[derive(Debug, Clone)]
254pub struct ElicitationRequest {
255 pub mode: ElicitationMode,
257 pub message: String,
259 pub schema: Option<serde_json::Value>,
261 pub url: Option<String>,
263 pub elicitation_id: Option<String>,
265}
266
267impl ElicitationRequest {
268 #[must_use]
270 pub fn form(message: impl Into<String>, schema: serde_json::Value) -> Self {
271 Self {
272 mode: ElicitationMode::Form,
273 message: message.into(),
274 schema: Some(schema),
275 url: None,
276 elicitation_id: None,
277 }
278 }
279
280 #[must_use]
282 pub fn url(
283 message: impl Into<String>,
284 url: impl Into<String>,
285 elicitation_id: impl Into<String>,
286 ) -> Self {
287 Self {
288 mode: ElicitationMode::Url,
289 message: message.into(),
290 schema: None,
291 url: Some(url.into()),
292 elicitation_id: Some(elicitation_id.into()),
293 }
294 }
295}
296
297#[derive(Debug, Clone, Copy, PartialEq, Eq)]
299pub enum ElicitationMode {
300 Form,
302 Url,
304}
305
306#[derive(Debug, Clone)]
308pub struct ElicitationResponse {
309 pub action: ElicitationAction,
311 pub content: Option<std::collections::HashMap<String, serde_json::Value>>,
313}
314
315impl ElicitationResponse {
316 #[must_use]
318 pub fn accept(content: std::collections::HashMap<String, serde_json::Value>) -> Self {
319 Self {
320 action: ElicitationAction::Accept,
321 content: Some(content),
322 }
323 }
324
325 #[must_use]
327 pub fn accept_url() -> Self {
328 Self {
329 action: ElicitationAction::Accept,
330 content: None,
331 }
332 }
333
334 #[must_use]
336 pub fn decline() -> Self {
337 Self {
338 action: ElicitationAction::Decline,
339 content: None,
340 }
341 }
342
343 #[must_use]
345 pub fn cancel() -> Self {
346 Self {
347 action: ElicitationAction::Cancel,
348 content: None,
349 }
350 }
351
352 #[must_use]
354 pub fn is_accepted(&self) -> bool {
355 matches!(self.action, ElicitationAction::Accept)
356 }
357
358 #[must_use]
360 pub fn is_declined(&self) -> bool {
361 matches!(self.action, ElicitationAction::Decline)
362 }
363
364 #[must_use]
366 pub fn is_cancelled(&self) -> bool {
367 matches!(self.action, ElicitationAction::Cancel)
368 }
369
370 #[must_use]
372 pub fn get_string(&self, key: &str) -> Option<&str> {
373 self.content.as_ref()?.get(key)?.as_str()
374 }
375
376 #[must_use]
378 pub fn get_bool(&self, key: &str) -> Option<bool> {
379 self.content.as_ref()?.get(key)?.as_bool()
380 }
381
382 #[must_use]
384 pub fn get_int(&self, key: &str) -> Option<i64> {
385 self.content.as_ref()?.get(key)?.as_i64()
386 }
387}
388
389#[derive(Debug, Clone, Copy, PartialEq, Eq)]
391pub enum ElicitationAction {
392 Accept,
394 Decline,
396 Cancel,
398}
399
400#[derive(Debug, Clone, Copy, Default)]
404pub struct NoOpElicitationSender;
405
406impl ElicitationSender for NoOpElicitationSender {
407 fn elicit(
408 &self,
409 _request: ElicitationRequest,
410 ) -> std::pin::Pin<
411 Box<dyn std::future::Future<Output = crate::McpResult<ElicitationResponse>> + Send + '_>,
412 > {
413 Box::pin(async {
414 Err(crate::McpError::new(
415 crate::McpErrorCode::InvalidRequest,
416 "Elicitation not supported: client does not have elicitation capability",
417 ))
418 })
419 }
420}
421
422pub const MAX_RESOURCE_READ_DEPTH: u32 = 10;
428
429#[derive(Debug, Clone)]
434pub struct ResourceContentItem {
435 pub uri: String,
437 pub mime_type: Option<String>,
439 pub text: Option<String>,
441 pub blob: Option<String>,
443}
444
445impl ResourceContentItem {
446 #[must_use]
448 pub fn text(uri: impl Into<String>, text: impl Into<String>) -> Self {
449 Self {
450 uri: uri.into(),
451 mime_type: Some("text/plain".to_string()),
452 text: Some(text.into()),
453 blob: None,
454 }
455 }
456
457 #[must_use]
459 pub fn json(uri: impl Into<String>, text: impl Into<String>) -> Self {
460 Self {
461 uri: uri.into(),
462 mime_type: Some("application/json".to_string()),
463 text: Some(text.into()),
464 blob: None,
465 }
466 }
467
468 #[must_use]
470 pub fn blob(
471 uri: impl Into<String>,
472 mime_type: impl Into<String>,
473 blob: impl Into<String>,
474 ) -> Self {
475 Self {
476 uri: uri.into(),
477 mime_type: Some(mime_type.into()),
478 text: None,
479 blob: Some(blob.into()),
480 }
481 }
482
483 #[must_use]
485 pub fn as_text(&self) -> Option<&str> {
486 self.text.as_deref()
487 }
488
489 #[must_use]
491 pub fn as_blob(&self) -> Option<&str> {
492 self.blob.as_deref()
493 }
494
495 #[must_use]
497 pub fn is_text(&self) -> bool {
498 self.text.is_some()
499 }
500
501 #[must_use]
503 pub fn is_blob(&self) -> bool {
504 self.blob.is_some()
505 }
506}
507
508#[derive(Debug, Clone)]
510pub struct ResourceReadResult {
511 pub contents: Vec<ResourceContentItem>,
513}
514
515impl ResourceReadResult {
516 #[must_use]
518 pub fn new(contents: Vec<ResourceContentItem>) -> Self {
519 Self { contents }
520 }
521
522 #[must_use]
524 pub fn text(uri: impl Into<String>, text: impl Into<String>) -> Self {
525 Self {
526 contents: vec![ResourceContentItem::text(uri, text)],
527 }
528 }
529
530 #[must_use]
532 pub fn first_text(&self) -> Option<&str> {
533 self.contents.first().and_then(|c| c.as_text())
534 }
535
536 #[must_use]
538 pub fn first_blob(&self) -> Option<&str> {
539 self.contents.first().and_then(|c| c.as_blob())
540 }
541}
542
543pub trait ResourceReader: Send + Sync {
552 fn read_resource(
565 &self,
566 cx: &Cx,
567 uri: &str,
568 auth: Option<AuthContext>,
569 depth: u32,
570 ) -> Pin<Box<dyn Future<Output = crate::McpResult<ResourceReadResult>> + Send + '_>>;
571}
572
573pub const MAX_TOOL_CALL_DEPTH: u32 = 10;
579
580#[derive(Debug, Clone)]
585pub enum ToolContentItem {
586 Text {
588 text: String,
590 },
591 Image {
593 data: String,
595 mime_type: String,
597 },
598 Audio {
600 data: String,
602 mime_type: String,
604 },
605 Resource {
607 uri: String,
609 mime_type: Option<String>,
611 text: Option<String>,
613 blob: Option<String>,
615 },
616}
617
618impl ToolContentItem {
619 #[must_use]
621 pub fn text(text: impl Into<String>) -> Self {
622 Self::Text { text: text.into() }
623 }
624
625 #[must_use]
627 pub fn as_text(&self) -> Option<&str> {
628 match self {
629 Self::Text { text } => Some(text),
630 _ => None,
631 }
632 }
633
634 #[must_use]
636 pub fn is_text(&self) -> bool {
637 matches!(self, Self::Text { .. })
638 }
639}
640
641#[derive(Debug, Clone)]
643pub struct ToolCallResult {
644 pub content: Vec<ToolContentItem>,
646 pub is_error: bool,
648}
649
650impl ToolCallResult {
651 #[must_use]
653 pub fn success(content: Vec<ToolContentItem>) -> Self {
654 Self {
655 content,
656 is_error: false,
657 }
658 }
659
660 #[must_use]
662 pub fn text(text: impl Into<String>) -> Self {
663 Self {
664 content: vec![ToolContentItem::text(text)],
665 is_error: false,
666 }
667 }
668
669 #[must_use]
671 pub fn error(message: impl Into<String>) -> Self {
672 Self {
673 content: vec![ToolContentItem::text(message)],
674 is_error: true,
675 }
676 }
677
678 #[must_use]
680 pub fn first_text(&self) -> Option<&str> {
681 self.content.first().and_then(|c| c.as_text())
682 }
683}
684
685pub trait ToolCaller: Send + Sync {
694 fn call_tool(
708 &self,
709 cx: &Cx,
710 name: &str,
711 args: serde_json::Value,
712 auth: Option<AuthContext>,
713 depth: u32,
714 ) -> Pin<Box<dyn Future<Output = crate::McpResult<ToolCallResult>> + Send + '_>>;
715}
716
717#[derive(Debug, Clone, Default)]
726pub struct ClientCapabilityInfo {
727 pub sampling: bool,
729 pub elicitation: bool,
731 pub elicitation_form: bool,
733 pub elicitation_url: bool,
735 pub roots: bool,
737 pub roots_list_changed: bool,
739}
740
741impl ClientCapabilityInfo {
742 #[must_use]
744 pub fn new() -> Self {
745 Self::default()
746 }
747
748 #[must_use]
750 pub fn with_sampling(mut self) -> Self {
751 self.sampling = true;
752 self
753 }
754
755 #[must_use]
757 pub fn with_elicitation(mut self, form: bool, url: bool) -> Self {
758 self.elicitation = form || url;
759 self.elicitation_form = form;
760 self.elicitation_url = url;
761 self
762 }
763
764 #[must_use]
766 pub fn with_roots(mut self, list_changed: bool) -> Self {
767 self.roots = true;
768 self.roots_list_changed = list_changed;
769 self
770 }
771}
772
773#[derive(Debug, Clone, Default)]
777pub struct ServerCapabilityInfo {
778 pub tools: bool,
780 pub resources: bool,
782 pub resources_subscribe: bool,
784 pub prompts: bool,
786 pub logging: bool,
788}
789
790impl ServerCapabilityInfo {
791 #[must_use]
793 pub fn new() -> Self {
794 Self::default()
795 }
796
797 #[must_use]
799 pub fn with_tools(mut self) -> Self {
800 self.tools = true;
801 self
802 }
803
804 #[must_use]
806 pub fn with_resources(mut self, subscribe: bool) -> Self {
807 self.resources = true;
808 self.resources_subscribe = subscribe;
809 self
810 }
811
812 #[must_use]
814 pub fn with_prompts(mut self) -> Self {
815 self.prompts = true;
816 self
817 }
818
819 #[must_use]
821 pub fn with_logging(mut self) -> Self {
822 self.logging = true;
823 self
824 }
825}
826
827#[derive(Debug, Clone, Copy, Default)]
829pub struct NoOpNotificationSender;
830
831impl NotificationSender for NoOpNotificationSender {
832 fn send_progress(&self, _progress: f64, _total: Option<f64>, _message: Option<&str>) {
833 }
835}
836
837#[derive(Clone)]
842pub struct ProgressReporter {
843 sender: Arc<dyn NotificationSender>,
844}
845
846impl ProgressReporter {
847 pub fn new(sender: Arc<dyn NotificationSender>) -> Self {
849 Self { sender }
850 }
851
852 pub fn report(&self, progress: f64, message: Option<&str>) {
859 self.sender.send_progress(progress, None, message);
860 }
861
862 pub fn report_with_total(&self, progress: f64, total: f64, message: Option<&str>) {
870 self.sender.send_progress(progress, Some(total), message);
871 }
872}
873
874impl std::fmt::Debug for ProgressReporter {
875 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
876 f.debug_struct("ProgressReporter").finish_non_exhaustive()
877 }
878}
879
880#[derive(Clone)]
918pub struct McpContext {
919 cx: Cx,
921 request_id: u64,
923 progress_reporter: Option<ProgressReporter>,
925 state: Option<SessionState>,
927 auth: Arc<Mutex<Option<AuthContext>>>,
929 sampling_sender: Option<Arc<dyn SamplingSender>>,
931 elicitation_sender: Option<Arc<dyn ElicitationSender>>,
933 resource_reader: Option<Arc<dyn ResourceReader>>,
935 resource_read_depth: u32,
937 tool_caller: Option<Arc<dyn ToolCaller>>,
939 tool_call_depth: u32,
941 client_capabilities: Option<ClientCapabilityInfo>,
943 server_capabilities: Option<ServerCapabilityInfo>,
945}
946
947impl std::fmt::Debug for McpContext {
948 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
949 f.debug_struct("McpContext")
950 .field("cx", &self.cx)
951 .field("request_id", &self.request_id)
952 .field("progress_reporter", &self.progress_reporter)
953 .field("state", &self.state.is_some())
954 .field(
955 "auth",
956 &self
957 .auth
958 .lock()
959 .unwrap_or_else(std::sync::PoisonError::into_inner)
960 .is_some(),
961 )
962 .field("sampling_sender", &self.sampling_sender.is_some())
963 .field("elicitation_sender", &self.elicitation_sender.is_some())
964 .field("resource_reader", &self.resource_reader.is_some())
965 .field("resource_read_depth", &self.resource_read_depth)
966 .field("tool_caller", &self.tool_caller.is_some())
967 .field("tool_call_depth", &self.tool_call_depth)
968 .field("client_capabilities", &self.client_capabilities)
969 .field("server_capabilities", &self.server_capabilities)
970 .finish()
971 }
972}
973
974impl McpContext {
975 #[must_use]
980 pub fn new(cx: Cx, request_id: u64) -> Self {
981 Self {
982 cx,
983 request_id,
984 progress_reporter: None,
985 state: None,
986 auth: Arc::new(Mutex::new(None)),
987 sampling_sender: None,
988 elicitation_sender: None,
989 resource_reader: None,
990 resource_read_depth: 0,
991 tool_caller: None,
992 tool_call_depth: 0,
993 client_capabilities: None,
994 server_capabilities: None,
995 }
996 }
997
998 #[must_use]
1002 pub fn with_state(cx: Cx, request_id: u64, state: SessionState) -> Self {
1003 Self {
1004 cx,
1005 request_id,
1006 progress_reporter: None,
1007 state: Some(state),
1008 auth: Arc::new(Mutex::new(None)),
1009 sampling_sender: None,
1010 elicitation_sender: None,
1011 resource_reader: None,
1012 resource_read_depth: 0,
1013 tool_caller: None,
1014 tool_call_depth: 0,
1015 client_capabilities: None,
1016 server_capabilities: None,
1017 }
1018 }
1019
1020 #[must_use]
1025 pub fn with_progress(cx: Cx, request_id: u64, reporter: ProgressReporter) -> Self {
1026 Self {
1027 cx,
1028 request_id,
1029 progress_reporter: Some(reporter),
1030 state: None,
1031 auth: Arc::new(Mutex::new(None)),
1032 sampling_sender: None,
1033 elicitation_sender: None,
1034 resource_reader: None,
1035 resource_read_depth: 0,
1036 tool_caller: None,
1037 tool_call_depth: 0,
1038 client_capabilities: None,
1039 server_capabilities: None,
1040 }
1041 }
1042
1043 #[must_use]
1045 pub fn with_state_and_progress(
1046 cx: Cx,
1047 request_id: u64,
1048 state: SessionState,
1049 reporter: ProgressReporter,
1050 ) -> Self {
1051 Self {
1052 cx,
1053 request_id,
1054 progress_reporter: Some(reporter),
1055 state: Some(state),
1056 auth: Arc::new(Mutex::new(None)),
1057 sampling_sender: None,
1058 elicitation_sender: None,
1059 resource_reader: None,
1060 resource_read_depth: 0,
1061 tool_caller: None,
1062 tool_call_depth: 0,
1063 client_capabilities: None,
1064 server_capabilities: None,
1065 }
1066 }
1067
1068 #[must_use]
1073 pub fn with_sampling(mut self, sender: Arc<dyn SamplingSender>) -> Self {
1074 self.sampling_sender = Some(sender);
1075 self
1076 }
1077
1078 #[must_use]
1083 pub fn with_elicitation(mut self, sender: Arc<dyn ElicitationSender>) -> Self {
1084 self.elicitation_sender = Some(sender);
1085 self
1086 }
1087
1088 #[must_use]
1093 pub fn with_resource_reader(mut self, reader: Arc<dyn ResourceReader>) -> Self {
1094 self.resource_reader = Some(reader);
1095 self
1096 }
1097
1098 #[must_use]
1103 pub fn with_resource_read_depth(mut self, depth: u32) -> Self {
1104 self.resource_read_depth = depth;
1105 self
1106 }
1107
1108 #[must_use]
1113 pub fn with_tool_caller(mut self, caller: Arc<dyn ToolCaller>) -> Self {
1114 self.tool_caller = Some(caller);
1115 self
1116 }
1117
1118 #[must_use]
1123 pub fn with_tool_call_depth(mut self, depth: u32) -> Self {
1124 self.tool_call_depth = depth;
1125 self
1126 }
1127
1128 #[must_use]
1133 pub fn with_client_capabilities(mut self, capabilities: ClientCapabilityInfo) -> Self {
1134 self.client_capabilities = Some(capabilities);
1135 self
1136 }
1137
1138 #[must_use]
1143 pub fn with_server_capabilities(mut self, capabilities: ServerCapabilityInfo) -> Self {
1144 self.server_capabilities = Some(capabilities);
1145 self
1146 }
1147
1148 #[must_use]
1150 pub fn has_progress_reporter(&self) -> bool {
1151 self.progress_reporter.is_some()
1152 }
1153
1154 pub fn report_progress(&self, progress: f64, message: Option<&str>) {
1177 if let Some(ref reporter) = self.progress_reporter {
1178 reporter.report(progress, message);
1179 }
1180 }
1181
1182 pub fn report_progress_with_total(&self, progress: f64, total: f64, message: Option<&str>) {
1205 if let Some(ref reporter) = self.progress_reporter {
1206 reporter.report_with_total(progress, total, message);
1207 }
1208 }
1209
1210 #[must_use]
1215 pub fn request_id(&self) -> u64 {
1216 self.request_id
1217 }
1218
1219 #[must_use]
1225 pub fn region_id(&self) -> RegionId {
1226 self.cx.region_id()
1227 }
1228
1229 #[must_use]
1231 pub fn task_id(&self) -> TaskId {
1232 self.cx.task_id()
1233 }
1234
1235 #[must_use]
1241 pub fn budget(&self) -> Budget {
1242 self.cx.budget()
1243 }
1244
1245 #[must_use]
1250 pub fn is_cancelled(&self) -> bool {
1251 let budget = self.cx.budget();
1252 self.cx.is_cancel_requested()
1253 || budget.is_exhausted()
1254 || budget.is_past_deadline(wall_now())
1255 }
1256
1257 pub fn checkpoint(&self) -> Result<(), CancelledError> {
1279 self.cx.checkpoint().map_err(|_| CancelledError)?;
1280 let budget = self.cx.budget();
1281 if budget.is_exhausted() {
1282 return Err(CancelledError);
1283 }
1284 if budget.is_past_deadline(wall_now()) {
1285 self.cx.cancel_fast(CancelKind::Deadline);
1288 return Err(CancelledError);
1289 }
1290 Ok(())
1291 }
1292
1293 pub fn masked<F, R>(&self, f: F) -> R
1309 where
1310 F: FnOnce() -> R,
1311 {
1312 self.cx.masked(f)
1313 }
1314
1315 pub fn trace(&self, message: &str) {
1320 self.cx.trace(message);
1321 }
1322
1323 #[must_use]
1328 pub fn cx(&self) -> &Cx {
1329 &self.cx
1330 }
1331
1332 #[must_use]
1355 pub fn get_state<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
1356 self.state.as_ref()?.get(key)
1357 }
1358
1359 #[must_use]
1361 pub fn auth(&self) -> Option<AuthContext> {
1362 self.auth
1363 .lock()
1364 .unwrap_or_else(std::sync::PoisonError::into_inner)
1365 .clone()
1366 }
1367
1368 pub fn set_auth(&self, auth: AuthContext) -> bool {
1372 *self
1373 .auth
1374 .lock()
1375 .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(auth);
1376 true
1377 }
1378
1379 #[must_use]
1381 pub fn with_auth(self, auth: AuthContext) -> Self {
1382 let _ = self.set_auth(auth);
1383 self
1384 }
1385
1386 pub fn set_state<T: serde::Serialize>(&self, key: impl Into<String>, value: T) -> bool {
1403 match &self.state {
1404 Some(state) => state.set(key, value),
1405 None => false,
1406 }
1407 }
1408
1409 pub fn remove_state(&self, key: &str) -> Option<serde_json::Value> {
1415 self.state.as_ref()?.remove(key)
1416 }
1417
1418 #[must_use]
1422 pub fn has_state(&self, key: &str) -> bool {
1423 self.state.as_ref().is_some_and(|s| s.contains(key))
1424 }
1425
1426 #[must_use]
1428 pub fn has_session_state(&self) -> bool {
1429 self.state.is_some()
1430 }
1431
1432 #[must_use]
1441 pub fn client_capabilities(&self) -> Option<&ClientCapabilityInfo> {
1442 self.client_capabilities.as_ref()
1443 }
1444
1445 #[must_use]
1449 pub fn server_capabilities(&self) -> Option<&ServerCapabilityInfo> {
1450 self.server_capabilities.as_ref()
1451 }
1452
1453 #[must_use]
1458 pub fn client_supports_sampling(&self) -> bool {
1459 self.client_capabilities
1460 .as_ref()
1461 .is_some_and(|c| c.sampling)
1462 }
1463
1464 #[must_use]
1469 pub fn client_supports_elicitation(&self) -> bool {
1470 self.client_capabilities
1471 .as_ref()
1472 .is_some_and(|c| c.elicitation)
1473 }
1474
1475 #[must_use]
1477 pub fn client_supports_elicitation_form(&self) -> bool {
1478 self.client_capabilities
1479 .as_ref()
1480 .is_some_and(|c| c.elicitation_form)
1481 }
1482
1483 #[must_use]
1485 pub fn client_supports_elicitation_url(&self) -> bool {
1486 self.client_capabilities
1487 .as_ref()
1488 .is_some_and(|c| c.elicitation_url)
1489 }
1490
1491 #[must_use]
1496 pub fn client_supports_roots(&self) -> bool {
1497 self.client_capabilities.as_ref().is_some_and(|c| c.roots)
1498 }
1499
1500 const DISABLED_TOOLS_KEY: &'static str = "fastmcp.disabled_tools";
1506 const DISABLED_RESOURCES_KEY: &'static str = "fastmcp.disabled_resources";
1508 const DISABLED_PROMPTS_KEY: &'static str = "fastmcp.disabled_prompts";
1510
1511 pub fn disable_tool(&self, name: impl Into<String>) -> bool {
1529 self.add_to_disabled_set(Self::DISABLED_TOOLS_KEY, name.into())
1530 }
1531
1532 pub fn enable_tool(&self, name: &str) -> bool {
1536 self.remove_from_disabled_set(Self::DISABLED_TOOLS_KEY, name)
1537 }
1538
1539 #[must_use]
1543 pub fn is_tool_enabled(&self, name: &str) -> bool {
1544 !self.is_in_disabled_set(Self::DISABLED_TOOLS_KEY, name)
1545 }
1546
1547 pub fn disable_resource(&self, uri: impl Into<String>) -> bool {
1554 self.add_to_disabled_set(Self::DISABLED_RESOURCES_KEY, uri.into())
1555 }
1556
1557 pub fn enable_resource(&self, uri: &str) -> bool {
1561 self.remove_from_disabled_set(Self::DISABLED_RESOURCES_KEY, uri)
1562 }
1563
1564 #[must_use]
1568 pub fn is_resource_enabled(&self, uri: &str) -> bool {
1569 !self.is_in_disabled_set(Self::DISABLED_RESOURCES_KEY, uri)
1570 }
1571
1572 pub fn disable_prompt(&self, name: impl Into<String>) -> bool {
1579 self.add_to_disabled_set(Self::DISABLED_PROMPTS_KEY, name.into())
1580 }
1581
1582 pub fn enable_prompt(&self, name: &str) -> bool {
1586 self.remove_from_disabled_set(Self::DISABLED_PROMPTS_KEY, name)
1587 }
1588
1589 #[must_use]
1593 pub fn is_prompt_enabled(&self, name: &str) -> bool {
1594 !self.is_in_disabled_set(Self::DISABLED_PROMPTS_KEY, name)
1595 }
1596
1597 #[must_use]
1599 pub fn disabled_tools(&self) -> std::collections::HashSet<String> {
1600 self.get_disabled_set(Self::DISABLED_TOOLS_KEY)
1601 }
1602
1603 #[must_use]
1605 pub fn disabled_resources(&self) -> std::collections::HashSet<String> {
1606 self.get_disabled_set(Self::DISABLED_RESOURCES_KEY)
1607 }
1608
1609 #[must_use]
1611 pub fn disabled_prompts(&self) -> std::collections::HashSet<String> {
1612 self.get_disabled_set(Self::DISABLED_PROMPTS_KEY)
1613 }
1614
1615 fn add_to_disabled_set(&self, key: &str, name: String) -> bool {
1617 let Some(state) = self.state.as_ref() else {
1618 return false;
1619 };
1620 let mut set: std::collections::HashSet<String> = state.get(key).unwrap_or_default();
1621 set.insert(name);
1622 state.set(key, set)
1623 }
1624
1625 fn remove_from_disabled_set(&self, key: &str, name: &str) -> bool {
1627 let Some(state) = self.state.as_ref() else {
1628 return false;
1629 };
1630 let mut set: std::collections::HashSet<String> = state.get(key).unwrap_or_default();
1631 set.remove(name);
1632 state.set(key, set)
1633 }
1634
1635 fn is_in_disabled_set(&self, key: &str, name: &str) -> bool {
1637 let Some(state) = self.state.as_ref() else {
1638 return false;
1639 };
1640 let set: std::collections::HashSet<String> = state.get(key).unwrap_or_default();
1641 set.contains(name)
1642 }
1643
1644 fn get_disabled_set(&self, key: &str) -> std::collections::HashSet<String> {
1646 self.state
1647 .as_ref()
1648 .and_then(|s| s.get(key))
1649 .unwrap_or_default()
1650 }
1651
1652 #[must_use]
1661 pub fn can_sample(&self) -> bool {
1662 self.sampling_sender.is_some()
1663 }
1664
1665 pub async fn sample(
1690 &self,
1691 prompt: impl Into<String>,
1692 max_tokens: u32,
1693 ) -> crate::McpResult<SamplingResponse> {
1694 let request = SamplingRequest::prompt(prompt, max_tokens);
1695 self.sample_with_request(request).await
1696 }
1697
1698 pub async fn sample_with_request(
1730 &self,
1731 request: SamplingRequest,
1732 ) -> crate::McpResult<SamplingResponse> {
1733 let sender = self.sampling_sender.as_ref().ok_or_else(|| {
1734 crate::McpError::new(
1735 crate::McpErrorCode::InvalidRequest,
1736 "Sampling not available: client does not support sampling capability",
1737 )
1738 })?;
1739
1740 sender.create_message(request).await
1741 }
1742
1743 #[must_use]
1752 pub fn can_elicit(&self) -> bool {
1753 self.elicitation_sender.is_some()
1754 }
1755
1756 pub async fn elicit_form(
1794 &self,
1795 message: impl Into<String>,
1796 schema: serde_json::Value,
1797 ) -> crate::McpResult<ElicitationResponse> {
1798 let request = ElicitationRequest::form(message, schema);
1799 self.elicit_with_request(request).await
1800 }
1801
1802 pub async fn elicit_url(
1836 &self,
1837 message: impl Into<String>,
1838 url: impl Into<String>,
1839 elicitation_id: impl Into<String>,
1840 ) -> crate::McpResult<ElicitationResponse> {
1841 let request = ElicitationRequest::url(message, url, elicitation_id);
1842 self.elicit_with_request(request).await
1843 }
1844
1845 pub async fn elicit_with_request(
1857 &self,
1858 request: ElicitationRequest,
1859 ) -> crate::McpResult<ElicitationResponse> {
1860 let sender = self.elicitation_sender.as_ref().ok_or_else(|| {
1861 crate::McpError::new(
1862 crate::McpErrorCode::InvalidRequest,
1863 "Elicitation not available: client does not support elicitation capability",
1864 )
1865 })?;
1866
1867 sender.elicit(request).await
1868 }
1869
1870 #[must_use]
1879 pub fn can_read_resources(&self) -> bool {
1880 self.resource_reader.is_some()
1881 }
1882
1883 #[must_use]
1887 pub fn resource_read_depth(&self) -> u32 {
1888 self.resource_read_depth
1889 }
1890
1891 pub async fn read_resource(&self, uri: &str) -> crate::McpResult<ResourceReadResult> {
1920 let reader = self.resource_reader.as_ref().ok_or_else(|| {
1922 crate::McpError::new(
1923 crate::McpErrorCode::InternalError,
1924 "Resource reading not available: no router attached to context",
1925 )
1926 })?;
1927
1928 if self.resource_read_depth >= MAX_RESOURCE_READ_DEPTH {
1930 return Err(crate::McpError::new(
1931 crate::McpErrorCode::InternalError,
1932 format!(
1933 "Maximum resource read depth ({}) exceeded; possible infinite recursion",
1934 MAX_RESOURCE_READ_DEPTH
1935 ),
1936 ));
1937 }
1938
1939 reader
1941 .read_resource(&self.cx, uri, self.auth(), self.resource_read_depth + 1)
1942 .await
1943 }
1944
1945 pub async fn read_resource_text(&self, uri: &str) -> crate::McpResult<String> {
1963 let result = self.read_resource(uri).await?;
1964 result.first_text().map(String::from).ok_or_else(|| {
1965 crate::McpError::new(
1966 crate::McpErrorCode::InternalError,
1967 format!("Resource '{}' has no text content", uri),
1968 )
1969 })
1970 }
1971
1972 pub async fn read_resource_json<T: serde::de::DeserializeOwned>(
1996 &self,
1997 uri: &str,
1998 ) -> crate::McpResult<T> {
1999 let text = self.read_resource_text(uri).await?;
2000 serde_json::from_str(&text).map_err(|e| {
2001 crate::McpError::new(
2002 crate::McpErrorCode::InternalError,
2003 format!("Failed to parse resource '{}' as JSON: {}", uri, e),
2004 )
2005 })
2006 }
2007
2008 #[must_use]
2017 pub fn can_call_tools(&self) -> bool {
2018 self.tool_caller.is_some()
2019 }
2020
2021 #[must_use]
2025 pub fn tool_call_depth(&self) -> u32 {
2026 self.tool_call_depth
2027 }
2028
2029 pub async fn call_tool(
2057 &self,
2058 name: &str,
2059 args: serde_json::Value,
2060 ) -> crate::McpResult<ToolCallResult> {
2061 let caller = self.tool_caller.as_ref().ok_or_else(|| {
2063 crate::McpError::new(
2064 crate::McpErrorCode::InternalError,
2065 "Tool calling not available: no router attached to context",
2066 )
2067 })?;
2068
2069 if self.tool_call_depth >= MAX_TOOL_CALL_DEPTH {
2071 return Err(crate::McpError::new(
2072 crate::McpErrorCode::InternalError,
2073 format!(
2074 "Maximum tool call depth ({}) exceeded calling '{}'; possible infinite recursion",
2075 MAX_TOOL_CALL_DEPTH, name
2076 ),
2077 ));
2078 }
2079
2080 caller
2082 .call_tool(&self.cx, name, args, self.auth(), self.tool_call_depth + 1)
2083 .await
2084 }
2085
2086 pub async fn call_tool_text(
2105 &self,
2106 name: &str,
2107 args: serde_json::Value,
2108 ) -> crate::McpResult<String> {
2109 let result = self.call_tool(name, args).await?;
2110
2111 if result.is_error {
2113 let error_msg = result.first_text().unwrap_or("Tool returned an error");
2114 return Err(crate::McpError::new(
2115 crate::McpErrorCode::InternalError,
2116 format!("Tool '{}' failed: {}", name, error_msg),
2117 ));
2118 }
2119
2120 result.first_text().map(String::from).ok_or_else(|| {
2121 crate::McpError::new(
2122 crate::McpErrorCode::InternalError,
2123 format!("Tool '{}' returned no text content", name),
2124 )
2125 })
2126 }
2127
2128 pub async fn call_tool_json<T: serde::de::DeserializeOwned>(
2153 &self,
2154 name: &str,
2155 args: serde_json::Value,
2156 ) -> crate::McpResult<T> {
2157 let text = self.call_tool_text(name, args).await?;
2158 serde_json::from_str(&text).map_err(|e| {
2159 crate::McpError::new(
2160 crate::McpErrorCode::InternalError,
2161 format!("Failed to parse tool '{}' result as JSON: {}", name, e),
2162 )
2163 })
2164 }
2165
2166 pub async fn join_all<T: Send + 'static>(
2186 &self,
2187 futures: Vec<crate::combinator::BoxFuture<'_, T>>,
2188 ) -> Vec<T> {
2189 crate::combinator::join_all(&self.cx, futures).await
2190 }
2191
2192 pub async fn race<T: Send + 'static>(
2207 &self,
2208 futures: Vec<crate::combinator::BoxFuture<'_, T>>,
2209 ) -> crate::McpResult<T> {
2210 crate::combinator::race(&self.cx, futures).await
2211 }
2212
2213 pub async fn quorum<T: Send + 'static>(
2229 &self,
2230 required: usize,
2231 futures: Vec<crate::combinator::BoxFuture<'_, crate::McpResult<T>>>,
2232 ) -> crate::McpResult<crate::combinator::QuorumResult<T>> {
2233 crate::combinator::quorum(&self.cx, required, futures).await
2234 }
2235
2236 pub async fn first_ok<T: Send + 'static>(
2251 &self,
2252 futures: Vec<crate::combinator::BoxFuture<'_, crate::McpResult<T>>>,
2253 ) -> crate::McpResult<T> {
2254 crate::combinator::first_ok(&self.cx, futures).await
2255 }
2256}
2257
2258#[derive(Debug, Clone, Copy)]
2264pub struct CancelledError;
2265
2266impl std::fmt::Display for CancelledError {
2267 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2268 write!(f, "request cancelled")
2269 }
2270}
2271
2272impl std::error::Error for CancelledError {}
2273
2274pub trait IntoOutcome<T, E> {
2279 fn into_outcome(self) -> Outcome<T, E>;
2281}
2282
2283impl<T, E> IntoOutcome<T, E> for Result<T, E> {
2284 fn into_outcome(self) -> Outcome<T, E> {
2285 match self {
2286 Ok(v) => Outcome::Ok(v),
2287 Err(e) => Outcome::Err(e),
2288 }
2289 }
2290}
2291
2292impl<T, E> IntoOutcome<T, E> for Result<T, CancelledError>
2293where
2294 E: Default,
2295{
2296 fn into_outcome(self) -> Outcome<T, E> {
2297 match self {
2298 Ok(v) => Outcome::Ok(v),
2299 Err(CancelledError) => Outcome::Cancelled(CancelReason::user("request cancelled")),
2300 }
2301 }
2302}
2303
2304#[cfg(test)]
2305mod tests {
2306 use super::*;
2307
2308 #[test]
2309 fn test_mcp_context_creation() {
2310 let cx = Cx::for_testing();
2311 let ctx = McpContext::new(cx, 42);
2312
2313 assert_eq!(ctx.request_id(), 42);
2314 }
2315
2316 #[test]
2317 fn test_mcp_context_not_cancelled_initially() {
2318 let cx = Cx::for_testing();
2319 let ctx = McpContext::new(cx, 1);
2320
2321 assert!(!ctx.is_cancelled());
2322 }
2323
2324 #[test]
2325 fn test_mcp_context_checkpoint_success() {
2326 let cx = Cx::for_testing();
2327 let ctx = McpContext::new(cx, 1);
2328
2329 assert!(ctx.checkpoint().is_ok());
2331 }
2332
2333 #[test]
2334 fn test_mcp_context_checkpoint_cancelled() {
2335 let cx = Cx::for_testing();
2336 cx.set_cancel_requested(true);
2337 let ctx = McpContext::new(cx, 1);
2338
2339 assert!(ctx.checkpoint().is_err());
2341 }
2342
2343 #[test]
2344 fn test_mcp_context_checkpoint_budget_exhausted() {
2345 let cx = Cx::for_testing_with_budget(Budget::ZERO);
2346 let ctx = McpContext::new(cx, 1);
2347
2348 assert!(ctx.checkpoint().is_err());
2350 }
2351
2352 #[test]
2353 fn test_mcp_context_masked_section() {
2354 let cx = Cx::for_testing();
2355 let ctx = McpContext::new(cx, 1);
2356
2357 let result = ctx.masked(|| 42);
2359 assert_eq!(result, 42);
2360 }
2361
2362 #[test]
2363 fn test_mcp_context_budget() {
2364 let cx = Cx::for_testing();
2365 let ctx = McpContext::new(cx, 1);
2366
2367 let budget = ctx.budget();
2369 assert!(!budget.is_exhausted());
2371 }
2372
2373 #[test]
2374 fn test_cancelled_error_display() {
2375 let err = CancelledError;
2376 assert_eq!(err.to_string(), "request cancelled");
2377 }
2378
2379 #[test]
2380 fn test_into_outcome_ok() {
2381 let result: Result<i32, CancelledError> = Ok(42);
2382 let outcome: Outcome<i32, CancelledError> = result.into_outcome();
2383 assert!(matches!(outcome, Outcome::Ok(42)));
2384 }
2385
2386 #[test]
2387 fn test_into_outcome_cancelled() {
2388 let result: Result<i32, CancelledError> = Err(CancelledError);
2389 let outcome: Outcome<i32, ()> = result.into_outcome();
2390 assert!(matches!(outcome, Outcome::Cancelled(_)));
2391 }
2392
2393 #[test]
2394 fn test_mcp_context_no_progress_reporter_by_default() {
2395 let cx = Cx::for_testing();
2396 let ctx = McpContext::new(cx, 1);
2397 assert!(!ctx.has_progress_reporter());
2398 }
2399
2400 #[test]
2401 fn test_mcp_context_with_progress_reporter() {
2402 let cx = Cx::for_testing();
2403 let sender = Arc::new(NoOpNotificationSender);
2404 let reporter = ProgressReporter::new(sender);
2405 let ctx = McpContext::with_progress(cx, 1, reporter);
2406 assert!(ctx.has_progress_reporter());
2407 }
2408
2409 #[test]
2410 fn test_report_progress_without_reporter() {
2411 let cx = Cx::for_testing();
2412 let ctx = McpContext::new(cx, 1);
2413 ctx.report_progress(0.5, Some("test"));
2415 ctx.report_progress_with_total(5.0, 10.0, None);
2416 }
2417
2418 #[test]
2419 fn test_report_progress_with_reporter() {
2420 use std::sync::atomic::{AtomicU32, Ordering};
2421
2422 struct CountingSender {
2423 count: AtomicU32,
2424 }
2425
2426 impl NotificationSender for CountingSender {
2427 fn send_progress(&self, _progress: f64, _total: Option<f64>, _message: Option<&str>) {
2428 self.count.fetch_add(1, Ordering::SeqCst);
2429 }
2430 }
2431
2432 let cx = Cx::for_testing();
2433 let sender = Arc::new(CountingSender {
2434 count: AtomicU32::new(0),
2435 });
2436 let reporter = ProgressReporter::new(sender.clone());
2437 let ctx = McpContext::with_progress(cx, 1, reporter);
2438
2439 ctx.report_progress(0.25, Some("step 1"));
2440 ctx.report_progress(0.5, None);
2441 ctx.report_progress_with_total(3.0, 4.0, Some("step 3"));
2442
2443 assert_eq!(sender.count.load(Ordering::SeqCst), 3);
2444 }
2445
2446 #[test]
2447 fn test_progress_reporter_debug() {
2448 let sender = Arc::new(NoOpNotificationSender);
2449 let reporter = ProgressReporter::new(sender);
2450 let debug = format!("{reporter:?}");
2451 assert!(debug.contains("ProgressReporter"));
2452 }
2453
2454 #[test]
2455 fn test_noop_notification_sender() {
2456 let sender = NoOpNotificationSender;
2457 sender.send_progress(0.5, Some(1.0), Some("test"));
2459 }
2460
2461 #[test]
2463 fn test_mcp_context_no_session_state_by_default() {
2464 let cx = Cx::for_testing();
2465 let ctx = McpContext::new(cx, 1);
2466 assert!(!ctx.has_session_state());
2467 }
2468
2469 #[test]
2470 fn test_mcp_context_with_session_state() {
2471 let cx = Cx::for_testing();
2472 let state = SessionState::new();
2473 let ctx = McpContext::with_state(cx, 1, state);
2474 assert!(ctx.has_session_state());
2475 }
2476
2477 #[test]
2478 fn test_mcp_context_get_set_state() {
2479 let cx = Cx::for_testing();
2480 let state = SessionState::new();
2481 let ctx = McpContext::with_state(cx, 1, state);
2482
2483 assert!(ctx.set_state("counter", 42));
2485
2486 let value: Option<i32> = ctx.get_state("counter");
2488 assert_eq!(value, Some(42));
2489 }
2490
2491 #[test]
2492 fn test_mcp_context_state_not_available() {
2493 let cx = Cx::for_testing();
2494 let ctx = McpContext::new(cx, 1);
2495
2496 assert!(!ctx.set_state("key", "value"));
2498
2499 let value: Option<String> = ctx.get_state("key");
2501 assert!(value.is_none());
2502 }
2503
2504 #[test]
2505 fn test_mcp_context_has_state() {
2506 let cx = Cx::for_testing();
2507 let state = SessionState::new();
2508 let ctx = McpContext::with_state(cx, 1, state);
2509
2510 assert!(!ctx.has_state("missing"));
2511
2512 ctx.set_state("present", true);
2513 assert!(ctx.has_state("present"));
2514 }
2515
2516 #[test]
2517 fn test_mcp_context_remove_state() {
2518 let cx = Cx::for_testing();
2519 let state = SessionState::new();
2520 let ctx = McpContext::with_state(cx, 1, state);
2521
2522 ctx.set_state("key", "value");
2523 assert!(ctx.has_state("key"));
2524
2525 let removed = ctx.remove_state("key");
2526 assert!(removed.is_some());
2527 assert!(!ctx.has_state("key"));
2528 }
2529
2530 #[test]
2531 fn test_mcp_context_with_state_and_progress() {
2532 let cx = Cx::for_testing();
2533 let state = SessionState::new();
2534 let sender = Arc::new(NoOpNotificationSender);
2535 let reporter = ProgressReporter::new(sender);
2536
2537 let ctx = McpContext::with_state_and_progress(cx, 1, state, reporter);
2538
2539 assert!(ctx.has_session_state());
2540 assert!(ctx.has_progress_reporter());
2541 }
2542
2543 #[test]
2544 fn test_mcp_context_auth_is_request_local() {
2545 let cx = Cx::for_testing();
2546 let state = SessionState::new();
2547 let ctx = McpContext::with_state(cx, 1, state.clone());
2548
2549 assert!(ctx.set_auth(AuthContext::with_subject("alice")));
2550
2551 assert_eq!(
2552 ctx.auth().and_then(|auth| auth.subject),
2553 Some("alice".to_string())
2554 );
2555 let stored: Option<AuthContext> = state.get(crate::AUTH_STATE_KEY);
2556 assert!(
2557 stored.is_none(),
2558 "request auth must not be persisted into session state"
2559 );
2560 }
2561
2562 #[test]
2563 fn test_mcp_context_clones_share_request_auth() {
2564 let cx = Cx::for_testing();
2565 let ctx = McpContext::new(cx, 1);
2566 let cloned = ctx.clone();
2567
2568 assert!(cloned.set_auth(AuthContext::with_subject("bob")));
2569
2570 assert_eq!(
2571 ctx.auth().and_then(|auth| auth.subject),
2572 Some("bob".to_string())
2573 );
2574 }
2575
2576 #[test]
2577 fn test_new_mcp_contexts_do_not_share_request_auth_even_with_same_cx() {
2578 let cx = Cx::for_testing();
2579 let state = SessionState::new();
2580 let first = McpContext::with_state(cx.clone(), 7, state.clone());
2581 let second = McpContext::with_state(cx, 7, state);
2582
2583 assert!(first.set_auth(AuthContext::with_subject("carol")));
2584
2585 assert!(second.auth().is_none());
2586 }
2587
2588 #[test]
2589 fn test_new_mcp_contexts_do_not_share_request_auth_across_requests() {
2590 let state = SessionState::new();
2591 let first = McpContext::with_state(Cx::for_testing(), 7, state.clone());
2592 let second = McpContext::with_state(Cx::for_testing(), 8, state);
2593
2594 assert!(first.set_auth(AuthContext::with_subject("dave")));
2595
2596 assert_eq!(
2597 first.auth().and_then(|auth| auth.subject),
2598 Some("dave".to_string())
2599 );
2600 assert!(second.auth().is_none());
2601 }
2602
2603 #[test]
2604 fn test_mcp_context_drop_does_not_leak_request_auth() {
2605 let cx = Cx::for_testing();
2606
2607 {
2608 let ctx = McpContext::new(cx.clone(), 9);
2609 assert!(ctx.set_auth(AuthContext::with_subject("erin")));
2610 }
2611
2612 assert!(
2613 McpContext::new(cx, 9).auth().is_none(),
2614 "fresh contexts must start without inherited request auth"
2615 );
2616 }
2617
2618 #[test]
2623 fn test_mcp_context_tools_enabled_by_default() {
2624 let cx = Cx::for_testing();
2625 let state = SessionState::new();
2626 let ctx = McpContext::with_state(cx, 1, state);
2627
2628 assert!(ctx.is_tool_enabled("any_tool"));
2629 assert!(ctx.is_tool_enabled("another_tool"));
2630 }
2631
2632 #[test]
2633 fn test_mcp_context_disable_enable_tool() {
2634 let cx = Cx::for_testing();
2635 let state = SessionState::new();
2636 let ctx = McpContext::with_state(cx, 1, state);
2637
2638 assert!(ctx.is_tool_enabled("my_tool"));
2640
2641 assert!(ctx.disable_tool("my_tool"));
2643 assert!(!ctx.is_tool_enabled("my_tool"));
2644 assert!(ctx.is_tool_enabled("other_tool"));
2645
2646 assert!(ctx.enable_tool("my_tool"));
2648 assert!(ctx.is_tool_enabled("my_tool"));
2649 }
2650
2651 #[test]
2652 fn test_mcp_context_disable_enable_resource() {
2653 let cx = Cx::for_testing();
2654 let state = SessionState::new();
2655 let ctx = McpContext::with_state(cx, 1, state);
2656
2657 assert!(ctx.is_resource_enabled("file://secret"));
2659
2660 assert!(ctx.disable_resource("file://secret"));
2662 assert!(!ctx.is_resource_enabled("file://secret"));
2663 assert!(ctx.is_resource_enabled("file://public"));
2664
2665 assert!(ctx.enable_resource("file://secret"));
2667 assert!(ctx.is_resource_enabled("file://secret"));
2668 }
2669
2670 #[test]
2671 fn test_mcp_context_disable_enable_prompt() {
2672 let cx = Cx::for_testing();
2673 let state = SessionState::new();
2674 let ctx = McpContext::with_state(cx, 1, state);
2675
2676 assert!(ctx.is_prompt_enabled("admin_prompt"));
2678
2679 assert!(ctx.disable_prompt("admin_prompt"));
2681 assert!(!ctx.is_prompt_enabled("admin_prompt"));
2682 assert!(ctx.is_prompt_enabled("user_prompt"));
2683
2684 assert!(ctx.enable_prompt("admin_prompt"));
2686 assert!(ctx.is_prompt_enabled("admin_prompt"));
2687 }
2688
2689 #[test]
2690 fn test_mcp_context_disable_multiple_tools() {
2691 let cx = Cx::for_testing();
2692 let state = SessionState::new();
2693 let ctx = McpContext::with_state(cx, 1, state);
2694
2695 ctx.disable_tool("tool1");
2696 ctx.disable_tool("tool2");
2697 ctx.disable_tool("tool3");
2698
2699 assert!(!ctx.is_tool_enabled("tool1"));
2700 assert!(!ctx.is_tool_enabled("tool2"));
2701 assert!(!ctx.is_tool_enabled("tool3"));
2702 assert!(ctx.is_tool_enabled("tool4"));
2703
2704 let disabled = ctx.disabled_tools();
2705 assert_eq!(disabled.len(), 3);
2706 assert!(disabled.contains("tool1"));
2707 assert!(disabled.contains("tool2"));
2708 assert!(disabled.contains("tool3"));
2709 }
2710
2711 #[test]
2712 fn test_mcp_context_disabled_sets_empty_by_default() {
2713 let cx = Cx::for_testing();
2714 let state = SessionState::new();
2715 let ctx = McpContext::with_state(cx, 1, state);
2716
2717 assert!(ctx.disabled_tools().is_empty());
2718 assert!(ctx.disabled_resources().is_empty());
2719 assert!(ctx.disabled_prompts().is_empty());
2720 }
2721
2722 #[test]
2723 fn test_mcp_context_enable_disable_no_state() {
2724 let cx = Cx::for_testing();
2725 let ctx = McpContext::new(cx, 1);
2726
2727 assert!(!ctx.disable_tool("tool"));
2729 assert!(!ctx.enable_tool("tool"));
2730
2731 assert!(ctx.is_tool_enabled("tool"));
2733 }
2734
2735 #[test]
2736 fn test_mcp_context_disabled_state_persists_across_contexts() {
2737 let state = SessionState::new();
2738
2739 {
2741 let cx = Cx::for_testing();
2742 let ctx = McpContext::with_state(cx, 1, state.clone());
2743 ctx.disable_tool("shared_tool");
2744 }
2745
2746 {
2748 let cx = Cx::for_testing();
2749 let ctx = McpContext::with_state(cx, 2, state.clone());
2750 assert!(!ctx.is_tool_enabled("shared_tool"));
2751 }
2752 }
2753
2754 #[test]
2759 fn test_mcp_context_no_capabilities_by_default() {
2760 let cx = Cx::for_testing();
2761 let ctx = McpContext::new(cx, 1);
2762
2763 assert!(ctx.client_capabilities().is_none());
2764 assert!(ctx.server_capabilities().is_none());
2765 assert!(!ctx.client_supports_sampling());
2766 assert!(!ctx.client_supports_elicitation());
2767 assert!(!ctx.client_supports_roots());
2768 }
2769
2770 #[test]
2771 fn test_mcp_context_with_client_capabilities() {
2772 let cx = Cx::for_testing();
2773 let caps = ClientCapabilityInfo::new()
2774 .with_sampling()
2775 .with_elicitation(true, false)
2776 .with_roots(true);
2777
2778 let ctx = McpContext::new(cx, 1).with_client_capabilities(caps);
2779
2780 assert!(ctx.client_capabilities().is_some());
2781 assert!(ctx.client_supports_sampling());
2782 assert!(ctx.client_supports_elicitation());
2783 assert!(ctx.client_supports_elicitation_form());
2784 assert!(!ctx.client_supports_elicitation_url());
2785 assert!(ctx.client_supports_roots());
2786 }
2787
2788 #[test]
2789 fn test_mcp_context_with_server_capabilities() {
2790 let cx = Cx::for_testing();
2791 let caps = ServerCapabilityInfo::new()
2792 .with_tools()
2793 .with_resources(true)
2794 .with_prompts()
2795 .with_logging();
2796
2797 let ctx = McpContext::new(cx, 1).with_server_capabilities(caps);
2798
2799 let server_caps = ctx.server_capabilities().unwrap();
2800 assert!(server_caps.tools);
2801 assert!(server_caps.resources);
2802 assert!(server_caps.resources_subscribe);
2803 assert!(server_caps.prompts);
2804 assert!(server_caps.logging);
2805 }
2806
2807 #[test]
2808 fn test_client_capability_info_builders() {
2809 let caps = ClientCapabilityInfo::new();
2810 assert!(!caps.sampling);
2811 assert!(!caps.elicitation);
2812 assert!(!caps.roots);
2813
2814 let caps = caps.with_sampling();
2815 assert!(caps.sampling);
2816
2817 let caps = ClientCapabilityInfo::new().with_elicitation(true, true);
2818 assert!(caps.elicitation);
2819 assert!(caps.elicitation_form);
2820 assert!(caps.elicitation_url);
2821
2822 let caps = ClientCapabilityInfo::new().with_roots(false);
2823 assert!(caps.roots);
2824 assert!(!caps.roots_list_changed);
2825 }
2826
2827 #[test]
2828 fn test_server_capability_info_builders() {
2829 let caps = ServerCapabilityInfo::new();
2830 assert!(!caps.tools);
2831 assert!(!caps.resources);
2832 assert!(!caps.prompts);
2833 assert!(!caps.logging);
2834
2835 let caps = caps
2836 .with_tools()
2837 .with_resources(false)
2838 .with_prompts()
2839 .with_logging();
2840 assert!(caps.tools);
2841 assert!(caps.resources);
2842 assert!(!caps.resources_subscribe);
2843 assert!(caps.prompts);
2844 assert!(caps.logging);
2845 }
2846
2847 #[test]
2852 fn test_resource_content_item_text() {
2853 let item = ResourceContentItem::text("test://uri", "hello");
2854 assert_eq!(item.uri, "test://uri");
2855 assert_eq!(item.mime_type.as_deref(), Some("text/plain"));
2856 assert_eq!(item.as_text(), Some("hello"));
2857 assert!(item.as_blob().is_none());
2858 assert!(item.is_text());
2859 assert!(!item.is_blob());
2860 }
2861
2862 #[test]
2863 fn test_resource_content_item_json() {
2864 let item = ResourceContentItem::json("data://config", r#"{"key":"val"}"#);
2865 assert_eq!(item.uri, "data://config");
2866 assert_eq!(item.mime_type.as_deref(), Some("application/json"));
2867 assert_eq!(item.as_text(), Some(r#"{"key":"val"}"#));
2868 assert!(item.is_text());
2869 assert!(!item.is_blob());
2870 }
2871
2872 #[test]
2873 fn test_resource_content_item_blob() {
2874 let item = ResourceContentItem::blob("binary://data", "application/octet-stream", "AQID");
2875 assert_eq!(item.uri, "binary://data");
2876 assert_eq!(item.mime_type.as_deref(), Some("application/octet-stream"));
2877 assert!(item.as_text().is_none());
2878 assert_eq!(item.as_blob(), Some("AQID"));
2879 assert!(!item.is_text());
2880 assert!(item.is_blob());
2881 }
2882
2883 #[test]
2888 fn test_resource_read_result_text() {
2889 let result = ResourceReadResult::text("test://doc", "content");
2890 assert_eq!(result.first_text(), Some("content"));
2891 assert!(result.first_blob().is_none());
2892 assert_eq!(result.contents.len(), 1);
2893 }
2894
2895 #[test]
2896 fn test_resource_read_result_new_multiple() {
2897 let result = ResourceReadResult::new(vec![
2898 ResourceContentItem::text("a://1", "first"),
2899 ResourceContentItem::blob("b://2", "image/png", "base64data"),
2900 ]);
2901 assert_eq!(result.contents.len(), 2);
2902 assert_eq!(result.first_text(), Some("first"));
2904 assert!(result.first_blob().is_none());
2906 }
2907
2908 #[test]
2909 fn test_resource_read_result_empty() {
2910 let result = ResourceReadResult::new(vec![]);
2911 assert!(result.first_text().is_none());
2912 assert!(result.first_blob().is_none());
2913 }
2914
2915 #[test]
2916 fn test_resource_read_result_blob_first() {
2917 let result = ResourceReadResult::new(vec![ResourceContentItem::blob(
2918 "b://1",
2919 "image/png",
2920 "data",
2921 )]);
2922 assert!(result.first_text().is_none());
2923 assert_eq!(result.first_blob(), Some("data"));
2924 }
2925
2926 #[test]
2931 fn test_tool_content_item_text() {
2932 let item = ToolContentItem::text("hello");
2933 assert_eq!(item.as_text(), Some("hello"));
2934 assert!(item.is_text());
2935 }
2936
2937 #[test]
2938 fn test_tool_content_item_image() {
2939 let item = ToolContentItem::Image {
2940 data: "base64img".to_string(),
2941 mime_type: "image/png".to_string(),
2942 };
2943 assert!(item.as_text().is_none());
2944 assert!(!item.is_text());
2945 }
2946
2947 #[test]
2948 fn test_tool_content_item_audio() {
2949 let item = ToolContentItem::Audio {
2950 data: "base64audio".to_string(),
2951 mime_type: "audio/wav".to_string(),
2952 };
2953 assert!(item.as_text().is_none());
2954 assert!(!item.is_text());
2955 }
2956
2957 #[test]
2958 fn test_tool_content_item_resource() {
2959 let item = ToolContentItem::Resource {
2960 uri: "file://test".to_string(),
2961 mime_type: Some("text/plain".to_string()),
2962 text: Some("embedded".to_string()),
2963 blob: None,
2964 };
2965 assert!(item.as_text().is_none());
2966 assert!(!item.is_text());
2967 }
2968
2969 #[test]
2974 fn test_tool_call_result_success() {
2975 let result = ToolCallResult::success(vec![
2976 ToolContentItem::text("item1"),
2977 ToolContentItem::text("item2"),
2978 ]);
2979 assert!(!result.is_error);
2980 assert_eq!(result.content.len(), 2);
2981 assert_eq!(result.first_text(), Some("item1"));
2982 }
2983
2984 #[test]
2985 fn test_tool_call_result_text() {
2986 let result = ToolCallResult::text("simple output");
2987 assert!(!result.is_error);
2988 assert_eq!(result.content.len(), 1);
2989 assert_eq!(result.first_text(), Some("simple output"));
2990 }
2991
2992 #[test]
2993 fn test_tool_call_result_error() {
2994 let result = ToolCallResult::error("something failed");
2995 assert!(result.is_error);
2996 assert_eq!(result.first_text(), Some("something failed"));
2997 }
2998
2999 #[test]
3000 fn test_tool_call_result_empty() {
3001 let result = ToolCallResult::success(vec![]);
3002 assert!(!result.is_error);
3003 assert!(result.first_text().is_none());
3004 }
3005
3006 #[test]
3011 fn test_elicitation_response_accept() {
3012 let mut data = std::collections::HashMap::new();
3013 data.insert("name".to_string(), serde_json::json!("Alice"));
3014 data.insert("age".to_string(), serde_json::json!(30));
3015 data.insert("active".to_string(), serde_json::json!(true));
3016
3017 let resp = ElicitationResponse::accept(data);
3018 assert!(resp.is_accepted());
3019 assert!(!resp.is_declined());
3020 assert!(!resp.is_cancelled());
3021 assert_eq!(resp.get_string("name"), Some("Alice"));
3022 assert_eq!(resp.get_int("age"), Some(30));
3023 assert_eq!(resp.get_bool("active"), Some(true));
3024 }
3025
3026 #[test]
3027 fn test_elicitation_response_accept_url() {
3028 let resp = ElicitationResponse::accept_url();
3029 assert!(resp.is_accepted());
3030 assert!(resp.content.is_none());
3031 assert!(resp.get_string("anything").is_none());
3032 }
3033
3034 #[test]
3035 fn test_elicitation_response_decline() {
3036 let resp = ElicitationResponse::decline();
3037 assert!(!resp.is_accepted());
3038 assert!(resp.is_declined());
3039 assert!(!resp.is_cancelled());
3040 assert!(resp.get_string("key").is_none());
3041 }
3042
3043 #[test]
3044 fn test_elicitation_response_cancel() {
3045 let resp = ElicitationResponse::cancel();
3046 assert!(!resp.is_accepted());
3047 assert!(!resp.is_declined());
3048 assert!(resp.is_cancelled());
3049 }
3050
3051 #[test]
3052 fn test_elicitation_response_missing_key() {
3053 let mut data = std::collections::HashMap::new();
3054 data.insert("exists".to_string(), serde_json::json!("value"));
3055 let resp = ElicitationResponse::accept(data);
3056
3057 assert!(resp.get_string("missing").is_none());
3058 assert!(resp.get_bool("missing").is_none());
3059 assert!(resp.get_int("missing").is_none());
3060 }
3061
3062 #[test]
3063 fn test_elicitation_response_type_mismatch() {
3064 let mut data = std::collections::HashMap::new();
3065 data.insert("num".to_string(), serde_json::json!(42));
3066 let resp = ElicitationResponse::accept(data);
3067
3068 assert!(resp.get_string("num").is_none());
3070 assert!(resp.get_bool("num").is_none());
3072 assert_eq!(resp.get_int("num"), Some(42));
3074 }
3075
3076 #[test]
3081 fn test_can_sample_false_by_default() {
3082 let cx = Cx::for_testing();
3083 let ctx = McpContext::new(cx, 1);
3084 assert!(!ctx.can_sample());
3085 }
3086
3087 #[test]
3088 fn test_can_elicit_false_by_default() {
3089 let cx = Cx::for_testing();
3090 let ctx = McpContext::new(cx, 1);
3091 assert!(!ctx.can_elicit());
3092 }
3093
3094 #[test]
3095 fn test_can_read_resources_false_by_default() {
3096 let cx = Cx::for_testing();
3097 let ctx = McpContext::new(cx, 1);
3098 assert!(!ctx.can_read_resources());
3099 }
3100
3101 #[test]
3102 fn test_can_call_tools_false_by_default() {
3103 let cx = Cx::for_testing();
3104 let ctx = McpContext::new(cx, 1);
3105 assert!(!ctx.can_call_tools());
3106 }
3107
3108 #[test]
3109 fn test_resource_read_depth_default() {
3110 let cx = Cx::for_testing();
3111 let ctx = McpContext::new(cx, 1);
3112 assert_eq!(ctx.resource_read_depth(), 0);
3113 }
3114
3115 #[test]
3116 fn test_tool_call_depth_default() {
3117 let cx = Cx::for_testing();
3118 let ctx = McpContext::new(cx, 1);
3119 assert_eq!(ctx.tool_call_depth(), 0);
3120 }
3121
3122 #[test]
3127 fn sampling_request_builder_chain() {
3128 let req = SamplingRequest::prompt("hello", 100)
3129 .with_system_prompt("You are helpful")
3130 .with_temperature(0.7)
3131 .with_stop_sequences(vec!["STOP".into()])
3132 .with_model_hints(vec!["gpt-4".into()]);
3133
3134 assert_eq!(req.messages.len(), 1);
3135 assert_eq!(req.max_tokens, 100);
3136 assert_eq!(req.system_prompt.as_deref(), Some("You are helpful"));
3137 assert_eq!(req.temperature, Some(0.7));
3138 assert_eq!(req.stop_sequences, vec!["STOP"]);
3139 assert_eq!(req.model_hints, vec!["gpt-4"]);
3140 }
3141
3142 #[test]
3143 fn sampling_request_message_roles() {
3144 let user = SamplingRequestMessage::user("hi");
3145 assert_eq!(user.role, SamplingRole::User);
3146 assert_eq!(user.text, "hi");
3147
3148 let asst = SamplingRequestMessage::assistant("hello");
3149 assert_eq!(asst.role, SamplingRole::Assistant);
3150 assert_eq!(asst.text, "hello");
3151 }
3152
3153 #[test]
3154 fn sampling_response_new_default_stop_reason() {
3155 let resp = SamplingResponse::new("output", "model-1");
3156 assert_eq!(resp.text, "output");
3157 assert_eq!(resp.model, "model-1");
3158 assert_eq!(resp.stop_reason, SamplingStopReason::EndTurn);
3159 assert_eq!(SamplingStopReason::default(), SamplingStopReason::EndTurn);
3160 }
3161
3162 #[test]
3163 fn noop_sampling_sender_returns_error() {
3164 let sender = NoOpSamplingSender;
3165 let req = SamplingRequest::prompt("test", 10);
3166 let result = crate::block_on(sender.create_message(req));
3167 assert!(result.is_err());
3168 }
3169
3170 #[test]
3171 fn noop_elicitation_sender_returns_error() {
3172 let sender = NoOpElicitationSender;
3173 let req = ElicitationRequest::form("msg", serde_json::json!({}));
3174 let result = crate::block_on(sender.elicit(req));
3175 assert!(result.is_err());
3176 }
3177
3178 #[test]
3179 fn elicitation_request_form_constructor() {
3180 let req = ElicitationRequest::form("Enter name", serde_json::json!({"type": "string"}));
3181 assert_eq!(req.mode, ElicitationMode::Form);
3182 assert_eq!(req.message, "Enter name");
3183 assert!(req.schema.is_some());
3184 assert!(req.url.is_none());
3185 assert!(req.elicitation_id.is_none());
3186 }
3187
3188 #[test]
3189 fn elicitation_request_url_constructor() {
3190 let req = ElicitationRequest::url("Login", "https://example.com", "id-1");
3191 assert_eq!(req.mode, ElicitationMode::Url);
3192 assert_eq!(req.message, "Login");
3193 assert_eq!(req.url.as_deref(), Some("https://example.com"));
3194 assert_eq!(req.elicitation_id.as_deref(), Some("id-1"));
3195 assert!(req.schema.is_none());
3196 }
3197
3198 #[test]
3199 fn mcp_context_with_sampling_enables_can_sample() {
3200 let cx = Cx::for_testing();
3201 let sender = Arc::new(NoOpSamplingSender);
3202 let ctx = McpContext::new(cx, 1).with_sampling(sender);
3203 assert!(ctx.can_sample());
3204 }
3205
3206 #[test]
3207 fn mcp_context_with_elicitation_enables_can_elicit() {
3208 let cx = Cx::for_testing();
3209 let sender = Arc::new(NoOpElicitationSender);
3210 let ctx = McpContext::new(cx, 1).with_elicitation(sender);
3211 assert!(ctx.can_elicit());
3212 }
3213
3214 #[test]
3215 fn mcp_context_depth_setters() {
3216 let cx = Cx::for_testing();
3217 let ctx = McpContext::new(cx, 1)
3218 .with_resource_read_depth(3)
3219 .with_tool_call_depth(5);
3220 assert_eq!(ctx.resource_read_depth(), 3);
3221 assert_eq!(ctx.tool_call_depth(), 5);
3222 }
3223
3224 #[test]
3225 fn mcp_context_debug_includes_request_id() {
3226 let cx = Cx::for_testing();
3227 let ctx = McpContext::new(cx, 99);
3228 let debug = format!("{ctx:?}");
3229 assert!(debug.contains("request_id: 99"));
3230 }
3231
3232 #[test]
3233 fn mcp_context_cx_and_trace() {
3234 let cx = Cx::for_testing();
3235 let ctx = McpContext::new(cx, 1);
3236 let _ = ctx.cx();
3238 ctx.trace("test event");
3240 }
3241}