naia_shared/connection/
encoder.rs1cfg_if! {
2 if #[cfg(feature = "zstd_support")]
3 {
4 use std::fs;
5
6 use log::info;
7
8 use zstd::{bulk::Compressor, dict::from_continuous};
9
10 use super::compression_config::CompressionMode;
11
12 pub struct Encoder {
13 result: Vec<u8>,
14 encoder: EncoderType,
15 }
16
17 impl Encoder {
18 pub fn new(compression_mode: CompressionMode) -> Self {
19 let encoder = match compression_mode {
20 CompressionMode::Training(sample_size) => {
21 EncoderType::DictionaryTrainer(DictionaryTrainer::new(sample_size))
22 }
23 CompressionMode::Default(compression_level) => EncoderType::Compressor(
24 Compressor::new(compression_level).expect("error creating Compressor"),
25 ),
26 CompressionMode::Dictionary(compression_level, dictionary) => EncoderType::Compressor(
27 Compressor::with_dictionary(compression_level, &dictionary)
28 .expect("error creating Compressor with dictionary"),
29 ),
30 };
31
32 Self {
33 result: Vec::new(),
34 encoder,
35 }
36 }
37
38 pub fn encode(&mut self, payload: &[u8]) -> &[u8] {
39 match &mut self.encoder {
41 EncoderType::DictionaryTrainer(trainer) => {
42 trainer.record_bytes(payload);
43 self.result = payload.to_vec();
44 return &self.result;
45 }
46 EncoderType::Compressor(encoder) => {
47 self.result = encoder.compress(payload).expect("encode error");
48 return &self.result;
49 }
50 }
51 }
52 }
53
54 pub enum EncoderType {
55 Compressor(Compressor<'static>),
56 DictionaryTrainer(DictionaryTrainer),
57 }
58
59 pub struct DictionaryTrainer {
60 sample_data: Vec<u8>,
61 sample_sizes: Vec<usize>,
62 next_alert_size: usize,
63 target_sample_size: usize,
64 training_complete: bool,
65 }
66
67 impl DictionaryTrainer {
68 pub fn new(target_sample_size: usize) -> Self {
72 Self {
73 target_sample_size,
74 sample_data: Vec::new(),
75 sample_sizes: Vec::new(),
76 next_alert_size: 0,
77 training_complete: false,
78 }
79 }
80
81 pub fn record_bytes(&mut self, bytes: &[u8]) {
82 if self.training_complete {
83 return;
84 }
85
86 self.sample_data.extend_from_slice(bytes);
87 self.sample_sizes.push(bytes.len());
88
89 let current_sample_size = self.sample_sizes.len();
90
91 if current_sample_size >= self.next_alert_size {
92 let percent =
93 ((self.next_alert_size as f32) / (self.target_sample_size as f32)) * 100.0;
94 info!("Dictionary training: {}% complete", percent);
95
96 self.next_alert_size += self.target_sample_size / 20;
97 }
98
99 if current_sample_size >= self.target_sample_size {
100 info!("Dictionary training complete!");
101 info!(
102 "Samples: {} ({} KB)",
103 self.sample_sizes.len(),
104 self.sample_data.len()
105 );
106 info!("Dictionary processing sample data...");
107
108 let target_dict_size = self.sample_data.len() / 100;
110 let dictionary =
111 from_continuous(&self.sample_data, &self.sample_sizes, target_dict_size)
112 .expect("Error while training dictionary");
113
114 fs::write("dictionary.txt", dictionary)
116 .expect("Error while writing dictionary to file");
117
118 info!("Dictionary written to `dictionary.txt`!");
119
120 self.training_complete = true;
121 }
122 }
123 }
124 }
125 else
126 {
127 use super::compression_config::CompressionMode;
128
129 pub struct Encoder {
130 result: Vec<u8>
131 }
132
133 impl Encoder {
134 pub fn new(_: CompressionMode) -> Self {
135 Self {
136 result: Vec::new(),
137 }
138 }
139
140 pub fn encode(&mut self, payload: &[u8]) -> &[u8] {
141 self.result = payload.to_vec();
142 &self.result
143 }
144 }
145 }
146}