use std::collections::HashMap;
use async_trait::async_trait;
use serde::Serialize;
use tracing::info;
use crate::{
Error, Result,
context::{ClientCtx, ServerCtx},
error::ToolError,
schema::{
self, CallToolResult, ClientRequest, Cursor, ElicitRequestParams, ElicitResult,
GetPromptResult, InitializeResult, ListPromptsResult, ListResourceTemplatesResult,
ListResourcesResult, ListRootsResult, ListTasksResult, ListToolsResult, LoggingLevel,
ReadResourceResult, ServerRequest,
},
};
#[async_trait]
pub trait ClientHandler: Send + Sync {
async fn on_connect(&self, _context: &ClientCtx) -> Result<()> {
Ok(())
}
async fn on_shutdown(&self, _context: &ClientCtx) -> Result<()> {
Ok(())
}
async fn pong(&self, _context: &ClientCtx) -> Result<()> {
Ok(())
}
async fn create_message(
&self,
_context: &ClientCtx,
_method: &str,
_params: schema::CreateMessageParams,
) -> Result<schema::CreateMessageResult> {
Err(Error::InvalidRequest(
"create_message not implemented".into(),
))
}
async fn list_roots(&self, _context: &ClientCtx) -> Result<schema::ListRootsResult> {
Err(Error::InvalidRequest("list_roots not implemented".into()))
}
async fn elicit(
&self,
_context: &ClientCtx,
_params: ElicitRequestParams,
) -> Result<ElicitResult> {
Err(Error::InvalidRequest("elicit not implemented".into()))
}
async fn get_task(
&self,
_context: &ClientCtx,
_task_id: String,
) -> Result<schema::GetTaskResult> {
Err(Error::InvalidRequest("get_task not implemented".into()))
}
async fn get_task_payload(
&self,
_context: &ClientCtx,
_task_id: String,
) -> Result<schema::GetTaskPayloadResult> {
Err(Error::InvalidRequest(
"get_task_payload not implemented".into(),
))
}
async fn list_tasks(
&self,
_context: &ClientCtx,
_cursor: Option<Cursor>,
) -> Result<ListTasksResult> {
Ok(ListTasksResult::default())
}
async fn cancel_task(
&self,
_context: &ClientCtx,
_task_id: String,
) -> Result<schema::CancelTaskResult> {
Err(Error::InvalidRequest("cancel_task not implemented".into()))
}
async fn notification(
&self,
_context: &ClientCtx,
_notification: schema::ServerNotification,
) -> Result<()> {
Ok(())
}
async fn handle_request(
&self,
context: &ClientCtx,
request: ServerRequest,
method: &str,
) -> Result<serde_json::Value> {
match request {
ServerRequest::Ping { _meta: _ } => {
info!("Server sent ping request, sending pong");
empty_result(self.pong(context).await)
}
ServerRequest::CreateMessage(params) => {
serialize_result(self.create_message(context, method, *params).await)
}
ServerRequest::ListRoots { _meta: _ } => {
serialize_result(self.list_roots(context).await)
}
ServerRequest::Elicit(params) => serialize_result(self.elicit(context, *params).await),
ServerRequest::GetTask { task_id, _meta: _ } => {
serialize_result(self.get_task(context, task_id).await)
}
ServerRequest::GetTaskPayload { task_id, _meta: _ } => {
serialize_result(self.get_task_payload(context, task_id).await)
}
ServerRequest::ListTasks { cursor, _meta: _ } => {
serialize_result(self.list_tasks(context, cursor).await)
}
ServerRequest::CancelTask { task_id, _meta: _ } => {
serialize_result(self.cancel_task(context, task_id).await)
}
}
}
}
#[async_trait]
pub trait ServerHandler: Send + Sync {
async fn on_connect(&self, _context: &ServerCtx, _remote_addr: &str) -> Result<()> {
Ok(())
}
async fn on_shutdown(&self) -> Result<()> {
Ok(())
}
async fn initialize(
&self,
_context: &ServerCtx,
_protocol_version: String,
_capabilities: schema::ClientCapabilities,
_client_info: schema::Implementation,
) -> Result<InitializeResult>;
async fn pong(&self, _context: &ServerCtx) -> Result<()> {
Ok(())
}
async fn list_tools(
&self,
_context: &ServerCtx,
_cursor: Option<Cursor>,
) -> Result<ListToolsResult> {
Ok(ListToolsResult::default())
}
async fn call_tool(
&self,
_context: &ServerCtx,
name: String,
_arguments: Option<crate::Arguments>,
_task: Option<schema::TaskMetadata>,
) -> Result<schema::CallToolResult> {
Err(Error::ToolNotFound(name))
}
async fn list_resources(
&self,
_context: &ServerCtx,
_cursor: Option<Cursor>,
) -> Result<ListResourcesResult> {
Ok(ListResourcesResult::new())
}
async fn list_resource_templates(
&self,
_context: &ServerCtx,
_cursor: Option<Cursor>,
) -> Result<ListResourceTemplatesResult> {
Ok(ListResourceTemplatesResult {
resource_templates: vec![],
next_cursor: None,
})
}
async fn read_resource(&self, _context: &ServerCtx, uri: String) -> Result<ReadResourceResult> {
Err(Error::ResourceNotFound { uri })
}
async fn resources_subscribe(&self, _context: &ServerCtx, _uri: String) -> Result<()> {
Ok(())
}
async fn resources_unsubscribe(&self, _context: &ServerCtx, _uri: String) -> Result<()> {
Ok(())
}
async fn list_prompts(
&self,
_context: &ServerCtx,
_cursor: Option<Cursor>,
) -> Result<ListPromptsResult> {
Ok(ListPromptsResult::new())
}
async fn get_prompt(
&self,
_context: &ServerCtx,
name: String,
_arguments: Option<HashMap<String, String>>,
) -> Result<GetPromptResult> {
Err(Error::PromptNotFound(name))
}
async fn complete(
&self,
_context: &ServerCtx,
_reference: schema::Reference,
_argument: schema::ArgumentInfo,
_context_info: Option<schema::CompleteContext>,
) -> Result<schema::CompleteResult> {
Ok(schema::CompleteResult {
completion: schema::CompletionInfo {
values: vec![],
total: None,
has_more: None,
},
_meta: None,
})
}
async fn set_level(&self, _context: &ServerCtx, _level: LoggingLevel) -> Result<()> {
Ok(())
}
async fn list_roots(&self, _context: &ServerCtx) -> Result<ListRootsResult> {
Ok(ListRootsResult {
roots: vec![],
_meta: None,
})
}
async fn create_message(
&self,
_context: &ServerCtx,
_params: schema::CreateMessageParams,
) -> Result<schema::CreateMessageResult> {
Err(Error::MethodNotFound("sampling/createMessage".to_string()))
}
async fn get_task(
&self,
_context: &ServerCtx,
_task_id: String,
) -> Result<schema::GetTaskResult> {
Err(Error::InvalidRequest("get_task not implemented".into()))
}
async fn get_task_payload(
&self,
_context: &ServerCtx,
_task_id: String,
) -> Result<schema::GetTaskPayloadResult> {
Err(Error::InvalidRequest(
"get_task_payload not implemented".into(),
))
}
async fn list_tasks(
&self,
_context: &ServerCtx,
_cursor: Option<Cursor>,
) -> Result<ListTasksResult> {
Ok(ListTasksResult::default())
}
async fn cancel_task(
&self,
_context: &ServerCtx,
_task_id: String,
) -> Result<schema::CancelTaskResult> {
Err(Error::InvalidRequest("cancel_task not implemented".into()))
}
async fn notification(
&self,
_context: &ServerCtx,
_notification: schema::ClientNotification,
) -> Result<()> {
Ok(())
}
async fn handle_request(
&self,
context: &ServerCtx,
request: ClientRequest,
) -> Result<serde_json::Value> {
match request {
ClientRequest::Initialize {
protocol_version,
capabilities,
client_info,
_meta: _,
} => serialize_result(
self.initialize(context, protocol_version, *capabilities, client_info)
.await,
),
ClientRequest::Ping { .. } => {
info!("Server received ping request, sending automatic response");
empty_result(self.pong(context).await)
}
ClientRequest::ListTools { cursor, _meta: _ } => {
serialize_result(self.list_tools(context, cursor).await)
}
ClientRequest::CallTool {
name,
arguments,
task,
_meta: _,
} => {
let result = self.call_tool(context, name, arguments, task).await;
match result {
Ok(result) => serialize_result(Ok(result)),
Err(Error::InvalidParams(message)) => {
let tool_result: CallToolResult = ToolError::invalid_input(message).into();
serialize_result(Ok(tool_result))
}
Err(err) => Err(err),
}
}
ClientRequest::ListResources { cursor, _meta: _ } => {
serialize_result(self.list_resources(context, cursor).await)
}
ClientRequest::ListResourceTemplates { cursor, _meta: _ } => {
serialize_result(self.list_resource_templates(context, cursor).await)
}
ClientRequest::ReadResource { uri, _meta: _ } => {
serialize_result(self.read_resource(context, uri).await)
}
ClientRequest::Subscribe { uri, _meta: _ } => {
empty_result(self.resources_subscribe(context, uri).await)
}
ClientRequest::Unsubscribe { uri, _meta: _ } => {
empty_result(self.resources_unsubscribe(context, uri).await)
}
ClientRequest::ListPrompts { cursor, _meta: _ } => {
serialize_result(self.list_prompts(context, cursor).await)
}
ClientRequest::GetPrompt {
name,
arguments,
_meta: _,
} => serialize_result(self.get_prompt(context, name, arguments).await),
ClientRequest::Complete {
reference,
argument,
context: completion_context,
_meta: _,
} => serialize_result(
self.complete(context, reference, argument, completion_context)
.await,
),
ClientRequest::SetLevel { level, _meta: _ } => {
empty_result(self.set_level(context, level).await)
}
ClientRequest::GetTask { task_id, _meta: _ } => {
serialize_result(self.get_task(context, task_id).await)
}
ClientRequest::GetTaskPayload { task_id, _meta: _ } => {
serialize_result(self.get_task_payload(context, task_id).await)
}
ClientRequest::ListTasks { cursor, _meta: _ } => {
serialize_result(self.list_tasks(context, cursor).await)
}
ClientRequest::CancelTask { task_id, _meta: _ } => {
serialize_result(self.cancel_task(context, task_id).await)
}
}
}
}
fn serialize_result<T: Serialize>(result: Result<T>) -> Result<serde_json::Value> {
result.and_then(|value| serde_json::to_value(value).map_err(Into::into))
}
fn empty_result(result: Result<()>) -> Result<serde_json::Value> {
result.map(|_| serde_json::json!({}))
}