naia_shared/connection/
encoder.rs

1cfg_if! {
2    if #[cfg(feature = "zstd_support")]
3    {
4        use std::fs;
5
6        use log::info;
7
8        use zstd::{bulk::Compressor, dict::from_continuous};
9
10        use super::compression_config::CompressionMode;
11
12        pub struct Encoder {
13            result: Vec<u8>,
14            encoder: EncoderType,
15        }
16
17        impl Encoder {
18            pub fn new(compression_mode: CompressionMode) -> Self {
19                let encoder = match compression_mode {
20                    CompressionMode::Training(sample_size) => {
21                        EncoderType::DictionaryTrainer(DictionaryTrainer::new(sample_size))
22                    }
23                    CompressionMode::Default(compression_level) => EncoderType::Compressor(
24                        Compressor::new(compression_level).expect("error creating Compressor"),
25                    ),
26                    CompressionMode::Dictionary(compression_level, dictionary) => EncoderType::Compressor(
27                        Compressor::with_dictionary(compression_level, &dictionary)
28                            .expect("error creating Compressor with dictionary"),
29                    ),
30                };
31
32                Self {
33                    result: Vec::new(),
34                    encoder,
35                }
36            }
37
38            pub fn encode(&mut self, payload: &[u8]) -> &[u8] {
39                // TODO: only use compressed packet if the resulting size would be less!
40                match &mut self.encoder {
41                    EncoderType::DictionaryTrainer(trainer) => {
42                        trainer.record_bytes(payload);
43                        self.result = payload.to_vec();
44                        return &self.result;
45                    }
46                    EncoderType::Compressor(encoder) => {
47                        self.result = encoder.compress(payload).expect("encode error");
48                        return &self.result;
49                    }
50                }
51            }
52        }
53
54        pub enum EncoderType {
55            Compressor(Compressor<'static>),
56            DictionaryTrainer(DictionaryTrainer),
57        }
58
59        pub struct DictionaryTrainer {
60            sample_data: Vec<u8>,
61            sample_sizes: Vec<usize>,
62            next_alert_size: usize,
63            target_sample_size: usize,
64            training_complete: bool,
65        }
66
67        impl DictionaryTrainer {
68            /// `target_sample_size` here describes the number of samples (packets) to
69            /// train on. Obviously, the more samples trained on, the better
70            /// theoretical compression.
71            pub fn new(target_sample_size: usize) -> Self {
72                Self {
73                    target_sample_size,
74                    sample_data: Vec::new(),
75                    sample_sizes: Vec::new(),
76                    next_alert_size: 0,
77                    training_complete: false,
78                }
79            }
80
81            pub fn record_bytes(&mut self, bytes: &[u8]) {
82                if self.training_complete {
83                    return;
84                }
85
86                self.sample_data.extend_from_slice(bytes);
87                self.sample_sizes.push(bytes.len());
88
89                let current_sample_size = self.sample_sizes.len();
90
91                if current_sample_size >= self.next_alert_size {
92                    let percent =
93                        ((self.next_alert_size as f32) / (self.target_sample_size as f32)) * 100.0;
94                    info!("Dictionary training: {}% complete", percent);
95
96                    self.next_alert_size += self.target_sample_size / 20;
97                }
98
99                if current_sample_size >= self.target_sample_size {
100                    info!("Dictionary training complete!");
101                    info!(
102                        "Samples: {} ({} KB)",
103                        self.sample_sizes.len(),
104                        self.sample_data.len()
105                    );
106                    info!("Dictionary processing sample data...");
107
108                    // We have enough sample data to train the dictionary!
109                    let target_dict_size = self.sample_data.len() / 100;
110                    let dictionary =
111                        from_continuous(&self.sample_data, &self.sample_sizes, target_dict_size)
112                            .expect("Error while training dictionary");
113
114                    // Now need to ... write it to a file I guess
115                    fs::write("dictionary.txt", dictionary)
116                        .expect("Error while writing dictionary to file");
117
118                    info!("Dictionary written to `dictionary.txt`!");
119
120                    self.training_complete = true;
121                }
122            }
123        }
124    }
125    else
126    {
127        use super::compression_config::CompressionMode;
128
129        pub struct Encoder {
130            result: Vec<u8>
131        }
132
133        impl Encoder {
134            pub fn new(_: CompressionMode) -> Self {
135                Self {
136                    result: Vec::new(),
137                }
138            }
139
140            pub fn encode(&mut self, payload: &[u8]) -> &[u8] {
141                self.result = payload.to_vec();
142                &self.result
143            }
144        }
145    }
146}