rmcp/handler/server/router/
prompt.rs1use std::{borrow::Cow, sync::Arc};
2
3use crate::{
4 handler::server::prompt::{DynGetPromptHandler, GetPromptHandler, PromptContext},
5 model::{GetPromptResult, Prompt},
6 service::{MaybeBoxFuture, MaybeSend},
7};
8
9#[non_exhaustive]
10pub struct PromptRoute<S> {
11 #[allow(clippy::type_complexity)]
12 pub get: Arc<DynGetPromptHandler<S>>,
13 pub attr: crate::model::Prompt,
14}
15
16impl<S> std::fmt::Debug for PromptRoute<S> {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 f.debug_struct("PromptRoute")
19 .field("name", &self.attr.name)
20 .field("description", &self.attr.description)
21 .field("arguments", &self.attr.arguments)
22 .finish()
23 }
24}
25
26impl<S> Clone for PromptRoute<S> {
27 fn clone(&self) -> Self {
28 Self {
29 get: self.get.clone(),
30 attr: self.attr.clone(),
31 }
32 }
33}
34
35impl<S: MaybeSend + 'static> PromptRoute<S> {
36 pub fn new<H, A: 'static>(attr: impl Into<Prompt>, handler: H) -> Self
37 where
38 H: GetPromptHandler<S, A> + MaybeSend + Clone + 'static,
39 {
40 Self {
41 get: Arc::new(move |context: PromptContext<S>| {
42 let handler = handler.clone();
43 handler.handle(context)
44 }),
45 attr: attr.into(),
46 }
47 }
48
49 pub fn new_dyn<H>(attr: impl Into<Prompt>, handler: H) -> Self
50 where
51 H: for<'a> Fn(
52 PromptContext<'a, S>,
53 ) -> MaybeBoxFuture<'a, Result<GetPromptResult, crate::ErrorData>>
54 + MaybeSend
55 + 'static,
56 {
57 Self {
58 get: Arc::new(handler),
59 attr: attr.into(),
60 }
61 }
62
63 pub fn name(&self) -> &str {
64 &self.attr.name
65 }
66}
67
68pub trait IntoPromptRoute<S, A> {
69 fn into_prompt_route(self) -> PromptRoute<S>;
70}
71
72impl<S, H, A, P> IntoPromptRoute<S, A> for (P, H)
73where
74 S: MaybeSend + 'static,
75 A: 'static,
76 H: GetPromptHandler<S, A> + MaybeSend + Clone + 'static,
77 P: Into<Prompt>,
78{
79 fn into_prompt_route(self) -> PromptRoute<S> {
80 PromptRoute::new(self.0.into(), self.1)
81 }
82}
83
84impl<S> IntoPromptRoute<S, ()> for PromptRoute<S>
85where
86 S: MaybeSend + 'static,
87{
88 fn into_prompt_route(self) -> PromptRoute<S> {
89 self
90 }
91}
92
93#[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")]
95pub struct PromptAttrGenerateFunctionAdapter;
96
97impl<S, F> IntoPromptRoute<S, PromptAttrGenerateFunctionAdapter> for F
98where
99 S: MaybeSend + 'static,
100 F: Fn() -> PromptRoute<S>,
101{
102 fn into_prompt_route(self) -> PromptRoute<S> {
103 (self)()
104 }
105}
106
107#[derive(Debug)]
108#[non_exhaustive]
109pub struct PromptRouter<S> {
110 #[allow(clippy::type_complexity)]
111 pub map: std::collections::HashMap<Cow<'static, str>, PromptRoute<S>>,
112}
113
114impl<S> Default for PromptRouter<S> {
115 fn default() -> Self {
116 Self {
117 map: std::collections::HashMap::new(),
118 }
119 }
120}
121
122impl<S> Clone for PromptRouter<S> {
123 fn clone(&self) -> Self {
124 Self {
125 map: self.map.clone(),
126 }
127 }
128}
129
130impl<S> IntoIterator for PromptRouter<S> {
131 type Item = PromptRoute<S>;
132 type IntoIter = std::collections::hash_map::IntoValues<Cow<'static, str>, PromptRoute<S>>;
133
134 fn into_iter(self) -> Self::IntoIter {
135 self.map.into_values()
136 }
137}
138
139impl<S> PromptRouter<S>
140where
141 S: MaybeSend + 'static,
142{
143 pub fn new() -> Self {
144 Self {
145 map: std::collections::HashMap::new(),
146 }
147 }
148
149 pub fn with_route<R, A: 'static>(mut self, route: R) -> Self
150 where
151 R: IntoPromptRoute<S, A>,
152 {
153 self.add_route(route.into_prompt_route());
154 self
155 }
156
157 pub fn add_route(&mut self, item: PromptRoute<S>) {
158 self.map.insert(item.attr.name.clone().into(), item);
159 }
160
161 pub fn merge(&mut self, other: PromptRouter<S>) {
162 for item in other.map.into_values() {
163 self.add_route(item);
164 }
165 }
166
167 pub fn remove_route(&mut self, name: &str) {
168 self.map.remove(name);
169 }
170
171 pub fn has_route(&self, name: &str) -> bool {
172 self.map.contains_key(name)
173 }
174
175 pub async fn get_prompt(
176 &self,
177 context: PromptContext<'_, S>,
178 ) -> Result<GetPromptResult, crate::ErrorData> {
179 let item = self.map.get(context.name.as_str()).ok_or_else(|| {
180 crate::ErrorData::invalid_params(
181 format!("prompt '{}' not found", context.name),
182 Some(serde_json::json!({
183 "available_prompts": self.list_all().iter().map(|p| &p.name).collect::<Vec<_>>()
184 })),
185 )
186 })?;
187 (item.get)(context).await
188 }
189
190 pub fn list_all(&self) -> Vec<crate::model::Prompt> {
191 let mut prompts: Vec<_> = self.map.values().map(|item| item.attr.clone()).collect();
192 prompts.sort_by(|a, b| a.name.cmp(&b.name));
193 prompts
194 }
195}
196
197impl<S> std::ops::Add<PromptRouter<S>> for PromptRouter<S>
198where
199 S: MaybeSend + 'static,
200{
201 type Output = Self;
202
203 fn add(mut self, other: PromptRouter<S>) -> Self::Output {
204 self.merge(other);
205 self
206 }
207}
208
209impl<S> std::ops::AddAssign<PromptRouter<S>> for PromptRouter<S>
210where
211 S: MaybeSend + 'static,
212{
213 fn add_assign(&mut self, other: PromptRouter<S>) {
214 self.merge(other);
215 }
216}