Skip to main content

naia_shared/connection/
encoder.rs

1cfg_if! {
2    if #[cfg(feature = "zstd_support")]
3    {
4        use std::fs;
5
6        use log::info;
7
8        use zstd::{bulk::Compressor, dict::from_continuous};
9
10        use super::compression_config::CompressionMode;
11
12        #[derive(Clone)]
13        pub struct Encoder {
14            result: Vec<u8>,
15            encoder: EncoderType,
16        }
17
18        impl Encoder {
19            pub fn new(compression_mode: CompressionMode) -> Self {
20                let encoder = match compression_mode {
21                    CompressionMode::Training(sample_size) => {
22                        EncoderType::DictionaryTrainer(DictionaryTrainer::new(sample_size))
23                    }
24                    CompressionMode::Default(compression_level) => EncoderType::Compressor(
25                        Compressor::new(compression_level).expect("error creating Compressor"),
26                    ),
27                    CompressionMode::Dictionary(compression_level, dictionary) => EncoderType::Compressor(
28                        Compressor::with_dictionary(compression_level, &dictionary)
29                            .expect("error creating Compressor with dictionary"),
30                    ),
31                };
32
33                Self {
34                    result: Vec::new(),
35                    encoder,
36                }
37            }
38
39            pub fn encode(&mut self, payload: &[u8]) -> &[u8] {
40                match &mut self.encoder {
41                    EncoderType::DictionaryTrainer(trainer) => {
42                        trainer.record_bytes(payload);
43                        // Training mode: emit uncompressed with is_compressed=0 prefix
44                        self.result = Vec::with_capacity(1 + payload.len());
45                        self.result.push(0u8);
46                        self.result.extend_from_slice(payload);
47                        return &self.result;
48                    }
49                    EncoderType::Compressor(encoder) => {
50                        let compressed = encoder.compress(payload).expect("encode error");
51                        if compressed.len() < payload.len() {
52                            // Compression is beneficial: is_compressed=1 prefix + compressed bytes
53                            self.result = Vec::with_capacity(1 + compressed.len());
54                            self.result.push(1u8);
55                            self.result.extend_from_slice(&compressed);
56                        } else {
57                            // Compression is not beneficial: is_compressed=0 prefix + original bytes
58                            self.result = Vec::with_capacity(1 + payload.len());
59                            self.result.push(0u8);
60                            self.result.extend_from_slice(payload);
61                        }
62                        return &self.result;
63                    }
64                }
65            }
66        }
67
68        #[derive(Clone)]
69        pub enum EncoderType {
70            Compressor(Compressor<'static>),
71            DictionaryTrainer(DictionaryTrainer),
72        }
73
74        #[derive(Clone)]
75        pub struct DictionaryTrainer {
76            sample_data: Vec<u8>,
77            sample_sizes: Vec<usize>,
78            next_alert_size: usize,
79            target_sample_size: usize,
80            training_complete: bool,
81        }
82
83        impl DictionaryTrainer {
84            /// `target_sample_size` here describes the number of samples (packets) to
85            /// train on. Obviously, the more samples trained on, the better
86            /// theoretical compression.
87            pub fn new(target_sample_size: usize) -> Self {
88                Self {
89                    target_sample_size,
90                    sample_data: Vec::new(),
91                    sample_sizes: Vec::new(),
92                    next_alert_size: 0,
93                    training_complete: false,
94                }
95            }
96
97            pub fn record_bytes(&mut self, bytes: &[u8]) {
98                if self.training_complete {
99                    return;
100                }
101
102                self.sample_data.extend_from_slice(bytes);
103                self.sample_sizes.push(bytes.len());
104
105                let current_sample_size = self.sample_sizes.len();
106
107                if current_sample_size >= self.next_alert_size {
108                    let percent =
109                        ((self.next_alert_size as f32) / (self.target_sample_size as f32)) * 100.0;
110                    info!("Dictionary training: {}% complete", percent);
111
112                    self.next_alert_size += self.target_sample_size / 20;
113                }
114
115                if current_sample_size >= self.target_sample_size {
116                    info!("Dictionary training complete!");
117                    info!(
118                        "Samples: {} ({} KB)",
119                        self.sample_sizes.len(),
120                        self.sample_data.len()
121                    );
122                    info!("Dictionary processing sample data...");
123
124                    // We have enough sample data to train the dictionary!
125                    let target_dict_size = self.sample_data.len() / 100;
126                    let dictionary =
127                        from_continuous(&self.sample_data, &self.sample_sizes, target_dict_size)
128                            .expect("Error while training dictionary");
129
130                    // Now need to ... write it to a file I guess
131                    fs::write("dictionary.txt", dictionary)
132                        .expect("Error while writing dictionary to file");
133
134                    info!("Dictionary written to `dictionary.txt`!");
135
136                    self.training_complete = true;
137                }
138            }
139        }
140    }
141    else
142    {
143        use super::compression_config::CompressionMode;
144
145        /// Packet encoder (no-op variant: passes payload through unchanged).
146        #[derive(Clone)]
147        pub struct Encoder {
148            result: Vec<u8>
149        }
150
151        impl Encoder {
152            /// Creates a no-op encoder (compression mode is ignored in this build variant).
153            pub fn new(_: CompressionMode) -> Self {
154                Self {
155                    result: Vec::new(),
156                }
157            }
158
159            /// Returns the payload unchanged (no-op compression).
160            pub fn encode(&mut self, payload: &[u8]) -> &[u8] {
161                self.result = payload.to_vec();
162                &self.result
163            }
164        }
165    }
166}