naia_shared/connection/
encoder.rs1cfg_if! {
2 if #[cfg(feature = "zstd_support")]
3 {
4 use std::fs;
5
6 use log::info;
7
8 use zstd::{bulk::Compressor, dict::from_continuous};
9
10 use super::compression_config::CompressionMode;
11
12 #[derive(Clone)]
13 pub struct Encoder {
14 result: Vec<u8>,
15 encoder: EncoderType,
16 }
17
18 impl Encoder {
19 pub fn new(compression_mode: CompressionMode) -> Self {
20 let encoder = match compression_mode {
21 CompressionMode::Training(sample_size) => {
22 EncoderType::DictionaryTrainer(DictionaryTrainer::new(sample_size))
23 }
24 CompressionMode::Default(compression_level) => EncoderType::Compressor(
25 Compressor::new(compression_level).expect("error creating Compressor"),
26 ),
27 CompressionMode::Dictionary(compression_level, dictionary) => EncoderType::Compressor(
28 Compressor::with_dictionary(compression_level, &dictionary)
29 .expect("error creating Compressor with dictionary"),
30 ),
31 };
32
33 Self {
34 result: Vec::new(),
35 encoder,
36 }
37 }
38
39 pub fn encode(&mut self, payload: &[u8]) -> &[u8] {
40 match &mut self.encoder {
41 EncoderType::DictionaryTrainer(trainer) => {
42 trainer.record_bytes(payload);
43 self.result = Vec::with_capacity(1 + payload.len());
45 self.result.push(0u8);
46 self.result.extend_from_slice(payload);
47 return &self.result;
48 }
49 EncoderType::Compressor(encoder) => {
50 let compressed = encoder.compress(payload).expect("encode error");
51 if compressed.len() < payload.len() {
52 self.result = Vec::with_capacity(1 + compressed.len());
54 self.result.push(1u8);
55 self.result.extend_from_slice(&compressed);
56 } else {
57 self.result = Vec::with_capacity(1 + payload.len());
59 self.result.push(0u8);
60 self.result.extend_from_slice(payload);
61 }
62 return &self.result;
63 }
64 }
65 }
66 }
67
68 #[derive(Clone)]
69 pub enum EncoderType {
70 Compressor(Compressor<'static>),
71 DictionaryTrainer(DictionaryTrainer),
72 }
73
74 #[derive(Clone)]
75 pub struct DictionaryTrainer {
76 sample_data: Vec<u8>,
77 sample_sizes: Vec<usize>,
78 next_alert_size: usize,
79 target_sample_size: usize,
80 training_complete: bool,
81 }
82
83 impl DictionaryTrainer {
84 pub fn new(target_sample_size: usize) -> Self {
88 Self {
89 target_sample_size,
90 sample_data: Vec::new(),
91 sample_sizes: Vec::new(),
92 next_alert_size: 0,
93 training_complete: false,
94 }
95 }
96
97 pub fn record_bytes(&mut self, bytes: &[u8]) {
98 if self.training_complete {
99 return;
100 }
101
102 self.sample_data.extend_from_slice(bytes);
103 self.sample_sizes.push(bytes.len());
104
105 let current_sample_size = self.sample_sizes.len();
106
107 if current_sample_size >= self.next_alert_size {
108 let percent =
109 ((self.next_alert_size as f32) / (self.target_sample_size as f32)) * 100.0;
110 info!("Dictionary training: {}% complete", percent);
111
112 self.next_alert_size += self.target_sample_size / 20;
113 }
114
115 if current_sample_size >= self.target_sample_size {
116 info!("Dictionary training complete!");
117 info!(
118 "Samples: {} ({} KB)",
119 self.sample_sizes.len(),
120 self.sample_data.len()
121 );
122 info!("Dictionary processing sample data...");
123
124 let target_dict_size = self.sample_data.len() / 100;
126 let dictionary =
127 from_continuous(&self.sample_data, &self.sample_sizes, target_dict_size)
128 .expect("Error while training dictionary");
129
130 fs::write("dictionary.txt", dictionary)
132 .expect("Error while writing dictionary to file");
133
134 info!("Dictionary written to `dictionary.txt`!");
135
136 self.training_complete = true;
137 }
138 }
139 }
140 }
141 else
142 {
143 use super::compression_config::CompressionMode;
144
145 #[derive(Clone)]
147 pub struct Encoder {
148 result: Vec<u8>
149 }
150
151 impl Encoder {
152 pub fn new(_: CompressionMode) -> Self {
154 Self {
155 result: Vec::new(),
156 }
157 }
158
159 pub fn encode(&mut self, payload: &[u8]) -> &[u8] {
161 self.result = payload.to_vec();
162 &self.result
163 }
164 }
165 }
166}