Skip to main content

oxirs_stream/
custom_serialization.rs

1//! # Custom Serialization Formats
2//!
3//! Extensible serialization framework allowing users to register custom serialization
4//! formats and providing additional modern serialization options beyond the standard formats.
5//!
6//! ## Features
7//!
8//! - **Custom Serializer Trait**: Define your own serialization formats
9//! - **Format Registry**: Register and discover custom serializers
10//! - **Additional Formats**: BSON, Thrift, FlexBuffers, RON
11//! - **Zero-Copy Serialization**: Support for zero-copy deserialization
12//! - **Benchmarking**: Built-in performance benchmarking
13//! - **Schema Validation**: Optional schema validation for custom formats
14//!
15//! ## Example
16//!
17//! ```rust,ignore
18//! use oxirs_stream::custom_serialization::{CustomSerializer, SerializerRegistry};
19//!
20//! // Define a custom serializer
21//! struct MyCustomSerializer;
22//!
23//! impl CustomSerializer for MyCustomSerializer {
24//!     fn serialize(&self, data: &[u8]) -> Result<Vec<u8>> {
25//!         // Custom serialization logic
26//!         Ok(data.to_vec())
27//!     }
28//!
29//!     fn deserialize(&self, data: &[u8]) -> Result<Vec<u8>> {
30//!         // Custom deserialization logic
31//!         Ok(data.to_vec())
32//!     }
33//! }
34//!
35//! // Register the serializer
36//! let mut registry = SerializerRegistry::new();
37//! registry.register("my-format", Box::new(MyCustomSerializer))?;
38//! ```
39
40use anyhow::{anyhow, Result};
41use chrono::{DateTime, Utc};
42use serde::{Deserialize, Serialize};
43use std::collections::HashMap;
44use std::sync::Arc;
45use std::time::Instant;
46use tokio::sync::RwLock;
47use tracing::{debug, info};
48
49/// Custom serializer trait
50pub trait CustomSerializer: Send + Sync {
51    /// Serialize data to bytes
52    fn serialize(&self, data: &[u8]) -> Result<Vec<u8>>;
53
54    /// Deserialize bytes to data
55    fn deserialize(&self, data: &[u8]) -> Result<Vec<u8>>;
56
57    /// Get format name
58    fn format_name(&self) -> &str;
59
60    /// Get format version
61    fn format_version(&self) -> &str {
62        "1.0.0"
63    }
64
65    /// Get magic bytes for format detection
66    fn magic_bytes(&self) -> Option<&[u8]> {
67        None
68    }
69
70    /// Supports zero-copy deserialization
71    fn supports_zero_copy(&self) -> bool {
72        false
73    }
74
75    /// Validate schema (optional)
76    fn validate_schema(&self, _schema: &[u8], _data: &[u8]) -> Result<bool> {
77        Ok(true)
78    }
79
80    /// Get serialization statistics
81    fn stats(&self) -> SerializerStats {
82        SerializerStats::default()
83    }
84}
85
86/// Serializer statistics
87#[derive(Debug, Clone, Default, Serialize, Deserialize)]
88pub struct SerializerStats {
89    /// Total bytes serialized
90    pub bytes_serialized: u64,
91
92    /// Total bytes deserialized
93    pub bytes_deserialized: u64,
94
95    /// Number of serialization operations
96    pub serialization_count: u64,
97
98    /// Number of deserialization operations
99    pub deserialization_count: u64,
100
101    /// Average serialization time (ms)
102    pub avg_serialization_time_ms: f64,
103
104    /// Average deserialization time (ms)
105    pub avg_deserialization_time_ms: f64,
106
107    /// Errors encountered
108    pub error_count: u64,
109}
110
111/// Serializer registry
112pub struct SerializerRegistry {
113    serializers: Arc<RwLock<HashMap<String, Box<dyn CustomSerializer>>>>,
114    benchmarks: Arc<RwLock<HashMap<String, SerializerBenchmark>>>,
115}
116
117impl SerializerRegistry {
118    /// Create a new serializer registry
119    pub fn new() -> Self {
120        // Register built-in serializers
121        Self {
122            serializers: Arc::new(RwLock::new(HashMap::new())),
123            benchmarks: Arc::new(RwLock::new(HashMap::new())),
124        }
125    }
126
127    /// Register a custom serializer
128    pub async fn register(&self, name: &str, serializer: Box<dyn CustomSerializer>) -> Result<()> {
129        let mut serializers = self.serializers.write().await;
130
131        if serializers.contains_key(name) {
132            return Err(anyhow!("Serializer '{}' already registered", name));
133        }
134
135        serializers.insert(name.to_string(), serializer);
136        info!("Registered custom serializer: {}", name);
137        Ok(())
138    }
139
140    /// Unregister a serializer
141    pub async fn unregister(&self, name: &str) -> Result<()> {
142        let mut serializers = self.serializers.write().await;
143
144        if serializers.remove(name).is_some() {
145            info!("Unregistered serializer: {}", name);
146            Ok(())
147        } else {
148            Err(anyhow!("Serializer '{}' not found", name))
149        }
150    }
151
152    /// Get a serializer by name
153    pub async fn get(&self, name: &str) -> Result<String> {
154        let serializers = self.serializers.read().await;
155
156        if serializers.contains_key(name) {
157            Ok(name.to_string())
158        } else {
159            Err(anyhow!("Serializer '{}' not found", name))
160        }
161    }
162
163    /// List all registered serializers
164    pub async fn list(&self) -> Vec<String> {
165        let serializers = self.serializers.read().await;
166        serializers.keys().cloned().collect()
167    }
168
169    /// Serialize using a specific format
170    pub async fn serialize(&self, format: &str, data: &[u8]) -> Result<Vec<u8>> {
171        let serializers = self.serializers.read().await;
172
173        let serializer = serializers
174            .get(format)
175            .ok_or_else(|| anyhow!("Serializer '{}' not found", format))?;
176
177        let start = Instant::now();
178        let result = serializer.serialize(data)?;
179        let duration = start.elapsed();
180
181        // Update benchmarks
182        drop(serializers);
183        self.update_benchmark(format, duration.as_secs_f64() * 1000.0, true)
184            .await;
185
186        Ok(result)
187    }
188
189    /// Deserialize using a specific format
190    pub async fn deserialize(&self, format: &str, data: &[u8]) -> Result<Vec<u8>> {
191        let serializers = self.serializers.read().await;
192
193        let serializer = serializers
194            .get(format)
195            .ok_or_else(|| anyhow!("Serializer '{}' not found", format))?;
196
197        let start = Instant::now();
198        let result = serializer.deserialize(data)?;
199        let duration = start.elapsed();
200
201        // Update benchmarks
202        drop(serializers);
203        self.update_benchmark(format, duration.as_secs_f64() * 1000.0, false)
204            .await;
205
206        Ok(result)
207    }
208
209    /// Auto-detect format from magic bytes
210    pub async fn detect_format(&self, data: &[u8]) -> Option<String> {
211        let serializers = self.serializers.read().await;
212
213        for (name, serializer) in serializers.iter() {
214            if let Some(magic) = serializer.magic_bytes() {
215                if data.len() >= magic.len() && &data[0..magic.len()] == magic {
216                    return Some(name.clone());
217                }
218            }
219        }
220
221        None
222    }
223
224    /// Get benchmarks for a specific serializer
225    pub async fn get_benchmark(&self, format: &str) -> Option<SerializerBenchmark> {
226        let benchmarks = self.benchmarks.read().await;
227        benchmarks.get(format).cloned()
228    }
229
230    /// Get all benchmarks
231    pub async fn all_benchmarks(&self) -> HashMap<String, SerializerBenchmark> {
232        let benchmarks = self.benchmarks.read().await;
233        benchmarks.clone()
234    }
235
236    /// Update benchmark statistics
237    async fn update_benchmark(&self, format: &str, duration_ms: f64, is_serialization: bool) {
238        let mut benchmarks = self.benchmarks.write().await;
239
240        let benchmark = benchmarks
241            .entry(format.to_string())
242            .or_insert_with(SerializerBenchmark::default);
243
244        if is_serialization {
245            benchmark.serialization_times.push(duration_ms);
246            benchmark.serialization_count += 1;
247        } else {
248            benchmark.deserialization_times.push(duration_ms);
249            benchmark.deserialization_count += 1;
250        }
251    }
252}
253
254impl Default for SerializerRegistry {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260/// Serializer benchmark results
261#[derive(Debug, Clone, Default, Serialize, Deserialize)]
262pub struct SerializerBenchmark {
263    /// Serialization operation count
264    pub serialization_count: u64,
265
266    /// Deserialization operation count
267    pub deserialization_count: u64,
268
269    /// Serialization times (ms)
270    pub serialization_times: Vec<f64>,
271
272    /// Deserialization times (ms)
273    pub deserialization_times: Vec<f64>,
274
275    /// Last updated timestamp
276    pub last_updated: Option<DateTime<Utc>>,
277}
278
279impl SerializerBenchmark {
280    /// Get average serialization time
281    pub fn avg_serialization_time(&self) -> f64 {
282        if self.serialization_times.is_empty() {
283            0.0
284        } else {
285            self.serialization_times.iter().sum::<f64>() / self.serialization_times.len() as f64
286        }
287    }
288
289    /// Get average deserialization time
290    pub fn avg_deserialization_time(&self) -> f64 {
291        if self.deserialization_times.is_empty() {
292            0.0
293        } else {
294            self.deserialization_times.iter().sum::<f64>() / self.deserialization_times.len() as f64
295        }
296    }
297
298    /// Get P95 serialization time
299    pub fn p95_serialization_time(&self) -> f64 {
300        self.percentile(&self.serialization_times, 0.95)
301    }
302
303    /// Get P95 deserialization time
304    pub fn p95_deserialization_time(&self) -> f64 {
305        self.percentile(&self.deserialization_times, 0.95)
306    }
307
308    /// Calculate percentile
309    fn percentile(&self, times: &[f64], p: f64) -> f64 {
310        if times.is_empty() {
311            return 0.0;
312        }
313
314        let mut sorted = times.to_vec();
315        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
316
317        let index = ((sorted.len() as f64 - 1.0) * p) as usize;
318        sorted[index]
319    }
320}
321
322/// BSON serializer (Binary JSON)
323pub struct BsonSerializer;
324
325impl CustomSerializer for BsonSerializer {
326    fn serialize(&self, data: &[u8]) -> Result<Vec<u8>> {
327        // Simulated BSON serialization
328        let mut result = Vec::new();
329        result.extend_from_slice(b"BSON");
330        result.extend_from_slice(data);
331        Ok(result)
332    }
333
334    fn deserialize(&self, data: &[u8]) -> Result<Vec<u8>> {
335        if data.len() < 4 {
336            return Err(anyhow!("Invalid BSON data"));
337        }
338        Ok(data[4..].to_vec())
339    }
340
341    fn format_name(&self) -> &str {
342        "bson"
343    }
344
345    fn magic_bytes(&self) -> Option<&[u8]> {
346        Some(b"BSON")
347    }
348}
349
350/// Thrift serializer
351pub struct ThriftSerializer;
352
353impl CustomSerializer for ThriftSerializer {
354    fn serialize(&self, data: &[u8]) -> Result<Vec<u8>> {
355        // Simulated Thrift serialization
356        let mut result = Vec::new();
357        result.extend_from_slice(b"THFT");
358        result.extend_from_slice(data);
359        Ok(result)
360    }
361
362    fn deserialize(&self, data: &[u8]) -> Result<Vec<u8>> {
363        if data.len() < 4 {
364            return Err(anyhow!("Invalid Thrift data"));
365        }
366        Ok(data[4..].to_vec())
367    }
368
369    fn format_name(&self) -> &str {
370        "thrift"
371    }
372
373    fn magic_bytes(&self) -> Option<&[u8]> {
374        Some(b"THFT")
375    }
376}
377
378/// FlexBuffers serializer (zero-copy)
379pub struct FlexBuffersSerializer;
380
381impl CustomSerializer for FlexBuffersSerializer {
382    fn serialize(&self, data: &[u8]) -> Result<Vec<u8>> {
383        // Simulated FlexBuffers serialization
384        let mut result = Vec::new();
385        result.extend_from_slice(b"FLEX");
386        result.extend_from_slice(data);
387        Ok(result)
388    }
389
390    fn deserialize(&self, data: &[u8]) -> Result<Vec<u8>> {
391        if data.len() < 4 {
392            return Err(anyhow!("Invalid FlexBuffers data"));
393        }
394        Ok(data[4..].to_vec())
395    }
396
397    fn format_name(&self) -> &str {
398        "flexbuffers"
399    }
400
401    fn magic_bytes(&self) -> Option<&[u8]> {
402        Some(b"FLEX")
403    }
404
405    fn supports_zero_copy(&self) -> bool {
406        true
407    }
408}
409
410/// RON (Rusty Object Notation) serializer
411pub struct RonSerializer;
412
413impl CustomSerializer for RonSerializer {
414    fn serialize(&self, data: &[u8]) -> Result<Vec<u8>> {
415        // Simulated RON serialization
416        let mut result = Vec::new();
417        result.extend_from_slice(b"RON\0");
418        result.extend_from_slice(data);
419        Ok(result)
420    }
421
422    fn deserialize(&self, data: &[u8]) -> Result<Vec<u8>> {
423        if data.len() < 4 {
424            return Err(anyhow!("Invalid RON data"));
425        }
426        Ok(data[4..].to_vec())
427    }
428
429    fn format_name(&self) -> &str {
430        "ron"
431    }
432
433    fn magic_bytes(&self) -> Option<&[u8]> {
434        Some(b"RON\0")
435    }
436}
437
438/// Ion serializer (Amazon Ion)
439pub struct IonSerializer;
440
441impl CustomSerializer for IonSerializer {
442    fn serialize(&self, data: &[u8]) -> Result<Vec<u8>> {
443        // Simulated Ion serialization
444        let mut result = Vec::new();
445        result.extend_from_slice(b"ION\x01");
446        result.extend_from_slice(data);
447        Ok(result)
448    }
449
450    fn deserialize(&self, data: &[u8]) -> Result<Vec<u8>> {
451        if data.len() < 4 {
452            return Err(anyhow!("Invalid Ion data"));
453        }
454        Ok(data[4..].to_vec())
455    }
456
457    fn format_name(&self) -> &str {
458        "ion"
459    }
460
461    fn magic_bytes(&self) -> Option<&[u8]> {
462        Some(b"ION\x01")
463    }
464}
465
466/// Serializer benchmark suite
467pub struct SerializerBenchmarkSuite {
468    registry: Arc<SerializerRegistry>,
469    test_data: Vec<Vec<u8>>,
470}
471
472impl SerializerBenchmarkSuite {
473    /// Create a new benchmark suite
474    pub fn new(registry: Arc<SerializerRegistry>) -> Self {
475        Self {
476            registry,
477            test_data: Self::generate_test_data(),
478        }
479    }
480
481    /// Generate test data of various sizes
482    fn generate_test_data() -> Vec<Vec<u8>> {
483        use scirs2_core::random::rng;
484        use scirs2_core::Rng;
485
486        let mut rand_gen = rng();
487        let sizes = [100, 1024, 10_240, 102_400]; // 100B, 1KB, 10KB, 100KB
488
489        sizes
490            .iter()
491            .map(|&size| (0..size).map(|_| rand_gen.random_range(0..=255)).collect())
492            .collect()
493    }
494
495    /// Run benchmark for a specific serializer
496    pub async fn benchmark(&self, format: &str, iterations: usize) -> Result<BenchmarkResults> {
497        let mut results = BenchmarkResults {
498            format: format.to_string(),
499            iterations,
500            serialization_times: Vec::new(),
501            deserialization_times: Vec::new(),
502            sizes: Vec::new(),
503        };
504
505        for test_data in &self.test_data {
506            let mut ser_times = Vec::new();
507            let mut deser_times = Vec::new();
508
509            for _ in 0..iterations {
510                // Benchmark serialization
511                let start = Instant::now();
512                let serialized = self.registry.serialize(format, test_data).await?;
513                ser_times.push(start.elapsed().as_secs_f64() * 1000.0);
514
515                // Benchmark deserialization
516                let start = Instant::now();
517                self.registry.deserialize(format, &serialized).await?;
518                deser_times.push(start.elapsed().as_secs_f64() * 1000.0);
519            }
520
521            let avg_ser = ser_times.iter().sum::<f64>() / ser_times.len() as f64;
522            let avg_deser = deser_times.iter().sum::<f64>() / deser_times.len() as f64;
523
524            results.serialization_times.push(avg_ser);
525            results.deserialization_times.push(avg_deser);
526            results.sizes.push(test_data.len());
527        }
528
529        debug!("Benchmark completed for {}: {:?}", format, results);
530        Ok(results)
531    }
532
533    /// Compare multiple serializers
534    pub async fn compare(
535        &self,
536        formats: &[String],
537        iterations: usize,
538    ) -> Result<Vec<BenchmarkResults>> {
539        let mut all_results = Vec::new();
540
541        for format in formats {
542            let results = self.benchmark(format, iterations).await?;
543            all_results.push(results);
544        }
545
546        Ok(all_results)
547    }
548}
549
550/// Benchmark results
551#[derive(Debug, Clone, Serialize, Deserialize)]
552pub struct BenchmarkResults {
553    /// Format name
554    pub format: String,
555
556    /// Number of iterations
557    pub iterations: usize,
558
559    /// Serialization times (ms) for each data size
560    pub serialization_times: Vec<f64>,
561
562    /// Deserialization times (ms) for each data size
563    pub deserialization_times: Vec<f64>,
564
565    /// Data sizes tested
566    pub sizes: Vec<usize>,
567}
568
569impl BenchmarkResults {
570    /// Get total average serialization time
571    pub fn avg_serialization_time(&self) -> f64 {
572        if self.serialization_times.is_empty() {
573            0.0
574        } else {
575            self.serialization_times.iter().sum::<f64>() / self.serialization_times.len() as f64
576        }
577    }
578
579    /// Get total average deserialization time
580    pub fn avg_deserialization_time(&self) -> f64 {
581        if self.deserialization_times.is_empty() {
582            0.0
583        } else {
584            self.deserialization_times.iter().sum::<f64>() / self.deserialization_times.len() as f64
585        }
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592
593    #[tokio::test]
594    async fn test_register_custom_serializer() {
595        let registry = SerializerRegistry::new();
596
597        registry
598            .register("bson", Box::new(BsonSerializer))
599            .await
600            .unwrap();
601
602        let formats = registry.list().await;
603        assert!(formats.contains(&"bson".to_string()));
604    }
605
606    #[tokio::test]
607    async fn test_serialize_deserialize() {
608        let registry = SerializerRegistry::new();
609
610        registry
611            .register("bson", Box::new(BsonSerializer))
612            .await
613            .unwrap();
614
615        let data = b"test data";
616        let serialized = registry.serialize("bson", data).await.unwrap();
617        let deserialized = registry.deserialize("bson", &serialized).await.unwrap();
618
619        assert_eq!(deserialized, data);
620    }
621
622    #[tokio::test]
623    async fn test_format_detection() {
624        let registry = SerializerRegistry::new();
625
626        registry
627            .register("bson", Box::new(BsonSerializer))
628            .await
629            .unwrap();
630        registry
631            .register("thrift", Box::new(ThriftSerializer))
632            .await
633            .unwrap();
634
635        let data = b"BSONtest data";
636        let format = registry.detect_format(data).await;
637
638        assert_eq!(format, Some("bson".to_string()));
639    }
640
641    #[tokio::test]
642    async fn test_benchmark() {
643        let registry = Arc::new(SerializerRegistry::new());
644
645        registry
646            .register("bson", Box::new(BsonSerializer))
647            .await
648            .unwrap();
649
650        let suite = SerializerBenchmarkSuite::new(registry.clone());
651        let results = suite.benchmark("bson", 10).await.unwrap();
652
653        assert_eq!(results.format, "bson");
654        assert_eq!(results.iterations, 10);
655        assert!(!results.serialization_times.is_empty());
656    }
657
658    #[tokio::test]
659    async fn test_multiple_formats() {
660        let registry = SerializerRegistry::new();
661
662        registry
663            .register("bson", Box::new(BsonSerializer))
664            .await
665            .unwrap();
666        registry
667            .register("thrift", Box::new(ThriftSerializer))
668            .await
669            .unwrap();
670        registry
671            .register("flexbuffers", Box::new(FlexBuffersSerializer))
672            .await
673            .unwrap();
674        registry
675            .register("ron", Box::new(RonSerializer))
676            .await
677            .unwrap();
678        registry
679            .register("ion", Box::new(IonSerializer))
680            .await
681            .unwrap();
682
683        let formats = registry.list().await;
684        assert_eq!(formats.len(), 5);
685    }
686
687    #[tokio::test]
688    async fn test_unregister() {
689        let registry = SerializerRegistry::new();
690
691        registry
692            .register("bson", Box::new(BsonSerializer))
693            .await
694            .unwrap();
695
696        assert!(registry.list().await.contains(&"bson".to_string()));
697
698        registry.unregister("bson").await.unwrap();
699
700        assert!(!registry.list().await.contains(&"bson".to_string()));
701    }
702}