agenterra_rmcp/model/
meta.rs1use std::ops::{Deref, DerefMut};
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use super::{
7 ClientNotification, ClientRequest, Extensions, JsonObject, JsonRpcMessage, NumberOrString,
8 ProgressToken, ServerNotification, ServerRequest,
9};
10
11pub trait GetMeta {
12 fn get_meta_mut(&mut self) -> &mut Meta;
13 fn get_meta(&self) -> &Meta;
14}
15
16pub trait GetExtensions {
17 fn extensions(&self) -> &Extensions;
18 fn extensions_mut(&mut self) -> &mut Extensions;
19}
20
21macro_rules! variant_extension {
22 (
23 $Enum: ident {
24 $($variant: ident)*
25 }
26 ) => {
27 impl GetExtensions for $Enum {
28 fn extensions(&self) -> &Extensions {
29 match self {
30 $(
31 $Enum::$variant(v) => &v.extensions,
32 )*
33 }
34 }
35 fn extensions_mut(&mut self) -> &mut Extensions {
36 match self {
37 $(
38 $Enum::$variant(v) => &mut v.extensions,
39 )*
40 }
41 }
42 }
43 impl GetMeta for $Enum {
44 fn get_meta_mut(&mut self) -> &mut Meta {
45 self.extensions_mut().get_or_insert_default()
46 }
47 fn get_meta(&self) -> &Meta {
48 self.extensions().get::<Meta>().unwrap_or(Meta::static_empty())
49 }
50 }
51 };
52}
53
54variant_extension! {
55 ClientRequest {
56 PingRequest
57 InitializeRequest
58 CompleteRequest
59 SetLevelRequest
60 GetPromptRequest
61 ListPromptsRequest
62 ListResourcesRequest
63 ListResourceTemplatesRequest
64 ReadResourceRequest
65 SubscribeRequest
66 UnsubscribeRequest
67 CallToolRequest
68 ListToolsRequest
69 }
70}
71
72variant_extension! {
73 ServerRequest {
74 PingRequest
75 CreateMessageRequest
76 ListRootsRequest
77 }
78}
79
80variant_extension! {
81 ClientNotification {
82 CancelledNotification
83 ProgressNotification
84 InitializedNotification
85 RootsListChangedNotification
86 }
87}
88
89variant_extension! {
90 ServerNotification {
91 CancelledNotification
92 ProgressNotification
93 LoggingMessageNotification
94 ResourceUpdatedNotification
95 ResourceListChangedNotification
96 ToolListChangedNotification
97 PromptListChangedNotification
98 }
99}
100#[derive(Debug, Serialize, Deserialize, Clone, Default)]
101#[serde(transparent)]
102pub struct Meta(pub JsonObject);
103const PROGRESS_TOKEN_FIELD: &str = "progressToken";
104impl Meta {
105 pub fn new() -> Self {
106 Self(JsonObject::new())
107 }
108
109 pub(crate) fn static_empty() -> &'static Self {
110 static EMPTY: std::sync::OnceLock<Meta> = std::sync::OnceLock::new();
111 EMPTY.get_or_init(Default::default)
112 }
113
114 pub fn get_progress_token(&self) -> Option<ProgressToken> {
115 self.0.get(PROGRESS_TOKEN_FIELD).and_then(|v| match v {
116 Value::String(s) => Some(ProgressToken(NumberOrString::String(s.to_string().into()))),
117 Value::Number(n) => n
118 .as_u64()
119 .map(|n| ProgressToken(NumberOrString::Number(n as u32))),
120 _ => None,
121 })
122 }
123
124 pub fn set_progress_token(&mut self, token: ProgressToken) {
125 match token.0 {
126 NumberOrString::String(ref s) => self.0.insert(
127 PROGRESS_TOKEN_FIELD.to_string(),
128 Value::String(s.to_string()),
129 ),
130 NumberOrString::Number(n) => self
131 .0
132 .insert(PROGRESS_TOKEN_FIELD.to_string(), Value::Number(n.into())),
133 };
134 }
135
136 pub fn extend(&mut self, other: Meta) {
137 for (k, v) in other.0.into_iter() {
138 self.0.insert(k, v);
139 }
140 }
141}
142
143impl Deref for Meta {
144 type Target = JsonObject;
145
146 fn deref(&self) -> &Self::Target {
147 &self.0
148 }
149}
150
151impl DerefMut for Meta {
152 fn deref_mut(&mut self) -> &mut Self::Target {
153 &mut self.0
154 }
155}
156
157impl<Req, Resp, Noti> JsonRpcMessage<Req, Resp, Noti>
158where
159 Req: GetExtensions,
160 Noti: GetExtensions,
161{
162 pub fn insert_extension<T: Clone + Send + Sync + 'static>(&mut self, value: T) {
163 match self {
164 JsonRpcMessage::Request(json_rpc_request) => {
165 json_rpc_request.request.extensions_mut().insert(value);
166 }
167 JsonRpcMessage::Notification(json_rpc_notification) => {
168 json_rpc_notification
169 .notification
170 .extensions_mut()
171 .insert(value);
172 }
173 JsonRpcMessage::BatchRequest(json_rpc_batch_request_items) => {
174 for item in json_rpc_batch_request_items {
175 match item {
176 super::JsonRpcBatchRequestItem::Request(json_rpc_request) => {
177 json_rpc_request
178 .request
179 .extensions_mut()
180 .insert(value.clone());
181 }
182 super::JsonRpcBatchRequestItem::Notification(json_rpc_notification) => {
183 json_rpc_notification
184 .notification
185 .extensions_mut()
186 .insert(value.clone());
187 }
188 }
189 }
190 }
191 _ => {}
192 }
193 }
194}