leptos_ws_pro/codec/
mod.rs1use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
7use thiserror::Error;
8
9pub trait Codec<T>: Send + Sync
11where
12 T: Send + Sync,
13{
14 fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError>;
16
17 fn decode(&self, data: &[u8]) -> Result<T, CodecError>;
19
20 fn content_type(&self) -> &'static str;
22}
23
24#[derive(Debug, Error)]
26pub enum CodecError {
27 #[error("Serialization failed: {0}")]
28 SerializationFailed(String),
29
30 #[error("Deserialization failed: {0}")]
31 DeserializationFailed(String),
32
33 #[error("Compression failed: {0}")]
34 CompressionFailed(String),
35
36 #[error("Decompression failed: {0}")]
37 DecompressionFailed(String),
38
39 #[error("Compression not supported: {0}")]
40 CompressionNotSupported(String),
41}
42
43pub struct JsonCodec;
45
46impl JsonCodec {
47 pub fn new() -> Self {
48 Self
49 }
50}
51
52impl<T> Codec<T> for JsonCodec
53where
54 T: SerdeSerialize + for<'de> SerdeDeserialize<'de> + Clone + Send + Sync,
55{
56 fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
57 serde_json::to_vec(message).map_err(|e| CodecError::SerializationFailed(e.to_string()))
58 }
59
60 fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
61 serde_json::from_slice(data).map_err(|e| CodecError::DeserializationFailed(e.to_string()))
62 }
63
64 fn content_type(&self) -> &'static str {
65 "application/json"
66 }
67}
68
69pub struct RkyvCodec;
71
72impl RkyvCodec {
73 pub fn new() -> Self {
74 Self
75 }
76}
77
78impl<T> Codec<T> for RkyvCodec
80where
81 T: SerdeSerialize + for<'de> SerdeDeserialize<'de> + Clone + Send + Sync,
82{
83 fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
84 #[cfg(feature = "zero-copy")]
86 {
87 }
93
94 serde_json::to_vec(message).map_err(|e| CodecError::SerializationFailed(e.to_string()))
96 }
97
98 fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
99 #[cfg(feature = "zero-copy")]
101 {
102 }
108
109 serde_json::from_slice(data).map_err(|e| CodecError::DeserializationFailed(e.to_string()))
111 }
112
113 fn content_type(&self) -> &'static str {
114 "application/rkyv"
115 }
116}
117
118pub struct HybridCodec {
120 rkyv_codec: RkyvCodec,
121 json_codec: JsonCodec,
122}
123
124impl HybridCodec {
125 pub fn new() -> Result<Self, CodecError> {
126 Ok(Self {
127 rkyv_codec: RkyvCodec::new(),
128 json_codec: JsonCodec::new(),
129 })
130 }
131}
132
133impl<T> Codec<T> for HybridCodec
134where
135 T: SerdeSerialize + for<'de> SerdeDeserialize<'de> + Clone + Send + Sync,
136{
137 fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
138 match self.rkyv_codec.encode(message) {
140 Ok(data) => Ok(data),
141 Err(_) => {
142 self.json_codec.encode(message)
144 }
145 }
146 }
147
148 fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
149 match self.json_codec.decode(data) {
151 Ok(result) => Ok(result),
152 Err(_) => {
153 match self.rkyv_codec.decode(data) {
155 Ok(result) => Ok(result),
156 Err(_e) => {
157 self.json_codec.decode(data)
159 }
160 }
161 }
162 }
163 }
164
165 fn content_type(&self) -> &'static str {
166 "application/hybrid"
167 }
168}
169
170#[derive(Debug, Clone, SerdeSerialize, SerdeDeserialize)]
172pub struct WsMessage<T> {
173 pub data: T,
174}
175
176impl<T> WsMessage<T> {
177 pub fn new(data: T) -> Self {
178 Self { data }
179 }
180}
181
182pub struct CompressedCodec<C> {
184 inner: C,
185 compression_level: i32,
186}
187
188impl<C> CompressedCodec<C> {
189 pub fn new(inner: C) -> Self {
190 Self {
191 inner,
192 compression_level: 3, }
194 }
195
196 pub fn with_level(inner: C, level: i32) -> Self {
197 Self {
198 inner,
199 compression_level: level,
200 }
201 }
202}
203
204impl<T, C> Codec<T> for CompressedCodec<C>
205where
206 C: Codec<T>,
207 T: Send + Sync,
208{
209 fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
210 let uncompressed = self.inner.encode(message)?;
212
213 #[cfg(feature = "compression")]
215 {
216 use std::io::Write;
217 let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::new(self.compression_level as u32));
218 encoder.write_all(&uncompressed)
219 .map_err(|e| CodecError::CompressionFailed(e.to_string()))?;
220 encoder.finish()
221 .map_err(|e| CodecError::CompressionFailed(e.to_string()))
222 }
223
224 #[cfg(not(feature = "compression"))]
225 {
226 Ok(uncompressed)
228 }
229 }
230
231 fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
232 #[cfg(feature = "compression")]
234 let decompressed = {
235 use std::io::Read;
236 let mut decoder = flate2::read::GzDecoder::new(data);
237 let mut decompressed = Vec::new();
238 decoder.read_to_end(&mut decompressed)
239 .map_err(|e| CodecError::DecompressionFailed(e.to_string()))?;
240 decompressed
241 };
242
243 #[cfg(not(feature = "compression"))]
244 let decompressed = data.to_vec();
245
246 self.inner.decode(&decompressed)
248 }
249
250 fn content_type(&self) -> &'static str {
251 "application/gzip"
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use serde::{Deserialize, Serialize};
260
261 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
262 struct TestMessage {
263 id: u32,
264 content: String,
265 }
266
267 #[test]
268 fn test_json_codec_basic() {
269 let codec = JsonCodec::new();
270 let message = TestMessage {
271 id: 42,
272 content: "Hello, World!".to_string(),
273 };
274
275 let encoded = codec.encode(&message).unwrap();
276 let decoded = codec.decode(&encoded).unwrap();
277
278 assert_eq!(message, decoded);
279 }
280
281 #[test]
282 fn test_rkyv_codec_basic() {
283 let codec = RkyvCodec::new();
284 let message = TestMessage {
285 id: 42,
286 content: "Hello, World!".to_string(),
287 };
288
289 let encoded = codec.encode(&message).unwrap();
290 let decoded = codec.decode(&encoded).unwrap();
291
292 assert_eq!(message, decoded);
293 }
294
295 #[test]
296 fn test_hybrid_codec_basic() {
297 let codec = HybridCodec::new().unwrap();
298 let message = TestMessage {
299 id: 42,
300 content: "Hello, World!".to_string(),
301 };
302
303 let encoded = codec.encode(&message).unwrap();
304 let decoded = codec.decode(&encoded).unwrap();
305
306 assert_eq!(message, decoded);
307 }
308
309 #[test]
310 fn test_ws_message_wrapper() {
311 let test_data = TestMessage {
312 id: 42,
313 content: "test".to_string(),
314 };
315
316 let ws_message = WsMessage::new(test_data.clone());
317 assert_eq!(ws_message.data, test_data);
318 }
319
320 #[test]
321 fn test_ws_message_serialization() {
322 let test_data = TestMessage {
323 id: 42,
324 content: "test".to_string(),
325 };
326
327 let ws_message = WsMessage::new(test_data.clone());
328
329 let json_encoded = serde_json::to_string(&ws_message).unwrap();
331 let json_decoded: WsMessage<TestMessage> = serde_json::from_str(&json_encoded).unwrap();
332 assert_eq!(ws_message.data, json_decoded.data);
333 }
334}