1#[derive(Debug, Clone, PartialEq, Eq, Default)]
9pub struct Bytes(pub Vec<u8>);
10
11impl From<Vec<u8>> for Bytes {
12 fn from(v: Vec<u8>) -> Self {
13 Bytes(v)
14 }
15}
16
17impl From<Bytes> for Vec<u8> {
18 fn from(b: Bytes) -> Self {
19 b.0
20 }
21}
22
23impl AsRef<[u8]> for Bytes {
24 fn as_ref(&self) -> &[u8] {
25 &self.0
26 }
27}
28
29impl serde::Serialize for Bytes {
30 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
31 serializer.serialize_bytes(&self.0)
32 }
33}
34
35impl crate::response::IntoResponse for Bytes {
36 fn into_tool_events(self, default_language: &str) -> Vec<crate::context::RawToolEvent> {
37 self.0.into_tool_events(default_language)
39 }
40}
41
42impl<'de> serde::Deserialize<'de> for Bytes {
43 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
44 struct BytesVisitor;
45
46 impl serde::de::Visitor<'_> for BytesVisitor {
47 type Value = Bytes;
48
49 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
50 f.write_str("a CBOR byte string or a base64-encoded string")
51 }
52
53 fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<Bytes, E> {
54 Ok(Bytes(v.to_vec()))
55 }
56
57 fn visit_byte_buf<E: serde::de::Error>(self, v: Vec<u8>) -> Result<Bytes, E> {
58 Ok(Bytes(v))
59 }
60
61 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Bytes, E> {
62 use base64::Engine as _;
63 base64::engine::general_purpose::STANDARD
64 .decode(v)
65 .map(Bytes)
66 .map_err(|e| E::custom(format!("invalid base64 for Bytes: {e}")))
67 }
68
69 fn visit_string<E: serde::de::Error>(self, v: String) -> Result<Bytes, E> {
70 self.visit_str(&v)
71 }
72 }
73
74 deserializer.deserialize_any(BytesVisitor)
75 }
76}
77
78impl schemars::JsonSchema for Bytes {
79 fn schema_name() -> std::borrow::Cow<'static, str> {
80 "Bytes".into()
81 }
82
83 fn inline_schema() -> bool {
86 true
87 }
88
89 fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
90 schemars::json_schema!({
91 "type": "string",
92 "contentEncoding": "base64",
93 "contentMediaType": "application/octet-stream"
94 })
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 fn cbor_of(v: &ciborium::value::Value) -> Vec<u8> {
103 let mut buf = Vec::new();
104 ciborium::into_writer(v, &mut buf).unwrap();
105 buf
106 }
107
108 #[test]
109 fn serializes_to_cbor_byte_string() {
110 let mut buf = Vec::new();
111 ciborium::into_writer(&Bytes(b"hello".to_vec()), &mut buf).unwrap();
112 let value: ciborium::value::Value = ciborium::from_reader(&buf[..]).unwrap();
113 assert_eq!(value, ciborium::value::Value::Bytes(b"hello".to_vec()));
114 }
115
116 #[test]
117 fn deserializes_from_cbor_byte_string() {
118 let cbor = cbor_of(&ciborium::value::Value::Bytes(b"hi".to_vec()));
119 let b: Bytes = ciborium::from_reader(&cbor[..]).unwrap();
120 assert_eq!(b.0, b"hi");
121 }
122
123 #[test]
124 fn deserializes_from_base64_text() {
125 let cbor = cbor_of(&ciborium::value::Value::Text("aGVsbG8=".into()));
126 let b: Bytes = ciborium::from_reader(&cbor[..]).unwrap();
127 assert_eq!(b.0, b"hello");
128 }
129
130 #[test]
131 fn rejects_invalid_base64_text() {
132 let cbor = cbor_of(&ciborium::value::Value::Text("@@@".into()));
133 let r: Result<Bytes, _> = ciborium::from_reader(&cbor[..]);
134 assert!(r.is_err());
135 }
136
137 #[test]
138 fn rejects_unpadded_base64() {
139 let cbor = cbor_of(&ciborium::value::Value::Text("aGVsbG8".into()));
141 let r: Result<Bytes, _> = ciborium::from_reader(&cbor[..]);
142 assert!(r.is_err());
143 }
144
145 #[test]
146 fn bytes_response_is_octet_stream() {
147 use crate::response::IntoResponse;
148 let events = Bytes(b"\x89PNG".to_vec()).into_tool_events("en");
149 match events.into_iter().next().unwrap() {
150 crate::context::RawToolEvent::Content {
151 data, mime_type, ..
152 } => {
153 assert_eq!(
154 mime_type.as_deref(),
155 Some(crate::constants::MIME_OCTET_STREAM)
156 );
157 assert_eq!(data, b"\x89PNG");
158 }
159 _ => panic!("expected Content event"),
160 }
161 }
162
163 #[test]
164 fn schema_advertises_base64() {
165 let schema = schemars::schema_for!(Bytes);
166 let v = serde_json::to_value(&schema).unwrap();
167 assert_eq!(v.get("type").and_then(|x| x.as_str()), Some("string"));
168 assert_eq!(
169 v.get("contentEncoding").and_then(|x| x.as_str()),
170 Some("base64")
171 );
172 assert_eq!(
173 v.get("contentMediaType").and_then(|x| x.as_str()),
174 Some("application/octet-stream")
175 );
176 }
177
178 #[derive(serde::Deserialize, schemars::JsonSchema)]
179 struct Params {
180 data: Bytes,
181 }
182
183 #[test]
184 fn bytes_field_composes_with_derive() {
185 let schema = schemars::schema_for!(Params);
186 let v = serde_json::to_value(&schema).unwrap();
187 assert_eq!(v["properties"]["data"]["contentEncoding"], "base64");
188
189 let cbor = cbor_of(&ciborium::value::Value::Map(vec![(
190 ciborium::value::Value::Text("data".into()),
191 ciborium::value::Value::Bytes(b"hi".to_vec()),
192 )]));
193 let p: Params = ciborium::from_reader(&cbor[..]).unwrap();
194 assert_eq!(p.data.0, b"hi");
195 }
196}