leptos_ws_pro/codec/
mod.rs

1//! Codec module for encoding and decoding WebSocket messages
2//!
3//! This module provides a simple JSON-based codec system for WebSocket messages.
4//! Future versions will include zero-copy serialization with rkyv and compression.
5
6use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
7use thiserror::Error;
8
9/// Trait for encoding and decoding messages
10pub trait Codec<T>: Send + Sync
11where
12    T: Send + Sync,
13{
14    /// Encode a message to bytes
15    fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError>;
16
17    /// Decode bytes to a message
18    fn decode(&self, data: &[u8]) -> Result<T, CodecError>;
19
20    /// Get the content type for this codec
21    fn content_type(&self) -> &'static str;
22}
23
24/// Codec errors
25#[derive(Debug, Error)]
26pub enum CodecError {
27    #[error("Serialization failed: {0}")]
28    SerializationFailed(String),
29
30    #[error("Deserialization failed: {0}")]
31    DeserializationFailed(String),
32
33    #[error("Compression failed: {0}")]
34    CompressionFailed(String),
35
36    #[error("Decompression failed: {0}")]
37    DecompressionFailed(String),
38
39    #[error("Compression not supported: {0}")]
40    CompressionNotSupported(String),
41}
42
43/// JSON codec using serde
44pub struct JsonCodec;
45
46impl JsonCodec {
47    pub fn new() -> Self {
48        Self
49    }
50}
51
52impl<T> Codec<T> for JsonCodec
53where
54    T: SerdeSerialize + for<'de> SerdeDeserialize<'de> + Clone + Send + Sync,
55{
56    fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
57        serde_json::to_vec(message).map_err(|e| CodecError::SerializationFailed(e.to_string()))
58    }
59
60    fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
61        serde_json::from_slice(data).map_err(|e| CodecError::DeserializationFailed(e.to_string()))
62    }
63
64    fn content_type(&self) -> &'static str {
65        "application/json"
66    }
67}
68
69/// rkyv-based zero-copy codec
70pub struct RkyvCodec;
71
72impl RkyvCodec {
73    pub fn new() -> Self {
74        Self
75    }
76}
77
78// For types that support rkyv serialization
79impl<T> Codec<T> for RkyvCodec
80where
81    T: SerdeSerialize + for<'de> SerdeDeserialize<'de> + Clone + Send + Sync,
82{
83    fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
84        // Try rkyv serialization first, fallback to JSON if rkyv not available
85        #[cfg(feature = "zero-copy")]
86        {
87            // TODO: Implement real rkyv serialization when type supports it
88            // For now, this is a framework for future rkyv integration
89            // Real implementation would use:
90            // use rkyv::{Archive, Deserialize, Serialize, to_bytes};
91            // to_bytes(message).map_err(|e| CodecError::SerializationFailed(e.to_string()))
92        }
93
94        // Fallback to JSON for now
95        serde_json::to_vec(message).map_err(|e| CodecError::SerializationFailed(e.to_string()))
96    }
97
98    fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
99        // Try rkyv deserialization first, fallback to JSON if rkyv not available
100        #[cfg(feature = "zero-copy")]
101        {
102            // TODO: Implement real rkyv deserialization when type supports it
103            // For now, this is a framework for future rkyv integration
104            // Real implementation would use:
105            // use rkyv::{Archive, Deserialize, from_bytes};
106            // from_bytes(data).map_err(|e| CodecError::DeserializationFailed(e.to_string()))
107        }
108
109        // Fallback to JSON for now
110        serde_json::from_slice(data).map_err(|e| CodecError::DeserializationFailed(e.to_string()))
111    }
112
113    fn content_type(&self) -> &'static str {
114        "application/rkyv"
115    }
116}
117
118/// Hybrid codec that tries rkyv first, falls back to JSON
119pub struct HybridCodec {
120    rkyv_codec: RkyvCodec,
121    json_codec: JsonCodec,
122}
123
124impl HybridCodec {
125    pub fn new() -> Result<Self, CodecError> {
126        Ok(Self {
127            rkyv_codec: RkyvCodec::new(),
128            json_codec: JsonCodec::new(),
129        })
130    }
131}
132
133impl<T> Codec<T> for HybridCodec
134where
135    T: SerdeSerialize + for<'de> SerdeDeserialize<'de> + Clone + Send + Sync,
136{
137    fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
138        // Try rkyv first for performance
139        match self.rkyv_codec.encode(message) {
140            Ok(data) => Ok(data),
141            Err(_) => {
142                // Fall back to JSON
143                self.json_codec.encode(message)
144            }
145        }
146    }
147
148    fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
149        // Try JSON first (simpler for now)
150        match self.json_codec.decode(data) {
151            Ok(result) => Ok(result),
152            Err(_) => {
153                // Fall back to rkyv
154                match self.rkyv_codec.decode(data) {
155                    Ok(result) => Ok(result),
156                    Err(_e) => {
157                        // If both fail, return the JSON error
158                        self.json_codec.decode(data)
159                    }
160                }
161            }
162        }
163    }
164
165    fn content_type(&self) -> &'static str {
166        "application/hybrid"
167    }
168}
169
170/// Wrapper for WebSocket messages with type information
171#[derive(Debug, Clone, SerdeSerialize, SerdeDeserialize)]
172pub struct WsMessage<T> {
173    pub data: T,
174}
175
176impl<T> WsMessage<T> {
177    pub fn new(data: T) -> Self {
178        Self { data }
179    }
180}
181
182/// Compressed codec wrapper
183pub struct CompressedCodec<C> {
184    inner: C,
185    compression_level: i32,
186}
187
188impl<C> CompressedCodec<C> {
189    pub fn new(inner: C) -> Self {
190        Self {
191            inner,
192            compression_level: 3, // Default compression level
193        }
194    }
195
196    pub fn with_level(inner: C, level: i32) -> Self {
197        Self {
198            inner,
199            compression_level: level,
200        }
201    }
202}
203
204impl<T, C> Codec<T> for CompressedCodec<C>
205where
206    C: Codec<T>,
207    T: Send + Sync,
208{
209    fn encode(&self, message: &T) -> Result<Vec<u8>, CodecError> {
210        // First encode with inner codec
211        let uncompressed = self.inner.encode(message)?;
212
213        // Then compress
214        #[cfg(feature = "compression")]
215        {
216            use std::io::Write;
217            let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::new(self.compression_level as u32));
218            encoder.write_all(&uncompressed)
219                .map_err(|e| CodecError::CompressionFailed(e.to_string()))?;
220            encoder.finish()
221                .map_err(|e| CodecError::CompressionFailed(e.to_string()))
222        }
223
224        #[cfg(not(feature = "compression"))]
225        {
226            // Return uncompressed if compression feature is not enabled
227            Ok(uncompressed)
228        }
229    }
230
231    fn decode(&self, data: &[u8]) -> Result<T, CodecError> {
232        // First decompress
233        #[cfg(feature = "compression")]
234        let decompressed = {
235            use std::io::Read;
236            let mut decoder = flate2::read::GzDecoder::new(data);
237            let mut decompressed = Vec::new();
238            decoder.read_to_end(&mut decompressed)
239                .map_err(|e| CodecError::DecompressionFailed(e.to_string()))?;
240            decompressed
241        };
242
243        #[cfg(not(feature = "compression"))]
244        let decompressed = data.to_vec();
245
246        // Then decode with inner codec
247        self.inner.decode(&decompressed)
248    }
249
250    fn content_type(&self) -> &'static str {
251        // Indicate compressed content
252        "application/gzip"
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use serde::{Deserialize, Serialize};
260
261    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
262    struct TestMessage {
263        id: u32,
264        content: String,
265    }
266
267    #[test]
268    fn test_json_codec_basic() {
269        let codec = JsonCodec::new();
270        let message = TestMessage {
271            id: 42,
272            content: "Hello, World!".to_string(),
273        };
274
275        let encoded = codec.encode(&message).unwrap();
276        let decoded = codec.decode(&encoded).unwrap();
277
278        assert_eq!(message, decoded);
279    }
280
281    #[test]
282    fn test_rkyv_codec_basic() {
283        let codec = RkyvCodec::new();
284        let message = TestMessage {
285            id: 42,
286            content: "Hello, World!".to_string(),
287        };
288
289        let encoded = codec.encode(&message).unwrap();
290        let decoded = codec.decode(&encoded).unwrap();
291
292        assert_eq!(message, decoded);
293    }
294
295    #[test]
296    fn test_hybrid_codec_basic() {
297        let codec = HybridCodec::new().unwrap();
298        let message = TestMessage {
299            id: 42,
300            content: "Hello, World!".to_string(),
301        };
302
303        let encoded = codec.encode(&message).unwrap();
304        let decoded = codec.decode(&encoded).unwrap();
305
306        assert_eq!(message, decoded);
307    }
308
309    #[test]
310    fn test_ws_message_wrapper() {
311        let test_data = TestMessage {
312            id: 42,
313            content: "test".to_string(),
314        };
315
316        let ws_message = WsMessage::new(test_data.clone());
317        assert_eq!(ws_message.data, test_data);
318    }
319
320    #[test]
321    fn test_ws_message_serialization() {
322        let test_data = TestMessage {
323            id: 42,
324            content: "test".to_string(),
325        };
326
327        let ws_message = WsMessage::new(test_data.clone());
328
329        // Test JSON serialization
330        let json_encoded = serde_json::to_string(&ws_message).unwrap();
331        let json_decoded: WsMessage<TestMessage> = serde_json::from_str(&json_encoded).unwrap();
332        assert_eq!(ws_message.data, json_decoded.data);
333    }
334}