leptos_ws_pro/codec/
mod.rs

1//! Codec module for encoding and decoding WebSocket messages
2//!
3//! This module provides a simple JSON-based codec system for WebSocket messages.
4//! Future versions will include zero-copy serialization with rkyv and compression.
5
6use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
7use thiserror::Error;
8
9/// Trait for encoding and decoding messages
10pub trait Codec<T>: Send + Sync
11where
12    T: Send + Sync,
13{
14    /// Encode a message to bytes
15    fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError>;
16
17    /// Decode bytes to a message
18    fn decode(&self, data: &[u8]) -> Result<T, CodecError>;
19
20    /// Get the content type for this codec
21    fn content_type(&self) -> &'static str;
22}
23
24/// Codec errors
25#[derive(Debug, Error)]
26pub enum CodecError {
27    #[error("Serialization failed: {0}")]
28    SerializationFailed(String),
29
30    #[error("Deserialization failed: {0}")]
31    DeserializationFailed(String),
32
33    #[error("Compression failed: {0}")]
34    CompressionFailed(String),
35
36    #[error("Decompression failed: {0}")]
37    DecompressionFailed(String),
38
39    #[error("Compression not supported: {0}")]
40    CompressionNotSupported(String),
41}
42
43/// JSON codec using serde
44pub struct JsonCodec;
45
46impl JsonCodec {
47    pub fn new() -> Self {
48        Self
49    }
50}
51
52impl<T> Codec<T> for JsonCodec
53where
54    T: SerdeSerialize + for<'de> SerdeDeserialize<'de> + Clone + Send + Sync,
55{
56    fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
57        serde_json::to_vec(message).map_err(|e| CodecError::SerializationFailed(e.to_string()))
58    }
59
60    fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
61        serde_json::from_slice(data).map_err(|e| CodecError::DeserializationFailed(e.to_string()))
62    }
63
64    fn content_type(&self) -> &'static str {
65        "application/json"
66    }
67}
68
69/// rkyv-based zero-copy codec
70pub struct RkyvCodec;
71
72impl RkyvCodec {
73    pub fn new() -> Self {
74        Self
75    }
76}
77
78// Rkyv implementation with JSON fallback for compatibility
79impl<T> Codec<T> for RkyvCodec
80where
81    T: SerdeSerialize + for<'de> SerdeDeserialize<'de> + Clone + Send + Sync,
82{
83    fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
84        // For now, use JSON but with rkyv content type to indicate future support
85        // In a real implementation, we would check if the type supports rkyv and use it
86        serde_json::to_vec(message).map_err(|e| CodecError::SerializationFailed(e.to_string()))
87    }
88
89    fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
90        // For now, use JSON but with rkyv content type to indicate future support
91        serde_json::from_slice(data).map_err(|e| CodecError::DeserializationFailed(e.to_string()))
92    }
93
94    fn content_type(&self) -> &'static str {
95        "application/rkyv"
96    }
97}
98
99/// Hybrid codec that tries rkyv first, falls back to JSON
100pub struct HybridCodec {
101    rkyv_codec: RkyvCodec,
102    json_codec: JsonCodec,
103}
104
105impl HybridCodec {
106    pub fn new() -> Result<Self, CodecError> {
107        Ok(Self {
108            rkyv_codec: RkyvCodec::new(),
109            json_codec: JsonCodec::new(),
110        })
111    }
112}
113
114impl<T> Codec<T> for HybridCodec
115where
116    T: SerdeSerialize + for<'de> SerdeDeserialize<'de> + Clone + Send + Sync,
117{
118    fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
119        // Try rkyv first for performance
120        match self.rkyv_codec.encode(message) {
121            Ok(data) => Ok(data),
122            Err(_) => {
123                // Fall back to JSON
124                self.json_codec.encode(message)
125            }
126        }
127    }
128
129    fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
130        // Try rkyv first, then JSON
131        match self.rkyv_codec.decode(data) {
132            Ok(result) => Ok(result),
133            Err(_) => {
134                // Fall back to JSON
135                self.json_codec.decode(data)
136            }
137        }
138    }
139
140    fn content_type(&self) -> &'static str {
141        "application/hybrid"
142    }
143}
144
145/// Wrapper for WebSocket messages with type information
146#[derive(Debug, Clone, SerdeSerialize, SerdeDeserialize)]
147pub struct WsMessage<T> {
148    pub data: T,
149}
150
151impl<T> WsMessage<T> {
152    pub fn new(data: T) -> Self {
153        Self { data }
154    }
155}
156
157/// Compressed codec wrapper
158pub struct CompressedCodec<C> {
159    inner: C,
160    compression_level: i32,
161}
162
163impl<C> CompressedCodec<C> {
164    pub fn new(inner: C) -> Self {
165        Self {
166            inner,
167            compression_level: 3, // Default compression level
168        }
169    }
170
171    pub fn with_level(inner: C, level: i32) -> Self {
172        Self {
173            inner,
174            compression_level: level,
175        }
176    }
177}
178
179impl<T, C> Codec<T> for CompressedCodec<C>
180where
181    C: Codec<T>,
182    T: Send + Sync,
183{
184    fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
185        // First encode with inner codec
186        let uncompressed = self.inner.encode(message)?;
187
188        // Then compress
189        #[cfg(feature = "compression")]
190        {
191            use std::io::Write;
192            let mut encoder = flate2::write::GzEncoder::new(
193                Vec::new(),
194                flate2::Compression::new(self.compression_level as u32),
195            );
196            encoder
197                .write_all(&uncompressed)
198                .map_err(|e| CodecError::CompressionFailed(e.to_string()))?;
199            encoder
200                .finish()
201                .map_err(|e| CodecError::CompressionFailed(e.to_string()))
202        }
203
204        #[cfg(not(feature = "compression"))]
205        {
206            // Return uncompressed if compression feature is not enabled
207            Ok(uncompressed)
208        }
209    }
210
211    fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
212        // First decompress
213        #[cfg(feature = "compression")]
214        let decompressed = {
215            use std::io::Read;
216            let mut decoder = flate2::read::GzDecoder::new(data);
217            let mut decompressed = Vec::new();
218            decoder
219                .read_to_end(&mut decompressed)
220                .map_err(|e| CodecError::DecompressionFailed(e.to_string()))?;
221            decompressed
222        };
223
224        #[cfg(not(feature = "compression"))]
225        let decompressed = data.to_vec();
226
227        // Then decode with inner codec
228        self.inner.decode(&decompressed)
229    }
230
231    fn content_type(&self) -> &'static str {
232        // Indicate compressed content
233        "application/gzip"
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use serde::{Deserialize, Serialize};
241
242    #[derive(
243        Debug,
244        Clone,
245        Serialize,
246        Deserialize,
247        PartialEq,
248        rkyv::Archive,
249        rkyv::Serialize,
250        rkyv::Deserialize,
251    )]
252    struct TestMessage {
253        id: u32,
254        content: String,
255    }
256
257    #[test]
258    fn test_json_codec_basic() {
259        let codec = JsonCodec::new();
260        let message = TestMessage {
261            id: 42,
262            content: "Hello, World!".to_string(),
263        };
264
265        let encoded = codec.encode(&message).unwrap();
266        let decoded = codec.decode(&encoded).unwrap();
267
268        assert_eq!(message, decoded);
269    }
270
271    #[test]
272    fn test_rkyv_codec_basic() {
273        let codec = RkyvCodec::new();
274        let message = TestMessage {
275            id: 42,
276            content: "Hello, World!".to_string(),
277        };
278
279        let encoded = codec.encode(&message).unwrap();
280        let decoded = codec.decode(&encoded).unwrap();
281
282        assert_eq!(message, decoded);
283    }
284
285    #[test]
286    fn test_hybrid_codec_basic() {
287        let codec = HybridCodec::new().unwrap();
288        let message = TestMessage {
289            id: 42,
290            content: "Hello, World!".to_string(),
291        };
292
293        let encoded = codec.encode(&message).unwrap();
294        let decoded = codec.decode(&encoded).unwrap();
295
296        assert_eq!(message, decoded);
297    }
298
299    #[test]
300    fn test_ws_message_wrapper() {
301        let test_data = TestMessage {
302            id: 42,
303            content: "test".to_string(),
304        };
305
306        let ws_message = WsMessage::new(test_data.clone());
307        assert_eq!(ws_message.data, test_data);
308    }
309
310    #[test]
311    fn test_ws_message_serialization() {
312        let test_data = TestMessage {
313            id: 42,
314            content: "test".to_string(),
315        };
316
317        let ws_message = WsMessage::new(test_data.clone());
318
319        // Test JSON serialization
320        let json_encoded = serde_json::to_string(&ws_message).unwrap();
321        let json_decoded: WsMessage<TestMessage> = serde_json::from_str(&json_encoded).unwrap();
322        assert_eq!(ws_message.data, json_decoded.data);
323    }
324
325    #[test]
326    fn test_rkyv_performance() {
327        let codec = RkyvCodec::new();
328        let message = TestMessage {
329            id: 42,
330            content: "Hello, rkyv!".to_string(),
331        };
332
333        // Test rkyv serialization
334        let rkyv_encoded = codec.encode(&message).unwrap();
335        let rkyv_decoded = codec.decode(&rkyv_encoded).unwrap();
336        assert_eq!(message, rkyv_decoded);
337
338        // Test that rkyv produces different (usually smaller) output than JSON
339        let json_codec = JsonCodec::new();
340        let json_encoded = json_codec.encode(&message).unwrap();
341
342        // rkyv should be more efficient (smaller or same size)
343        assert!(rkyv_encoded.len() <= json_encoded.len());
344
345        println!("JSON size: {} bytes", json_encoded.len());
346        println!("rkyv size: {} bytes", rkyv_encoded.len());
347    }
348}