Skip to main content

agenterra_rmcp/model/
meta.rs

1use std::ops::{Deref, DerefMut};
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use super::{
7    ClientNotification, ClientRequest, Extensions, JsonObject, JsonRpcMessage, NumberOrString,
8    ProgressToken, ServerNotification, ServerRequest,
9};
10
11pub trait GetMeta {
12    fn get_meta_mut(&mut self) -> &mut Meta;
13    fn get_meta(&self) -> &Meta;
14}
15
16pub trait GetExtensions {
17    fn extensions(&self) -> &Extensions;
18    fn extensions_mut(&mut self) -> &mut Extensions;
19}
20
21macro_rules! variant_extension {
22    (
23        $Enum: ident {
24            $($variant: ident)*
25        }
26    ) => {
27        impl GetExtensions for $Enum {
28            fn extensions(&self) -> &Extensions {
29                match self {
30                    $(
31                        $Enum::$variant(v) => &v.extensions,
32                    )*
33                }
34            }
35            fn extensions_mut(&mut self) -> &mut Extensions {
36                match self {
37                    $(
38                        $Enum::$variant(v) => &mut v.extensions,
39                    )*
40                }
41            }
42        }
43        impl GetMeta for $Enum {
44            fn get_meta_mut(&mut self) -> &mut Meta {
45                self.extensions_mut().get_or_insert_default()
46            }
47            fn get_meta(&self) -> &Meta {
48                self.extensions().get::<Meta>().unwrap_or(Meta::static_empty())
49            }
50        }
51    };
52}
53
54variant_extension! {
55    ClientRequest {
56        PingRequest
57        InitializeRequest
58        CompleteRequest
59        SetLevelRequest
60        GetPromptRequest
61        ListPromptsRequest
62        ListResourcesRequest
63        ListResourceTemplatesRequest
64        ReadResourceRequest
65        SubscribeRequest
66        UnsubscribeRequest
67        CallToolRequest
68        ListToolsRequest
69    }
70}
71
72variant_extension! {
73    ServerRequest {
74        PingRequest
75        CreateMessageRequest
76        ListRootsRequest
77    }
78}
79
80variant_extension! {
81    ClientNotification {
82        CancelledNotification
83        ProgressNotification
84        InitializedNotification
85        RootsListChangedNotification
86    }
87}
88
89variant_extension! {
90    ServerNotification {
91        CancelledNotification
92        ProgressNotification
93        LoggingMessageNotification
94        ResourceUpdatedNotification
95        ResourceListChangedNotification
96        ToolListChangedNotification
97        PromptListChangedNotification
98    }
99}
100#[derive(Debug, Serialize, Deserialize, Clone, Default)]
101#[serde(transparent)]
102pub struct Meta(pub JsonObject);
103const PROGRESS_TOKEN_FIELD: &str = "progressToken";
104impl Meta {
105    pub fn new() -> Self {
106        Self(JsonObject::new())
107    }
108
109    pub(crate) fn static_empty() -> &'static Self {
110        static EMPTY: std::sync::OnceLock<Meta> = std::sync::OnceLock::new();
111        EMPTY.get_or_init(Default::default)
112    }
113
114    pub fn get_progress_token(&self) -> Option<ProgressToken> {
115        self.0.get(PROGRESS_TOKEN_FIELD).and_then(|v| match v {
116            Value::String(s) => Some(ProgressToken(NumberOrString::String(s.to_string().into()))),
117            Value::Number(n) => n
118                .as_u64()
119                .map(|n| ProgressToken(NumberOrString::Number(n as u32))),
120            _ => None,
121        })
122    }
123
124    pub fn set_progress_token(&mut self, token: ProgressToken) {
125        match token.0 {
126            NumberOrString::String(ref s) => self.0.insert(
127                PROGRESS_TOKEN_FIELD.to_string(),
128                Value::String(s.to_string()),
129            ),
130            NumberOrString::Number(n) => self
131                .0
132                .insert(PROGRESS_TOKEN_FIELD.to_string(), Value::Number(n.into())),
133        };
134    }
135
136    pub fn extend(&mut self, other: Meta) {
137        for (k, v) in other.0.into_iter() {
138            self.0.insert(k, v);
139        }
140    }
141}
142
143impl Deref for Meta {
144    type Target = JsonObject;
145
146    fn deref(&self) -> &Self::Target {
147        &self.0
148    }
149}
150
151impl DerefMut for Meta {
152    fn deref_mut(&mut self) -> &mut Self::Target {
153        &mut self.0
154    }
155}
156
157impl<Req, Resp, Noti> JsonRpcMessage<Req, Resp, Noti>
158where
159    Req: GetExtensions,
160    Noti: GetExtensions,
161{
162    pub fn insert_extension<T: Clone + Send + Sync + 'static>(&mut self, value: T) {
163        match self {
164            JsonRpcMessage::Request(json_rpc_request) => {
165                json_rpc_request.request.extensions_mut().insert(value);
166            }
167            JsonRpcMessage::Notification(json_rpc_notification) => {
168                json_rpc_notification
169                    .notification
170                    .extensions_mut()
171                    .insert(value);
172            }
173            JsonRpcMessage::BatchRequest(json_rpc_batch_request_items) => {
174                for item in json_rpc_batch_request_items {
175                    match item {
176                        super::JsonRpcBatchRequestItem::Request(json_rpc_request) => {
177                            json_rpc_request
178                                .request
179                                .extensions_mut()
180                                .insert(value.clone());
181                        }
182                        super::JsonRpcBatchRequestItem::Notification(json_rpc_notification) => {
183                            json_rpc_notification
184                                .notification
185                                .extensions_mut()
186                                .insert(value.clone());
187                        }
188                    }
189                }
190            }
191            _ => {}
192        }
193    }
194}