1use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9
10use asupersync::time::wall_now;
11use asupersync::types::CancelReason;
12use asupersync::{Budget, CancelKind, Cx, Outcome, RegionId, TaskId};
13
14use crate::{AUTH_STATE_KEY, AuthContext, SessionState};
15
16pub trait NotificationSender: Send + Sync {
25 fn send_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>);
33}
34
35pub trait SamplingSender: Send + Sync {
45 fn create_message(
56 &self,
57 request: SamplingRequest,
58 ) -> std::pin::Pin<
59 Box<dyn std::future::Future<Output = crate::McpResult<SamplingResponse>> + Send + '_>,
60 >;
61}
62
63#[derive(Debug, Clone)]
65pub struct SamplingRequest {
66 pub messages: Vec<SamplingRequestMessage>,
68 pub max_tokens: u32,
70 pub system_prompt: Option<String>,
72 pub temperature: Option<f64>,
74 pub stop_sequences: Vec<String>,
76 pub model_hints: Vec<String>,
78}
79
80impl SamplingRequest {
81 #[must_use]
83 pub fn new(messages: Vec<SamplingRequestMessage>, max_tokens: u32) -> Self {
84 Self {
85 messages,
86 max_tokens,
87 system_prompt: None,
88 temperature: None,
89 stop_sequences: Vec::new(),
90 model_hints: Vec::new(),
91 }
92 }
93
94 #[must_use]
96 pub fn prompt(text: impl Into<String>, max_tokens: u32) -> Self {
97 Self::new(vec![SamplingRequestMessage::user(text)], max_tokens)
98 }
99
100 #[must_use]
102 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
103 self.system_prompt = Some(prompt.into());
104 self
105 }
106
107 #[must_use]
109 pub fn with_temperature(mut self, temp: f64) -> Self {
110 self.temperature = Some(temp);
111 self
112 }
113
114 #[must_use]
116 pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
117 self.stop_sequences = sequences;
118 self
119 }
120
121 #[must_use]
123 pub fn with_model_hints(mut self, hints: Vec<String>) -> Self {
124 self.model_hints = hints;
125 self
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct SamplingRequestMessage {
132 pub role: SamplingRole,
134 pub text: String,
136}
137
138impl SamplingRequestMessage {
139 #[must_use]
141 pub fn user(text: impl Into<String>) -> Self {
142 Self {
143 role: SamplingRole::User,
144 text: text.into(),
145 }
146 }
147
148 #[must_use]
150 pub fn assistant(text: impl Into<String>) -> Self {
151 Self {
152 role: SamplingRole::Assistant,
153 text: text.into(),
154 }
155 }
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum SamplingRole {
161 User,
163 Assistant,
165}
166
167#[derive(Debug, Clone)]
169pub struct SamplingResponse {
170 pub text: String,
172 pub model: String,
174 pub stop_reason: SamplingStopReason,
176}
177
178impl SamplingResponse {
179 #[must_use]
181 pub fn new(text: impl Into<String>, model: impl Into<String>) -> Self {
182 Self {
183 text: text.into(),
184 model: model.into(),
185 stop_reason: SamplingStopReason::EndTurn,
186 }
187 }
188}
189
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
192pub enum SamplingStopReason {
193 #[default]
195 EndTurn,
196 StopSequence,
198 MaxTokens,
200}
201
202#[derive(Debug, Clone, Copy, Default)]
206pub struct NoOpSamplingSender;
207
208impl SamplingSender for NoOpSamplingSender {
209 fn create_message(
210 &self,
211 _request: SamplingRequest,
212 ) -> std::pin::Pin<
213 Box<dyn std::future::Future<Output = crate::McpResult<SamplingResponse>> + Send + '_>,
214 > {
215 Box::pin(async {
216 Err(crate::McpError::new(
217 crate::McpErrorCode::InvalidRequest,
218 "Sampling not supported: client does not have sampling capability",
219 ))
220 })
221 }
222}
223
224pub trait ElicitationSender: Send + Sync {
234 fn elicit(
245 &self,
246 request: ElicitationRequest,
247 ) -> std::pin::Pin<
248 Box<dyn std::future::Future<Output = crate::McpResult<ElicitationResponse>> + Send + '_>,
249 >;
250}
251
252#[derive(Debug, Clone)]
254pub struct ElicitationRequest {
255 pub mode: ElicitationMode,
257 pub message: String,
259 pub schema: Option<serde_json::Value>,
261 pub url: Option<String>,
263 pub elicitation_id: Option<String>,
265}
266
267impl ElicitationRequest {
268 #[must_use]
270 pub fn form(message: impl Into<String>, schema: serde_json::Value) -> Self {
271 Self {
272 mode: ElicitationMode::Form,
273 message: message.into(),
274 schema: Some(schema),
275 url: None,
276 elicitation_id: None,
277 }
278 }
279
280 #[must_use]
282 pub fn url(
283 message: impl Into<String>,
284 url: impl Into<String>,
285 elicitation_id: impl Into<String>,
286 ) -> Self {
287 Self {
288 mode: ElicitationMode::Url,
289 message: message.into(),
290 schema: None,
291 url: Some(url.into()),
292 elicitation_id: Some(elicitation_id.into()),
293 }
294 }
295}
296
297#[derive(Debug, Clone, Copy, PartialEq, Eq)]
299pub enum ElicitationMode {
300 Form,
302 Url,
304}
305
306#[derive(Debug, Clone)]
308pub struct ElicitationResponse {
309 pub action: ElicitationAction,
311 pub content: Option<std::collections::HashMap<String, serde_json::Value>>,
313}
314
315impl ElicitationResponse {
316 #[must_use]
318 pub fn accept(content: std::collections::HashMap<String, serde_json::Value>) -> Self {
319 Self {
320 action: ElicitationAction::Accept,
321 content: Some(content),
322 }
323 }
324
325 #[must_use]
327 pub fn accept_url() -> Self {
328 Self {
329 action: ElicitationAction::Accept,
330 content: None,
331 }
332 }
333
334 #[must_use]
336 pub fn decline() -> Self {
337 Self {
338 action: ElicitationAction::Decline,
339 content: None,
340 }
341 }
342
343 #[must_use]
345 pub fn cancel() -> Self {
346 Self {
347 action: ElicitationAction::Cancel,
348 content: None,
349 }
350 }
351
352 #[must_use]
354 pub fn is_accepted(&self) -> bool {
355 matches!(self.action, ElicitationAction::Accept)
356 }
357
358 #[must_use]
360 pub fn is_declined(&self) -> bool {
361 matches!(self.action, ElicitationAction::Decline)
362 }
363
364 #[must_use]
366 pub fn is_cancelled(&self) -> bool {
367 matches!(self.action, ElicitationAction::Cancel)
368 }
369
370 #[must_use]
372 pub fn get_string(&self, key: &str) -> Option<&str> {
373 self.content.as_ref()?.get(key)?.as_str()
374 }
375
376 #[must_use]
378 pub fn get_bool(&self, key: &str) -> Option<bool> {
379 self.content.as_ref()?.get(key)?.as_bool()
380 }
381
382 #[must_use]
384 pub fn get_int(&self, key: &str) -> Option<i64> {
385 self.content.as_ref()?.get(key)?.as_i64()
386 }
387}
388
389#[derive(Debug, Clone, Copy, PartialEq, Eq)]
391pub enum ElicitationAction {
392 Accept,
394 Decline,
396 Cancel,
398}
399
400#[derive(Debug, Clone, Copy, Default)]
404pub struct NoOpElicitationSender;
405
406impl ElicitationSender for NoOpElicitationSender {
407 fn elicit(
408 &self,
409 _request: ElicitationRequest,
410 ) -> std::pin::Pin<
411 Box<dyn std::future::Future<Output = crate::McpResult<ElicitationResponse>> + Send + '_>,
412 > {
413 Box::pin(async {
414 Err(crate::McpError::new(
415 crate::McpErrorCode::InvalidRequest,
416 "Elicitation not supported: client does not have elicitation capability",
417 ))
418 })
419 }
420}
421
422pub const MAX_RESOURCE_READ_DEPTH: u32 = 10;
428
429#[derive(Debug, Clone)]
434pub struct ResourceContentItem {
435 pub uri: String,
437 pub mime_type: Option<String>,
439 pub text: Option<String>,
441 pub blob: Option<String>,
443}
444
445impl ResourceContentItem {
446 #[must_use]
448 pub fn text(uri: impl Into<String>, text: impl Into<String>) -> Self {
449 Self {
450 uri: uri.into(),
451 mime_type: Some("text/plain".to_string()),
452 text: Some(text.into()),
453 blob: None,
454 }
455 }
456
457 #[must_use]
459 pub fn json(uri: impl Into<String>, text: impl Into<String>) -> Self {
460 Self {
461 uri: uri.into(),
462 mime_type: Some("application/json".to_string()),
463 text: Some(text.into()),
464 blob: None,
465 }
466 }
467
468 #[must_use]
470 pub fn blob(
471 uri: impl Into<String>,
472 mime_type: impl Into<String>,
473 blob: impl Into<String>,
474 ) -> Self {
475 Self {
476 uri: uri.into(),
477 mime_type: Some(mime_type.into()),
478 text: None,
479 blob: Some(blob.into()),
480 }
481 }
482
483 #[must_use]
485 pub fn as_text(&self) -> Option<&str> {
486 self.text.as_deref()
487 }
488
489 #[must_use]
491 pub fn as_blob(&self) -> Option<&str> {
492 self.blob.as_deref()
493 }
494
495 #[must_use]
497 pub fn is_text(&self) -> bool {
498 self.text.is_some()
499 }
500
501 #[must_use]
503 pub fn is_blob(&self) -> bool {
504 self.blob.is_some()
505 }
506}
507
508#[derive(Debug, Clone)]
510pub struct ResourceReadResult {
511 pub contents: Vec<ResourceContentItem>,
513}
514
515impl ResourceReadResult {
516 #[must_use]
518 pub fn new(contents: Vec<ResourceContentItem>) -> Self {
519 Self { contents }
520 }
521
522 #[must_use]
524 pub fn text(uri: impl Into<String>, text: impl Into<String>) -> Self {
525 Self {
526 contents: vec![ResourceContentItem::text(uri, text)],
527 }
528 }
529
530 #[must_use]
532 pub fn first_text(&self) -> Option<&str> {
533 self.contents.first().and_then(|c| c.as_text())
534 }
535
536 #[must_use]
538 pub fn first_blob(&self) -> Option<&str> {
539 self.contents.first().and_then(|c| c.as_blob())
540 }
541}
542
543pub trait ResourceReader: Send + Sync {
552 fn read_resource(
565 &self,
566 cx: &Cx,
567 uri: &str,
568 depth: u32,
569 ) -> Pin<Box<dyn Future<Output = crate::McpResult<ResourceReadResult>> + Send + '_>>;
570}
571
572pub const MAX_TOOL_CALL_DEPTH: u32 = 10;
578
579#[derive(Debug, Clone)]
584pub enum ToolContentItem {
585 Text {
587 text: String,
589 },
590 Image {
592 data: String,
594 mime_type: String,
596 },
597 Audio {
599 data: String,
601 mime_type: String,
603 },
604 Resource {
606 uri: String,
608 mime_type: Option<String>,
610 text: Option<String>,
612 blob: Option<String>,
614 },
615}
616
617impl ToolContentItem {
618 #[must_use]
620 pub fn text(text: impl Into<String>) -> Self {
621 Self::Text { text: text.into() }
622 }
623
624 #[must_use]
626 pub fn as_text(&self) -> Option<&str> {
627 match self {
628 Self::Text { text } => Some(text),
629 _ => None,
630 }
631 }
632
633 #[must_use]
635 pub fn is_text(&self) -> bool {
636 matches!(self, Self::Text { .. })
637 }
638}
639
640#[derive(Debug, Clone)]
642pub struct ToolCallResult {
643 pub content: Vec<ToolContentItem>,
645 pub is_error: bool,
647}
648
649impl ToolCallResult {
650 #[must_use]
652 pub fn success(content: Vec<ToolContentItem>) -> Self {
653 Self {
654 content,
655 is_error: false,
656 }
657 }
658
659 #[must_use]
661 pub fn text(text: impl Into<String>) -> Self {
662 Self {
663 content: vec![ToolContentItem::text(text)],
664 is_error: false,
665 }
666 }
667
668 #[must_use]
670 pub fn error(message: impl Into<String>) -> Self {
671 Self {
672 content: vec![ToolContentItem::text(message)],
673 is_error: true,
674 }
675 }
676
677 #[must_use]
679 pub fn first_text(&self) -> Option<&str> {
680 self.content.first().and_then(|c| c.as_text())
681 }
682}
683
684pub trait ToolCaller: Send + Sync {
693 fn call_tool(
707 &self,
708 cx: &Cx,
709 name: &str,
710 args: serde_json::Value,
711 depth: u32,
712 ) -> Pin<Box<dyn Future<Output = crate::McpResult<ToolCallResult>> + Send + '_>>;
713}
714
715#[derive(Debug, Clone, Default)]
724pub struct ClientCapabilityInfo {
725 pub sampling: bool,
727 pub elicitation: bool,
729 pub elicitation_form: bool,
731 pub elicitation_url: bool,
733 pub roots: bool,
735 pub roots_list_changed: bool,
737}
738
739impl ClientCapabilityInfo {
740 #[must_use]
742 pub fn new() -> Self {
743 Self::default()
744 }
745
746 #[must_use]
748 pub fn with_sampling(mut self) -> Self {
749 self.sampling = true;
750 self
751 }
752
753 #[must_use]
755 pub fn with_elicitation(mut self, form: bool, url: bool) -> Self {
756 self.elicitation = form || url;
757 self.elicitation_form = form;
758 self.elicitation_url = url;
759 self
760 }
761
762 #[must_use]
764 pub fn with_roots(mut self, list_changed: bool) -> Self {
765 self.roots = true;
766 self.roots_list_changed = list_changed;
767 self
768 }
769}
770
771#[derive(Debug, Clone, Default)]
775pub struct ServerCapabilityInfo {
776 pub tools: bool,
778 pub resources: bool,
780 pub resources_subscribe: bool,
782 pub prompts: bool,
784 pub logging: bool,
786}
787
788impl ServerCapabilityInfo {
789 #[must_use]
791 pub fn new() -> Self {
792 Self::default()
793 }
794
795 #[must_use]
797 pub fn with_tools(mut self) -> Self {
798 self.tools = true;
799 self
800 }
801
802 #[must_use]
804 pub fn with_resources(mut self, subscribe: bool) -> Self {
805 self.resources = true;
806 self.resources_subscribe = subscribe;
807 self
808 }
809
810 #[must_use]
812 pub fn with_prompts(mut self) -> Self {
813 self.prompts = true;
814 self
815 }
816
817 #[must_use]
819 pub fn with_logging(mut self) -> Self {
820 self.logging = true;
821 self
822 }
823}
824
825#[derive(Debug, Clone, Copy, Default)]
827pub struct NoOpNotificationSender;
828
829impl NotificationSender for NoOpNotificationSender {
830 fn send_progress(&self, _progress: f64, _total: Option<f64>, _message: Option<&str>) {
831 }
833}
834
835#[derive(Clone)]
840pub struct ProgressReporter {
841 sender: Arc<dyn NotificationSender>,
842}
843
844impl ProgressReporter {
845 pub fn new(sender: Arc<dyn NotificationSender>) -> Self {
847 Self { sender }
848 }
849
850 pub fn report(&self, progress: f64, message: Option<&str>) {
857 self.sender.send_progress(progress, None, message);
858 }
859
860 pub fn report_with_total(&self, progress: f64, total: f64, message: Option<&str>) {
868 self.sender.send_progress(progress, Some(total), message);
869 }
870}
871
872impl std::fmt::Debug for ProgressReporter {
873 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
874 f.debug_struct("ProgressReporter").finish_non_exhaustive()
875 }
876}
877
878#[derive(Clone)]
916pub struct McpContext {
917 cx: Cx,
919 request_id: u64,
921 progress_reporter: Option<ProgressReporter>,
923 state: Option<SessionState>,
925 sampling_sender: Option<Arc<dyn SamplingSender>>,
927 elicitation_sender: Option<Arc<dyn ElicitationSender>>,
929 resource_reader: Option<Arc<dyn ResourceReader>>,
931 resource_read_depth: u32,
933 tool_caller: Option<Arc<dyn ToolCaller>>,
935 tool_call_depth: u32,
937 client_capabilities: Option<ClientCapabilityInfo>,
939 server_capabilities: Option<ServerCapabilityInfo>,
941}
942
943impl std::fmt::Debug for McpContext {
944 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
945 f.debug_struct("McpContext")
946 .field("cx", &self.cx)
947 .field("request_id", &self.request_id)
948 .field("progress_reporter", &self.progress_reporter)
949 .field("state", &self.state.is_some())
950 .field("sampling_sender", &self.sampling_sender.is_some())
951 .field("elicitation_sender", &self.elicitation_sender.is_some())
952 .field("resource_reader", &self.resource_reader.is_some())
953 .field("resource_read_depth", &self.resource_read_depth)
954 .field("tool_caller", &self.tool_caller.is_some())
955 .field("tool_call_depth", &self.tool_call_depth)
956 .field("client_capabilities", &self.client_capabilities)
957 .field("server_capabilities", &self.server_capabilities)
958 .finish()
959 }
960}
961
962impl McpContext {
963 #[must_use]
968 pub fn new(cx: Cx, request_id: u64) -> Self {
969 Self {
970 cx,
971 request_id,
972 progress_reporter: None,
973 state: None,
974 sampling_sender: None,
975 elicitation_sender: None,
976 resource_reader: None,
977 resource_read_depth: 0,
978 tool_caller: None,
979 tool_call_depth: 0,
980 client_capabilities: None,
981 server_capabilities: None,
982 }
983 }
984
985 #[must_use]
989 pub fn with_state(cx: Cx, request_id: u64, state: SessionState) -> Self {
990 Self {
991 cx,
992 request_id,
993 progress_reporter: None,
994 state: Some(state),
995 sampling_sender: None,
996 elicitation_sender: None,
997 resource_reader: None,
998 resource_read_depth: 0,
999 tool_caller: None,
1000 tool_call_depth: 0,
1001 client_capabilities: None,
1002 server_capabilities: None,
1003 }
1004 }
1005
1006 #[must_use]
1011 pub fn with_progress(cx: Cx, request_id: u64, reporter: ProgressReporter) -> Self {
1012 Self {
1013 cx,
1014 request_id,
1015 progress_reporter: Some(reporter),
1016 state: None,
1017 sampling_sender: None,
1018 elicitation_sender: None,
1019 resource_reader: None,
1020 resource_read_depth: 0,
1021 tool_caller: None,
1022 tool_call_depth: 0,
1023 client_capabilities: None,
1024 server_capabilities: None,
1025 }
1026 }
1027
1028 #[must_use]
1030 pub fn with_state_and_progress(
1031 cx: Cx,
1032 request_id: u64,
1033 state: SessionState,
1034 reporter: ProgressReporter,
1035 ) -> Self {
1036 Self {
1037 cx,
1038 request_id,
1039 progress_reporter: Some(reporter),
1040 state: Some(state),
1041 sampling_sender: None,
1042 elicitation_sender: None,
1043 resource_reader: None,
1044 resource_read_depth: 0,
1045 tool_caller: None,
1046 tool_call_depth: 0,
1047 client_capabilities: None,
1048 server_capabilities: None,
1049 }
1050 }
1051
1052 #[must_use]
1057 pub fn with_sampling(mut self, sender: Arc<dyn SamplingSender>) -> Self {
1058 self.sampling_sender = Some(sender);
1059 self
1060 }
1061
1062 #[must_use]
1067 pub fn with_elicitation(mut self, sender: Arc<dyn ElicitationSender>) -> Self {
1068 self.elicitation_sender = Some(sender);
1069 self
1070 }
1071
1072 #[must_use]
1077 pub fn with_resource_reader(mut self, reader: Arc<dyn ResourceReader>) -> Self {
1078 self.resource_reader = Some(reader);
1079 self
1080 }
1081
1082 #[must_use]
1087 pub fn with_resource_read_depth(mut self, depth: u32) -> Self {
1088 self.resource_read_depth = depth;
1089 self
1090 }
1091
1092 #[must_use]
1097 pub fn with_tool_caller(mut self, caller: Arc<dyn ToolCaller>) -> Self {
1098 self.tool_caller = Some(caller);
1099 self
1100 }
1101
1102 #[must_use]
1107 pub fn with_tool_call_depth(mut self, depth: u32) -> Self {
1108 self.tool_call_depth = depth;
1109 self
1110 }
1111
1112 #[must_use]
1117 pub fn with_client_capabilities(mut self, capabilities: ClientCapabilityInfo) -> Self {
1118 self.client_capabilities = Some(capabilities);
1119 self
1120 }
1121
1122 #[must_use]
1127 pub fn with_server_capabilities(mut self, capabilities: ServerCapabilityInfo) -> Self {
1128 self.server_capabilities = Some(capabilities);
1129 self
1130 }
1131
1132 #[must_use]
1134 pub fn has_progress_reporter(&self) -> bool {
1135 self.progress_reporter.is_some()
1136 }
1137
1138 pub fn report_progress(&self, progress: f64, message: Option<&str>) {
1161 if let Some(ref reporter) = self.progress_reporter {
1162 reporter.report(progress, message);
1163 }
1164 }
1165
1166 pub fn report_progress_with_total(&self, progress: f64, total: f64, message: Option<&str>) {
1189 if let Some(ref reporter) = self.progress_reporter {
1190 reporter.report_with_total(progress, total, message);
1191 }
1192 }
1193
1194 #[must_use]
1199 pub fn request_id(&self) -> u64 {
1200 self.request_id
1201 }
1202
1203 #[must_use]
1209 pub fn region_id(&self) -> RegionId {
1210 self.cx.region_id()
1211 }
1212
1213 #[must_use]
1215 pub fn task_id(&self) -> TaskId {
1216 self.cx.task_id()
1217 }
1218
1219 #[must_use]
1225 pub fn budget(&self) -> Budget {
1226 self.cx.budget()
1227 }
1228
1229 #[must_use]
1234 pub fn is_cancelled(&self) -> bool {
1235 let budget = self.cx.budget();
1236 self.cx.is_cancel_requested()
1237 || budget.is_exhausted()
1238 || budget.is_past_deadline(wall_now())
1239 }
1240
1241 pub fn checkpoint(&self) -> Result<(), CancelledError> {
1263 self.cx.checkpoint().map_err(|_| CancelledError)?;
1264 let budget = self.cx.budget();
1265 if budget.is_exhausted() {
1266 return Err(CancelledError);
1267 }
1268 if budget.is_past_deadline(wall_now()) {
1269 self.cx.cancel_fast(CancelKind::Deadline);
1272 return Err(CancelledError);
1273 }
1274 Ok(())
1275 }
1276
1277 pub fn masked<F, R>(&self, f: F) -> R
1293 where
1294 F: FnOnce() -> R,
1295 {
1296 self.cx.masked(f)
1297 }
1298
1299 pub fn trace(&self, message: &str) {
1304 self.cx.trace(message);
1305 }
1306
1307 #[must_use]
1312 pub fn cx(&self) -> &Cx {
1313 &self.cx
1314 }
1315
1316 #[must_use]
1339 pub fn get_state<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
1340 self.state.as_ref()?.get(key)
1341 }
1342
1343 #[must_use]
1345 pub fn auth(&self) -> Option<AuthContext> {
1346 self.state.as_ref()?.get(AUTH_STATE_KEY)
1347 }
1348
1349 pub fn set_auth(&self, auth: AuthContext) -> bool {
1353 let Some(state) = self.state.as_ref() else {
1354 return false;
1355 };
1356 state.set(AUTH_STATE_KEY, auth)
1357 }
1358
1359 pub fn set_state<T: serde::Serialize>(&self, key: impl Into<String>, value: T) -> bool {
1376 match &self.state {
1377 Some(state) => state.set(key, value),
1378 None => false,
1379 }
1380 }
1381
1382 pub fn remove_state(&self, key: &str) -> Option<serde_json::Value> {
1388 self.state.as_ref()?.remove(key)
1389 }
1390
1391 #[must_use]
1395 pub fn has_state(&self, key: &str) -> bool {
1396 self.state.as_ref().is_some_and(|s| s.contains(key))
1397 }
1398
1399 #[must_use]
1401 pub fn has_session_state(&self) -> bool {
1402 self.state.is_some()
1403 }
1404
1405 #[must_use]
1414 pub fn client_capabilities(&self) -> Option<&ClientCapabilityInfo> {
1415 self.client_capabilities.as_ref()
1416 }
1417
1418 #[must_use]
1422 pub fn server_capabilities(&self) -> Option<&ServerCapabilityInfo> {
1423 self.server_capabilities.as_ref()
1424 }
1425
1426 #[must_use]
1431 pub fn client_supports_sampling(&self) -> bool {
1432 self.client_capabilities
1433 .as_ref()
1434 .is_some_and(|c| c.sampling)
1435 }
1436
1437 #[must_use]
1442 pub fn client_supports_elicitation(&self) -> bool {
1443 self.client_capabilities
1444 .as_ref()
1445 .is_some_and(|c| c.elicitation)
1446 }
1447
1448 #[must_use]
1450 pub fn client_supports_elicitation_form(&self) -> bool {
1451 self.client_capabilities
1452 .as_ref()
1453 .is_some_and(|c| c.elicitation_form)
1454 }
1455
1456 #[must_use]
1458 pub fn client_supports_elicitation_url(&self) -> bool {
1459 self.client_capabilities
1460 .as_ref()
1461 .is_some_and(|c| c.elicitation_url)
1462 }
1463
1464 #[must_use]
1469 pub fn client_supports_roots(&self) -> bool {
1470 self.client_capabilities.as_ref().is_some_and(|c| c.roots)
1471 }
1472
1473 const DISABLED_TOOLS_KEY: &'static str = "fastmcp.disabled_tools";
1479 const DISABLED_RESOURCES_KEY: &'static str = "fastmcp.disabled_resources";
1481 const DISABLED_PROMPTS_KEY: &'static str = "fastmcp.disabled_prompts";
1483
1484 pub fn disable_tool(&self, name: impl Into<String>) -> bool {
1502 self.add_to_disabled_set(Self::DISABLED_TOOLS_KEY, name.into())
1503 }
1504
1505 pub fn enable_tool(&self, name: &str) -> bool {
1509 self.remove_from_disabled_set(Self::DISABLED_TOOLS_KEY, name)
1510 }
1511
1512 #[must_use]
1516 pub fn is_tool_enabled(&self, name: &str) -> bool {
1517 !self.is_in_disabled_set(Self::DISABLED_TOOLS_KEY, name)
1518 }
1519
1520 pub fn disable_resource(&self, uri: impl Into<String>) -> bool {
1527 self.add_to_disabled_set(Self::DISABLED_RESOURCES_KEY, uri.into())
1528 }
1529
1530 pub fn enable_resource(&self, uri: &str) -> bool {
1534 self.remove_from_disabled_set(Self::DISABLED_RESOURCES_KEY, uri)
1535 }
1536
1537 #[must_use]
1541 pub fn is_resource_enabled(&self, uri: &str) -> bool {
1542 !self.is_in_disabled_set(Self::DISABLED_RESOURCES_KEY, uri)
1543 }
1544
1545 pub fn disable_prompt(&self, name: impl Into<String>) -> bool {
1552 self.add_to_disabled_set(Self::DISABLED_PROMPTS_KEY, name.into())
1553 }
1554
1555 pub fn enable_prompt(&self, name: &str) -> bool {
1559 self.remove_from_disabled_set(Self::DISABLED_PROMPTS_KEY, name)
1560 }
1561
1562 #[must_use]
1566 pub fn is_prompt_enabled(&self, name: &str) -> bool {
1567 !self.is_in_disabled_set(Self::DISABLED_PROMPTS_KEY, name)
1568 }
1569
1570 #[must_use]
1572 pub fn disabled_tools(&self) -> std::collections::HashSet<String> {
1573 self.get_disabled_set(Self::DISABLED_TOOLS_KEY)
1574 }
1575
1576 #[must_use]
1578 pub fn disabled_resources(&self) -> std::collections::HashSet<String> {
1579 self.get_disabled_set(Self::DISABLED_RESOURCES_KEY)
1580 }
1581
1582 #[must_use]
1584 pub fn disabled_prompts(&self) -> std::collections::HashSet<String> {
1585 self.get_disabled_set(Self::DISABLED_PROMPTS_KEY)
1586 }
1587
1588 fn add_to_disabled_set(&self, key: &str, name: String) -> bool {
1590 let Some(state) = self.state.as_ref() else {
1591 return false;
1592 };
1593 let mut set: std::collections::HashSet<String> = state.get(key).unwrap_or_default();
1594 set.insert(name);
1595 state.set(key, set)
1596 }
1597
1598 fn remove_from_disabled_set(&self, key: &str, name: &str) -> bool {
1600 let Some(state) = self.state.as_ref() else {
1601 return false;
1602 };
1603 let mut set: std::collections::HashSet<String> = state.get(key).unwrap_or_default();
1604 set.remove(name);
1605 state.set(key, set)
1606 }
1607
1608 fn is_in_disabled_set(&self, key: &str, name: &str) -> bool {
1610 let Some(state) = self.state.as_ref() else {
1611 return false;
1612 };
1613 let set: std::collections::HashSet<String> = state.get(key).unwrap_or_default();
1614 set.contains(name)
1615 }
1616
1617 fn get_disabled_set(&self, key: &str) -> std::collections::HashSet<String> {
1619 self.state
1620 .as_ref()
1621 .and_then(|s| s.get(key))
1622 .unwrap_or_default()
1623 }
1624
1625 #[must_use]
1634 pub fn can_sample(&self) -> bool {
1635 self.sampling_sender.is_some()
1636 }
1637
1638 pub async fn sample(
1663 &self,
1664 prompt: impl Into<String>,
1665 max_tokens: u32,
1666 ) -> crate::McpResult<SamplingResponse> {
1667 let request = SamplingRequest::prompt(prompt, max_tokens);
1668 self.sample_with_request(request).await
1669 }
1670
1671 pub async fn sample_with_request(
1703 &self,
1704 request: SamplingRequest,
1705 ) -> crate::McpResult<SamplingResponse> {
1706 let sender = self.sampling_sender.as_ref().ok_or_else(|| {
1707 crate::McpError::new(
1708 crate::McpErrorCode::InvalidRequest,
1709 "Sampling not available: client does not support sampling capability",
1710 )
1711 })?;
1712
1713 sender.create_message(request).await
1714 }
1715
1716 #[must_use]
1725 pub fn can_elicit(&self) -> bool {
1726 self.elicitation_sender.is_some()
1727 }
1728
1729 pub async fn elicit_form(
1767 &self,
1768 message: impl Into<String>,
1769 schema: serde_json::Value,
1770 ) -> crate::McpResult<ElicitationResponse> {
1771 let request = ElicitationRequest::form(message, schema);
1772 self.elicit_with_request(request).await
1773 }
1774
1775 pub async fn elicit_url(
1809 &self,
1810 message: impl Into<String>,
1811 url: impl Into<String>,
1812 elicitation_id: impl Into<String>,
1813 ) -> crate::McpResult<ElicitationResponse> {
1814 let request = ElicitationRequest::url(message, url, elicitation_id);
1815 self.elicit_with_request(request).await
1816 }
1817
1818 pub async fn elicit_with_request(
1830 &self,
1831 request: ElicitationRequest,
1832 ) -> crate::McpResult<ElicitationResponse> {
1833 let sender = self.elicitation_sender.as_ref().ok_or_else(|| {
1834 crate::McpError::new(
1835 crate::McpErrorCode::InvalidRequest,
1836 "Elicitation not available: client does not support elicitation capability",
1837 )
1838 })?;
1839
1840 sender.elicit(request).await
1841 }
1842
1843 #[must_use]
1852 pub fn can_read_resources(&self) -> bool {
1853 self.resource_reader.is_some()
1854 }
1855
1856 #[must_use]
1860 pub fn resource_read_depth(&self) -> u32 {
1861 self.resource_read_depth
1862 }
1863
1864 pub async fn read_resource(&self, uri: &str) -> crate::McpResult<ResourceReadResult> {
1893 let reader = self.resource_reader.as_ref().ok_or_else(|| {
1895 crate::McpError::new(
1896 crate::McpErrorCode::InternalError,
1897 "Resource reading not available: no router attached to context",
1898 )
1899 })?;
1900
1901 if self.resource_read_depth >= MAX_RESOURCE_READ_DEPTH {
1903 return Err(crate::McpError::new(
1904 crate::McpErrorCode::InternalError,
1905 format!(
1906 "Maximum resource read depth ({}) exceeded; possible infinite recursion",
1907 MAX_RESOURCE_READ_DEPTH
1908 ),
1909 ));
1910 }
1911
1912 reader
1914 .read_resource(&self.cx, uri, self.resource_read_depth + 1)
1915 .await
1916 }
1917
1918 pub async fn read_resource_text(&self, uri: &str) -> crate::McpResult<String> {
1936 let result = self.read_resource(uri).await?;
1937 result.first_text().map(String::from).ok_or_else(|| {
1938 crate::McpError::new(
1939 crate::McpErrorCode::InternalError,
1940 format!("Resource '{}' has no text content", uri),
1941 )
1942 })
1943 }
1944
1945 pub async fn read_resource_json<T: serde::de::DeserializeOwned>(
1969 &self,
1970 uri: &str,
1971 ) -> crate::McpResult<T> {
1972 let text = self.read_resource_text(uri).await?;
1973 serde_json::from_str(&text).map_err(|e| {
1974 crate::McpError::new(
1975 crate::McpErrorCode::InternalError,
1976 format!("Failed to parse resource '{}' as JSON: {}", uri, e),
1977 )
1978 })
1979 }
1980
1981 #[must_use]
1990 pub fn can_call_tools(&self) -> bool {
1991 self.tool_caller.is_some()
1992 }
1993
1994 #[must_use]
1998 pub fn tool_call_depth(&self) -> u32 {
1999 self.tool_call_depth
2000 }
2001
2002 pub async fn call_tool(
2030 &self,
2031 name: &str,
2032 args: serde_json::Value,
2033 ) -> crate::McpResult<ToolCallResult> {
2034 let caller = self.tool_caller.as_ref().ok_or_else(|| {
2036 crate::McpError::new(
2037 crate::McpErrorCode::InternalError,
2038 "Tool calling not available: no router attached to context",
2039 )
2040 })?;
2041
2042 if self.tool_call_depth >= MAX_TOOL_CALL_DEPTH {
2044 return Err(crate::McpError::new(
2045 crate::McpErrorCode::InternalError,
2046 format!(
2047 "Maximum tool call depth ({}) exceeded calling '{}'; possible infinite recursion",
2048 MAX_TOOL_CALL_DEPTH, name
2049 ),
2050 ));
2051 }
2052
2053 caller
2055 .call_tool(&self.cx, name, args, self.tool_call_depth + 1)
2056 .await
2057 }
2058
2059 pub async fn call_tool_text(
2078 &self,
2079 name: &str,
2080 args: serde_json::Value,
2081 ) -> crate::McpResult<String> {
2082 let result = self.call_tool(name, args).await?;
2083
2084 if result.is_error {
2086 let error_msg = result.first_text().unwrap_or("Tool returned an error");
2087 return Err(crate::McpError::new(
2088 crate::McpErrorCode::InternalError,
2089 format!("Tool '{}' failed: {}", name, error_msg),
2090 ));
2091 }
2092
2093 result.first_text().map(String::from).ok_or_else(|| {
2094 crate::McpError::new(
2095 crate::McpErrorCode::InternalError,
2096 format!("Tool '{}' returned no text content", name),
2097 )
2098 })
2099 }
2100
2101 pub async fn call_tool_json<T: serde::de::DeserializeOwned>(
2126 &self,
2127 name: &str,
2128 args: serde_json::Value,
2129 ) -> crate::McpResult<T> {
2130 let text = self.call_tool_text(name, args).await?;
2131 serde_json::from_str(&text).map_err(|e| {
2132 crate::McpError::new(
2133 crate::McpErrorCode::InternalError,
2134 format!("Failed to parse tool '{}' result as JSON: {}", name, e),
2135 )
2136 })
2137 }
2138
2139 pub async fn join_all<T: Send + 'static>(
2159 &self,
2160 futures: Vec<crate::combinator::BoxFuture<'_, T>>,
2161 ) -> Vec<T> {
2162 crate::combinator::join_all(&self.cx, futures).await
2163 }
2164
2165 pub async fn race<T: Send + 'static>(
2180 &self,
2181 futures: Vec<crate::combinator::BoxFuture<'_, T>>,
2182 ) -> crate::McpResult<T> {
2183 crate::combinator::race(&self.cx, futures).await
2184 }
2185
2186 pub async fn quorum<T: Send + 'static>(
2202 &self,
2203 required: usize,
2204 futures: Vec<crate::combinator::BoxFuture<'_, crate::McpResult<T>>>,
2205 ) -> crate::McpResult<crate::combinator::QuorumResult<T>> {
2206 crate::combinator::quorum(&self.cx, required, futures).await
2207 }
2208
2209 pub async fn first_ok<T: Send + 'static>(
2224 &self,
2225 futures: Vec<crate::combinator::BoxFuture<'_, crate::McpResult<T>>>,
2226 ) -> crate::McpResult<T> {
2227 crate::combinator::first_ok(&self.cx, futures).await
2228 }
2229}
2230
2231#[derive(Debug, Clone, Copy)]
2237pub struct CancelledError;
2238
2239impl std::fmt::Display for CancelledError {
2240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2241 write!(f, "request cancelled")
2242 }
2243}
2244
2245impl std::error::Error for CancelledError {}
2246
2247pub trait IntoOutcome<T, E> {
2252 fn into_outcome(self) -> Outcome<T, E>;
2254}
2255
2256impl<T, E> IntoOutcome<T, E> for Result<T, E> {
2257 fn into_outcome(self) -> Outcome<T, E> {
2258 match self {
2259 Ok(v) => Outcome::Ok(v),
2260 Err(e) => Outcome::Err(e),
2261 }
2262 }
2263}
2264
2265impl<T, E> IntoOutcome<T, E> for Result<T, CancelledError>
2266where
2267 E: Default,
2268{
2269 fn into_outcome(self) -> Outcome<T, E> {
2270 match self {
2271 Ok(v) => Outcome::Ok(v),
2272 Err(CancelledError) => Outcome::Cancelled(CancelReason::user("request cancelled")),
2273 }
2274 }
2275}
2276
2277#[cfg(test)]
2278mod tests {
2279 use super::*;
2280
2281 #[test]
2282 fn test_mcp_context_creation() {
2283 let cx = Cx::for_testing();
2284 let ctx = McpContext::new(cx, 42);
2285
2286 assert_eq!(ctx.request_id(), 42);
2287 }
2288
2289 #[test]
2290 fn test_mcp_context_not_cancelled_initially() {
2291 let cx = Cx::for_testing();
2292 let ctx = McpContext::new(cx, 1);
2293
2294 assert!(!ctx.is_cancelled());
2295 }
2296
2297 #[test]
2298 fn test_mcp_context_checkpoint_success() {
2299 let cx = Cx::for_testing();
2300 let ctx = McpContext::new(cx, 1);
2301
2302 assert!(ctx.checkpoint().is_ok());
2304 }
2305
2306 #[test]
2307 fn test_mcp_context_checkpoint_cancelled() {
2308 let cx = Cx::for_testing();
2309 cx.set_cancel_requested(true);
2310 let ctx = McpContext::new(cx, 1);
2311
2312 assert!(ctx.checkpoint().is_err());
2314 }
2315
2316 #[test]
2317 fn test_mcp_context_checkpoint_budget_exhausted() {
2318 let cx = Cx::for_testing_with_budget(Budget::ZERO);
2319 let ctx = McpContext::new(cx, 1);
2320
2321 assert!(ctx.checkpoint().is_err());
2323 }
2324
2325 #[test]
2326 fn test_mcp_context_masked_section() {
2327 let cx = Cx::for_testing();
2328 let ctx = McpContext::new(cx, 1);
2329
2330 let result = ctx.masked(|| 42);
2332 assert_eq!(result, 42);
2333 }
2334
2335 #[test]
2336 fn test_mcp_context_budget() {
2337 let cx = Cx::for_testing();
2338 let ctx = McpContext::new(cx, 1);
2339
2340 let budget = ctx.budget();
2342 assert!(!budget.is_exhausted());
2344 }
2345
2346 #[test]
2347 fn test_cancelled_error_display() {
2348 let err = CancelledError;
2349 assert_eq!(err.to_string(), "request cancelled");
2350 }
2351
2352 #[test]
2353 fn test_into_outcome_ok() {
2354 let result: Result<i32, CancelledError> = Ok(42);
2355 let outcome: Outcome<i32, CancelledError> = result.into_outcome();
2356 assert!(matches!(outcome, Outcome::Ok(42)));
2357 }
2358
2359 #[test]
2360 fn test_into_outcome_cancelled() {
2361 let result: Result<i32, CancelledError> = Err(CancelledError);
2362 let outcome: Outcome<i32, ()> = result.into_outcome();
2363 assert!(matches!(outcome, Outcome::Cancelled(_)));
2364 }
2365
2366 #[test]
2367 fn test_mcp_context_no_progress_reporter_by_default() {
2368 let cx = Cx::for_testing();
2369 let ctx = McpContext::new(cx, 1);
2370 assert!(!ctx.has_progress_reporter());
2371 }
2372
2373 #[test]
2374 fn test_mcp_context_with_progress_reporter() {
2375 let cx = Cx::for_testing();
2376 let sender = Arc::new(NoOpNotificationSender);
2377 let reporter = ProgressReporter::new(sender);
2378 let ctx = McpContext::with_progress(cx, 1, reporter);
2379 assert!(ctx.has_progress_reporter());
2380 }
2381
2382 #[test]
2383 fn test_report_progress_without_reporter() {
2384 let cx = Cx::for_testing();
2385 let ctx = McpContext::new(cx, 1);
2386 ctx.report_progress(0.5, Some("test"));
2388 ctx.report_progress_with_total(5.0, 10.0, None);
2389 }
2390
2391 #[test]
2392 fn test_report_progress_with_reporter() {
2393 use std::sync::atomic::{AtomicU32, Ordering};
2394
2395 struct CountingSender {
2396 count: AtomicU32,
2397 }
2398
2399 impl NotificationSender for CountingSender {
2400 fn send_progress(&self, _progress: f64, _total: Option<f64>, _message: Option<&str>) {
2401 self.count.fetch_add(1, Ordering::SeqCst);
2402 }
2403 }
2404
2405 let cx = Cx::for_testing();
2406 let sender = Arc::new(CountingSender {
2407 count: AtomicU32::new(0),
2408 });
2409 let reporter = ProgressReporter::new(sender.clone());
2410 let ctx = McpContext::with_progress(cx, 1, reporter);
2411
2412 ctx.report_progress(0.25, Some("step 1"));
2413 ctx.report_progress(0.5, None);
2414 ctx.report_progress_with_total(3.0, 4.0, Some("step 3"));
2415
2416 assert_eq!(sender.count.load(Ordering::SeqCst), 3);
2417 }
2418
2419 #[test]
2420 fn test_progress_reporter_debug() {
2421 let sender = Arc::new(NoOpNotificationSender);
2422 let reporter = ProgressReporter::new(sender);
2423 let debug = format!("{reporter:?}");
2424 assert!(debug.contains("ProgressReporter"));
2425 }
2426
2427 #[test]
2428 fn test_noop_notification_sender() {
2429 let sender = NoOpNotificationSender;
2430 sender.send_progress(0.5, Some(1.0), Some("test"));
2432 }
2433
2434 #[test]
2436 fn test_mcp_context_no_session_state_by_default() {
2437 let cx = Cx::for_testing();
2438 let ctx = McpContext::new(cx, 1);
2439 assert!(!ctx.has_session_state());
2440 }
2441
2442 #[test]
2443 fn test_mcp_context_with_session_state() {
2444 let cx = Cx::for_testing();
2445 let state = SessionState::new();
2446 let ctx = McpContext::with_state(cx, 1, state);
2447 assert!(ctx.has_session_state());
2448 }
2449
2450 #[test]
2451 fn test_mcp_context_get_set_state() {
2452 let cx = Cx::for_testing();
2453 let state = SessionState::new();
2454 let ctx = McpContext::with_state(cx, 1, state);
2455
2456 assert!(ctx.set_state("counter", 42));
2458
2459 let value: Option<i32> = ctx.get_state("counter");
2461 assert_eq!(value, Some(42));
2462 }
2463
2464 #[test]
2465 fn test_mcp_context_state_not_available() {
2466 let cx = Cx::for_testing();
2467 let ctx = McpContext::new(cx, 1);
2468
2469 assert!(!ctx.set_state("key", "value"));
2471
2472 let value: Option<String> = ctx.get_state("key");
2474 assert!(value.is_none());
2475 }
2476
2477 #[test]
2478 fn test_mcp_context_has_state() {
2479 let cx = Cx::for_testing();
2480 let state = SessionState::new();
2481 let ctx = McpContext::with_state(cx, 1, state);
2482
2483 assert!(!ctx.has_state("missing"));
2484
2485 ctx.set_state("present", true);
2486 assert!(ctx.has_state("present"));
2487 }
2488
2489 #[test]
2490 fn test_mcp_context_remove_state() {
2491 let cx = Cx::for_testing();
2492 let state = SessionState::new();
2493 let ctx = McpContext::with_state(cx, 1, state);
2494
2495 ctx.set_state("key", "value");
2496 assert!(ctx.has_state("key"));
2497
2498 let removed = ctx.remove_state("key");
2499 assert!(removed.is_some());
2500 assert!(!ctx.has_state("key"));
2501 }
2502
2503 #[test]
2504 fn test_mcp_context_with_state_and_progress() {
2505 let cx = Cx::for_testing();
2506 let state = SessionState::new();
2507 let sender = Arc::new(NoOpNotificationSender);
2508 let reporter = ProgressReporter::new(sender);
2509
2510 let ctx = McpContext::with_state_and_progress(cx, 1, state, reporter);
2511
2512 assert!(ctx.has_session_state());
2513 assert!(ctx.has_progress_reporter());
2514 }
2515
2516 #[test]
2521 fn test_mcp_context_tools_enabled_by_default() {
2522 let cx = Cx::for_testing();
2523 let state = SessionState::new();
2524 let ctx = McpContext::with_state(cx, 1, state);
2525
2526 assert!(ctx.is_tool_enabled("any_tool"));
2527 assert!(ctx.is_tool_enabled("another_tool"));
2528 }
2529
2530 #[test]
2531 fn test_mcp_context_disable_enable_tool() {
2532 let cx = Cx::for_testing();
2533 let state = SessionState::new();
2534 let ctx = McpContext::with_state(cx, 1, state);
2535
2536 assert!(ctx.is_tool_enabled("my_tool"));
2538
2539 assert!(ctx.disable_tool("my_tool"));
2541 assert!(!ctx.is_tool_enabled("my_tool"));
2542 assert!(ctx.is_tool_enabled("other_tool"));
2543
2544 assert!(ctx.enable_tool("my_tool"));
2546 assert!(ctx.is_tool_enabled("my_tool"));
2547 }
2548
2549 #[test]
2550 fn test_mcp_context_disable_enable_resource() {
2551 let cx = Cx::for_testing();
2552 let state = SessionState::new();
2553 let ctx = McpContext::with_state(cx, 1, state);
2554
2555 assert!(ctx.is_resource_enabled("file://secret"));
2557
2558 assert!(ctx.disable_resource("file://secret"));
2560 assert!(!ctx.is_resource_enabled("file://secret"));
2561 assert!(ctx.is_resource_enabled("file://public"));
2562
2563 assert!(ctx.enable_resource("file://secret"));
2565 assert!(ctx.is_resource_enabled("file://secret"));
2566 }
2567
2568 #[test]
2569 fn test_mcp_context_disable_enable_prompt() {
2570 let cx = Cx::for_testing();
2571 let state = SessionState::new();
2572 let ctx = McpContext::with_state(cx, 1, state);
2573
2574 assert!(ctx.is_prompt_enabled("admin_prompt"));
2576
2577 assert!(ctx.disable_prompt("admin_prompt"));
2579 assert!(!ctx.is_prompt_enabled("admin_prompt"));
2580 assert!(ctx.is_prompt_enabled("user_prompt"));
2581
2582 assert!(ctx.enable_prompt("admin_prompt"));
2584 assert!(ctx.is_prompt_enabled("admin_prompt"));
2585 }
2586
2587 #[test]
2588 fn test_mcp_context_disable_multiple_tools() {
2589 let cx = Cx::for_testing();
2590 let state = SessionState::new();
2591 let ctx = McpContext::with_state(cx, 1, state);
2592
2593 ctx.disable_tool("tool1");
2594 ctx.disable_tool("tool2");
2595 ctx.disable_tool("tool3");
2596
2597 assert!(!ctx.is_tool_enabled("tool1"));
2598 assert!(!ctx.is_tool_enabled("tool2"));
2599 assert!(!ctx.is_tool_enabled("tool3"));
2600 assert!(ctx.is_tool_enabled("tool4"));
2601
2602 let disabled = ctx.disabled_tools();
2603 assert_eq!(disabled.len(), 3);
2604 assert!(disabled.contains("tool1"));
2605 assert!(disabled.contains("tool2"));
2606 assert!(disabled.contains("tool3"));
2607 }
2608
2609 #[test]
2610 fn test_mcp_context_disabled_sets_empty_by_default() {
2611 let cx = Cx::for_testing();
2612 let state = SessionState::new();
2613 let ctx = McpContext::with_state(cx, 1, state);
2614
2615 assert!(ctx.disabled_tools().is_empty());
2616 assert!(ctx.disabled_resources().is_empty());
2617 assert!(ctx.disabled_prompts().is_empty());
2618 }
2619
2620 #[test]
2621 fn test_mcp_context_enable_disable_no_state() {
2622 let cx = Cx::for_testing();
2623 let ctx = McpContext::new(cx, 1);
2624
2625 assert!(!ctx.disable_tool("tool"));
2627 assert!(!ctx.enable_tool("tool"));
2628
2629 assert!(ctx.is_tool_enabled("tool"));
2631 }
2632
2633 #[test]
2634 fn test_mcp_context_disabled_state_persists_across_contexts() {
2635 let state = SessionState::new();
2636
2637 {
2639 let cx = Cx::for_testing();
2640 let ctx = McpContext::with_state(cx, 1, state.clone());
2641 ctx.disable_tool("shared_tool");
2642 }
2643
2644 {
2646 let cx = Cx::for_testing();
2647 let ctx = McpContext::with_state(cx, 2, state.clone());
2648 assert!(!ctx.is_tool_enabled("shared_tool"));
2649 }
2650 }
2651
2652 #[test]
2657 fn test_mcp_context_no_capabilities_by_default() {
2658 let cx = Cx::for_testing();
2659 let ctx = McpContext::new(cx, 1);
2660
2661 assert!(ctx.client_capabilities().is_none());
2662 assert!(ctx.server_capabilities().is_none());
2663 assert!(!ctx.client_supports_sampling());
2664 assert!(!ctx.client_supports_elicitation());
2665 assert!(!ctx.client_supports_roots());
2666 }
2667
2668 #[test]
2669 fn test_mcp_context_with_client_capabilities() {
2670 let cx = Cx::for_testing();
2671 let caps = ClientCapabilityInfo::new()
2672 .with_sampling()
2673 .with_elicitation(true, false)
2674 .with_roots(true);
2675
2676 let ctx = McpContext::new(cx, 1).with_client_capabilities(caps);
2677
2678 assert!(ctx.client_capabilities().is_some());
2679 assert!(ctx.client_supports_sampling());
2680 assert!(ctx.client_supports_elicitation());
2681 assert!(ctx.client_supports_elicitation_form());
2682 assert!(!ctx.client_supports_elicitation_url());
2683 assert!(ctx.client_supports_roots());
2684 }
2685
2686 #[test]
2687 fn test_mcp_context_with_server_capabilities() {
2688 let cx = Cx::for_testing();
2689 let caps = ServerCapabilityInfo::new()
2690 .with_tools()
2691 .with_resources(true)
2692 .with_prompts()
2693 .with_logging();
2694
2695 let ctx = McpContext::new(cx, 1).with_server_capabilities(caps);
2696
2697 let server_caps = ctx.server_capabilities().unwrap();
2698 assert!(server_caps.tools);
2699 assert!(server_caps.resources);
2700 assert!(server_caps.resources_subscribe);
2701 assert!(server_caps.prompts);
2702 assert!(server_caps.logging);
2703 }
2704
2705 #[test]
2706 fn test_client_capability_info_builders() {
2707 let caps = ClientCapabilityInfo::new();
2708 assert!(!caps.sampling);
2709 assert!(!caps.elicitation);
2710 assert!(!caps.roots);
2711
2712 let caps = caps.with_sampling();
2713 assert!(caps.sampling);
2714
2715 let caps = ClientCapabilityInfo::new().with_elicitation(true, true);
2716 assert!(caps.elicitation);
2717 assert!(caps.elicitation_form);
2718 assert!(caps.elicitation_url);
2719
2720 let caps = ClientCapabilityInfo::new().with_roots(false);
2721 assert!(caps.roots);
2722 assert!(!caps.roots_list_changed);
2723 }
2724
2725 #[test]
2726 fn test_server_capability_info_builders() {
2727 let caps = ServerCapabilityInfo::new();
2728 assert!(!caps.tools);
2729 assert!(!caps.resources);
2730 assert!(!caps.prompts);
2731 assert!(!caps.logging);
2732
2733 let caps = caps
2734 .with_tools()
2735 .with_resources(false)
2736 .with_prompts()
2737 .with_logging();
2738 assert!(caps.tools);
2739 assert!(caps.resources);
2740 assert!(!caps.resources_subscribe);
2741 assert!(caps.prompts);
2742 assert!(caps.logging);
2743 }
2744
2745 #[test]
2750 fn test_resource_content_item_text() {
2751 let item = ResourceContentItem::text("test://uri", "hello");
2752 assert_eq!(item.uri, "test://uri");
2753 assert_eq!(item.mime_type.as_deref(), Some("text/plain"));
2754 assert_eq!(item.as_text(), Some("hello"));
2755 assert!(item.as_blob().is_none());
2756 assert!(item.is_text());
2757 assert!(!item.is_blob());
2758 }
2759
2760 #[test]
2761 fn test_resource_content_item_json() {
2762 let item = ResourceContentItem::json("data://config", r#"{"key":"val"}"#);
2763 assert_eq!(item.uri, "data://config");
2764 assert_eq!(item.mime_type.as_deref(), Some("application/json"));
2765 assert_eq!(item.as_text(), Some(r#"{"key":"val"}"#));
2766 assert!(item.is_text());
2767 assert!(!item.is_blob());
2768 }
2769
2770 #[test]
2771 fn test_resource_content_item_blob() {
2772 let item = ResourceContentItem::blob("binary://data", "application/octet-stream", "AQID");
2773 assert_eq!(item.uri, "binary://data");
2774 assert_eq!(item.mime_type.as_deref(), Some("application/octet-stream"));
2775 assert!(item.as_text().is_none());
2776 assert_eq!(item.as_blob(), Some("AQID"));
2777 assert!(!item.is_text());
2778 assert!(item.is_blob());
2779 }
2780
2781 #[test]
2786 fn test_resource_read_result_text() {
2787 let result = ResourceReadResult::text("test://doc", "content");
2788 assert_eq!(result.first_text(), Some("content"));
2789 assert!(result.first_blob().is_none());
2790 assert_eq!(result.contents.len(), 1);
2791 }
2792
2793 #[test]
2794 fn test_resource_read_result_new_multiple() {
2795 let result = ResourceReadResult::new(vec![
2796 ResourceContentItem::text("a://1", "first"),
2797 ResourceContentItem::blob("b://2", "image/png", "base64data"),
2798 ]);
2799 assert_eq!(result.contents.len(), 2);
2800 assert_eq!(result.first_text(), Some("first"));
2802 assert!(result.first_blob().is_none());
2804 }
2805
2806 #[test]
2807 fn test_resource_read_result_empty() {
2808 let result = ResourceReadResult::new(vec![]);
2809 assert!(result.first_text().is_none());
2810 assert!(result.first_blob().is_none());
2811 }
2812
2813 #[test]
2814 fn test_resource_read_result_blob_first() {
2815 let result = ResourceReadResult::new(vec![ResourceContentItem::blob(
2816 "b://1",
2817 "image/png",
2818 "data",
2819 )]);
2820 assert!(result.first_text().is_none());
2821 assert_eq!(result.first_blob(), Some("data"));
2822 }
2823
2824 #[test]
2829 fn test_tool_content_item_text() {
2830 let item = ToolContentItem::text("hello");
2831 assert_eq!(item.as_text(), Some("hello"));
2832 assert!(item.is_text());
2833 }
2834
2835 #[test]
2836 fn test_tool_content_item_image() {
2837 let item = ToolContentItem::Image {
2838 data: "base64img".to_string(),
2839 mime_type: "image/png".to_string(),
2840 };
2841 assert!(item.as_text().is_none());
2842 assert!(!item.is_text());
2843 }
2844
2845 #[test]
2846 fn test_tool_content_item_audio() {
2847 let item = ToolContentItem::Audio {
2848 data: "base64audio".to_string(),
2849 mime_type: "audio/wav".to_string(),
2850 };
2851 assert!(item.as_text().is_none());
2852 assert!(!item.is_text());
2853 }
2854
2855 #[test]
2856 fn test_tool_content_item_resource() {
2857 let item = ToolContentItem::Resource {
2858 uri: "file://test".to_string(),
2859 mime_type: Some("text/plain".to_string()),
2860 text: Some("embedded".to_string()),
2861 blob: None,
2862 };
2863 assert!(item.as_text().is_none());
2864 assert!(!item.is_text());
2865 }
2866
2867 #[test]
2872 fn test_tool_call_result_success() {
2873 let result = ToolCallResult::success(vec![
2874 ToolContentItem::text("item1"),
2875 ToolContentItem::text("item2"),
2876 ]);
2877 assert!(!result.is_error);
2878 assert_eq!(result.content.len(), 2);
2879 assert_eq!(result.first_text(), Some("item1"));
2880 }
2881
2882 #[test]
2883 fn test_tool_call_result_text() {
2884 let result = ToolCallResult::text("simple output");
2885 assert!(!result.is_error);
2886 assert_eq!(result.content.len(), 1);
2887 assert_eq!(result.first_text(), Some("simple output"));
2888 }
2889
2890 #[test]
2891 fn test_tool_call_result_error() {
2892 let result = ToolCallResult::error("something failed");
2893 assert!(result.is_error);
2894 assert_eq!(result.first_text(), Some("something failed"));
2895 }
2896
2897 #[test]
2898 fn test_tool_call_result_empty() {
2899 let result = ToolCallResult::success(vec![]);
2900 assert!(!result.is_error);
2901 assert!(result.first_text().is_none());
2902 }
2903
2904 #[test]
2909 fn test_elicitation_response_accept() {
2910 let mut data = std::collections::HashMap::new();
2911 data.insert("name".to_string(), serde_json::json!("Alice"));
2912 data.insert("age".to_string(), serde_json::json!(30));
2913 data.insert("active".to_string(), serde_json::json!(true));
2914
2915 let resp = ElicitationResponse::accept(data);
2916 assert!(resp.is_accepted());
2917 assert!(!resp.is_declined());
2918 assert!(!resp.is_cancelled());
2919 assert_eq!(resp.get_string("name"), Some("Alice"));
2920 assert_eq!(resp.get_int("age"), Some(30));
2921 assert_eq!(resp.get_bool("active"), Some(true));
2922 }
2923
2924 #[test]
2925 fn test_elicitation_response_accept_url() {
2926 let resp = ElicitationResponse::accept_url();
2927 assert!(resp.is_accepted());
2928 assert!(resp.content.is_none());
2929 assert!(resp.get_string("anything").is_none());
2930 }
2931
2932 #[test]
2933 fn test_elicitation_response_decline() {
2934 let resp = ElicitationResponse::decline();
2935 assert!(!resp.is_accepted());
2936 assert!(resp.is_declined());
2937 assert!(!resp.is_cancelled());
2938 assert!(resp.get_string("key").is_none());
2939 }
2940
2941 #[test]
2942 fn test_elicitation_response_cancel() {
2943 let resp = ElicitationResponse::cancel();
2944 assert!(!resp.is_accepted());
2945 assert!(!resp.is_declined());
2946 assert!(resp.is_cancelled());
2947 }
2948
2949 #[test]
2950 fn test_elicitation_response_missing_key() {
2951 let mut data = std::collections::HashMap::new();
2952 data.insert("exists".to_string(), serde_json::json!("value"));
2953 let resp = ElicitationResponse::accept(data);
2954
2955 assert!(resp.get_string("missing").is_none());
2956 assert!(resp.get_bool("missing").is_none());
2957 assert!(resp.get_int("missing").is_none());
2958 }
2959
2960 #[test]
2961 fn test_elicitation_response_type_mismatch() {
2962 let mut data = std::collections::HashMap::new();
2963 data.insert("num".to_string(), serde_json::json!(42));
2964 let resp = ElicitationResponse::accept(data);
2965
2966 assert!(resp.get_string("num").is_none());
2968 assert!(resp.get_bool("num").is_none());
2970 assert_eq!(resp.get_int("num"), Some(42));
2972 }
2973
2974 #[test]
2979 fn test_can_sample_false_by_default() {
2980 let cx = Cx::for_testing();
2981 let ctx = McpContext::new(cx, 1);
2982 assert!(!ctx.can_sample());
2983 }
2984
2985 #[test]
2986 fn test_can_elicit_false_by_default() {
2987 let cx = Cx::for_testing();
2988 let ctx = McpContext::new(cx, 1);
2989 assert!(!ctx.can_elicit());
2990 }
2991
2992 #[test]
2993 fn test_can_read_resources_false_by_default() {
2994 let cx = Cx::for_testing();
2995 let ctx = McpContext::new(cx, 1);
2996 assert!(!ctx.can_read_resources());
2997 }
2998
2999 #[test]
3000 fn test_can_call_tools_false_by_default() {
3001 let cx = Cx::for_testing();
3002 let ctx = McpContext::new(cx, 1);
3003 assert!(!ctx.can_call_tools());
3004 }
3005
3006 #[test]
3007 fn test_resource_read_depth_default() {
3008 let cx = Cx::for_testing();
3009 let ctx = McpContext::new(cx, 1);
3010 assert_eq!(ctx.resource_read_depth(), 0);
3011 }
3012
3013 #[test]
3014 fn test_tool_call_depth_default() {
3015 let cx = Cx::for_testing();
3016 let ctx = McpContext::new(cx, 1);
3017 assert_eq!(ctx.tool_call_depth(), 0);
3018 }
3019
3020 #[test]
3025 fn sampling_request_builder_chain() {
3026 let req = SamplingRequest::prompt("hello", 100)
3027 .with_system_prompt("You are helpful")
3028 .with_temperature(0.7)
3029 .with_stop_sequences(vec!["STOP".into()])
3030 .with_model_hints(vec!["gpt-4".into()]);
3031
3032 assert_eq!(req.messages.len(), 1);
3033 assert_eq!(req.max_tokens, 100);
3034 assert_eq!(req.system_prompt.as_deref(), Some("You are helpful"));
3035 assert_eq!(req.temperature, Some(0.7));
3036 assert_eq!(req.stop_sequences, vec!["STOP"]);
3037 assert_eq!(req.model_hints, vec!["gpt-4"]);
3038 }
3039
3040 #[test]
3041 fn sampling_request_message_roles() {
3042 let user = SamplingRequestMessage::user("hi");
3043 assert_eq!(user.role, SamplingRole::User);
3044 assert_eq!(user.text, "hi");
3045
3046 let asst = SamplingRequestMessage::assistant("hello");
3047 assert_eq!(asst.role, SamplingRole::Assistant);
3048 assert_eq!(asst.text, "hello");
3049 }
3050
3051 #[test]
3052 fn sampling_response_new_default_stop_reason() {
3053 let resp = SamplingResponse::new("output", "model-1");
3054 assert_eq!(resp.text, "output");
3055 assert_eq!(resp.model, "model-1");
3056 assert_eq!(resp.stop_reason, SamplingStopReason::EndTurn);
3057 assert_eq!(SamplingStopReason::default(), SamplingStopReason::EndTurn);
3058 }
3059
3060 #[test]
3061 fn noop_sampling_sender_returns_error() {
3062 let sender = NoOpSamplingSender;
3063 let req = SamplingRequest::prompt("test", 10);
3064 let result = crate::block_on(sender.create_message(req));
3065 assert!(result.is_err());
3066 }
3067
3068 #[test]
3069 fn noop_elicitation_sender_returns_error() {
3070 let sender = NoOpElicitationSender;
3071 let req = ElicitationRequest::form("msg", serde_json::json!({}));
3072 let result = crate::block_on(sender.elicit(req));
3073 assert!(result.is_err());
3074 }
3075
3076 #[test]
3077 fn elicitation_request_form_constructor() {
3078 let req = ElicitationRequest::form("Enter name", serde_json::json!({"type": "string"}));
3079 assert_eq!(req.mode, ElicitationMode::Form);
3080 assert_eq!(req.message, "Enter name");
3081 assert!(req.schema.is_some());
3082 assert!(req.url.is_none());
3083 assert!(req.elicitation_id.is_none());
3084 }
3085
3086 #[test]
3087 fn elicitation_request_url_constructor() {
3088 let req = ElicitationRequest::url("Login", "https://example.com", "id-1");
3089 assert_eq!(req.mode, ElicitationMode::Url);
3090 assert_eq!(req.message, "Login");
3091 assert_eq!(req.url.as_deref(), Some("https://example.com"));
3092 assert_eq!(req.elicitation_id.as_deref(), Some("id-1"));
3093 assert!(req.schema.is_none());
3094 }
3095
3096 #[test]
3097 fn mcp_context_with_sampling_enables_can_sample() {
3098 let cx = Cx::for_testing();
3099 let sender = Arc::new(NoOpSamplingSender);
3100 let ctx = McpContext::new(cx, 1).with_sampling(sender);
3101 assert!(ctx.can_sample());
3102 }
3103
3104 #[test]
3105 fn mcp_context_with_elicitation_enables_can_elicit() {
3106 let cx = Cx::for_testing();
3107 let sender = Arc::new(NoOpElicitationSender);
3108 let ctx = McpContext::new(cx, 1).with_elicitation(sender);
3109 assert!(ctx.can_elicit());
3110 }
3111
3112 #[test]
3113 fn mcp_context_depth_setters() {
3114 let cx = Cx::for_testing();
3115 let ctx = McpContext::new(cx, 1)
3116 .with_resource_read_depth(3)
3117 .with_tool_call_depth(5);
3118 assert_eq!(ctx.resource_read_depth(), 3);
3119 assert_eq!(ctx.tool_call_depth(), 5);
3120 }
3121
3122 #[test]
3123 fn mcp_context_debug_includes_request_id() {
3124 let cx = Cx::for_testing();
3125 let ctx = McpContext::new(cx, 99);
3126 let debug = format!("{ctx:?}");
3127 assert!(debug.contains("request_id: 99"));
3128 }
3129
3130 #[test]
3131 fn mcp_context_cx_and_trace() {
3132 let cx = Cx::for_testing();
3133 let ctx = McpContext::new(cx, 1);
3134 let _ = ctx.cx();
3136 ctx.trace("test event");
3138 }
3139}