use crate::PROTOCOL_VERSIONS;
use crate::client::notification_handler::NotificationsHandler;
use crate::transport::{StdIoClient, TransportProto, stdio::options::StdIoOptions};
use crate::types::elicitation::ElicitationHandler;
use crate::types::sampling::SamplingHandler;
use crate::types::{
ElicitationCapability, Implementation, Root, RootsCapability, SamplingCapability, Uri,
};
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "tasks")]
use crate::types::ClientTasksCapability;
#[cfg(feature = "http-client")]
use crate::transport::http::HttpClient;
const DEFAULT_REQUEST_TIMEOUT: u64 = 10;
pub struct McpOptions {
pub(crate) implementation: Implementation,
pub(super) timeout: Duration,
pub(super) roots_capability: Option<RootsCapability>,
pub(super) sampling_capability: Option<SamplingCapability>,
pub(super) elicitation_capability: Option<ElicitationCapability>,
#[cfg(feature = "tasks")]
pub(super) tasks_capability: Option<ClientTasksCapability>,
pub(super) sampling_handler: Option<SamplingHandler>,
pub(super) elicitation_handler: Option<ElicitationHandler>,
pub(super) notification_handler: Option<Arc<NotificationsHandler>>,
protocol_ver: Option<&'static str>,
proto: Option<TransportProto>,
roots: HashMap<Uri, Root>,
}
impl Debug for McpOptions {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut binding = f.debug_struct("McpOptions");
let dbg = binding
.field("implementation", &self.implementation)
.field("timeout", &self.timeout)
.field("roots_capability", &self.roots_capability)
.field("elicitation_capability", &self.elicitation_capability)
.field("sampling_capability", &self.sampling_capability)
.field("protocol_ver", &self.protocol_ver)
.field("roots", &self.roots);
#[cfg(feature = "tasks")]
dbg.field("tasks_capability", &self.tasks_capability);
dbg.finish()
}
}
impl Default for McpOptions {
#[inline]
fn default() -> Self {
Self {
timeout: Duration::from_secs(DEFAULT_REQUEST_TIMEOUT),
implementation: Default::default(),
roots: Default::default(),
roots_capability: None,
sampling_capability: None,
elicitation_capability: None,
#[cfg(feature = "tasks")]
tasks_capability: None,
proto: None,
protocol_ver: None,
sampling_handler: None,
elicitation_handler: None,
notification_handler: None,
}
}
}
impl McpOptions {
pub fn with_stdio<T>(mut self, command: &'static str, args: T) -> Self
where
T: IntoIterator<Item = &'static str>,
{
self.proto = Some(TransportProto::StdioClient(StdIoClient::new(
StdIoOptions::new(command, args),
)));
self
}
#[cfg(feature = "http-client")]
pub fn with_http<F: FnOnce(HttpClient) -> HttpClient>(mut self, config: F) -> Self {
self.proto = Some(TransportProto::HttpClient(config(HttpClient::default())));
self
}
#[cfg(feature = "http-client")]
pub fn with_default_http(self) -> Self {
self.with_http(|http| http)
}
pub fn with_name(mut self, name: &str) -> Self {
self.implementation.name = name.into();
self
}
pub fn with_version(mut self, ver: &str) -> Self {
self.implementation.version = ver.into();
self
}
pub fn with_mcp_version(mut self, ver: &'static str) -> Self {
self.protocol_ver = Some(ver);
self
}
pub fn with_roots<T>(mut self, config: T) -> Self
where
T: FnOnce(RootsCapability) -> RootsCapability,
{
self.roots_capability = Some(config(Default::default()));
self
}
pub fn with_sampling<T>(mut self, config: T) -> Self
where
T: FnOnce(SamplingCapability) -> SamplingCapability,
{
self.sampling_capability = Some(config(Default::default()));
self
}
pub fn with_elicitation<T>(mut self, config: T) -> Self
where
T: FnOnce(ElicitationCapability) -> ElicitationCapability,
{
self.elicitation_capability = Some(config(Default::default()));
self
}
#[cfg(feature = "tasks")]
pub fn with_tasks<T>(mut self, config: T) -> Self
where
T: FnOnce(ClientTasksCapability) -> ClientTasksCapability,
{
self.tasks_capability = Some(config(Default::default()));
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
#[inline]
pub(crate) fn protocol_ver(&self) -> &'static str {
match self.protocol_ver {
Some(ver) => ver,
None => PROTOCOL_VERSIONS.last().unwrap(),
}
}
pub(crate) fn transport(&mut self) -> TransportProto {
let transport = self.proto.take();
transport.unwrap_or_default()
}
pub fn add_root(&mut self, root: Root) -> &mut Root {
self.roots.entry(root.uri.clone()).or_insert(root)
}
pub fn add_roots<T, I>(&mut self, roots: I) -> &mut Self
where
T: Into<Root>,
I: IntoIterator<Item = T>,
{
let roots = roots.into_iter().map(|item| {
let root: Root = item.into();
(root.uri.clone(), root)
});
self.roots.extend(roots);
self
}
pub fn roots(&self) -> Vec<Root> {
self.roots.values().cloned().collect()
}
pub(crate) fn add_sampling_handler(&mut self, handler: SamplingHandler) {
self.sampling_handler = Some(handler);
}
pub(crate) fn add_elicitation_handler(&mut self, handler: ElicitationHandler) {
self.elicitation_handler = Some(handler);
}
pub(crate) fn roots_capability(&self) -> Option<RootsCapability> {
self.roots_capability
.clone()
.or_else(|| (!self.roots.is_empty()).then(Default::default))
}
pub(crate) fn sampling_capability(&self) -> Option<SamplingCapability> {
self.sampling_capability
.clone()
.or_else(|| self.sampling_handler.is_none().then(Default::default))
}
pub(crate) fn elicitation_capability(&self) -> Option<ElicitationCapability> {
self.elicitation_capability
.clone()
.or_else(|| self.elicitation_handler.is_none().then(Default::default))
}
#[cfg(feature = "tasks")]
pub(crate) fn tasks_capability(&self) -> Option<ClientTasksCapability> {
self.tasks_capability.clone()
}
}