use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use asupersync::time::wall_now;
use asupersync::types::CancelReason;
use asupersync::{Budget, CancelKind, Cx, Outcome, RegionId, TaskId};
use crate::{AuthContext, SessionState};
pub trait NotificationSender: Send + Sync {
fn send_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>);
}
pub trait SamplingSender: Send + Sync {
fn create_message(
&self,
request: SamplingRequest,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = crate::McpResult<SamplingResponse>> + Send + '_>,
>;
}
#[derive(Debug, Clone)]
pub struct SamplingRequest {
pub messages: Vec<SamplingRequestMessage>,
pub max_tokens: u32,
pub system_prompt: Option<String>,
pub temperature: Option<f64>,
pub stop_sequences: Vec<String>,
pub model_hints: Vec<String>,
}
impl SamplingRequest {
#[must_use]
pub fn new(messages: Vec<SamplingRequestMessage>, max_tokens: u32) -> Self {
Self {
messages,
max_tokens,
system_prompt: None,
temperature: None,
stop_sequences: Vec::new(),
model_hints: Vec::new(),
}
}
#[must_use]
pub fn prompt(text: impl Into<String>, max_tokens: u32) -> Self {
Self::new(vec![SamplingRequestMessage::user(text)], max_tokens)
}
#[must_use]
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
#[must_use]
pub fn with_temperature(mut self, temp: f64) -> Self {
self.temperature = Some(temp);
self
}
#[must_use]
pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
self.stop_sequences = sequences;
self
}
#[must_use]
pub fn with_model_hints(mut self, hints: Vec<String>) -> Self {
self.model_hints = hints;
self
}
}
#[derive(Debug, Clone)]
pub struct SamplingRequestMessage {
pub role: SamplingRole,
pub text: String,
}
impl SamplingRequestMessage {
#[must_use]
pub fn user(text: impl Into<String>) -> Self {
Self {
role: SamplingRole::User,
text: text.into(),
}
}
#[must_use]
pub fn assistant(text: impl Into<String>) -> Self {
Self {
role: SamplingRole::Assistant,
text: text.into(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SamplingRole {
User,
Assistant,
}
#[derive(Debug, Clone)]
pub struct SamplingResponse {
pub text: String,
pub model: String,
pub stop_reason: SamplingStopReason,
}
impl SamplingResponse {
#[must_use]
pub fn new(text: impl Into<String>, model: impl Into<String>) -> Self {
Self {
text: text.into(),
model: model.into(),
stop_reason: SamplingStopReason::EndTurn,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SamplingStopReason {
#[default]
EndTurn,
StopSequence,
MaxTokens,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NoOpSamplingSender;
impl SamplingSender for NoOpSamplingSender {
fn create_message(
&self,
_request: SamplingRequest,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = crate::McpResult<SamplingResponse>> + Send + '_>,
> {
Box::pin(async {
Err(crate::McpError::new(
crate::McpErrorCode::InvalidRequest,
"Sampling not supported: client does not have sampling capability",
))
})
}
}
pub trait ElicitationSender: Send + Sync {
fn elicit(
&self,
request: ElicitationRequest,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = crate::McpResult<ElicitationResponse>> + Send + '_>,
>;
}
#[derive(Debug, Clone)]
pub struct ElicitationRequest {
pub mode: ElicitationMode,
pub message: String,
pub schema: Option<serde_json::Value>,
pub url: Option<String>,
pub elicitation_id: Option<String>,
}
impl ElicitationRequest {
#[must_use]
pub fn form(message: impl Into<String>, schema: serde_json::Value) -> Self {
Self {
mode: ElicitationMode::Form,
message: message.into(),
schema: Some(schema),
url: None,
elicitation_id: None,
}
}
#[must_use]
pub fn url(
message: impl Into<String>,
url: impl Into<String>,
elicitation_id: impl Into<String>,
) -> Self {
Self {
mode: ElicitationMode::Url,
message: message.into(),
schema: None,
url: Some(url.into()),
elicitation_id: Some(elicitation_id.into()),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ElicitationMode {
Form,
Url,
}
#[derive(Debug, Clone)]
pub struct ElicitationResponse {
pub action: ElicitationAction,
pub content: Option<std::collections::HashMap<String, serde_json::Value>>,
}
impl ElicitationResponse {
#[must_use]
pub fn accept(content: std::collections::HashMap<String, serde_json::Value>) -> Self {
Self {
action: ElicitationAction::Accept,
content: Some(content),
}
}
#[must_use]
pub fn accept_url() -> Self {
Self {
action: ElicitationAction::Accept,
content: None,
}
}
#[must_use]
pub fn decline() -> Self {
Self {
action: ElicitationAction::Decline,
content: None,
}
}
#[must_use]
pub fn cancel() -> Self {
Self {
action: ElicitationAction::Cancel,
content: None,
}
}
#[must_use]
pub fn is_accepted(&self) -> bool {
matches!(self.action, ElicitationAction::Accept)
}
#[must_use]
pub fn is_declined(&self) -> bool {
matches!(self.action, ElicitationAction::Decline)
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
matches!(self.action, ElicitationAction::Cancel)
}
#[must_use]
pub fn get_string(&self, key: &str) -> Option<&str> {
self.content.as_ref()?.get(key)?.as_str()
}
#[must_use]
pub fn get_bool(&self, key: &str) -> Option<bool> {
self.content.as_ref()?.get(key)?.as_bool()
}
#[must_use]
pub fn get_int(&self, key: &str) -> Option<i64> {
self.content.as_ref()?.get(key)?.as_i64()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ElicitationAction {
Accept,
Decline,
Cancel,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NoOpElicitationSender;
impl ElicitationSender for NoOpElicitationSender {
fn elicit(
&self,
_request: ElicitationRequest,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = crate::McpResult<ElicitationResponse>> + Send + '_>,
> {
Box::pin(async {
Err(crate::McpError::new(
crate::McpErrorCode::InvalidRequest,
"Elicitation not supported: client does not have elicitation capability",
))
})
}
}
pub const MAX_RESOURCE_READ_DEPTH: u32 = 10;
#[derive(Debug, Clone)]
pub struct ResourceContentItem {
pub uri: String,
pub mime_type: Option<String>,
pub text: Option<String>,
pub blob: Option<String>,
}
impl ResourceContentItem {
#[must_use]
pub fn text(uri: impl Into<String>, text: impl Into<String>) -> Self {
Self {
uri: uri.into(),
mime_type: Some("text/plain".to_string()),
text: Some(text.into()),
blob: None,
}
}
#[must_use]
pub fn json(uri: impl Into<String>, text: impl Into<String>) -> Self {
Self {
uri: uri.into(),
mime_type: Some("application/json".to_string()),
text: Some(text.into()),
blob: None,
}
}
#[must_use]
pub fn blob(
uri: impl Into<String>,
mime_type: impl Into<String>,
blob: impl Into<String>,
) -> Self {
Self {
uri: uri.into(),
mime_type: Some(mime_type.into()),
text: None,
blob: Some(blob.into()),
}
}
#[must_use]
pub fn as_text(&self) -> Option<&str> {
self.text.as_deref()
}
#[must_use]
pub fn as_blob(&self) -> Option<&str> {
self.blob.as_deref()
}
#[must_use]
pub fn is_text(&self) -> bool {
self.text.is_some()
}
#[must_use]
pub fn is_blob(&self) -> bool {
self.blob.is_some()
}
}
#[derive(Debug, Clone)]
pub struct ResourceReadResult {
pub contents: Vec<ResourceContentItem>,
}
impl ResourceReadResult {
#[must_use]
pub fn new(contents: Vec<ResourceContentItem>) -> Self {
Self { contents }
}
#[must_use]
pub fn text(uri: impl Into<String>, text: impl Into<String>) -> Self {
Self {
contents: vec![ResourceContentItem::text(uri, text)],
}
}
#[must_use]
pub fn first_text(&self) -> Option<&str> {
self.contents.first().and_then(|c| c.as_text())
}
#[must_use]
pub fn first_blob(&self) -> Option<&str> {
self.contents.first().and_then(|c| c.as_blob())
}
}
pub trait ResourceReader: Send + Sync {
fn read_resource(
&self,
cx: &Cx,
uri: &str,
auth: Option<AuthContext>,
depth: u32,
) -> Pin<Box<dyn Future<Output = crate::McpResult<ResourceReadResult>> + Send + '_>>;
}
pub const MAX_TOOL_CALL_DEPTH: u32 = 10;
#[derive(Debug, Clone)]
pub enum ToolContentItem {
Text {
text: String,
},
Image {
data: String,
mime_type: String,
},
Audio {
data: String,
mime_type: String,
},
Resource {
uri: String,
mime_type: Option<String>,
text: Option<String>,
blob: Option<String>,
},
}
impl ToolContentItem {
#[must_use]
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
#[must_use]
pub fn as_text(&self) -> Option<&str> {
match self {
Self::Text { text } => Some(text),
_ => None,
}
}
#[must_use]
pub fn is_text(&self) -> bool {
matches!(self, Self::Text { .. })
}
}
#[derive(Debug, Clone)]
pub struct ToolCallResult {
pub content: Vec<ToolContentItem>,
pub is_error: bool,
}
impl ToolCallResult {
#[must_use]
pub fn success(content: Vec<ToolContentItem>) -> Self {
Self {
content,
is_error: false,
}
}
#[must_use]
pub fn text(text: impl Into<String>) -> Self {
Self {
content: vec![ToolContentItem::text(text)],
is_error: false,
}
}
#[must_use]
pub fn error(message: impl Into<String>) -> Self {
Self {
content: vec![ToolContentItem::text(message)],
is_error: true,
}
}
#[must_use]
pub fn first_text(&self) -> Option<&str> {
self.content.first().and_then(|c| c.as_text())
}
}
pub trait ToolCaller: Send + Sync {
fn call_tool(
&self,
cx: &Cx,
name: &str,
args: serde_json::Value,
auth: Option<AuthContext>,
depth: u32,
) -> Pin<Box<dyn Future<Output = crate::McpResult<ToolCallResult>> + Send + '_>>;
}
#[derive(Debug, Clone, Default)]
pub struct ClientCapabilityInfo {
pub sampling: bool,
pub elicitation: bool,
pub elicitation_form: bool,
pub elicitation_url: bool,
pub roots: bool,
pub roots_list_changed: bool,
}
impl ClientCapabilityInfo {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_sampling(mut self) -> Self {
self.sampling = true;
self
}
#[must_use]
pub fn with_elicitation(mut self, form: bool, url: bool) -> Self {
self.elicitation = form || url;
self.elicitation_form = form;
self.elicitation_url = url;
self
}
#[must_use]
pub fn with_roots(mut self, list_changed: bool) -> Self {
self.roots = true;
self.roots_list_changed = list_changed;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct ServerCapabilityInfo {
pub tools: bool,
pub resources: bool,
pub resources_subscribe: bool,
pub prompts: bool,
pub logging: bool,
}
impl ServerCapabilityInfo {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_tools(mut self) -> Self {
self.tools = true;
self
}
#[must_use]
pub fn with_resources(mut self, subscribe: bool) -> Self {
self.resources = true;
self.resources_subscribe = subscribe;
self
}
#[must_use]
pub fn with_prompts(mut self) -> Self {
self.prompts = true;
self
}
#[must_use]
pub fn with_logging(mut self) -> Self {
self.logging = true;
self
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NoOpNotificationSender;
impl NotificationSender for NoOpNotificationSender {
fn send_progress(&self, _progress: f64, _total: Option<f64>, _message: Option<&str>) {
}
}
#[derive(Clone)]
pub struct ProgressReporter {
sender: Arc<dyn NotificationSender>,
}
impl ProgressReporter {
pub fn new(sender: Arc<dyn NotificationSender>) -> Self {
Self { sender }
}
pub fn report(&self, progress: f64, message: Option<&str>) {
self.sender.send_progress(progress, None, message);
}
pub fn report_with_total(&self, progress: f64, total: f64, message: Option<&str>) {
self.sender.send_progress(progress, Some(total), message);
}
}
impl std::fmt::Debug for ProgressReporter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProgressReporter").finish_non_exhaustive()
}
}
#[derive(Clone)]
pub struct McpContext {
cx: Cx,
request_id: u64,
progress_reporter: Option<ProgressReporter>,
state: Option<SessionState>,
auth: Arc<Mutex<Option<AuthContext>>>,
sampling_sender: Option<Arc<dyn SamplingSender>>,
elicitation_sender: Option<Arc<dyn ElicitationSender>>,
resource_reader: Option<Arc<dyn ResourceReader>>,
resource_read_depth: u32,
tool_caller: Option<Arc<dyn ToolCaller>>,
tool_call_depth: u32,
client_capabilities: Option<ClientCapabilityInfo>,
server_capabilities: Option<ServerCapabilityInfo>,
}
impl std::fmt::Debug for McpContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpContext")
.field("cx", &self.cx)
.field("request_id", &self.request_id)
.field("progress_reporter", &self.progress_reporter)
.field("state", &self.state.is_some())
.field(
"auth",
&self
.auth
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.is_some(),
)
.field("sampling_sender", &self.sampling_sender.is_some())
.field("elicitation_sender", &self.elicitation_sender.is_some())
.field("resource_reader", &self.resource_reader.is_some())
.field("resource_read_depth", &self.resource_read_depth)
.field("tool_caller", &self.tool_caller.is_some())
.field("tool_call_depth", &self.tool_call_depth)
.field("client_capabilities", &self.client_capabilities)
.field("server_capabilities", &self.server_capabilities)
.finish()
}
}
impl McpContext {
#[must_use]
pub fn new(cx: Cx, request_id: u64) -> Self {
Self {
cx,
request_id,
progress_reporter: None,
state: None,
auth: Arc::new(Mutex::new(None)),
sampling_sender: None,
elicitation_sender: None,
resource_reader: None,
resource_read_depth: 0,
tool_caller: None,
tool_call_depth: 0,
client_capabilities: None,
server_capabilities: None,
}
}
#[must_use]
pub fn with_state(cx: Cx, request_id: u64, state: SessionState) -> Self {
Self {
cx,
request_id,
progress_reporter: None,
state: Some(state),
auth: Arc::new(Mutex::new(None)),
sampling_sender: None,
elicitation_sender: None,
resource_reader: None,
resource_read_depth: 0,
tool_caller: None,
tool_call_depth: 0,
client_capabilities: None,
server_capabilities: None,
}
}
#[must_use]
pub fn with_progress(cx: Cx, request_id: u64, reporter: ProgressReporter) -> Self {
Self {
cx,
request_id,
progress_reporter: Some(reporter),
state: None,
auth: Arc::new(Mutex::new(None)),
sampling_sender: None,
elicitation_sender: None,
resource_reader: None,
resource_read_depth: 0,
tool_caller: None,
tool_call_depth: 0,
client_capabilities: None,
server_capabilities: None,
}
}
#[must_use]
pub fn with_state_and_progress(
cx: Cx,
request_id: u64,
state: SessionState,
reporter: ProgressReporter,
) -> Self {
Self {
cx,
request_id,
progress_reporter: Some(reporter),
state: Some(state),
auth: Arc::new(Mutex::new(None)),
sampling_sender: None,
elicitation_sender: None,
resource_reader: None,
resource_read_depth: 0,
tool_caller: None,
tool_call_depth: 0,
client_capabilities: None,
server_capabilities: None,
}
}
#[must_use]
pub fn with_sampling(mut self, sender: Arc<dyn SamplingSender>) -> Self {
self.sampling_sender = Some(sender);
self
}
#[must_use]
pub fn with_elicitation(mut self, sender: Arc<dyn ElicitationSender>) -> Self {
self.elicitation_sender = Some(sender);
self
}
#[must_use]
pub fn with_resource_reader(mut self, reader: Arc<dyn ResourceReader>) -> Self {
self.resource_reader = Some(reader);
self
}
#[must_use]
pub fn with_resource_read_depth(mut self, depth: u32) -> Self {
self.resource_read_depth = depth;
self
}
#[must_use]
pub fn with_tool_caller(mut self, caller: Arc<dyn ToolCaller>) -> Self {
self.tool_caller = Some(caller);
self
}
#[must_use]
pub fn with_tool_call_depth(mut self, depth: u32) -> Self {
self.tool_call_depth = depth;
self
}
#[must_use]
pub fn with_client_capabilities(mut self, capabilities: ClientCapabilityInfo) -> Self {
self.client_capabilities = Some(capabilities);
self
}
#[must_use]
pub fn with_server_capabilities(mut self, capabilities: ServerCapabilityInfo) -> Self {
self.server_capabilities = Some(capabilities);
self
}
#[must_use]
pub fn has_progress_reporter(&self) -> bool {
self.progress_reporter.is_some()
}
pub fn report_progress(&self, progress: f64, message: Option<&str>) {
if let Some(ref reporter) = self.progress_reporter {
reporter.report(progress, message);
}
}
pub fn report_progress_with_total(&self, progress: f64, total: f64, message: Option<&str>) {
if let Some(ref reporter) = self.progress_reporter {
reporter.report_with_total(progress, total, message);
}
}
#[must_use]
pub fn request_id(&self) -> u64 {
self.request_id
}
#[must_use]
pub fn region_id(&self) -> RegionId {
self.cx.region_id()
}
#[must_use]
pub fn task_id(&self) -> TaskId {
self.cx.task_id()
}
#[must_use]
pub fn budget(&self) -> Budget {
self.cx.budget()
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
let budget = self.cx.budget();
self.cx.is_cancel_requested()
|| budget.is_exhausted()
|| budget.is_past_deadline(wall_now())
}
pub fn checkpoint(&self) -> Result<(), CancelledError> {
self.cx.checkpoint().map_err(|_| CancelledError)?;
let budget = self.cx.budget();
if budget.is_exhausted() {
return Err(CancelledError);
}
if budget.is_past_deadline(wall_now()) {
self.cx.cancel_fast(CancelKind::Deadline);
return Err(CancelledError);
}
Ok(())
}
pub fn masked<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
self.cx.masked(f)
}
pub fn trace(&self, message: &str) {
self.cx.trace(message);
}
#[must_use]
pub fn cx(&self) -> &Cx {
&self.cx
}
#[must_use]
pub fn get_state<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
self.state.as_ref()?.get(key)
}
#[must_use]
pub fn auth(&self) -> Option<AuthContext> {
self.auth
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone()
}
pub fn set_auth(&self, auth: AuthContext) -> bool {
*self
.auth
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(auth);
true
}
#[must_use]
pub fn with_auth(self, auth: AuthContext) -> Self {
let _ = self.set_auth(auth);
self
}
pub fn set_state<T: serde::Serialize>(&self, key: impl Into<String>, value: T) -> bool {
match &self.state {
Some(state) => state.set(key, value),
None => false,
}
}
pub fn remove_state(&self, key: &str) -> Option<serde_json::Value> {
self.state.as_ref()?.remove(key)
}
#[must_use]
pub fn has_state(&self, key: &str) -> bool {
self.state.as_ref().is_some_and(|s| s.contains(key))
}
#[must_use]
pub fn has_session_state(&self) -> bool {
self.state.is_some()
}
#[must_use]
pub fn client_capabilities(&self) -> Option<&ClientCapabilityInfo> {
self.client_capabilities.as_ref()
}
#[must_use]
pub fn server_capabilities(&self) -> Option<&ServerCapabilityInfo> {
self.server_capabilities.as_ref()
}
#[must_use]
pub fn client_supports_sampling(&self) -> bool {
self.client_capabilities
.as_ref()
.is_some_and(|c| c.sampling)
}
#[must_use]
pub fn client_supports_elicitation(&self) -> bool {
self.client_capabilities
.as_ref()
.is_some_and(|c| c.elicitation)
}
#[must_use]
pub fn client_supports_elicitation_form(&self) -> bool {
self.client_capabilities
.as_ref()
.is_some_and(|c| c.elicitation_form)
}
#[must_use]
pub fn client_supports_elicitation_url(&self) -> bool {
self.client_capabilities
.as_ref()
.is_some_and(|c| c.elicitation_url)
}
#[must_use]
pub fn client_supports_roots(&self) -> bool {
self.client_capabilities.as_ref().is_some_and(|c| c.roots)
}
const DISABLED_TOOLS_KEY: &'static str = "fastmcp.disabled_tools";
const DISABLED_RESOURCES_KEY: &'static str = "fastmcp.disabled_resources";
const DISABLED_PROMPTS_KEY: &'static str = "fastmcp.disabled_prompts";
pub fn disable_tool(&self, name: impl Into<String>) -> bool {
self.add_to_disabled_set(Self::DISABLED_TOOLS_KEY, name.into())
}
pub fn enable_tool(&self, name: &str) -> bool {
self.remove_from_disabled_set(Self::DISABLED_TOOLS_KEY, name)
}
#[must_use]
pub fn is_tool_enabled(&self, name: &str) -> bool {
!self.is_in_disabled_set(Self::DISABLED_TOOLS_KEY, name)
}
pub fn disable_resource(&self, uri: impl Into<String>) -> bool {
self.add_to_disabled_set(Self::DISABLED_RESOURCES_KEY, uri.into())
}
pub fn enable_resource(&self, uri: &str) -> bool {
self.remove_from_disabled_set(Self::DISABLED_RESOURCES_KEY, uri)
}
#[must_use]
pub fn is_resource_enabled(&self, uri: &str) -> bool {
!self.is_in_disabled_set(Self::DISABLED_RESOURCES_KEY, uri)
}
pub fn disable_prompt(&self, name: impl Into<String>) -> bool {
self.add_to_disabled_set(Self::DISABLED_PROMPTS_KEY, name.into())
}
pub fn enable_prompt(&self, name: &str) -> bool {
self.remove_from_disabled_set(Self::DISABLED_PROMPTS_KEY, name)
}
#[must_use]
pub fn is_prompt_enabled(&self, name: &str) -> bool {
!self.is_in_disabled_set(Self::DISABLED_PROMPTS_KEY, name)
}
#[must_use]
pub fn disabled_tools(&self) -> std::collections::HashSet<String> {
self.get_disabled_set(Self::DISABLED_TOOLS_KEY)
}
#[must_use]
pub fn disabled_resources(&self) -> std::collections::HashSet<String> {
self.get_disabled_set(Self::DISABLED_RESOURCES_KEY)
}
#[must_use]
pub fn disabled_prompts(&self) -> std::collections::HashSet<String> {
self.get_disabled_set(Self::DISABLED_PROMPTS_KEY)
}
fn add_to_disabled_set(&self, key: &str, name: String) -> bool {
let Some(state) = self.state.as_ref() else {
return false;
};
let mut set: std::collections::HashSet<String> = state.get(key).unwrap_or_default();
set.insert(name);
state.set(key, set)
}
fn remove_from_disabled_set(&self, key: &str, name: &str) -> bool {
let Some(state) = self.state.as_ref() else {
return false;
};
let mut set: std::collections::HashSet<String> = state.get(key).unwrap_or_default();
set.remove(name);
state.set(key, set)
}
fn is_in_disabled_set(&self, key: &str, name: &str) -> bool {
let Some(state) = self.state.as_ref() else {
return false;
};
let set: std::collections::HashSet<String> = state.get(key).unwrap_or_default();
set.contains(name)
}
fn get_disabled_set(&self, key: &str) -> std::collections::HashSet<String> {
self.state
.as_ref()
.and_then(|s| s.get(key))
.unwrap_or_default()
}
#[must_use]
pub fn can_sample(&self) -> bool {
self.sampling_sender.is_some()
}
pub async fn sample(
&self,
prompt: impl Into<String>,
max_tokens: u32,
) -> crate::McpResult<SamplingResponse> {
let request = SamplingRequest::prompt(prompt, max_tokens);
self.sample_with_request(request).await
}
pub async fn sample_with_request(
&self,
request: SamplingRequest,
) -> crate::McpResult<SamplingResponse> {
let sender = self.sampling_sender.as_ref().ok_or_else(|| {
crate::McpError::new(
crate::McpErrorCode::InvalidRequest,
"Sampling not available: client does not support sampling capability",
)
})?;
sender.create_message(request).await
}
#[must_use]
pub fn can_elicit(&self) -> bool {
self.elicitation_sender.is_some()
}
pub async fn elicit_form(
&self,
message: impl Into<String>,
schema: serde_json::Value,
) -> crate::McpResult<ElicitationResponse> {
let request = ElicitationRequest::form(message, schema);
self.elicit_with_request(request).await
}
pub async fn elicit_url(
&self,
message: impl Into<String>,
url: impl Into<String>,
elicitation_id: impl Into<String>,
) -> crate::McpResult<ElicitationResponse> {
let request = ElicitationRequest::url(message, url, elicitation_id);
self.elicit_with_request(request).await
}
pub async fn elicit_with_request(
&self,
request: ElicitationRequest,
) -> crate::McpResult<ElicitationResponse> {
let sender = self.elicitation_sender.as_ref().ok_or_else(|| {
crate::McpError::new(
crate::McpErrorCode::InvalidRequest,
"Elicitation not available: client does not support elicitation capability",
)
})?;
sender.elicit(request).await
}
#[must_use]
pub fn can_read_resources(&self) -> bool {
self.resource_reader.is_some()
}
#[must_use]
pub fn resource_read_depth(&self) -> u32 {
self.resource_read_depth
}
pub async fn read_resource(&self, uri: &str) -> crate::McpResult<ResourceReadResult> {
let reader = self.resource_reader.as_ref().ok_or_else(|| {
crate::McpError::new(
crate::McpErrorCode::InternalError,
"Resource reading not available: no router attached to context",
)
})?;
if self.resource_read_depth >= MAX_RESOURCE_READ_DEPTH {
return Err(crate::McpError::new(
crate::McpErrorCode::InternalError,
format!(
"Maximum resource read depth ({}) exceeded; possible infinite recursion",
MAX_RESOURCE_READ_DEPTH
),
));
}
reader
.read_resource(&self.cx, uri, self.auth(), self.resource_read_depth + 1)
.await
}
pub async fn read_resource_text(&self, uri: &str) -> crate::McpResult<String> {
let result = self.read_resource(uri).await?;
result.first_text().map(String::from).ok_or_else(|| {
crate::McpError::new(
crate::McpErrorCode::InternalError,
format!("Resource '{}' has no text content", uri),
)
})
}
pub async fn read_resource_json<T: serde::de::DeserializeOwned>(
&self,
uri: &str,
) -> crate::McpResult<T> {
let text = self.read_resource_text(uri).await?;
serde_json::from_str(&text).map_err(|e| {
crate::McpError::new(
crate::McpErrorCode::InternalError,
format!("Failed to parse resource '{}' as JSON: {}", uri, e),
)
})
}
#[must_use]
pub fn can_call_tools(&self) -> bool {
self.tool_caller.is_some()
}
#[must_use]
pub fn tool_call_depth(&self) -> u32 {
self.tool_call_depth
}
pub async fn call_tool(
&self,
name: &str,
args: serde_json::Value,
) -> crate::McpResult<ToolCallResult> {
let caller = self.tool_caller.as_ref().ok_or_else(|| {
crate::McpError::new(
crate::McpErrorCode::InternalError,
"Tool calling not available: no router attached to context",
)
})?;
if self.tool_call_depth >= MAX_TOOL_CALL_DEPTH {
return Err(crate::McpError::new(
crate::McpErrorCode::InternalError,
format!(
"Maximum tool call depth ({}) exceeded calling '{}'; possible infinite recursion",
MAX_TOOL_CALL_DEPTH, name
),
));
}
caller
.call_tool(&self.cx, name, args, self.auth(), self.tool_call_depth + 1)
.await
}
pub async fn call_tool_text(
&self,
name: &str,
args: serde_json::Value,
) -> crate::McpResult<String> {
let result = self.call_tool(name, args).await?;
if result.is_error {
let error_msg = result.first_text().unwrap_or("Tool returned an error");
return Err(crate::McpError::new(
crate::McpErrorCode::InternalError,
format!("Tool '{}' failed: {}", name, error_msg),
));
}
result.first_text().map(String::from).ok_or_else(|| {
crate::McpError::new(
crate::McpErrorCode::InternalError,
format!("Tool '{}' returned no text content", name),
)
})
}
pub async fn call_tool_json<T: serde::de::DeserializeOwned>(
&self,
name: &str,
args: serde_json::Value,
) -> crate::McpResult<T> {
let text = self.call_tool_text(name, args).await?;
serde_json::from_str(&text).map_err(|e| {
crate::McpError::new(
crate::McpErrorCode::InternalError,
format!("Failed to parse tool '{}' result as JSON: {}", name, e),
)
})
}
pub async fn join_all<T: Send + 'static>(
&self,
futures: Vec<crate::combinator::BoxFuture<'_, T>>,
) -> Vec<T> {
crate::combinator::join_all(&self.cx, futures).await
}
pub async fn race<T: Send + 'static>(
&self,
futures: Vec<crate::combinator::BoxFuture<'_, T>>,
) -> crate::McpResult<T> {
crate::combinator::race(&self.cx, futures).await
}
pub async fn quorum<T: Send + 'static>(
&self,
required: usize,
futures: Vec<crate::combinator::BoxFuture<'_, crate::McpResult<T>>>,
) -> crate::McpResult<crate::combinator::QuorumResult<T>> {
crate::combinator::quorum(&self.cx, required, futures).await
}
pub async fn first_ok<T: Send + 'static>(
&self,
futures: Vec<crate::combinator::BoxFuture<'_, crate::McpResult<T>>>,
) -> crate::McpResult<T> {
crate::combinator::first_ok(&self.cx, futures).await
}
}
#[derive(Debug, Clone, Copy)]
pub struct CancelledError;
impl std::fmt::Display for CancelledError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "request cancelled")
}
}
impl std::error::Error for CancelledError {}
pub trait IntoOutcome<T, E> {
fn into_outcome(self) -> Outcome<T, E>;
}
impl<T, E> IntoOutcome<T, E> for Result<T, E> {
fn into_outcome(self) -> Outcome<T, E> {
match self {
Ok(v) => Outcome::Ok(v),
Err(e) => Outcome::Err(e),
}
}
}
impl<T, E> IntoOutcome<T, E> for Result<T, CancelledError>
where
E: Default,
{
fn into_outcome(self) -> Outcome<T, E> {
match self {
Ok(v) => Outcome::Ok(v),
Err(CancelledError) => Outcome::Cancelled(CancelReason::user("request cancelled")),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mcp_context_creation() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 42);
assert_eq!(ctx.request_id(), 42);
}
#[test]
fn test_mcp_context_not_cancelled_initially() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
assert!(!ctx.is_cancelled());
}
#[test]
fn test_mcp_context_checkpoint_success() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
assert!(ctx.checkpoint().is_ok());
}
#[test]
fn test_mcp_context_checkpoint_cancelled() {
let cx = Cx::for_testing();
cx.set_cancel_requested(true);
let ctx = McpContext::new(cx, 1);
assert!(ctx.checkpoint().is_err());
}
#[test]
fn test_mcp_context_checkpoint_budget_exhausted() {
let cx = Cx::for_testing_with_budget(Budget::ZERO);
let ctx = McpContext::new(cx, 1);
assert!(ctx.checkpoint().is_err());
}
#[test]
fn test_mcp_context_masked_section() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
let result = ctx.masked(|| 42);
assert_eq!(result, 42);
}
#[test]
fn test_mcp_context_budget() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
let budget = ctx.budget();
assert!(!budget.is_exhausted());
}
#[test]
fn test_cancelled_error_display() {
let err = CancelledError;
assert_eq!(err.to_string(), "request cancelled");
}
#[test]
fn test_into_outcome_ok() {
let result: Result<i32, CancelledError> = Ok(42);
let outcome: Outcome<i32, CancelledError> = result.into_outcome();
assert!(matches!(outcome, Outcome::Ok(42)));
}
#[test]
fn test_into_outcome_cancelled() {
let result: Result<i32, CancelledError> = Err(CancelledError);
let outcome: Outcome<i32, ()> = result.into_outcome();
assert!(matches!(outcome, Outcome::Cancelled(_)));
}
#[test]
fn test_mcp_context_no_progress_reporter_by_default() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
assert!(!ctx.has_progress_reporter());
}
#[test]
fn test_mcp_context_with_progress_reporter() {
let cx = Cx::for_testing();
let sender = Arc::new(NoOpNotificationSender);
let reporter = ProgressReporter::new(sender);
let ctx = McpContext::with_progress(cx, 1, reporter);
assert!(ctx.has_progress_reporter());
}
#[test]
fn test_report_progress_without_reporter() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
ctx.report_progress(0.5, Some("test"));
ctx.report_progress_with_total(5.0, 10.0, None);
}
#[test]
fn test_report_progress_with_reporter() {
use std::sync::atomic::{AtomicU32, Ordering};
struct CountingSender {
count: AtomicU32,
}
impl NotificationSender for CountingSender {
fn send_progress(&self, _progress: f64, _total: Option<f64>, _message: Option<&str>) {
self.count.fetch_add(1, Ordering::SeqCst);
}
}
let cx = Cx::for_testing();
let sender = Arc::new(CountingSender {
count: AtomicU32::new(0),
});
let reporter = ProgressReporter::new(sender.clone());
let ctx = McpContext::with_progress(cx, 1, reporter);
ctx.report_progress(0.25, Some("step 1"));
ctx.report_progress(0.5, None);
ctx.report_progress_with_total(3.0, 4.0, Some("step 3"));
assert_eq!(sender.count.load(Ordering::SeqCst), 3);
}
#[test]
fn test_progress_reporter_debug() {
let sender = Arc::new(NoOpNotificationSender);
let reporter = ProgressReporter::new(sender);
let debug = format!("{reporter:?}");
assert!(debug.contains("ProgressReporter"));
}
#[test]
fn test_noop_notification_sender() {
let sender = NoOpNotificationSender;
sender.send_progress(0.5, Some(1.0), Some("test"));
}
#[test]
fn test_mcp_context_no_session_state_by_default() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
assert!(!ctx.has_session_state());
}
#[test]
fn test_mcp_context_with_session_state() {
let cx = Cx::for_testing();
let state = SessionState::new();
let ctx = McpContext::with_state(cx, 1, state);
assert!(ctx.has_session_state());
}
#[test]
fn test_mcp_context_get_set_state() {
let cx = Cx::for_testing();
let state = SessionState::new();
let ctx = McpContext::with_state(cx, 1, state);
assert!(ctx.set_state("counter", 42));
let value: Option<i32> = ctx.get_state("counter");
assert_eq!(value, Some(42));
}
#[test]
fn test_mcp_context_state_not_available() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
assert!(!ctx.set_state("key", "value"));
let value: Option<String> = ctx.get_state("key");
assert!(value.is_none());
}
#[test]
fn test_mcp_context_has_state() {
let cx = Cx::for_testing();
let state = SessionState::new();
let ctx = McpContext::with_state(cx, 1, state);
assert!(!ctx.has_state("missing"));
ctx.set_state("present", true);
assert!(ctx.has_state("present"));
}
#[test]
fn test_mcp_context_remove_state() {
let cx = Cx::for_testing();
let state = SessionState::new();
let ctx = McpContext::with_state(cx, 1, state);
ctx.set_state("key", "value");
assert!(ctx.has_state("key"));
let removed = ctx.remove_state("key");
assert!(removed.is_some());
assert!(!ctx.has_state("key"));
}
#[test]
fn test_mcp_context_with_state_and_progress() {
let cx = Cx::for_testing();
let state = SessionState::new();
let sender = Arc::new(NoOpNotificationSender);
let reporter = ProgressReporter::new(sender);
let ctx = McpContext::with_state_and_progress(cx, 1, state, reporter);
assert!(ctx.has_session_state());
assert!(ctx.has_progress_reporter());
}
#[test]
fn test_mcp_context_auth_is_request_local() {
let cx = Cx::for_testing();
let state = SessionState::new();
let ctx = McpContext::with_state(cx, 1, state.clone());
assert!(ctx.set_auth(AuthContext::with_subject("alice")));
assert_eq!(
ctx.auth().and_then(|auth| auth.subject),
Some("alice".to_string())
);
let stored: Option<AuthContext> = state.get(crate::AUTH_STATE_KEY);
assert!(
stored.is_none(),
"request auth must not be persisted into session state"
);
}
#[test]
fn test_mcp_context_clones_share_request_auth() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
let cloned = ctx.clone();
assert!(cloned.set_auth(AuthContext::with_subject("bob")));
assert_eq!(
ctx.auth().and_then(|auth| auth.subject),
Some("bob".to_string())
);
}
#[test]
fn test_new_mcp_contexts_do_not_share_request_auth_even_with_same_cx() {
let cx = Cx::for_testing();
let state = SessionState::new();
let first = McpContext::with_state(cx.clone(), 7, state.clone());
let second = McpContext::with_state(cx, 7, state);
assert!(first.set_auth(AuthContext::with_subject("carol")));
assert!(second.auth().is_none());
}
#[test]
fn test_new_mcp_contexts_do_not_share_request_auth_across_requests() {
let state = SessionState::new();
let first = McpContext::with_state(Cx::for_testing(), 7, state.clone());
let second = McpContext::with_state(Cx::for_testing(), 8, state);
assert!(first.set_auth(AuthContext::with_subject("dave")));
assert_eq!(
first.auth().and_then(|auth| auth.subject),
Some("dave".to_string())
);
assert!(second.auth().is_none());
}
#[test]
fn test_mcp_context_drop_does_not_leak_request_auth() {
let cx = Cx::for_testing();
{
let ctx = McpContext::new(cx.clone(), 9);
assert!(ctx.set_auth(AuthContext::with_subject("erin")));
}
assert!(
McpContext::new(cx, 9).auth().is_none(),
"fresh contexts must start without inherited request auth"
);
}
#[test]
fn test_mcp_context_tools_enabled_by_default() {
let cx = Cx::for_testing();
let state = SessionState::new();
let ctx = McpContext::with_state(cx, 1, state);
assert!(ctx.is_tool_enabled("any_tool"));
assert!(ctx.is_tool_enabled("another_tool"));
}
#[test]
fn test_mcp_context_disable_enable_tool() {
let cx = Cx::for_testing();
let state = SessionState::new();
let ctx = McpContext::with_state(cx, 1, state);
assert!(ctx.is_tool_enabled("my_tool"));
assert!(ctx.disable_tool("my_tool"));
assert!(!ctx.is_tool_enabled("my_tool"));
assert!(ctx.is_tool_enabled("other_tool"));
assert!(ctx.enable_tool("my_tool"));
assert!(ctx.is_tool_enabled("my_tool"));
}
#[test]
fn test_mcp_context_disable_enable_resource() {
let cx = Cx::for_testing();
let state = SessionState::new();
let ctx = McpContext::with_state(cx, 1, state);
assert!(ctx.is_resource_enabled("file://secret"));
assert!(ctx.disable_resource("file://secret"));
assert!(!ctx.is_resource_enabled("file://secret"));
assert!(ctx.is_resource_enabled("file://public"));
assert!(ctx.enable_resource("file://secret"));
assert!(ctx.is_resource_enabled("file://secret"));
}
#[test]
fn test_mcp_context_disable_enable_prompt() {
let cx = Cx::for_testing();
let state = SessionState::new();
let ctx = McpContext::with_state(cx, 1, state);
assert!(ctx.is_prompt_enabled("admin_prompt"));
assert!(ctx.disable_prompt("admin_prompt"));
assert!(!ctx.is_prompt_enabled("admin_prompt"));
assert!(ctx.is_prompt_enabled("user_prompt"));
assert!(ctx.enable_prompt("admin_prompt"));
assert!(ctx.is_prompt_enabled("admin_prompt"));
}
#[test]
fn test_mcp_context_disable_multiple_tools() {
let cx = Cx::for_testing();
let state = SessionState::new();
let ctx = McpContext::with_state(cx, 1, state);
ctx.disable_tool("tool1");
ctx.disable_tool("tool2");
ctx.disable_tool("tool3");
assert!(!ctx.is_tool_enabled("tool1"));
assert!(!ctx.is_tool_enabled("tool2"));
assert!(!ctx.is_tool_enabled("tool3"));
assert!(ctx.is_tool_enabled("tool4"));
let disabled = ctx.disabled_tools();
assert_eq!(disabled.len(), 3);
assert!(disabled.contains("tool1"));
assert!(disabled.contains("tool2"));
assert!(disabled.contains("tool3"));
}
#[test]
fn test_mcp_context_disabled_sets_empty_by_default() {
let cx = Cx::for_testing();
let state = SessionState::new();
let ctx = McpContext::with_state(cx, 1, state);
assert!(ctx.disabled_tools().is_empty());
assert!(ctx.disabled_resources().is_empty());
assert!(ctx.disabled_prompts().is_empty());
}
#[test]
fn test_mcp_context_enable_disable_no_state() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
assert!(!ctx.disable_tool("tool"));
assert!(!ctx.enable_tool("tool"));
assert!(ctx.is_tool_enabled("tool"));
}
#[test]
fn test_mcp_context_disabled_state_persists_across_contexts() {
let state = SessionState::new();
{
let cx = Cx::for_testing();
let ctx = McpContext::with_state(cx, 1, state.clone());
ctx.disable_tool("shared_tool");
}
{
let cx = Cx::for_testing();
let ctx = McpContext::with_state(cx, 2, state.clone());
assert!(!ctx.is_tool_enabled("shared_tool"));
}
}
#[test]
fn test_mcp_context_no_capabilities_by_default() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
assert!(ctx.client_capabilities().is_none());
assert!(ctx.server_capabilities().is_none());
assert!(!ctx.client_supports_sampling());
assert!(!ctx.client_supports_elicitation());
assert!(!ctx.client_supports_roots());
}
#[test]
fn test_mcp_context_with_client_capabilities() {
let cx = Cx::for_testing();
let caps = ClientCapabilityInfo::new()
.with_sampling()
.with_elicitation(true, false)
.with_roots(true);
let ctx = McpContext::new(cx, 1).with_client_capabilities(caps);
assert!(ctx.client_capabilities().is_some());
assert!(ctx.client_supports_sampling());
assert!(ctx.client_supports_elicitation());
assert!(ctx.client_supports_elicitation_form());
assert!(!ctx.client_supports_elicitation_url());
assert!(ctx.client_supports_roots());
}
#[test]
fn test_mcp_context_with_server_capabilities() {
let cx = Cx::for_testing();
let caps = ServerCapabilityInfo::new()
.with_tools()
.with_resources(true)
.with_prompts()
.with_logging();
let ctx = McpContext::new(cx, 1).with_server_capabilities(caps);
let server_caps = ctx.server_capabilities().unwrap();
assert!(server_caps.tools);
assert!(server_caps.resources);
assert!(server_caps.resources_subscribe);
assert!(server_caps.prompts);
assert!(server_caps.logging);
}
#[test]
fn test_client_capability_info_builders() {
let caps = ClientCapabilityInfo::new();
assert!(!caps.sampling);
assert!(!caps.elicitation);
assert!(!caps.roots);
let caps = caps.with_sampling();
assert!(caps.sampling);
let caps = ClientCapabilityInfo::new().with_elicitation(true, true);
assert!(caps.elicitation);
assert!(caps.elicitation_form);
assert!(caps.elicitation_url);
let caps = ClientCapabilityInfo::new().with_roots(false);
assert!(caps.roots);
assert!(!caps.roots_list_changed);
}
#[test]
fn test_server_capability_info_builders() {
let caps = ServerCapabilityInfo::new();
assert!(!caps.tools);
assert!(!caps.resources);
assert!(!caps.prompts);
assert!(!caps.logging);
let caps = caps
.with_tools()
.with_resources(false)
.with_prompts()
.with_logging();
assert!(caps.tools);
assert!(caps.resources);
assert!(!caps.resources_subscribe);
assert!(caps.prompts);
assert!(caps.logging);
}
#[test]
fn test_resource_content_item_text() {
let item = ResourceContentItem::text("test://uri", "hello");
assert_eq!(item.uri, "test://uri");
assert_eq!(item.mime_type.as_deref(), Some("text/plain"));
assert_eq!(item.as_text(), Some("hello"));
assert!(item.as_blob().is_none());
assert!(item.is_text());
assert!(!item.is_blob());
}
#[test]
fn test_resource_content_item_json() {
let item = ResourceContentItem::json("data://config", r#"{"key":"val"}"#);
assert_eq!(item.uri, "data://config");
assert_eq!(item.mime_type.as_deref(), Some("application/json"));
assert_eq!(item.as_text(), Some(r#"{"key":"val"}"#));
assert!(item.is_text());
assert!(!item.is_blob());
}
#[test]
fn test_resource_content_item_blob() {
let item = ResourceContentItem::blob("binary://data", "application/octet-stream", "AQID");
assert_eq!(item.uri, "binary://data");
assert_eq!(item.mime_type.as_deref(), Some("application/octet-stream"));
assert!(item.as_text().is_none());
assert_eq!(item.as_blob(), Some("AQID"));
assert!(!item.is_text());
assert!(item.is_blob());
}
#[test]
fn test_resource_read_result_text() {
let result = ResourceReadResult::text("test://doc", "content");
assert_eq!(result.first_text(), Some("content"));
assert!(result.first_blob().is_none());
assert_eq!(result.contents.len(), 1);
}
#[test]
fn test_resource_read_result_new_multiple() {
let result = ResourceReadResult::new(vec![
ResourceContentItem::text("a://1", "first"),
ResourceContentItem::blob("b://2", "image/png", "base64data"),
]);
assert_eq!(result.contents.len(), 2);
assert_eq!(result.first_text(), Some("first"));
assert!(result.first_blob().is_none());
}
#[test]
fn test_resource_read_result_empty() {
let result = ResourceReadResult::new(vec![]);
assert!(result.first_text().is_none());
assert!(result.first_blob().is_none());
}
#[test]
fn test_resource_read_result_blob_first() {
let result = ResourceReadResult::new(vec![ResourceContentItem::blob(
"b://1",
"image/png",
"data",
)]);
assert!(result.first_text().is_none());
assert_eq!(result.first_blob(), Some("data"));
}
#[test]
fn test_tool_content_item_text() {
let item = ToolContentItem::text("hello");
assert_eq!(item.as_text(), Some("hello"));
assert!(item.is_text());
}
#[test]
fn test_tool_content_item_image() {
let item = ToolContentItem::Image {
data: "base64img".to_string(),
mime_type: "image/png".to_string(),
};
assert!(item.as_text().is_none());
assert!(!item.is_text());
}
#[test]
fn test_tool_content_item_audio() {
let item = ToolContentItem::Audio {
data: "base64audio".to_string(),
mime_type: "audio/wav".to_string(),
};
assert!(item.as_text().is_none());
assert!(!item.is_text());
}
#[test]
fn test_tool_content_item_resource() {
let item = ToolContentItem::Resource {
uri: "file://test".to_string(),
mime_type: Some("text/plain".to_string()),
text: Some("embedded".to_string()),
blob: None,
};
assert!(item.as_text().is_none());
assert!(!item.is_text());
}
#[test]
fn test_tool_call_result_success() {
let result = ToolCallResult::success(vec![
ToolContentItem::text("item1"),
ToolContentItem::text("item2"),
]);
assert!(!result.is_error);
assert_eq!(result.content.len(), 2);
assert_eq!(result.first_text(), Some("item1"));
}
#[test]
fn test_tool_call_result_text() {
let result = ToolCallResult::text("simple output");
assert!(!result.is_error);
assert_eq!(result.content.len(), 1);
assert_eq!(result.first_text(), Some("simple output"));
}
#[test]
fn test_tool_call_result_error() {
let result = ToolCallResult::error("something failed");
assert!(result.is_error);
assert_eq!(result.first_text(), Some("something failed"));
}
#[test]
fn test_tool_call_result_empty() {
let result = ToolCallResult::success(vec![]);
assert!(!result.is_error);
assert!(result.first_text().is_none());
}
#[test]
fn test_elicitation_response_accept() {
let mut data = std::collections::HashMap::new();
data.insert("name".to_string(), serde_json::json!("Alice"));
data.insert("age".to_string(), serde_json::json!(30));
data.insert("active".to_string(), serde_json::json!(true));
let resp = ElicitationResponse::accept(data);
assert!(resp.is_accepted());
assert!(!resp.is_declined());
assert!(!resp.is_cancelled());
assert_eq!(resp.get_string("name"), Some("Alice"));
assert_eq!(resp.get_int("age"), Some(30));
assert_eq!(resp.get_bool("active"), Some(true));
}
#[test]
fn test_elicitation_response_accept_url() {
let resp = ElicitationResponse::accept_url();
assert!(resp.is_accepted());
assert!(resp.content.is_none());
assert!(resp.get_string("anything").is_none());
}
#[test]
fn test_elicitation_response_decline() {
let resp = ElicitationResponse::decline();
assert!(!resp.is_accepted());
assert!(resp.is_declined());
assert!(!resp.is_cancelled());
assert!(resp.get_string("key").is_none());
}
#[test]
fn test_elicitation_response_cancel() {
let resp = ElicitationResponse::cancel();
assert!(!resp.is_accepted());
assert!(!resp.is_declined());
assert!(resp.is_cancelled());
}
#[test]
fn test_elicitation_response_missing_key() {
let mut data = std::collections::HashMap::new();
data.insert("exists".to_string(), serde_json::json!("value"));
let resp = ElicitationResponse::accept(data);
assert!(resp.get_string("missing").is_none());
assert!(resp.get_bool("missing").is_none());
assert!(resp.get_int("missing").is_none());
}
#[test]
fn test_elicitation_response_type_mismatch() {
let mut data = std::collections::HashMap::new();
data.insert("num".to_string(), serde_json::json!(42));
let resp = ElicitationResponse::accept(data);
assert!(resp.get_string("num").is_none());
assert!(resp.get_bool("num").is_none());
assert_eq!(resp.get_int("num"), Some(42));
}
#[test]
fn test_can_sample_false_by_default() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
assert!(!ctx.can_sample());
}
#[test]
fn test_can_elicit_false_by_default() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
assert!(!ctx.can_elicit());
}
#[test]
fn test_can_read_resources_false_by_default() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
assert!(!ctx.can_read_resources());
}
#[test]
fn test_can_call_tools_false_by_default() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
assert!(!ctx.can_call_tools());
}
#[test]
fn test_resource_read_depth_default() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
assert_eq!(ctx.resource_read_depth(), 0);
}
#[test]
fn test_tool_call_depth_default() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
assert_eq!(ctx.tool_call_depth(), 0);
}
#[test]
fn sampling_request_builder_chain() {
let req = SamplingRequest::prompt("hello", 100)
.with_system_prompt("You are helpful")
.with_temperature(0.7)
.with_stop_sequences(vec!["STOP".into()])
.with_model_hints(vec!["gpt-4".into()]);
assert_eq!(req.messages.len(), 1);
assert_eq!(req.max_tokens, 100);
assert_eq!(req.system_prompt.as_deref(), Some("You are helpful"));
assert_eq!(req.temperature, Some(0.7));
assert_eq!(req.stop_sequences, vec!["STOP"]);
assert_eq!(req.model_hints, vec!["gpt-4"]);
}
#[test]
fn sampling_request_message_roles() {
let user = SamplingRequestMessage::user("hi");
assert_eq!(user.role, SamplingRole::User);
assert_eq!(user.text, "hi");
let asst = SamplingRequestMessage::assistant("hello");
assert_eq!(asst.role, SamplingRole::Assistant);
assert_eq!(asst.text, "hello");
}
#[test]
fn sampling_response_new_default_stop_reason() {
let resp = SamplingResponse::new("output", "model-1");
assert_eq!(resp.text, "output");
assert_eq!(resp.model, "model-1");
assert_eq!(resp.stop_reason, SamplingStopReason::EndTurn);
assert_eq!(SamplingStopReason::default(), SamplingStopReason::EndTurn);
}
#[test]
fn noop_sampling_sender_returns_error() {
let sender = NoOpSamplingSender;
let req = SamplingRequest::prompt("test", 10);
let result = crate::block_on(sender.create_message(req));
assert!(result.is_err());
}
#[test]
fn noop_elicitation_sender_returns_error() {
let sender = NoOpElicitationSender;
let req = ElicitationRequest::form("msg", serde_json::json!({}));
let result = crate::block_on(sender.elicit(req));
assert!(result.is_err());
}
#[test]
fn elicitation_request_form_constructor() {
let req = ElicitationRequest::form("Enter name", serde_json::json!({"type": "string"}));
assert_eq!(req.mode, ElicitationMode::Form);
assert_eq!(req.message, "Enter name");
assert!(req.schema.is_some());
assert!(req.url.is_none());
assert!(req.elicitation_id.is_none());
}
#[test]
fn elicitation_request_url_constructor() {
let req = ElicitationRequest::url("Login", "https://example.com", "id-1");
assert_eq!(req.mode, ElicitationMode::Url);
assert_eq!(req.message, "Login");
assert_eq!(req.url.as_deref(), Some("https://example.com"));
assert_eq!(req.elicitation_id.as_deref(), Some("id-1"));
assert!(req.schema.is_none());
}
#[test]
fn mcp_context_with_sampling_enables_can_sample() {
let cx = Cx::for_testing();
let sender = Arc::new(NoOpSamplingSender);
let ctx = McpContext::new(cx, 1).with_sampling(sender);
assert!(ctx.can_sample());
}
#[test]
fn mcp_context_with_elicitation_enables_can_elicit() {
let cx = Cx::for_testing();
let sender = Arc::new(NoOpElicitationSender);
let ctx = McpContext::new(cx, 1).with_elicitation(sender);
assert!(ctx.can_elicit());
}
#[test]
fn mcp_context_depth_setters() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1)
.with_resource_read_depth(3)
.with_tool_call_depth(5);
assert_eq!(ctx.resource_read_depth(), 3);
assert_eq!(ctx.tool_call_depth(), 5);
}
#[test]
fn mcp_context_debug_includes_request_id() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 99);
let debug = format!("{ctx:?}");
assert!(debug.contains("request_id: 99"));
}
#[test]
fn mcp_context_cx_and_trace() {
let cx = Cx::for_testing();
let ctx = McpContext::new(cx, 1);
let _ = ctx.cx();
ctx.trace("test event");
}
}