1use crate::errors::BoxedError;
2use futures::future::BoxFuture;
3use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5
6#[derive(Clone)]
9#[allow(clippy::type_complexity)]
10pub enum MCPInit<TCtx>
11where
12 TCtx: Send + Sync + 'static,
13{
14 Params(MCPParams),
15 Func(Arc<dyn Fn(&TCtx) -> Result<MCPParams, BoxedError> + Send + Sync>),
16 AsyncFunc(
17 Arc<dyn Fn(&TCtx) -> BoxFuture<'static, Result<MCPParams, BoxedError>> + Send + Sync>,
18 ),
19}
20
21impl<TCtx> MCPInit<TCtx>
22where
23 TCtx: Send + Sync + 'static,
24{
25 #[must_use]
27 pub fn from_params(params: MCPParams) -> Self {
28 Self::Params(params)
29 }
30
31 pub fn from_fn<F>(func: F) -> Self
33 where
34 F: Fn(&TCtx) -> Result<MCPParams, BoxedError> + Send + Sync + 'static,
35 {
36 Self::Func(Arc::new(func))
37 }
38
39 pub fn from_async_fn<F, Fut>(func: F) -> Self
41 where
42 F: Fn(&TCtx) -> Fut + Send + Sync + 'static,
43 Fut: std::future::Future<Output = Result<MCPParams, BoxedError>> + Send + 'static,
44 {
45 Self::AsyncFunc(Arc::new(move |ctx| Box::pin(func(ctx))))
46 }
47
48 pub(crate) async fn resolve(&self, context: &TCtx) -> Result<MCPParams, BoxedError> {
50 match self {
51 Self::Params(params) => Ok(params.clone()),
52 Self::Func(func) => func(context),
53 Self::AsyncFunc(func) => func(context).await,
54 }
55 }
56}
57
58impl<TCtx> From<MCPParams> for MCPInit<TCtx>
59where
60 TCtx: Send + Sync + 'static,
61{
62 fn from(value: MCPParams) -> Self {
63 Self::from_params(value)
64 }
65}
66
67impl<TCtx, F> From<F> for MCPInit<TCtx>
68where
69 TCtx: Send + Sync + 'static,
70 F: Fn(&TCtx) -> Result<MCPParams, BoxedError> + Send + Sync + 'static,
71{
72 fn from(value: F) -> Self {
73 Self::from_fn(value)
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
82#[serde(tag = "type", rename_all = "kebab-case")]
83pub enum MCPParams {
84 Stdio(MCPStdioParams),
85 StreamableHttp(MCPStreamableHTTPParams),
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct MCPStdioParams {
91 pub command: String,
93 #[serde(default, skip_serializing_if = "Vec::is_empty")]
95 pub args: Vec<String>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct MCPStreamableHTTPParams {
101 pub url: String,
103 #[serde(default, skip_serializing_if = "Option::is_none")]
106 pub authorization: Option<String>,
107}