Skip to main content

rmcp/handler/server/router/
prompt.rs

1use std::{borrow::Cow, sync::Arc};
2
3use crate::{
4    handler::server::prompt::{DynGetPromptHandler, GetPromptHandler, PromptContext},
5    model::{GetPromptResult, Prompt},
6    service::{MaybeBoxFuture, MaybeSend},
7};
8
9#[non_exhaustive]
10pub struct PromptRoute<S> {
11    #[allow(clippy::type_complexity)]
12    pub get: Arc<DynGetPromptHandler<S>>,
13    pub attr: crate::model::Prompt,
14}
15
16impl<S> std::fmt::Debug for PromptRoute<S> {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        f.debug_struct("PromptRoute")
19            .field("name", &self.attr.name)
20            .field("description", &self.attr.description)
21            .field("arguments", &self.attr.arguments)
22            .finish()
23    }
24}
25
26impl<S> Clone for PromptRoute<S> {
27    fn clone(&self) -> Self {
28        Self {
29            get: self.get.clone(),
30            attr: self.attr.clone(),
31        }
32    }
33}
34
35impl<S: MaybeSend + 'static> PromptRoute<S> {
36    pub fn new<H, A: 'static>(attr: impl Into<Prompt>, handler: H) -> Self
37    where
38        H: GetPromptHandler<S, A> + MaybeSend + Clone + 'static,
39    {
40        Self {
41            get: Arc::new(move |context: PromptContext<S>| {
42                let handler = handler.clone();
43                handler.handle(context)
44            }),
45            attr: attr.into(),
46        }
47    }
48
49    pub fn new_dyn<H>(attr: impl Into<Prompt>, handler: H) -> Self
50    where
51        H: for<'a> Fn(
52                PromptContext<'a, S>,
53            ) -> MaybeBoxFuture<'a, Result<GetPromptResult, crate::ErrorData>>
54            + MaybeSend
55            + 'static,
56    {
57        Self {
58            get: Arc::new(handler),
59            attr: attr.into(),
60        }
61    }
62
63    pub fn name(&self) -> &str {
64        &self.attr.name
65    }
66}
67
68pub trait IntoPromptRoute<S, A> {
69    fn into_prompt_route(self) -> PromptRoute<S>;
70}
71
72impl<S, H, A, P> IntoPromptRoute<S, A> for (P, H)
73where
74    S: MaybeSend + 'static,
75    A: 'static,
76    H: GetPromptHandler<S, A> + MaybeSend + Clone + 'static,
77    P: Into<Prompt>,
78{
79    fn into_prompt_route(self) -> PromptRoute<S> {
80        PromptRoute::new(self.0.into(), self.1)
81    }
82}
83
84impl<S> IntoPromptRoute<S, ()> for PromptRoute<S>
85where
86    S: MaybeSend + 'static,
87{
88    fn into_prompt_route(self) -> PromptRoute<S> {
89        self
90    }
91}
92
93/// Adapter for functions generated by the #\[prompt\] macro
94#[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")]
95pub struct PromptAttrGenerateFunctionAdapter;
96
97impl<S, F> IntoPromptRoute<S, PromptAttrGenerateFunctionAdapter> for F
98where
99    S: MaybeSend + 'static,
100    F: Fn() -> PromptRoute<S>,
101{
102    fn into_prompt_route(self) -> PromptRoute<S> {
103        (self)()
104    }
105}
106
107#[derive(Debug)]
108#[non_exhaustive]
109pub struct PromptRouter<S> {
110    #[allow(clippy::type_complexity)]
111    pub map: std::collections::HashMap<Cow<'static, str>, PromptRoute<S>>,
112}
113
114impl<S> Default for PromptRouter<S> {
115    fn default() -> Self {
116        Self {
117            map: std::collections::HashMap::new(),
118        }
119    }
120}
121
122impl<S> Clone for PromptRouter<S> {
123    fn clone(&self) -> Self {
124        Self {
125            map: self.map.clone(),
126        }
127    }
128}
129
130impl<S> IntoIterator for PromptRouter<S> {
131    type Item = PromptRoute<S>;
132    type IntoIter = std::collections::hash_map::IntoValues<Cow<'static, str>, PromptRoute<S>>;
133
134    fn into_iter(self) -> Self::IntoIter {
135        self.map.into_values()
136    }
137}
138
139impl<S> PromptRouter<S>
140where
141    S: MaybeSend + 'static,
142{
143    pub fn new() -> Self {
144        Self {
145            map: std::collections::HashMap::new(),
146        }
147    }
148
149    pub fn with_route<R, A: 'static>(mut self, route: R) -> Self
150    where
151        R: IntoPromptRoute<S, A>,
152    {
153        self.add_route(route.into_prompt_route());
154        self
155    }
156
157    pub fn add_route(&mut self, item: PromptRoute<S>) {
158        self.map.insert(item.attr.name.clone().into(), item);
159    }
160
161    pub fn merge(&mut self, other: PromptRouter<S>) {
162        for item in other.map.into_values() {
163            self.add_route(item);
164        }
165    }
166
167    pub fn remove_route(&mut self, name: &str) {
168        self.map.remove(name);
169    }
170
171    pub fn has_route(&self, name: &str) -> bool {
172        self.map.contains_key(name)
173    }
174
175    pub async fn get_prompt(
176        &self,
177        context: PromptContext<'_, S>,
178    ) -> Result<GetPromptResult, crate::ErrorData> {
179        let item = self.map.get(context.name.as_str()).ok_or_else(|| {
180            crate::ErrorData::invalid_params(
181                format!("prompt '{}' not found", context.name),
182                Some(serde_json::json!({
183                    "available_prompts": self.list_all().iter().map(|p| &p.name).collect::<Vec<_>>()
184                })),
185            )
186        })?;
187        (item.get)(context).await
188    }
189
190    pub fn list_all(&self) -> Vec<crate::model::Prompt> {
191        let mut prompts: Vec<_> = self.map.values().map(|item| item.attr.clone()).collect();
192        prompts.sort_by(|a, b| a.name.cmp(&b.name));
193        prompts
194    }
195}
196
197impl<S> std::ops::Add<PromptRouter<S>> for PromptRouter<S>
198where
199    S: MaybeSend + 'static,
200{
201    type Output = Self;
202
203    fn add(mut self, other: PromptRouter<S>) -> Self::Output {
204        self.merge(other);
205        self
206    }
207}
208
209impl<S> std::ops::AddAssign<PromptRouter<S>> for PromptRouter<S>
210where
211    S: MaybeSend + 'static,
212{
213    fn add_assign(&mut self, other: PromptRouter<S>) {
214        self.merge(other);
215    }
216}