mcp_attr/
utils.rs

1//! Types used in the MCP protocol that are not defined in the schema
2
3use std::{borrow::Cow, marker::PhantomData};
4
5use base64::Engine;
6use parse_display::Display;
7use serde::{Deserialize, Deserializer, Serialize};
8use serde_json::{Map, Value};
9
10/// Type for handling byte sequences as Base64-encoded strings
11///
12/// This type is used when you want to handle a byte sequence as a Base64-encoded string in JSON serialization,
13/// and then convert it back to a byte sequence when deserializing.
14///
15/// # Example
16///
17/// ```
18/// use mcp_attr::utils::Base64Bytes;
19/// use serde_json::json;
20///
21/// let bytes = Base64Bytes(vec![1, 2, 3, 4, 5]);
22/// let json = json!(bytes);
23/// assert_eq!(json, json!("AQIDBAU="));
24///
25/// let bytes: Base64Bytes = serde_json::from_value(json).unwrap();
26/// assert_eq!(bytes.0, vec![1, 2, 3, 4, 5]);
27/// ```
28#[derive(Clone, Default, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
29pub struct Base64Bytes(pub Vec<u8>);
30
31impl Serialize for Base64Bytes {
32    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
33    where
34        S: serde::Serializer,
35    {
36        let s = base64::prelude::BASE64_STANDARD.encode(&self.0);
37        serializer.serialize_str(&s)
38    }
39}
40
41impl<'de> Deserialize<'de> for Base64Bytes {
42    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
43    where
44        D: Deserializer<'de>,
45    {
46        let s: Cow<'de, str> = Deserialize::deserialize(deserializer)?;
47        base64::prelude::BASE64_STANDARD
48            .decode(&*s)
49            .map_err(serde::de::Error::custom)
50            .map(Base64Bytes)
51    }
52}
53
54/// Type representing an empty JSON object
55///
56/// This type is used when you want to output an empty JSON object `{}` in JSON serialization,
57/// and accept any JSON object when deserializing, but its content is ignored.
58///
59/// # Example
60///
61/// ```
62/// use mcp_attr::utils::Empty;
63/// use serde_json::json;
64///
65/// let empty = Empty::default();
66/// let json = json!(empty);
67/// assert_eq!(json, json!({}));
68///
69/// let empty: Empty = serde_json::from_value(json!({ "key": "value" })).unwrap();
70/// let json = json!(empty);
71/// assert_eq!(json, json!({}));
72/// ```
73#[derive(Serialize, Default)]
74#[serde(transparent)]
75pub struct Empty(#[allow(unused)] Map<String, Value>);
76
77impl<'de> Deserialize<'de> for Empty {
78    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79    where
80        D: Deserializer<'de>,
81    {
82        let _: Map<String, Value> = Deserialize::deserialize(deserializer)?;
83        Ok(Empty::default())
84    }
85}
86
87/// Type representing a tag string associated with a type
88///
89/// This type is used when you want to output a tag string associated with a type in JSON serialization,
90/// and check if the tag string matches when deserializing.
91///
92/// The tag string is specified by the `TAG` constant of the `TagData` trait.
93///
94/// # Example
95///
96/// ```
97/// use mcp_attr::utils::{Tag, TagData};
98/// use serde_json::json;
99///
100/// #[derive(Default)]
101/// struct MyTag;
102///
103/// impl TagData for MyTag {
104///     const TAG: &'static str = "my-tag";
105/// }
106///
107/// let tag = Tag(MyTag::default());
108/// let json = serde_json::to_value(&tag).unwrap();
109/// assert_eq!(json, json!("my-tag"));
110///
111/// let tag: Tag<MyTag> = serde_json::from_value(json).unwrap();
112/// ```
113pub struct Tag<T>(pub T);
114
115/// Trait for specifying a tag string associated with a type
116///
117/// This trait is used to specify the tag string for use with the `Tag` type.
118pub trait TagData: Default {
119    /// Tag string
120    const TAG: &'static str;
121}
122
123impl<T: TagData> Serialize for Tag<T> {
124    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
125    where
126        S: serde::Serializer,
127    {
128        serializer.serialize_str(T::TAG)
129    }
130}
131
132impl<'de, T: TagData> Deserialize<'de> for Tag<T> {
133    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134    where
135        D: Deserializer<'de>,
136    {
137        let s: Cow<'de, str> = Deserialize::deserialize(deserializer)?;
138        if s != T::TAG {
139            return Err(serde::de::Error::custom(format!("expected tag {}", T::TAG)));
140        }
141        Ok(Tag(T::default()))
142    }
143}
144
145#[derive(
146    Serialize, Deserialize, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Display, Clone, Copy,
147)]
148pub struct ProtocolVersion(&'static str);
149
150impl ProtocolVersion {
151    pub const LATEST: Self = Self::V_2025_03_26;
152    pub const V_2024_11_05: Self = Self("2024-11-05");
153    pub const V_2025_03_26: Self = Self("2025-03-26");
154
155    pub fn as_str(&self) -> &'static str {
156        self.0
157    }
158}
159
160#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Display)]
161#[display("{json}")]
162pub struct Json<T> {
163    json: String,
164    _marker: PhantomData<T>,
165}
166
167impl<T: Serialize> Json<T> {
168    pub fn from(value: &T) -> Result<Self, serde_json::Error> {
169        Ok(Self {
170            json: serde_json::to_string_pretty(value)?,
171            _marker: PhantomData,
172        })
173    }
174}
175impl<T> Json<T> {
176    pub fn into_string(self) -> String {
177        self.json
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use serde_json::json;
185
186    #[test]
187    fn test_base64_bytes() {
188        let bytes = Base64Bytes(vec![1, 2, 3, 4, 5]);
189        let json = json!(bytes);
190        assert_eq!(json, json!("AQIDBAU="));
191
192        let bytes: Base64Bytes = serde_json::from_value(json).unwrap();
193        assert_eq!(bytes.0, vec![1, 2, 3, 4, 5]);
194    }
195}