1use std::collections::HashMap;
10
11use arrow_schema::DataType;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum BssMode {
16 Off,
18 On,
20 Auto,
22}
23
24impl BssMode {
25 pub fn to_sensitivity(&self) -> f32 {
27 match self {
28 Self::Off => 0.0,
29 Self::On => 1.0,
30 Self::Auto => 0.5, }
32 }
33
34 pub fn parse(s: &str) -> Option<Self> {
36 match s.to_lowercase().as_str() {
37 "off" => Some(Self::Off),
38 "on" => Some(Self::On),
39 "auto" => Some(Self::Auto),
40 _ => None,
41 }
42 }
43}
44
45#[derive(Debug, Clone, PartialEq)]
47pub struct CompressionParams {
48 pub columns: HashMap<String, CompressionFieldParams>,
50
51 pub types: HashMap<String, CompressionFieldParams>,
53}
54
55#[derive(Debug, Clone, PartialEq, Default)]
57pub struct CompressionFieldParams {
58 pub rle_threshold: Option<f64>,
61
62 pub compression: Option<String>,
64
65 pub compression_level: Option<i32>,
67
68 pub bss: Option<BssMode>,
70}
71
72impl CompressionParams {
73 pub fn new() -> Self {
75 Self {
76 columns: HashMap::new(),
77 types: HashMap::new(),
78 }
79 }
80
81 pub fn get_field_params(
83 &self,
84 field_name: &str,
85 data_type: &DataType,
86 ) -> CompressionFieldParams {
87 let mut params = CompressionFieldParams::default();
88
89 let type_name = data_type.to_string();
91 if let Some(type_params) = self.types.get(&type_name) {
92 params.merge(type_params);
93 }
94
95 if let Some(col_params) = self.columns.get(field_name) {
98 params.merge(col_params);
99 } else {
100 for (pattern, col_params) in &self.columns {
102 if matches_pattern(field_name, pattern) {
103 params.merge(col_params);
104 break; }
106 }
107 }
108
109 params
110 }
111}
112
113impl Default for CompressionParams {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119impl CompressionFieldParams {
120 pub fn merge(&mut self, other: &Self) {
122 if other.rle_threshold.is_some() {
123 self.rle_threshold = other.rle_threshold;
124 }
125 if other.compression.is_some() {
126 self.compression = other.compression.clone();
127 }
128 if other.compression_level.is_some() {
129 self.compression_level = other.compression_level;
130 }
131 if other.bss.is_some() {
132 self.bss = other.bss;
133 }
134 }
135}
136
137fn matches_pattern(name: &str, pattern: &str) -> bool {
139 if pattern == "*" {
140 return true;
141 }
142
143 if let Some(prefix) = pattern.strip_suffix('*') {
144 return name.starts_with(prefix);
145 }
146
147 if let Some(suffix) = pattern.strip_prefix('*') {
148 return name.ends_with(suffix);
149 }
150
151 if pattern.contains('*') {
152 if let Some(pos) = pattern.find('*') {
154 let prefix = &pattern[..pos];
155 let suffix = &pattern[pos + 1..];
156 return name.starts_with(prefix)
157 && name.ends_with(suffix)
158 && name.len() >= pattern.len() - 1;
159 }
160 }
161
162 name == pattern
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn test_pattern_matching() {
171 assert!(matches_pattern("user_id", "*_id"));
172 assert!(matches_pattern("product_id", "*_id"));
173 assert!(!matches_pattern("identity", "*_id"));
174
175 assert!(matches_pattern("log_message", "log_*"));
176 assert!(matches_pattern("log_level", "log_*"));
177 assert!(!matches_pattern("message_log", "log_*"));
178
179 assert!(matches_pattern("test_field_name", "test_*_name"));
180 assert!(matches_pattern("test_column_name", "test_*_name"));
181 assert!(!matches_pattern("test_name", "test_*_name"));
182
183 assert!(matches_pattern("anything", "*"));
184 assert!(matches_pattern("exact_match", "exact_match"));
185 }
186
187 #[test]
188 fn test_field_params_merge() {
189 let mut params = CompressionFieldParams::default();
190 assert_eq!(params.rle_threshold, None);
191 assert_eq!(params.compression, None);
192 assert_eq!(params.compression_level, None);
193 assert_eq!(params.bss, None);
194
195 let other = CompressionFieldParams {
196 rle_threshold: Some(0.3),
197 compression: Some("lz4".to_string()),
198 compression_level: None,
199 bss: Some(BssMode::On),
200 };
201
202 params.merge(&other);
203 assert_eq!(params.rle_threshold, Some(0.3));
204 assert_eq!(params.compression, Some("lz4".to_string()));
205 assert_eq!(params.compression_level, None);
206 assert_eq!(params.bss, Some(BssMode::On));
207
208 let another = CompressionFieldParams {
209 rle_threshold: None,
210 compression: Some("zstd".to_string()),
211 compression_level: Some(3),
212 bss: Some(BssMode::Auto),
213 };
214
215 params.merge(&another);
216 assert_eq!(params.rle_threshold, Some(0.3)); assert_eq!(params.compression, Some("zstd".to_string())); assert_eq!(params.compression_level, Some(3)); assert_eq!(params.bss, Some(BssMode::Auto)); }
221
222 #[test]
223 fn test_get_field_params() {
224 let mut params = CompressionParams::new();
225
226 params.types.insert(
228 "Int32".to_string(),
229 CompressionFieldParams {
230 rle_threshold: Some(0.5),
231 compression: Some("lz4".to_string()),
232 ..Default::default()
233 },
234 );
235
236 params.columns.insert(
238 "*_id".to_string(),
239 CompressionFieldParams {
240 rle_threshold: Some(0.3),
241 compression: Some("zstd".to_string()),
242 compression_level: Some(3),
243 bss: None,
244 },
245 );
246
247 let field_params = params.get_field_params("some_field", &DataType::Float32);
249 assert_eq!(field_params.compression, None);
250 assert_eq!(field_params.rle_threshold, None);
251
252 let field_params = params.get_field_params("some_field", &DataType::Int32);
254 assert_eq!(field_params.compression, Some("lz4".to_string())); assert_eq!(field_params.rle_threshold, Some(0.5)); let field_params = params.get_field_params("user_id", &DataType::Int32);
259 assert_eq!(field_params.compression, Some("zstd".to_string())); assert_eq!(field_params.compression_level, Some(3)); assert_eq!(field_params.rle_threshold, Some(0.3)); }
263
264 #[test]
265 fn test_exact_match_priority() {
266 let mut params = CompressionParams::new();
267
268 params.columns.insert(
270 "*_id".to_string(),
271 CompressionFieldParams {
272 compression: Some("lz4".to_string()),
273 ..Default::default()
274 },
275 );
276
277 params.columns.insert(
279 "user_id".to_string(),
280 CompressionFieldParams {
281 compression: Some("zstd".to_string()),
282 ..Default::default()
283 },
284 );
285
286 let field_params = params.get_field_params("user_id", &DataType::Int32);
288 assert_eq!(field_params.compression, Some("zstd".to_string()));
289 }
290}