1use std::ops::{Deref, DerefMut};
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use super::{
7 ClientNotification, ClientRequest, CustomNotification, CustomRequest, Extensions, JsonObject,
8 JsonRpcMessage, NumberOrString, ProgressToken, ServerNotification, ServerRequest,
9};
10
11pub trait GetMeta {
12 fn get_meta_mut(&mut self) -> &mut Meta;
13 fn get_meta(&self) -> &Meta;
14}
15
16pub trait GetExtensions {
17 fn extensions(&self) -> &Extensions;
18 fn extensions_mut(&mut self) -> &mut Extensions;
19}
20
21pub trait RequestParamsMeta {
26 fn meta(&self) -> Option<&Meta>;
28 fn meta_mut(&mut self) -> &mut Option<Meta>;
30 fn set_meta(&mut self, meta: Meta) {
32 *self.meta_mut() = Some(meta);
33 }
34 fn progress_token(&self) -> Option<ProgressToken> {
36 self.meta().and_then(|m| m.get_progress_token())
37 }
38 fn set_progress_token(&mut self, token: ProgressToken) {
40 match self.meta_mut() {
41 Some(meta) => meta.set_progress_token(token),
42 none => {
43 let mut meta = Meta::new();
44 meta.set_progress_token(token);
45 *none = Some(meta);
46 }
47 }
48 }
49}
50
51pub trait TaskAugmentedRequestParamsMeta: RequestParamsMeta {
56 fn task(&self) -> Option<&JsonObject>;
58 fn task_mut(&mut self) -> &mut Option<JsonObject>;
60 fn set_task(&mut self, task: JsonObject) {
62 *self.task_mut() = Some(task);
63 }
64}
65
66impl GetExtensions for CustomNotification {
67 fn extensions(&self) -> &Extensions {
68 &self.extensions
69 }
70 fn extensions_mut(&mut self) -> &mut Extensions {
71 &mut self.extensions
72 }
73}
74
75impl GetMeta for CustomNotification {
76 fn get_meta_mut(&mut self) -> &mut Meta {
77 self.extensions_mut().get_or_insert_default()
78 }
79 fn get_meta(&self) -> &Meta {
80 self.extensions()
81 .get::<Meta>()
82 .unwrap_or(Meta::static_empty())
83 }
84}
85
86impl GetExtensions for CustomRequest {
87 fn extensions(&self) -> &Extensions {
88 &self.extensions
89 }
90 fn extensions_mut(&mut self) -> &mut Extensions {
91 &mut self.extensions
92 }
93}
94
95impl GetMeta for CustomRequest {
96 fn get_meta_mut(&mut self) -> &mut Meta {
97 self.extensions_mut().get_or_insert_default()
98 }
99 fn get_meta(&self) -> &Meta {
100 self.extensions()
101 .get::<Meta>()
102 .unwrap_or(Meta::static_empty())
103 }
104}
105
106macro_rules! variant_extension {
107 (
108 $Enum: ident {
109 $($variant: ident)*
110 }
111 ) => {
112 impl GetExtensions for $Enum {
113 fn extensions(&self) -> &Extensions {
114 match self {
115 $(
116 $Enum::$variant(v) => &v.extensions,
117 )*
118 }
119 }
120 fn extensions_mut(&mut self) -> &mut Extensions {
121 match self {
122 $(
123 $Enum::$variant(v) => &mut v.extensions,
124 )*
125 }
126 }
127 }
128 impl GetMeta for $Enum {
129 fn get_meta_mut(&mut self) -> &mut Meta {
130 self.extensions_mut().get_or_insert_default()
131 }
132 fn get_meta(&self) -> &Meta {
133 self.extensions().get::<Meta>().unwrap_or(Meta::static_empty())
134 }
135 }
136 };
137}
138
139variant_extension! {
140 ClientRequest {
141 PingRequest
142 InitializeRequest
143 CompleteRequest
144 SetLevelRequest
145 GetPromptRequest
146 ListPromptsRequest
147 ListResourcesRequest
148 ListResourceTemplatesRequest
149 ReadResourceRequest
150 SubscribeRequest
151 UnsubscribeRequest
152 CallToolRequest
153 ListToolsRequest
154 CustomRequest
155 GetTaskInfoRequest
156 ListTasksRequest
157 GetTaskResultRequest
158 CancelTaskRequest
159 }
160}
161
162variant_extension! {
163 ServerRequest {
164 PingRequest
165 CreateMessageRequest
166 ListRootsRequest
167 CreateElicitationRequest
168 CustomRequest
169 }
170}
171
172variant_extension! {
173 ClientNotification {
174 CancelledNotification
175 ProgressNotification
176 InitializedNotification
177 RootsListChangedNotification
178 CustomNotification
179 }
180}
181
182variant_extension! {
183 ServerNotification {
184 CancelledNotification
185 ProgressNotification
186 LoggingMessageNotification
187 ResourceUpdatedNotification
188 ResourceListChangedNotification
189 ToolListChangedNotification
190 PromptListChangedNotification
191 CustomNotification
192 }
193}
194#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
195#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
196#[serde(transparent)]
197pub struct Meta(pub JsonObject);
198const PROGRESS_TOKEN_FIELD: &str = "progressToken";
199impl Meta {
200 pub fn new() -> Self {
201 Self(JsonObject::new())
202 }
203
204 pub fn with_progress_token(token: ProgressToken) -> Self {
206 let mut meta = Self::new();
207 meta.set_progress_token(token);
208 meta
209 }
210
211 pub(crate) fn static_empty() -> &'static Self {
212 static EMPTY: std::sync::OnceLock<Meta> = std::sync::OnceLock::new();
213 EMPTY.get_or_init(Default::default)
214 }
215
216 pub fn get_progress_token(&self) -> Option<ProgressToken> {
217 self.0.get(PROGRESS_TOKEN_FIELD).and_then(|v| match v {
218 Value::String(s) => Some(ProgressToken(NumberOrString::String(s.to_string().into()))),
219 Value::Number(n) => {
220 if let Some(i) = n.as_i64() {
221 Some(ProgressToken(NumberOrString::Number(i)))
222 } else if let Some(u) = n.as_u64() {
223 if u <= i64::MAX as u64 {
224 Some(ProgressToken(NumberOrString::Number(u as i64)))
225 } else {
226 None
227 }
228 } else {
229 None
230 }
231 }
232 _ => None,
233 })
234 }
235
236 pub fn set_progress_token(&mut self, token: ProgressToken) {
237 match token.0 {
238 NumberOrString::String(ref s) => self.0.insert(
239 PROGRESS_TOKEN_FIELD.to_string(),
240 Value::String(s.to_string()),
241 ),
242 NumberOrString::Number(n) => self
243 .0
244 .insert(PROGRESS_TOKEN_FIELD.to_string(), Value::Number(n.into())),
245 };
246 }
247
248 pub fn extend(&mut self, other: Meta) {
249 for (k, v) in other.0.into_iter() {
250 self.0.insert(k, v);
251 }
252 }
253}
254
255impl Deref for Meta {
256 type Target = JsonObject;
257
258 fn deref(&self) -> &Self::Target {
259 &self.0
260 }
261}
262
263impl DerefMut for Meta {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269impl<Req, Resp, Noti> JsonRpcMessage<Req, Resp, Noti>
270where
271 Req: GetExtensions,
272 Noti: GetExtensions,
273{
274 pub fn insert_extension<T: Clone + Send + Sync + 'static>(&mut self, value: T) {
275 match self {
276 JsonRpcMessage::Request(json_rpc_request) => {
277 json_rpc_request.request.extensions_mut().insert(value);
278 }
279 JsonRpcMessage::Notification(json_rpc_notification) => {
280 json_rpc_notification
281 .notification
282 .extensions_mut()
283 .insert(value);
284 }
285 _ => {}
286 }
287 }
288}