use crate::SDK_NAME;
use crate::types::notification::Notification;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
#[cfg(feature = "server")]
use crate::{
app::handler::{FromHandlerParams, HandlerParams},
app::options::McpOptions,
error::Error,
};
#[cfg(feature = "server")]
pub use request::FromRequest;
#[cfg(feature = "http-server")]
use {crate::auth::DefaultClaims, volga::headers::HeaderMap};
pub use capabilities::{
ClientCapabilities, CompletionsCapability, ElicitationCapability, ElicitationFormCapability,
ElicitationUrlCapability, LoggingCapability, PromptsCapability, ResourcesCapability,
RootsCapability, SamplingCapability, SamplingContextCapability, SamplingToolsCapability,
ServerCapabilities, ToolsCapability,
};
pub use completion::{Argument, CompleteRequestParams, CompleteResult, Completion};
pub use content::{
AudioContent, Content, EmbeddedResource, ImageContent, ResourceLink, TextContent, ToolResult,
ToolUse,
};
pub use cursor::{Cursor, Page, Pagination};
pub use helpers::{Json, Meta, PropertyType};
pub use reference::Reference;
pub use request::{Request, RequestId, RequestParamsMeta};
pub use response::{ErrorDetails, IntoResponse, Response};
#[cfg(feature = "tasks")]
pub use capabilities::{
ClientTaskRequestsCapability, ClientTasksCapability, ElicitationCreateTaskCapability,
ElicitationTaskCapability, SamplingCreateMessageTaskCapability, SamplingTaskCapability,
ServerTaskRequestsCapability, ServerTasksCapability, TaskCancellationCapability,
TaskListCapability, ToolsCallTaskCapability, ToolsTaskCapability,
};
pub use tool::{
CallToolRequestParams, CallToolResponse, ListToolsRequestParams, ListToolsResult, Tool,
ToolAnnotations, ToolSchema,
};
#[cfg(feature = "server")]
pub use tool::ToolHandler;
pub use elicitation::{
ElicitRequestFormParams, ElicitRequestParams, ElicitRequestUrlParams, ElicitResult,
ElicitationAction, ElicitationCompleteParams, ElicitationMode, UrlElicitationRequiredError,
};
pub use prompt::{
GetPromptRequestParams, GetPromptResult, ListPromptsRequestParams, ListPromptsResult, Prompt,
PromptArgument, PromptMessage,
};
pub use resource::{
BlobResourceContents, ListResourceTemplatesRequestParams, ListResourceTemplatesResult,
ListResourcesRequestParams, ListResourcesResult, ReadResourceRequestParams, ReadResourceResult,
Resource, ResourceContents, ResourceTemplate, SubscribeRequestParams, TextResourceContents,
UnsubscribeRequestParams, Uri,
};
pub use sampling::{
CreateMessageRequestParams, CreateMessageResult, SamplingMessage, StopReason, ToolChoice,
ToolChoiceMode,
};
pub use schema::{
BooleanSchema, NumberSchema, Schema, StringFormat, StringSchema, TitledMultiSelectEnumSchema,
TitledSingleSelectEnumSchema, UntitledMultiSelectEnumSchema, UntitledSingleSelectEnumSchema,
};
pub use icon::{Icon, IconSize, IconTheme};
#[cfg(feature = "tasks")]
pub use task::{
CancelTaskRequestParams, CreateTaskResult, GetTaskPayloadRequestParams, GetTaskRequestParams,
ListTasksRequestParams, ListTasksResult, RelatedTaskMetadata, Task, TaskMetadata, TaskPayload,
TaskStatus,
};
#[cfg(feature = "server")]
pub use prompt::PromptHandler;
pub use progress::ProgressToken;
pub use root::Root;
mod capabilities;
pub mod completion;
mod content;
pub mod cursor;
pub mod elicitation;
pub(crate) mod helpers;
mod icon;
pub mod notification;
mod progress;
pub mod prompt;
mod reference;
mod request;
pub mod resource;
mod response;
pub mod root;
pub mod sampling;
mod schema;
#[cfg(feature = "tasks")]
pub mod task;
pub mod tool;
pub(super) const JSONRPC_VERSION: &str = "2.0";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Message {
Request(Request),
Response(Response),
Notification(Notification),
Batch(MessageBatch),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageEnvelope {
Request(Request),
Response(Response),
Notification(Notification),
}
#[derive(Debug, Clone)]
pub struct MessageBatch {
pub(crate) id: RequestId,
pub(crate) session_id: Option<uuid::Uuid>,
#[cfg(feature = "http-server")]
pub(crate) headers: HeaderMap,
#[cfg(feature = "http-server")]
pub(crate) claims: Option<Box<DefaultClaims>>,
items: Vec<MessageEnvelope>,
}
impl MessageBatch {
pub fn new(items: Vec<MessageEnvelope>) -> Result<Self, crate::error::Error> {
if items.is_empty() {
return Err(crate::error::Error::new(
crate::error::ErrorCode::InvalidRequest,
"batch must not be empty",
));
}
Ok(Self {
id: RequestId::Uuid(uuid::Uuid::new_v4()),
session_id: None,
#[cfg(feature = "http-server")]
headers: HeaderMap::with_capacity(8),
#[cfg(feature = "http-server")]
claims: None,
items,
})
}
pub(crate) fn full_id(&self) -> RequestId {
let id = self.id.clone();
if let Some(session_id) = self.session_id {
id.concat(RequestId::Uuid(session_id))
} else {
id
}
}
#[inline]
pub fn len(&self) -> usize {
self.items.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
#[inline]
pub fn iter(&self) -> impl Iterator<Item = &MessageEnvelope> {
self.items.iter()
}
#[inline]
#[cfg(any(feature = "http-server", feature = "http-client"))]
pub(crate) fn has_requests(&self) -> bool {
self.items
.iter()
.any(|e| matches!(e, MessageEnvelope::Request(_)))
}
#[inline]
#[cfg(feature = "http-server")]
pub(crate) fn has_error_responses(&self) -> bool {
self.items
.iter()
.any(|e| matches!(e, MessageEnvelope::Response(Response::Err(_))))
}
}
impl IntoIterator for MessageBatch {
type Item = MessageEnvelope;
type IntoIter = std::vec::IntoIter<MessageEnvelope>;
fn into_iter(self) -> Self::IntoIter {
self.items.into_iter()
}
}
impl Serialize for MessageBatch {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.items.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for MessageBatch {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let raw = Vec::<serde_json::Value>::deserialize(deserializer)?;
if raw.is_empty() {
return Err(serde::de::Error::custom(
"JSON-RPC batch array must not be empty",
));
}
let items: Vec<MessageEnvelope> = raw
.into_iter()
.map(|value| {
let id = value
.get("id")
.and_then(|v| serde_json::from_value::<RequestId>(v.clone()).ok())
.unwrap_or(RequestId::Null);
match serde_json::from_value::<MessageEnvelope>(value) {
Ok(envelope) => envelope,
Err(_) => MessageEnvelope::Response(Response::error(
id,
crate::error::Error::new(
crate::error::ErrorCode::InvalidRequest,
"Invalid Request",
),
)),
}
})
.collect();
Ok(Self {
id: RequestId::Uuid(uuid::Uuid::new_v4()),
session_id: None,
#[cfg(feature = "http-server")]
headers: HeaderMap::with_capacity(8),
#[cfg(feature = "http-server")]
claims: None,
items,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitializeRequestParams {
#[serde(rename = "protocolVersion")]
pub protocol_ver: String,
pub capabilities: Option<ClientCapabilities>,
#[serde(rename = "clientInfo")]
pub client_info: Option<Implementation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitializeResult {
#[serde(rename = "protocolVersion")]
pub protocol_ver: String,
pub capabilities: ServerCapabilities,
#[serde(rename = "serverInfo")]
pub server_info: Implementation,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Implementation {
pub name: String,
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub icons: Option<Vec<Icon>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct Annotations {
audience: Vec<Role>,
#[serde(rename = "lastModified", skip_serializing_if = "Option::is_none")]
last_modified: Option<DateTime<Utc>>,
priority: f32,
}
impl Display for Role {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
}
}
}
impl From<&str> for Role {
#[inline]
fn from(role: &str) -> Self {
match role {
"user" => Self::User,
"assistant" => Self::Assistant,
_ => Self::User,
}
}
}
impl From<String> for Role {
#[inline]
fn from(role: String) -> Self {
match role.as_str() {
"user" => Self::User,
"assistant" => Self::Assistant,
_ => Self::User,
}
}
}
impl Default for Implementation {
fn default() -> Self {
Self {
name: SDK_NAME.into(),
version: env!("CARGO_PKG_VERSION").into(),
icons: None,
}
}
}
impl IntoResponse for InitializeResult {
#[inline]
fn into_response(self, req_id: RequestId) -> Response {
match serde_json::to_value(self) {
Ok(v) => Response::success(req_id, v),
Err(err) => Response::error(req_id, err.into()),
}
}
}
#[cfg(feature = "server")]
impl FromHandlerParams for InitializeRequestParams {
#[inline]
fn from_params(params: &HandlerParams) -> Result<Self, Error> {
let req = Request::from_params(params)?;
Self::from_request(req)
}
}
impl From<MessageEnvelope> for Message {
fn from(envelope: MessageEnvelope) -> Self {
match envelope {
MessageEnvelope::Request(r) => Message::Request(r),
MessageEnvelope::Response(r) => Message::Response(r),
MessageEnvelope::Notification(n) => Message::Notification(n),
}
}
}
impl Message {
#[inline]
pub fn is_request(&self) -> bool {
matches!(self, Message::Request(_))
}
#[inline]
pub fn is_response(&self) -> bool {
matches!(self, Message::Response(_))
}
#[inline]
pub fn is_notification(&self) -> bool {
matches!(self, Message::Notification(_))
}
#[inline]
pub fn is_batch(&self) -> bool {
matches!(self, Message::Batch(_))
}
#[inline]
pub fn id(&self) -> RequestId {
match self {
Message::Request(req) => req.id(),
Message::Response(resp) => resp.id().clone(),
Message::Notification(_) | Message::Batch(_) => RequestId::default(),
}
}
pub fn full_id(&self) -> RequestId {
match self {
Message::Request(req) => req.full_id(),
Message::Response(resp) => resp.full_id(),
Message::Notification(notification) => notification.full_id(),
Message::Batch(batch) => batch.full_id(),
}
}
#[inline]
pub fn session_id(&self) -> Option<&uuid::Uuid> {
match self {
Message::Request(req) => req.session_id.as_ref(),
Message::Response(resp) => resp.session_id(),
Message::Notification(notification) => notification.session_id.as_ref(),
Message::Batch(batch) => batch.session_id.as_ref(),
}
}
pub fn set_session_id(mut self, id: uuid::Uuid) -> Self {
match self {
Message::Request(ref mut req) => req.session_id = Some(id),
Message::Notification(ref mut notification) => notification.session_id = Some(id),
Message::Response(resp) => self = Message::Response(resp.set_session_id(id)),
Message::Batch(ref mut batch) => batch.session_id = Some(id),
}
self
}
#[cfg(feature = "http-server")]
pub fn set_headers(mut self, headers: HeaderMap) -> Self {
match self {
Message::Request(ref mut req) => req.headers = headers,
Message::Response(resp) => self = Message::Response(resp.set_headers(headers)),
Message::Batch(ref mut batch) => batch.headers = headers,
_ => (),
}
self
}
#[cfg(feature = "http-server")]
pub(crate) fn set_claims(mut self, claims: DefaultClaims) -> Self {
match self {
Message::Request(ref mut req) => req.claims = Some(Box::new(claims)),
Message::Batch(ref mut batch) => batch.claims = Some(Box::new(claims)),
_ => (),
}
self
}
}
impl Annotations {
#[inline]
pub fn from_json_str(json: &str) -> Self {
serde_json::from_str(json).expect("Annotations: Incorrect JSON string provided")
}
pub fn with_audience<T: Into<Role>>(mut self, role: T) -> Self {
self.audience.push(role.into());
self
}
pub fn with_priority(mut self, priority: f32) -> Self {
self.priority = priority;
self
}
pub fn with_last_modified(mut self, last_modified: DateTime<Utc>) -> Self {
self.last_modified = Some(last_modified);
self
}
}
impl Implementation {
#[inline]
pub fn with_icons(mut self, icons: impl IntoIterator<Item = Icon>) -> Self {
self.icons = Some(icons.into_iter().collect());
self
}
}
#[cfg(feature = "server")]
impl InitializeResult {
pub(crate) fn new(options: &McpOptions) -> Self {
Self {
protocol_ver: options.protocol_ver().into(),
capabilities: ServerCapabilities {
tools: options.tools_capability(),
resources: options.resources_capability(),
prompts: options.prompts_capability(),
logging: Some(LoggingCapability::default()),
completions: Some(CompletionsCapability::default()),
#[cfg(feature = "tasks")]
tasks: options.tasks_capability(),
experimental: None,
},
server_info: options.implementation.clone(),
instructions: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn message_envelope_deserializes_request() {
let json = r#"{"jsonrpc":"2.0","id":1,"method":"ping","params":null}"#;
let envelope: MessageEnvelope = serde_json::from_str(json).unwrap();
assert!(matches!(envelope, MessageEnvelope::Request(_)));
}
#[test]
fn message_batch_rejects_empty_vec() {
let err = MessageBatch::new(vec![]);
assert!(err.is_err());
}
#[test]
fn message_batch_rejects_empty_json_array() {
let err: Result<MessageBatch, _> = serde_json::from_str("[]");
assert!(err.is_err());
}
#[test]
fn message_batch_accepts_non_empty() {
let json = r#"[{"jsonrpc":"2.0","id":1,"method":"ping","params":null}]"#;
let batch: MessageBatch = serde_json::from_str(json).unwrap();
assert_eq!(batch.len(), 1);
}
#[test]
fn message_deserializes_batch() {
let json = r#"[{"jsonrpc":"2.0","id":1,"method":"ping","params":null}]"#;
let msg: Message = serde_json::from_str(json).unwrap();
assert!(matches!(msg, Message::Batch(_)));
}
#[test]
fn message_batch_emits_null_error_for_malformed_item_without_id() {
let json = r#"[
{"jsonrpc":"2.0","id":1,"method":"ping","params":null},
{"not":"valid json-rpc"}
]"#;
let batch: MessageBatch = serde_json::from_str(json).unwrap();
assert_eq!(batch.len(), 2);
let mut iter = batch.iter();
assert!(matches!(iter.next(), Some(MessageEnvelope::Request(_))));
let second = iter.next().expect("second item should be present");
let MessageEnvelope::Response(resp) = second else {
panic!("expected error response for malformed item, got {second:?}");
};
assert!(matches!(resp, Response::Err(_)));
let serialized = serde_json::to_string(resp).unwrap();
assert!(
serialized.contains(r#""id":null"#),
"expected null id, got: {serialized}"
);
}
#[test]
fn message_batch_produces_error_response_for_malformed_item_with_id() {
let json = r#"[
{"jsonrpc":"2.0","id":1,"method":"ping","params":null},
{"jsonrpc":"2.0","id":2,"params":"not-an-object-and-no-method"}
]"#;
let batch: MessageBatch = serde_json::from_str(json).unwrap();
assert_eq!(batch.len(), 2);
let mut iter = batch.iter();
assert!(matches!(iter.next(), Some(MessageEnvelope::Request(_))));
let second = iter.next().expect("second item should be present");
let MessageEnvelope::Response(resp) = second else {
panic!("expected error response for malformed item, got {second:?}");
};
assert!(matches!(resp, Response::Err(_)));
let serialized = serde_json::to_string(resp).unwrap();
assert!(
serialized.contains(r#""id":2"#),
"expected id 2, got: {serialized}"
);
}
#[test]
fn message_batch_all_malformed_without_ids_produces_null_error_responses() {
let json = r#"[{"not":"valid"},{"also":"not valid"}]"#;
let batch: MessageBatch = serde_json::from_str(json).unwrap();
assert_eq!(batch.len(), 2);
for envelope in batch.iter() {
let MessageEnvelope::Response(resp) = envelope else {
panic!("expected error response, got {envelope:?}");
};
assert!(matches!(resp, Response::Err(_)));
let serialized = serde_json::to_string(resp).unwrap();
assert!(
serialized.contains(r#""id":null"#),
"expected null id, got: {serialized}"
);
}
}
}