use serde::{Deserialize, Serialize};
use serde_json::Value;
#[cfg(not(feature = "std"))]
use alloc::{collections::BTreeMap as HashMap, string::String, vec::Vec};
#[cfg(feature = "std")]
use std::collections::HashMap;
use crate::content::{Role, SamplingContent, SamplingContentBlock};
use crate::definitions::Tool;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct TaskMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub ttl: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Task {
#[serde(rename = "taskId")]
pub task_id: String,
pub status: TaskStatus,
#[serde(rename = "statusMessage", skip_serializing_if = "Option::is_none")]
pub status_message: Option<String>,
#[serde(rename = "createdAt")]
pub created_at: String,
#[serde(rename = "lastUpdatedAt")]
pub last_updated_at: String,
pub ttl: Option<u64>,
#[serde(rename = "pollInterval", skip_serializing_if = "Option::is_none")]
pub poll_interval: Option<u64>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum TaskStatus {
Cancelled,
Completed,
Failed,
InputRequired,
Working,
}
impl core::fmt::Display for TaskStatus {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Cancelled => f.write_str("cancelled"),
Self::Completed => f.write_str("completed"),
Self::Failed => f.write_str("failed"),
Self::InputRequired => f.write_str("input_required"),
Self::Working => f.write_str("working"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CreateTaskResult {
pub task: Task,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ListTasksResult {
pub tasks: Vec<Task>,
#[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RelatedTaskMetadata {
#[serde(rename = "taskId")]
pub task_id: String,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ElicitRequestParams {
Form(ElicitRequestFormParams),
Url(ElicitRequestURLParams),
}
impl ElicitRequestParams {
#[must_use]
pub fn form(message: impl Into<String>, requested_schema: Value) -> Self {
Self::Form(ElicitRequestFormParams {
message: message.into(),
requested_schema,
task: None,
meta: None,
})
}
#[must_use]
pub fn url(
message: impl Into<String>,
url: impl Into<String>,
elicitation_id: impl Into<String>,
) -> Self {
Self::Url(ElicitRequestURLParams {
message: message.into(),
url: url.into(),
elicitation_id: elicitation_id.into(),
task: None,
meta: None,
})
}
#[must_use]
pub fn message(&self) -> &str {
match self {
Self::Form(p) => &p.message,
Self::Url(p) => &p.message,
}
}
#[must_use]
pub fn task(&self) -> Option<&TaskMetadata> {
match self {
Self::Form(p) => p.task.as_ref(),
Self::Url(p) => p.task.as_ref(),
}
}
#[must_use]
pub fn meta(&self) -> Option<&HashMap<String, Value>> {
match self {
Self::Form(p) => p.meta.as_ref(),
Self::Url(p) => p.meta.as_ref(),
}
}
}
impl Serialize for ElicitRequestParams {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match self {
Self::Form(params) => {
let mut value = serde_json::to_value(params).map_err(serde::ser::Error::custom)?;
if let Some(obj) = value.as_object_mut() {
obj.insert("mode".into(), Value::String("form".into()));
}
value.serialize(serializer)
}
Self::Url(params) => {
let mut value = serde_json::to_value(params).map_err(serde::ser::Error::custom)?;
if let Some(obj) = value.as_object_mut() {
obj.insert("mode".into(), Value::String("url".into()));
}
value.serialize(serializer)
}
}
}
}
impl<'de> Deserialize<'de> for ElicitRequestParams {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let value = Value::deserialize(deserializer)?;
let mode = value.get("mode").and_then(|v| v.as_str()).unwrap_or("form");
match mode {
"url" => {
let params: ElicitRequestURLParams =
serde_json::from_value(value).map_err(serde::de::Error::custom)?;
Ok(Self::Url(params))
}
_ => {
let params: ElicitRequestFormParams =
serde_json::from_value(value).map_err(serde::de::Error::custom)?;
Ok(Self::Form(params))
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ElicitRequestFormParams {
pub message: String,
#[serde(rename = "requestedSchema")]
pub requested_schema: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<TaskMetadata>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ElicitRequestURLParams {
pub message: String,
pub url: String,
#[serde(rename = "elicitationId")]
pub elicitation_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<TaskMetadata>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ElicitResult {
pub action: ElicitAction,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Value>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, Value>>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum ElicitAction {
Accept,
Decline,
Cancel,
}
impl core::fmt::Display for ElicitAction {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Accept => f.write_str("accept"),
Self::Decline => f.write_str("decline"),
Self::Cancel => f.write_str("cancel"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ElicitationCompleteNotification {
#[serde(rename = "elicitationId")]
pub elicitation_id: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct CreateMessageRequest {
#[serde(default)]
pub messages: Vec<SamplingMessage>,
#[serde(rename = "maxTokens")]
pub max_tokens: u32,
#[serde(rename = "modelPreferences", skip_serializing_if = "Option::is_none")]
pub model_preferences: Option<ModelPreferences>,
#[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
pub include_context: Option<IncludeContext>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub task: Option<TaskMetadata>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(rename = "toolChoice", skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<Value>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, Value>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct SamplingMessage {
pub role: Role,
pub content: SamplingContentBlock,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, Value>>,
}
impl SamplingMessage {
#[must_use]
pub fn user(text: impl Into<String>) -> Self {
Self {
role: Role::User,
content: SamplingContent::text(text).into(),
meta: None,
}
}
#[must_use]
pub fn assistant(text: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: SamplingContent::text(text).into(),
meta: None,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ModelPreferences {
#[serde(skip_serializing_if = "Option::is_none")]
pub hints: Option<Vec<ModelHint>>,
#[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
pub cost_priority: Option<f64>,
#[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
pub speed_priority: Option<f64>,
#[serde(
rename = "intelligencePriority",
skip_serializing_if = "Option::is_none"
)]
pub intelligence_priority: Option<f64>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ModelHint {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl core::fmt::Display for IncludeContext {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::AllServers => f.write_str("allServers"),
Self::ThisServer => f.write_str("thisServer"),
Self::None => f.write_str("none"),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum IncludeContext {
#[serde(rename = "allServers")]
AllServers,
#[serde(rename = "thisServer")]
ThisServer,
#[serde(rename = "none")]
None,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ToolChoice {
#[serde(skip_serializing_if = "Option::is_none")]
pub mode: Option<ToolChoiceMode>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum ToolChoiceMode {
Auto,
None,
Required,
}
impl core::fmt::Display for ToolChoiceMode {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Auto => f.write_str("auto"),
Self::None => f.write_str("none"),
Self::Required => f.write_str("required"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CreateMessageResult {
pub role: Role,
pub content: SamplingContentBlock,
pub model: String,
#[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<String>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, Value>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ClientCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub roots: Option<RootsCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling: Option<SamplingCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub elicitation: Option<ElicitationCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tasks: Option<ClientTasksCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub extensions: Option<HashMap<String, Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub experimental: Option<HashMap<String, Value>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ElicitationCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub form: Option<ElicitationFormCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub url: Option<ElicitationUrlCapabilities>,
#[serde(rename = "schemaValidation", skip_serializing_if = "Option::is_none")]
pub schema_validation: Option<bool>,
}
impl ElicitationCapabilities {
#[must_use]
pub fn full() -> Self {
Self {
form: Some(ElicitationFormCapabilities {}),
url: Some(ElicitationUrlCapabilities {}),
schema_validation: None,
}
}
#[must_use]
pub fn form_only() -> Self {
Self {
form: Some(ElicitationFormCapabilities {}),
url: None,
schema_validation: None,
}
}
#[must_use]
pub fn supports_form(&self) -> bool {
self.form.is_some() || (self.form.is_none() && self.url.is_none())
}
#[must_use]
pub fn supports_url(&self) -> bool {
self.url.is_some()
}
#[must_use]
pub fn with_schema_validation(mut self) -> Self {
self.schema_validation = Some(true);
self
}
#[must_use]
pub fn without_schema_validation(mut self) -> Self {
self.schema_validation = Some(false);
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ElicitationFormCapabilities {}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ElicitationUrlCapabilities {}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct SamplingCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub context: Option<HashMap<String, Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<HashMap<String, Value>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct RootsCapabilities {
#[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ClientTasksCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub list: Option<TasksListCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cancel: Option<TasksCancelCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub requests: Option<ClientTasksRequestsCapabilities>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ClientTasksRequestsCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling: Option<TasksSamplingCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub elicitation: Option<TasksElicitationCapabilities>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct TasksSamplingCapabilities {
#[serde(rename = "createMessage", skip_serializing_if = "Option::is_none")]
pub create_message: Option<TasksSamplingCreateMessageCapabilities>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct TasksSamplingCreateMessageCapabilities {}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct TasksElicitationCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub create: Option<TasksElicitationCreateCapabilities>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct TasksElicitationCreateCapabilities {}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ServerCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<ToolsCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub resources: Option<ResourcesCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompts: Option<PromptsCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logging: Option<LoggingCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub completions: Option<CompletionCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tasks: Option<ServerTasksCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub extensions: Option<HashMap<String, Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub experimental: Option<HashMap<String, Value>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ToolsCapabilities {
#[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ResourcesCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub subscribe: Option<bool>,
#[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct PromptsCapabilities {
#[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
pub list_changed: Option<bool>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct LoggingCapabilities {}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct CompletionCapabilities {}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ServerTasksCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub list: Option<TasksListCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cancel: Option<TasksCancelCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub requests: Option<ServerTasksRequestsCapabilities>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ServerTasksRequestsCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<TasksToolsCapabilities>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct TasksToolsCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub call: Option<TasksToolsCallCapabilities>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct TasksToolsCallCapabilities {}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct TasksListCapabilities {}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct TasksCancelCapabilities {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_include_context_serde() {
let json = serde_json::to_string(&IncludeContext::ThisServer).unwrap();
assert_eq!(json, "\"thisServer\"");
let json = serde_json::to_string(&IncludeContext::AllServers).unwrap();
assert_eq!(json, "\"allServers\"");
let json = serde_json::to_string(&IncludeContext::None).unwrap();
assert_eq!(json, "\"none\"");
let parsed: IncludeContext = serde_json::from_str("\"thisServer\"").unwrap();
assert_eq!(parsed, IncludeContext::ThisServer);
}
#[test]
fn test_tool_choice_mode_optional() {
let tc = ToolChoice { mode: None };
let json = serde_json::to_string(&tc).unwrap();
assert_eq!(json, "{}");
let tc = ToolChoice {
mode: Some(ToolChoiceMode::Required),
};
let json = serde_json::to_string(&tc).unwrap();
assert!(json.contains("\"required\""));
}
#[test]
fn test_model_hint_name_optional() {
let hint = ModelHint { name: None };
let json = serde_json::to_string(&hint).unwrap();
assert_eq!(json, "{}");
let hint = ModelHint {
name: Some("claude".into()),
};
let json = serde_json::to_string(&hint).unwrap();
assert!(json.contains("\"claude\""));
}
#[test]
fn test_task_status_serde() {
let json = serde_json::to_string(&TaskStatus::InputRequired).unwrap();
assert_eq!(json, "\"input_required\"");
let json = serde_json::to_string(&TaskStatus::Working).unwrap();
assert_eq!(json, "\"working\"");
}
#[test]
fn test_create_message_request_default() {
let req = CreateMessageRequest {
messages: vec![SamplingMessage::user("hello")],
max_tokens: 100,
..Default::default()
};
assert_eq!(req.messages.len(), 1);
assert_eq!(req.max_tokens, 100);
assert!(req.tools.is_none());
}
#[test]
fn test_sampling_message_content_single_or_array() {
let msg = SamplingMessage::user("hello");
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"text\":\"hello\""));
let parsed: SamplingMessage = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.content.as_text(), Some("hello"));
let json_array = r#"{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]}"#;
let parsed: SamplingMessage = serde_json::from_str(json_array).unwrap();
match &parsed.content {
SamplingContentBlock::Multiple(v) => assert_eq!(v.len(), 2),
_ => panic!("Expected multiple content blocks"),
}
}
#[test]
fn test_server_capabilities_structure() {
let caps = ServerCapabilities {
tasks: Some(ServerTasksCapabilities {
list: Some(TasksListCapabilities {}),
cancel: Some(TasksCancelCapabilities {}),
requests: Some(ServerTasksRequestsCapabilities {
tools: Some(TasksToolsCapabilities {
call: Some(TasksToolsCallCapabilities {}),
}),
}),
}),
extensions: Some(HashMap::from([(
"trace".to_string(),
serde_json::json!({"version": "1"}),
)])),
..Default::default()
};
let json = serde_json::to_string(&caps).unwrap();
let v: Value = serde_json::from_str(&json).unwrap();
assert!(v["tasks"]["requests"]["tools"]["call"].is_object());
assert!(v["extensions"]["trace"].is_object());
}
#[test]
fn test_elicit_action_serde() {
let cases = [
(ElicitAction::Accept, "\"accept\""),
(ElicitAction::Decline, "\"decline\""),
(ElicitAction::Cancel, "\"cancel\""),
];
for (action, expected) in cases {
let json = serde_json::to_string(&action).unwrap();
assert_eq!(json, expected);
let parsed: ElicitAction = serde_json::from_str(expected).unwrap();
assert_eq!(parsed, action);
}
}
#[test]
fn test_elicit_result_round_trip() {
let result = ElicitResult {
action: ElicitAction::Accept,
content: Some(serde_json::json!({"name": "test"})),
meta: None,
};
let json = serde_json::to_string(&result).unwrap();
let parsed: ElicitResult = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.action, ElicitAction::Accept);
assert!(parsed.content.is_some());
let decline = ElicitResult {
action: ElicitAction::Decline,
content: None,
meta: None,
};
let json = serde_json::to_string(&decline).unwrap();
assert!(!json.contains("\"content\""));
let parsed: ElicitResult = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.action, ElicitAction::Decline);
assert!(parsed.content.is_none());
}
#[test]
fn test_server_capabilities_no_elicitation_or_sampling() {
let caps = ServerCapabilities::default();
let json = serde_json::to_string(&caps).unwrap();
assert!(!json.contains("elicitation"));
assert!(!json.contains("sampling"));
let caps = ServerCapabilities {
tools: Some(ToolsCapabilities {
list_changed: Some(true),
}),
resources: Some(ResourcesCapabilities {
subscribe: Some(true),
list_changed: Some(true),
}),
prompts: Some(PromptsCapabilities {
list_changed: Some(true),
}),
logging: Some(LoggingCapabilities {}),
completions: Some(CompletionCapabilities {}),
tasks: Some(ServerTasksCapabilities::default()),
extensions: Some(HashMap::from([(
"trace".to_string(),
serde_json::json!({"version": "1"}),
)])),
experimental: Some(HashMap::new()),
};
let json = serde_json::to_string(&caps).unwrap();
assert!(!json.contains("elicitation"));
assert!(!json.contains("sampling"));
assert!(json.contains("extensions"));
}
#[test]
fn test_sampling_message_array_content_round_trip() {
let json_array =
r#"{"role":"user","content":[{"type":"text","text":"a"},{"type":"text","text":"b"}]}"#;
let parsed: SamplingMessage = serde_json::from_str(json_array).unwrap();
let re_serialized = serde_json::to_string(&parsed).unwrap();
let re_parsed: Value = serde_json::from_str(&re_serialized).unwrap();
assert!(re_parsed["content"].is_array());
assert_eq!(re_parsed["content"].as_array().unwrap().len(), 2);
}
#[test]
fn test_tool_choice_mode_all_variants() {
let cases = [
(ToolChoiceMode::Auto, "\"auto\""),
(ToolChoiceMode::None, "\"none\""),
(ToolChoiceMode::Required, "\"required\""),
];
for (mode, expected) in cases {
let json = serde_json::to_string(&mode).unwrap();
assert_eq!(json, expected);
let parsed: ToolChoiceMode = serde_json::from_str(expected).unwrap();
assert_eq!(parsed, mode);
}
}
#[test]
fn test_elicit_request_params_form_without_mode() {
let json = r#"{"message":"Enter name","requestedSchema":{"type":"object"}}"#;
let parsed: ElicitRequestParams = serde_json::from_str(json).unwrap();
match &parsed {
ElicitRequestParams::Form(params) => {
assert_eq!(params.message, "Enter name");
}
ElicitRequestParams::Url(_) => panic!("expected Form variant"),
}
}
#[test]
fn test_elicit_request_params_form_with_explicit_mode() {
let json = r#"{"mode":"form","message":"Enter name","requestedSchema":{"type":"object"}}"#;
let parsed: ElicitRequestParams = serde_json::from_str(json).unwrap();
match &parsed {
ElicitRequestParams::Form(params) => {
assert_eq!(params.message, "Enter name");
}
ElicitRequestParams::Url(_) => panic!("expected Form variant"),
}
}
#[test]
fn test_elicit_request_params_url_mode() {
let json = r#"{"mode":"url","message":"Authenticate","url":"https://example.com/auth","elicitationId":"e-123"}"#;
let parsed: ElicitRequestParams = serde_json::from_str(json).unwrap();
match &parsed {
ElicitRequestParams::Url(params) => {
assert_eq!(params.message, "Authenticate");
assert_eq!(params.url, "https://example.com/auth");
assert_eq!(params.elicitation_id, "e-123");
}
ElicitRequestParams::Form(_) => panic!("expected Url variant"),
}
}
#[test]
fn test_elicit_request_params_form_round_trip() {
let params = ElicitRequestParams::Form(ElicitRequestFormParams {
message: "Enter details".into(),
requested_schema: serde_json::json!({"type": "object", "properties": {"name": {"type": "string"}}}),
task: None,
meta: None,
});
let json = serde_json::to_string(¶ms).unwrap();
let v: Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["mode"], "form");
let parsed: ElicitRequestParams = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, params);
}
#[test]
fn test_elicit_request_params_url_round_trip() {
let params = ElicitRequestParams::Url(ElicitRequestURLParams {
message: "Please authenticate".into(),
url: "https://example.com/oauth".into(),
elicitation_id: "elicit-456".into(),
task: None,
meta: None,
});
let json = serde_json::to_string(¶ms).unwrap();
let v: Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["mode"], "url");
let parsed: ElicitRequestParams = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, params);
}
#[test]
fn test_task_status_all_variants() {
let cases = [
(TaskStatus::Cancelled, "\"cancelled\""),
(TaskStatus::Completed, "\"completed\""),
(TaskStatus::Failed, "\"failed\""),
(TaskStatus::InputRequired, "\"input_required\""),
(TaskStatus::Working, "\"working\""),
];
for (status, expected) in cases {
let json = serde_json::to_string(&status).unwrap();
assert_eq!(json, expected);
let parsed: TaskStatus = serde_json::from_str(expected).unwrap();
assert_eq!(parsed, status);
}
}
}