use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use turbomcp_protocol::types::{CreateMessageRequest, CreateMessageResult};
pub type BoxSamplingFuture<'a, T> =
Pin<Box<dyn Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>> + Send + 'a>>;
pub trait SamplingHandler: Send + Sync + std::fmt::Debug {
fn handle_create_message(
&self,
request_id: String,
request: CreateMessageRequest,
) -> BoxSamplingFuture<'_, CreateMessageResult>;
}
#[derive(Debug)]
pub struct DelegatingSamplingHandler {
llm_clients: Vec<Arc<dyn LLMServerClient>>,
user_handler: Arc<dyn UserInteractionHandler>,
}
pub trait LLMServerClient: Send + Sync + std::fmt::Debug {
fn create_message(
&self,
request: CreateMessageRequest,
) -> BoxSamplingFuture<'_, CreateMessageResult>;
fn get_server_info(&self) -> BoxSamplingFuture<'_, ServerInfo>;
}
pub trait UserInteractionHandler: Send + Sync + std::fmt::Debug {
fn approve_request(&self, request: &CreateMessageRequest) -> BoxSamplingFuture<'_, bool>;
fn approve_response(
&self,
request: &CreateMessageRequest,
response: &CreateMessageResult,
) -> BoxSamplingFuture<'_, Option<CreateMessageResult>>;
}
#[derive(Debug, Clone)]
pub struct ServerInfo {
pub name: String,
pub models: Vec<String>,
pub capabilities: Vec<String>,
}
impl SamplingHandler for DelegatingSamplingHandler {
fn handle_create_message(
&self,
_request_id: String,
request: CreateMessageRequest,
) -> BoxSamplingFuture<'_, CreateMessageResult> {
Box::pin(async move {
if !self.user_handler.approve_request(&request).await? {
return Err(Box::new(crate::handlers::HandlerError::UserCancelled)
as Box<dyn std::error::Error + Send + Sync>);
}
let selected_client = self.select_llm_client(&request).await?;
let result = selected_client.create_message(request.clone()).await?;
let approved_result = self
.user_handler
.approve_response(&request, &result)
.await?;
Ok(approved_result.unwrap_or(result))
})
}
}
impl DelegatingSamplingHandler {
pub fn new(
llm_clients: Vec<Arc<dyn LLMServerClient>>,
user_handler: Arc<dyn UserInteractionHandler>,
) -> Self {
Self {
llm_clients,
user_handler,
}
}
async fn select_llm_client(
&self,
_request: &CreateMessageRequest,
) -> Result<Arc<dyn LLMServerClient>, Box<dyn std::error::Error + Send + Sync>> {
if let Some(first_client) = self.llm_clients.first() {
Ok(first_client.clone())
} else {
Err(Box::new(crate::handlers::HandlerError::Configuration {
message: "No LLM servers configured".to_string(),
}))
}
}
}
#[derive(Debug)]
pub struct AutoApprovingUserHandler;
impl UserInteractionHandler for AutoApprovingUserHandler {
fn approve_request(&self, _request: &CreateMessageRequest) -> BoxSamplingFuture<'_, bool> {
Box::pin(async move {
Ok(true) })
}
fn approve_response(
&self,
_request: &CreateMessageRequest,
_response: &CreateMessageResult,
) -> BoxSamplingFuture<'_, Option<CreateMessageResult>> {
Box::pin(async move {
Ok(None) })
}
}