1use std::ops::{Deref, DerefMut};
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use super::{
7 ClientNotification, ClientRequest, CustomNotification, CustomRequest, Extensions, JsonObject,
8 JsonRpcMessage, NumberOrString, ProgressToken, ServerNotification, ServerRequest,
9};
10
11pub trait GetMeta {
12 fn get_meta_mut(&mut self) -> &mut Meta;
13 fn get_meta(&self) -> &Meta;
14}
15
16pub trait GetExtensions {
17 fn extensions(&self) -> &Extensions;
18 fn extensions_mut(&mut self) -> &mut Extensions;
19}
20
21pub trait RequestParamsMeta {
26 fn meta(&self) -> Option<&Meta>;
28 fn meta_mut(&mut self) -> &mut Option<Meta>;
30 fn set_meta(&mut self, meta: Meta) {
32 *self.meta_mut() = Some(meta);
33 }
34 fn progress_token(&self) -> Option<ProgressToken> {
36 self.meta().and_then(|m| m.get_progress_token())
37 }
38 fn set_progress_token(&mut self, token: ProgressToken) {
40 match self.meta_mut() {
41 Some(meta) => meta.set_progress_token(token),
42 none => {
43 let mut meta = Meta::new();
44 meta.set_progress_token(token);
45 *none = Some(meta);
46 }
47 }
48 }
49}
50
51pub trait TaskAugmentedRequestParamsMeta: RequestParamsMeta {
56 fn task(&self) -> Option<&JsonObject>;
58 fn task_mut(&mut self) -> &mut Option<JsonObject>;
60 fn set_task(&mut self, task: JsonObject) {
62 *self.task_mut() = Some(task);
63 }
64}
65
66impl GetExtensions for CustomNotification {
67 fn extensions(&self) -> &Extensions {
68 &self.extensions
69 }
70 fn extensions_mut(&mut self) -> &mut Extensions {
71 &mut self.extensions
72 }
73}
74
75impl GetMeta for CustomNotification {
76 fn get_meta_mut(&mut self) -> &mut Meta {
77 self.extensions_mut().get_or_insert_default()
78 }
79 fn get_meta(&self) -> &Meta {
80 self.extensions()
81 .get::<Meta>()
82 .unwrap_or(Meta::static_empty())
83 }
84}
85
86impl GetExtensions for CustomRequest {
87 fn extensions(&self) -> &Extensions {
88 &self.extensions
89 }
90 fn extensions_mut(&mut self) -> &mut Extensions {
91 &mut self.extensions
92 }
93}
94
95impl GetMeta for CustomRequest {
96 fn get_meta_mut(&mut self) -> &mut Meta {
97 self.extensions_mut().get_or_insert_default()
98 }
99 fn get_meta(&self) -> &Meta {
100 self.extensions()
101 .get::<Meta>()
102 .unwrap_or(Meta::static_empty())
103 }
104}
105
106macro_rules! variant_extension {
107 (
108 $Enum: ident {
109 $($variant: ident)*
110 }
111 ) => {
112 impl GetExtensions for $Enum {
113 fn extensions(&self) -> &Extensions {
114 match self {
115 $(
116 $Enum::$variant(v) => &v.extensions,
117 )*
118 }
119 }
120 fn extensions_mut(&mut self) -> &mut Extensions {
121 match self {
122 $(
123 $Enum::$variant(v) => &mut v.extensions,
124 )*
125 }
126 }
127 }
128 impl GetMeta for $Enum {
129 fn get_meta_mut(&mut self) -> &mut Meta {
130 self.extensions_mut().get_or_insert_default()
131 }
132 fn get_meta(&self) -> &Meta {
133 self.extensions().get::<Meta>().unwrap_or(Meta::static_empty())
134 }
135 }
136 };
137}
138
139variant_extension! {
140 ClientRequest {
141 PingRequest
142 InitializeRequest
143 CompleteRequest
144 SetLevelRequest
145 GetPromptRequest
146 ListPromptsRequest
147 ListResourcesRequest
148 ListResourceTemplatesRequest
149 ReadResourceRequest
150 SubscribeRequest
151 UnsubscribeRequest
152 CallToolRequest
153 ListToolsRequest
154 CustomRequest
155 GetTaskInfoRequest
156 ListTasksRequest
157 GetTaskResultRequest
158 CancelTaskRequest
159 }
160}
161
162variant_extension! {
163 ServerRequest {
164 PingRequest
165 CreateMessageRequest
166 ListRootsRequest
167 CreateElicitationRequest
168 CustomRequest
169 }
170}
171
172variant_extension! {
173 ClientNotification {
174 CancelledNotification
175 ProgressNotification
176 InitializedNotification
177 RootsListChangedNotification
178 CustomNotification
179 }
180}
181
182variant_extension! {
183 ServerNotification {
184 CancelledNotification
185 ProgressNotification
186 LoggingMessageNotification
187 ResourceUpdatedNotification
188 ResourceListChangedNotification
189 ToolListChangedNotification
190 PromptListChangedNotification
191 ElicitationCompletionNotification
192 CustomNotification
193 }
194}
195#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
196#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
197#[serde(transparent)]
198pub struct Meta(pub JsonObject);
199const PROGRESS_TOKEN_FIELD: &str = "progressToken";
200impl Meta {
201 pub fn new() -> Self {
202 Self(JsonObject::new())
203 }
204
205 pub fn with_progress_token(token: ProgressToken) -> Self {
207 let mut meta = Self::new();
208 meta.set_progress_token(token);
209 meta
210 }
211
212 pub(crate) fn static_empty() -> &'static Self {
213 static EMPTY: std::sync::OnceLock<Meta> = std::sync::OnceLock::new();
214 EMPTY.get_or_init(Default::default)
215 }
216
217 pub fn get_progress_token(&self) -> Option<ProgressToken> {
218 self.0.get(PROGRESS_TOKEN_FIELD).and_then(|v| match v {
219 Value::String(s) => Some(ProgressToken(NumberOrString::String(s.to_string().into()))),
220 Value::Number(n) => {
221 if let Some(i) = n.as_i64() {
222 Some(ProgressToken(NumberOrString::Number(i)))
223 } else if let Some(u) = n.as_u64() {
224 if u <= i64::MAX as u64 {
225 Some(ProgressToken(NumberOrString::Number(u as i64)))
226 } else {
227 None
228 }
229 } else {
230 None
231 }
232 }
233 _ => None,
234 })
235 }
236
237 pub fn set_progress_token(&mut self, token: ProgressToken) {
238 match token.0 {
239 NumberOrString::String(ref s) => self.0.insert(
240 PROGRESS_TOKEN_FIELD.to_string(),
241 Value::String(s.to_string()),
242 ),
243 NumberOrString::Number(n) => self
244 .0
245 .insert(PROGRESS_TOKEN_FIELD.to_string(), Value::Number(n.into())),
246 };
247 }
248
249 pub fn extend(&mut self, other: Meta) {
250 for (k, v) in other.0.into_iter() {
251 self.0.insert(k, v);
252 }
253 }
254}
255
256impl Deref for Meta {
257 type Target = JsonObject;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263
264impl DerefMut for Meta {
265 fn deref_mut(&mut self) -> &mut Self::Target {
266 &mut self.0
267 }
268}
269
270impl<Req, Resp, Noti> JsonRpcMessage<Req, Resp, Noti>
271where
272 Req: GetExtensions,
273 Noti: GetExtensions,
274{
275 pub fn insert_extension<T: Clone + Send + Sync + 'static>(&mut self, value: T) {
276 match self {
277 JsonRpcMessage::Request(json_rpc_request) => {
278 json_rpc_request.request.extensions_mut().insert(value);
279 }
280 JsonRpcMessage::Notification(json_rpc_notification) => {
281 json_rpc_notification
282 .notification
283 .extensions_mut()
284 .insert(value);
285 }
286 _ => {}
287 }
288 }
289}