use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use crate::{
protocol::Protocol,
tools::{ToolHandler, ToolHandlerFn, Tools},
types::{
CallToolRequest, ListRequest, ProtocolVersion, Tool, ToolsListResponse,
LATEST_PROTOCOL_VERSION,
},
};
use super::{
protocol::ProtocolBuilder,
transport::Transport,
types::{
ClientCapabilities, Implementation, InitializeRequest, InitializeResponse,
ServerCapabilities,
},
};
use anyhow::Result;
use std::pin::Pin;
#[derive(Clone)]
pub struct ClientConnection {
pub client_capabilities: Option<ClientCapabilities>,
pub client_info: Option<Implementation>,
pub initialized: bool,
}
#[derive(Clone)]
pub struct Server;
impl Server {
pub fn builder(
name: String,
version: String,
protocol_version: ProtocolVersion,
) -> ServerProtocolBuilder {
ServerProtocolBuilder::new(name, version).set_protocol_version(protocol_version)
}
pub async fn start<T: Transport>(transport: T) -> Result<()> {
transport.open().await
}
}
pub struct ServerProtocolBuilder {
protocol_version: ProtocolVersion,
protocol_builder: ProtocolBuilder,
server_info: Implementation,
capabilities: ServerCapabilities,
instructions: Option<String>,
tools: HashMap<String, ToolHandler>,
client_connection: Arc<RwLock<ClientConnection>>,
}
impl ServerProtocolBuilder {
pub fn new(name: String, version: String) -> Self {
ServerProtocolBuilder {
protocol_version: LATEST_PROTOCOL_VERSION,
protocol_builder: ProtocolBuilder::new(),
server_info: Implementation { name, version },
capabilities: ServerCapabilities::default(),
instructions: None,
tools: HashMap::new(),
client_connection: Arc::new(RwLock::new(ClientConnection {
client_capabilities: None,
client_info: None,
initialized: false,
})),
}
}
pub fn set_protocol_version(mut self, protocol_version: ProtocolVersion) -> Self {
self.protocol_version = protocol_version;
self
}
pub fn set_capabilities(mut self, capabilities: ServerCapabilities) -> Self {
self.capabilities = capabilities;
self
}
pub fn set_instructions(mut self, instructions: String) -> Self {
self.instructions = Some(instructions);
self
}
pub fn remove_instructions(mut self) -> Self {
self.instructions = None;
self
}
pub fn register_tool(mut self, tool: Tool, f: ToolHandlerFn) -> Self {
self.tools.insert(
tool.name.clone(),
ToolHandler {
tool,
f: Box::new(f),
},
);
self
}
fn handle_init(
protocol_version: ProtocolVersion,
state: Arc<RwLock<ClientConnection>>,
server_info: Implementation,
capabilities: ServerCapabilities,
instructions: Option<String>,
) -> impl Fn(
InitializeRequest,
)
-> Pin<Box<dyn std::future::Future<Output = Result<InitializeResponse>> + Send>> {
move |req| {
let state = state.clone();
let server_info = server_info.clone();
let capabilities = capabilities.clone();
let instructions = instructions.clone();
let protocol_version = protocol_version.clone();
Box::pin(async move {
let mut state = state
.write()
.map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
state.client_capabilities = Some(req.capabilities);
state.client_info = Some(req.client_info);
Ok(InitializeResponse {
protocol_version: protocol_version.as_str().to_string(),
capabilities,
server_info,
instructions,
})
})
}
}
fn handle_initialized(
state: Arc<RwLock<ClientConnection>>,
) -> impl Fn(()) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>> {
move |_| {
let state = state.clone();
Box::pin(async move {
let mut state = state
.write()
.map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
state.initialized = true;
Ok(())
})
}
}
pub fn get_client_capabilities(&self) -> Option<ClientCapabilities> {
self.client_connection
.read()
.ok()?
.client_capabilities
.clone()
}
pub fn get_client_info(&self) -> Option<Implementation> {
self.client_connection.read().ok()?.client_info.clone()
}
pub fn is_initialized(&self) -> bool {
self.client_connection
.read()
.ok()
.map(|client_connection| client_connection.initialized)
.unwrap_or(false)
}
pub fn build(self) -> Protocol {
let tools = Arc::new(Tools::new(self.tools));
let tools_clone = tools.clone();
let tools_list = tools.clone();
let tools_call = tools_clone.clone();
let conn_for_list = self.client_connection.clone();
let conn_for_call = self.client_connection.clone();
self.protocol_builder
.request_handler(
"initialize",
Self::handle_init(
self.protocol_version.clone(),
self.client_connection.clone(),
self.server_info,
self.capabilities,
self.instructions,
),
)
.notification_handler(
"notifications/initialized",
Self::handle_initialized(self.client_connection),
)
.request_handler("tools/list", move |_req: ListRequest| {
let tools_list = tools_list.clone();
let conn = conn_for_list.clone();
Box::pin(async move {
match conn.read() {
Ok(conn) => {
if !conn.initialized {
return Err(anyhow::anyhow!("Client not initialized"));
}
}
Err(_) => return Err(anyhow::anyhow!("Lock poisoned")),
}
let tools = tools_list.list_tools();
Ok(ToolsListResponse {
tools,
next_cursor: None,
meta: None,
})
})
})
.request_handler("tools/call", move |req: CallToolRequest| {
let tools_call = tools_call.clone();
let conn = conn_for_call.clone();
Box::pin(async move {
match conn.read() {
Ok(conn) => {
if !conn.initialized {
return Err(anyhow::anyhow!("Client not initialized"));
}
}
Err(_) => return Err(anyhow::anyhow!("Lock poisoned")),
}
match tools_call.call_tool(req).await {
Ok(resp) => Ok(resp),
Err(e) => Err(e),
}
})
})
.build()
}
}