use std::sync::Arc;
use async_trait::async_trait;
use crate::core::{DynTool, ReadonlyContext};
use crate::error::Result;
use crate::tools::Toolset;
use crate::mcp::client::McpClient;
use crate::mcp::http::McpHttpParams;
use crate::mcp::stdio::McpStdioParams;
use crate::mcp::tool::McpTool;
#[derive(Debug, Clone, Default)]
pub enum ConfirmationPolicy {
#[default]
None,
All,
Named(std::collections::HashSet<String>),
}
impl ConfirmationPolicy {
fn applies_to(&self, name: &str) -> bool {
match self {
Self::None => false,
Self::All => true,
Self::Named(set) => set.contains(name),
}
}
}
pub struct McpToolset {
client: Arc<McpClient>,
confirmation: ConfirmationPolicy,
cached: tokio::sync::OnceCell<Vec<Arc<dyn DynTool>>>,
}
impl std::fmt::Debug for McpToolset {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("McpToolset").finish_non_exhaustive()
}
}
impl McpToolset {
pub fn from_client(client: Arc<McpClient>) -> Self {
Self {
client,
confirmation: ConfirmationPolicy::default(),
cached: tokio::sync::OnceCell::new(),
}
}
#[must_use]
pub fn with_confirmation_policy(mut self, policy: ConfirmationPolicy) -> Self {
self.confirmation = policy;
self
}
pub async fn stdio(params: McpStdioParams) -> Result<Self> {
Ok(Self::from_client(Arc::new(McpClient::spawn(params).await?)))
}
pub async fn http(params: McpHttpParams) -> Result<Self> {
Ok(Self::from_client(Arc::new(McpClient::http(params).await?)))
}
}
#[async_trait]
impl Toolset for McpToolset {
async fn list_tools(&self, _ctx: &ReadonlyContext) -> Result<Vec<Arc<dyn DynTool>>> {
if let Some(t) = self.cached.get() {
return Ok(t.clone());
}
let descs = self.client.list_tools().await?;
let tools: Vec<Arc<dyn DynTool>> = descs
.into_iter()
.map(|d| {
let confirm = self.confirmation.applies_to(&d.name);
Arc::new(McpTool::new(d, self.client.clone()).with_require_confirmation(confirm))
as Arc<dyn DynTool>
})
.collect();
let _ = self.cached.set(tools.clone());
Ok(tools)
}
}