1use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
7use thiserror::Error;
8
9pub trait Codec<T>: Send + Sync
11where
12 T: Send + Sync,
13{
14 fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError>;
16
17 fn decode(&self, data: &[u8]) -> Result<T, CodecError>;
19
20 fn content_type(&self) -> &'static str;
22}
23
24#[derive(Debug, Error)]
26pub enum CodecError {
27 #[error("Serialization failed: {0}")]
28 SerializationFailed(String),
29
30 #[error("Deserialization failed: {0}")]
31 DeserializationFailed(String),
32
33 #[error("Compression failed: {0}")]
34 CompressionFailed(String),
35
36 #[error("Decompression failed: {0}")]
37 DecompressionFailed(String),
38
39 #[error("Compression not supported: {0}")]
40 CompressionNotSupported(String),
41}
42
43pub struct JsonCodec;
45
46impl JsonCodec {
47 pub fn new() -> Self {
48 Self
49 }
50}
51
52impl<T> Codec<T> for JsonCodec
53where
54 T: SerdeSerialize + for<'de> SerdeDeserialize<'de> + Clone + Send + Sync,
55{
56 fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
57 serde_json::to_vec(message).map_err(|e| CodecError::SerializationFailed(e.to_string()))
58 }
59
60 fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
61 serde_json::from_slice(data).map_err(|e| CodecError::DeserializationFailed(e.to_string()))
62 }
63
64 fn content_type(&self) -> &'static str {
65 "application/json"
66 }
67}
68
69pub struct RkyvCodec;
71
72impl RkyvCodec {
73 pub fn new() -> Self {
74 Self
75 }
76}
77
78impl<T> Codec<T> for RkyvCodec
80where
81 T: SerdeSerialize + for<'de> SerdeDeserialize<'de> + Clone + Send + Sync,
82{
83 fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
84 serde_json::to_vec(message).map_err(|e| CodecError::SerializationFailed(e.to_string()))
87 }
88
89 fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
90 serde_json::from_slice(data).map_err(|e| CodecError::DeserializationFailed(e.to_string()))
92 }
93
94 fn content_type(&self) -> &'static str {
95 "application/rkyv"
96 }
97}
98
99pub struct HybridCodec {
101 rkyv_codec: RkyvCodec,
102 json_codec: JsonCodec,
103}
104
105impl HybridCodec {
106 pub fn new() -> Result<Self, CodecError> {
107 Ok(Self {
108 rkyv_codec: RkyvCodec::new(),
109 json_codec: JsonCodec::new(),
110 })
111 }
112}
113
114impl<T> Codec<T> for HybridCodec
115where
116 T: SerdeSerialize + for<'de> SerdeDeserialize<'de> + Clone + Send + Sync,
117{
118 fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
119 match self.rkyv_codec.encode(message) {
121 Ok(data) => Ok(data),
122 Err(_) => {
123 self.json_codec.encode(message)
125 }
126 }
127 }
128
129 fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
130 match self.rkyv_codec.decode(data) {
132 Ok(result) => Ok(result),
133 Err(_) => {
134 self.json_codec.decode(data)
136 }
137 }
138 }
139
140 fn content_type(&self) -> &'static str {
141 "application/hybrid"
142 }
143}
144
145#[derive(Debug, Clone, SerdeSerialize, SerdeDeserialize)]
147pub struct WsMessage<T> {
148 pub data: T,
149}
150
151impl<T> WsMessage<T> {
152 pub fn new(data: T) -> Self {
153 Self { data }
154 }
155}
156
157pub struct CompressedCodec<C> {
159 inner: C,
160 compression_level: i32,
161}
162
163impl<C> CompressedCodec<C> {
164 pub fn new(inner: C) -> Self {
165 Self {
166 inner,
167 compression_level: 3, }
169 }
170
171 pub fn with_level(inner: C, level: i32) -> Self {
172 Self {
173 inner,
174 compression_level: level,
175 }
176 }
177}
178
179impl<T, C> Codec<T> for CompressedCodec<C>
180where
181 C: Codec<T>,
182 T: Send + Sync,
183{
184 fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
185 let uncompressed = self.inner.encode(message)?;
187
188 #[cfg(feature = "compression")]
190 {
191 use std::io::Write;
192 let mut encoder = flate2::write::GzEncoder::new(
193 Vec::new(),
194 flate2::Compression::new(self.compression_level as u32),
195 );
196 encoder
197 .write_all(&uncompressed)
198 .map_err(|e| CodecError::CompressionFailed(e.to_string()))?;
199 encoder
200 .finish()
201 .map_err(|e| CodecError::CompressionFailed(e.to_string()))
202 }
203
204 #[cfg(not(feature = "compression"))]
205 {
206 Ok(uncompressed)
208 }
209 }
210
211 fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
212 #[cfg(feature = "compression")]
214 let decompressed = {
215 use std::io::Read;
216 let mut decoder = flate2::read::GzDecoder::new(data);
217 let mut decompressed = Vec::new();
218 decoder
219 .read_to_end(&mut decompressed)
220 .map_err(|e| CodecError::DecompressionFailed(e.to_string()))?;
221 decompressed
222 };
223
224 #[cfg(not(feature = "compression"))]
225 let decompressed = data.to_vec();
226
227 self.inner.decode(&decompressed)
229 }
230
231 fn content_type(&self) -> &'static str {
232 "application/gzip"
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use serde::{Deserialize, Serialize};
241
242 #[derive(
243 Debug,
244 Clone,
245 Serialize,
246 Deserialize,
247 PartialEq,
248 rkyv::Archive,
249 rkyv::Serialize,
250 rkyv::Deserialize,
251 )]
252 struct TestMessage {
253 id: u32,
254 content: String,
255 }
256
257 #[test]
258 fn test_json_codec_basic() {
259 let codec = JsonCodec::new();
260 let message = TestMessage {
261 id: 42,
262 content: "Hello, World!".to_string(),
263 };
264
265 let encoded = codec.encode(&message).unwrap();
266 let decoded = codec.decode(&encoded).unwrap();
267
268 assert_eq!(message, decoded);
269 }
270
271 #[test]
272 fn test_rkyv_codec_basic() {
273 let codec = RkyvCodec::new();
274 let message = TestMessage {
275 id: 42,
276 content: "Hello, World!".to_string(),
277 };
278
279 let encoded = codec.encode(&message).unwrap();
280 let decoded = codec.decode(&encoded).unwrap();
281
282 assert_eq!(message, decoded);
283 }
284
285 #[test]
286 fn test_hybrid_codec_basic() {
287 let codec = HybridCodec::new().unwrap();
288 let message = TestMessage {
289 id: 42,
290 content: "Hello, World!".to_string(),
291 };
292
293 let encoded = codec.encode(&message).unwrap();
294 let decoded = codec.decode(&encoded).unwrap();
295
296 assert_eq!(message, decoded);
297 }
298
299 #[test]
300 fn test_ws_message_wrapper() {
301 let test_data = TestMessage {
302 id: 42,
303 content: "test".to_string(),
304 };
305
306 let ws_message = WsMessage::new(test_data.clone());
307 assert_eq!(ws_message.data, test_data);
308 }
309
310 #[test]
311 fn test_ws_message_serialization() {
312 let test_data = TestMessage {
313 id: 42,
314 content: "test".to_string(),
315 };
316
317 let ws_message = WsMessage::new(test_data.clone());
318
319 let json_encoded = serde_json::to_string(&ws_message).unwrap();
321 let json_decoded: WsMessage<TestMessage> = serde_json::from_str(&json_encoded).unwrap();
322 assert_eq!(ws_message.data, json_decoded.data);
323 }
324
325 #[test]
326 fn test_rkyv_performance() {
327 let codec = RkyvCodec::new();
328 let message = TestMessage {
329 id: 42,
330 content: "Hello, rkyv!".to_string(),
331 };
332
333 let rkyv_encoded = codec.encode(&message).unwrap();
335 let rkyv_decoded = codec.decode(&rkyv_encoded).unwrap();
336 assert_eq!(message, rkyv_decoded);
337
338 let json_codec = JsonCodec::new();
340 let json_encoded = json_codec.encode(&message).unwrap();
341
342 assert!(rkyv_encoded.len() <= json_encoded.len());
344
345 println!("JSON size: {} bytes", json_encoded.len());
346 println!("rkyv size: {} bytes", rkyv_encoded.len());
347 }
348}