1use gbp::CodecError;
4use gbp_core::PayloadCodec;
5use serde::{Deserialize, Serialize};
6use serde_bytes::ByteBuf;
7
8#[repr(u8)]
10#[derive(Copy, Clone, Debug, PartialEq, Eq)]
11pub enum GtpContentType {
12 Plain = 0,
14 Markdown = 1,
16 Binary = 2,
18 AttachmentRef = 3,
20}
21
22#[derive(Clone, Debug, Serialize, Deserialize)]
24pub struct GtpMessage {
25 #[serde(rename = "mid")]
27 pub message_id: u64,
28 #[serde(rename = "sid")]
30 pub sender_id: u32,
31 #[serde(rename = "ts")]
33 pub timestamp_ms: u64,
34 #[serde(rename = "rid")]
36 pub request_id: u32,
37 #[serde(rename = "fl")]
39 pub flags: u8,
40 #[serde(rename = "ct")]
42 pub content_type: u8,
43 #[serde(rename = "len")]
45 pub content_length: u32,
46 #[serde(rename = "body")]
48 pub content: ByteBuf,
49}
50
51impl GtpMessage {
52 pub fn plain(sender_id: u32, message_id: u64, text: &str) -> Self {
54 let body = text.as_bytes().to_vec();
55 Self {
56 message_id,
57 sender_id,
58 timestamp_ms: 0,
59 request_id: 0,
60 flags: 0x01,
61 content_type: GtpContentType::Plain as u8,
62 content_length: body.len() as u32,
63 content: ByteBuf::from(body),
64 }
65 }
66
67 pub fn to_cbor(&self) -> Vec<u8> {
69 let mut buf = Vec::new();
70 ciborium::into_writer(self, &mut buf).expect("cbor encode");
71 buf
72 }
73
74 pub fn from_cbor(data: &[u8]) -> Result<Self, CodecError> {
76 let m: Self = ciborium::from_reader(data).map_err(|e| CodecError::Decode(e.to_string()))?;
77 if m.content_length as usize != m.content.len() {
78 return Err(CodecError::PayloadSizeMismatch);
79 }
80 Ok(m)
81 }
82
83 pub fn text(&self) -> Option<&str> {
85 std::str::from_utf8(&self.content).ok()
86 }
87
88 pub fn to_bytes(&self, codec: PayloadCodec) -> Vec<u8> {
90 match codec {
91 PayloadCodec::Cbor => self.to_cbor(),
92 PayloadCodec::Protobuf => {
93 use prost::Message as _;
94 gbp_proto::gtp::GtpMessage::from(self).encode_to_vec()
95 }
96 PayloadCodec::FlatBuffers => {
97 let mut b = gbp_flat::planus::Builder::new();
98 b.finish(gbp_flat::gtp::GtpMessage::from(self), None).to_vec()
99 }
100 }
101 }
102
103 pub fn from_bytes(data: &[u8], codec: PayloadCodec) -> Result<Self, CodecError> {
105 match codec {
106 PayloadCodec::Cbor => Self::from_cbor(data),
107 PayloadCodec::Protobuf => {
108 use prost::Message as _;
109 let p = gbp_proto::gtp::GtpMessage::decode(data)
110 .map_err(|e| CodecError::Decode(e.to_string()))?;
111 Self::try_from(p).map_err(|_| CodecError::PayloadSizeMismatch)
112 }
113 PayloadCodec::FlatBuffers => {
114 use gbp_flat::planus::ReadAsRoot as _;
115 let r = gbp_flat::gtp::GtpMessageRef::read_as_root(data)
116 .map_err(|e| CodecError::Decode(e.to_string()))?;
117 Self::try_from(r).map_err(|_| CodecError::PayloadSizeMismatch)
118 }
119 }
120 }
121}
122
123impl From<&GtpMessage> for gbp_proto::gtp::GtpMessage {
126 fn from(m: &GtpMessage) -> Self {
127 Self {
128 message_id: m.message_id,
129 sender_id: m.sender_id,
130 timestamp_ms: m.timestamp_ms,
131 request_id: m.request_id,
132 flags: m.flags as u32,
133 content_type: m.content_type as u32,
134 content_length: m.content_length,
135 content: m.content.to_vec(),
136 }
137 }
138}
139
140impl TryFrom<gbp_proto::gtp::GtpMessage> for GtpMessage {
141 type Error = ();
142 fn try_from(p: gbp_proto::gtp::GtpMessage) -> Result<Self, ()> {
143 if p.content_length as usize != p.content.len() {
144 return Err(());
145 }
146 Ok(Self {
147 message_id: p.message_id,
148 sender_id: p.sender_id,
149 timestamp_ms: p.timestamp_ms,
150 request_id: p.request_id,
151 flags: p.flags as u8,
152 content_type: p.content_type as u8,
153 content_length: p.content_length,
154 content: ByteBuf::from(p.content),
155 })
156 }
157}
158
159impl From<&GtpMessage> for gbp_flat::gtp::GtpMessage {
162 fn from(m: &GtpMessage) -> Self {
163 Self {
164 message_id: m.message_id,
165 sender_id: m.sender_id,
166 timestamp_ms: m.timestamp_ms,
167 request_id: m.request_id,
168 flags: m.flags as u32,
169 content_type: m.content_type as u32,
170 content_length: m.content_length,
171 content: Some(m.content.to_vec()),
172 }
173 }
174}
175
176impl<'a> TryFrom<gbp_flat::gtp::GtpMessageRef<'a>> for GtpMessage {
177 type Error = ();
178 fn try_from(r: gbp_flat::gtp::GtpMessageRef<'a>) -> Result<Self, ()> {
179 let content = r.content().map_err(|_| ())?.unwrap_or(&[]).to_vec();
180 let content_length = r.content_length().map_err(|_| ())?;
181 if content_length as usize != content.len() {
182 return Err(());
183 }
184 Ok(Self {
185 message_id: r.message_id().map_err(|_| ())?,
186 sender_id: r.sender_id().map_err(|_| ())?,
187 timestamp_ms: r.timestamp_ms().map_err(|_| ())?,
188 request_id: r.request_id().map_err(|_| ())?,
189 flags: r.flags().map_err(|_| ())? as u8,
190 content_type: r.content_type().map_err(|_| ())? as u8,
191 content_length,
192 content: ByteBuf::from(content),
193 })
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 fn sample() -> GtpMessage {
202 GtpMessage::plain(42, 0xDEAD_BEEF, "codec roundtrip")
203 }
204
205 #[test]
206 fn cbor_roundtrip() {
207 let orig = sample();
208 let bytes = orig.to_bytes(PayloadCodec::Cbor);
209 let decoded = GtpMessage::from_bytes(&bytes, PayloadCodec::Cbor).unwrap();
210 assert_eq!(decoded.message_id, orig.message_id);
211 assert_eq!(decoded.sender_id, orig.sender_id);
212 assert_eq!(decoded.text().unwrap(), "codec roundtrip");
213 }
214
215 #[test]
216 fn protobuf_roundtrip() {
217 let orig = sample();
218 let bytes = orig.to_bytes(PayloadCodec::Protobuf);
219 let decoded = GtpMessage::from_bytes(&bytes, PayloadCodec::Protobuf).unwrap();
220 assert_eq!(decoded.message_id, orig.message_id);
221 assert_eq!(decoded.sender_id, orig.sender_id);
222 assert_eq!(decoded.text().unwrap(), "codec roundtrip");
223 }
224
225 #[test]
226 fn flatbuffers_roundtrip() {
227 let orig = sample();
228 let bytes = orig.to_bytes(PayloadCodec::FlatBuffers);
229 let decoded = GtpMessage::from_bytes(&bytes, PayloadCodec::FlatBuffers).unwrap();
230 assert_eq!(decoded.message_id, orig.message_id);
231 assert_eq!(decoded.sender_id, orig.sender_id);
232 assert_eq!(decoded.text().unwrap(), "codec roundtrip");
233 }
234
235 #[test]
236 fn codec_bytes_differ() {
237 let msg = sample();
238 let cbor = msg.to_bytes(PayloadCodec::Cbor);
239 let proto = msg.to_bytes(PayloadCodec::Protobuf);
240 let flat = msg.to_bytes(PayloadCodec::FlatBuffers);
241 assert_ne!(cbor, proto);
242 assert_ne!(cbor, flat);
243 assert_ne!(proto, flat);
244 }
245}