1use std::collections::HashMap;
10
11use arrow_schema::DataType;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum BssMode {
16 Off,
18 On,
20 Auto,
22}
23
24impl BssMode {
25 pub fn to_sensitivity(&self) -> f32 {
27 match self {
28 Self::Off => 0.0,
29 Self::On => 1.0,
30 Self::Auto => 0.5, }
32 }
33
34 pub fn parse(s: &str) -> Option<Self> {
36 match s.to_lowercase().as_str() {
37 "off" => Some(Self::Off),
38 "on" => Some(Self::On),
39 "auto" => Some(Self::Auto),
40 _ => None,
41 }
42 }
43}
44
45#[derive(Debug, Clone, PartialEq)]
47pub struct CompressionParams {
48 pub columns: HashMap<String, CompressionFieldParams>,
50
51 pub types: HashMap<String, CompressionFieldParams>,
53}
54
55#[derive(Debug, Clone, PartialEq, Default)]
57pub struct CompressionFieldParams {
58 pub rle_threshold: Option<f64>,
61
62 pub compression: Option<String>,
64
65 pub compression_level: Option<i32>,
67
68 pub bss: Option<BssMode>,
70
71 pub minichunk_size: Option<i64>,
73}
74
75impl CompressionParams {
76 pub fn new() -> Self {
78 Self {
79 columns: HashMap::new(),
80 types: HashMap::new(),
81 }
82 }
83
84 pub fn get_field_params(
86 &self,
87 field_name: &str,
88 data_type: &DataType,
89 ) -> CompressionFieldParams {
90 let mut params = CompressionFieldParams::default();
91
92 let type_name = data_type.to_string();
94 if let Some(type_params) = self.types.get(&type_name) {
95 params.merge(type_params);
96 }
97
98 if let Some(col_params) = self.columns.get(field_name) {
101 params.merge(col_params);
102 } else {
103 for (pattern, col_params) in &self.columns {
105 if matches_pattern(field_name, pattern) {
106 params.merge(col_params);
107 break; }
109 }
110 }
111
112 params
113 }
114}
115
116impl Default for CompressionParams {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl CompressionFieldParams {
123 pub fn merge(&mut self, other: &Self) {
125 if other.rle_threshold.is_some() {
126 self.rle_threshold = other.rle_threshold;
127 }
128 if other.compression.is_some() {
129 self.compression = other.compression.clone();
130 }
131 if other.compression_level.is_some() {
132 self.compression_level = other.compression_level;
133 }
134 if other.bss.is_some() {
135 self.bss = other.bss;
136 }
137 if other.minichunk_size.is_some() {
138 self.minichunk_size = other.minichunk_size;
139 }
140 }
141}
142
143fn matches_pattern(name: &str, pattern: &str) -> bool {
145 if pattern == "*" {
146 return true;
147 }
148
149 if let Some(prefix) = pattern.strip_suffix('*') {
150 return name.starts_with(prefix);
151 }
152
153 if let Some(suffix) = pattern.strip_prefix('*') {
154 return name.ends_with(suffix);
155 }
156
157 if pattern.contains('*') {
158 if let Some(pos) = pattern.find('*') {
160 let prefix = &pattern[..pos];
161 let suffix = &pattern[pos + 1..];
162 return name.starts_with(prefix)
163 && name.ends_with(suffix)
164 && name.len() >= pattern.len() - 1;
165 }
166 }
167
168 name == pattern
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn test_pattern_matching() {
177 assert!(matches_pattern("user_id", "*_id"));
178 assert!(matches_pattern("product_id", "*_id"));
179 assert!(!matches_pattern("identity", "*_id"));
180
181 assert!(matches_pattern("log_message", "log_*"));
182 assert!(matches_pattern("log_level", "log_*"));
183 assert!(!matches_pattern("message_log", "log_*"));
184
185 assert!(matches_pattern("test_field_name", "test_*_name"));
186 assert!(matches_pattern("test_column_name", "test_*_name"));
187 assert!(!matches_pattern("test_name", "test_*_name"));
188
189 assert!(matches_pattern("anything", "*"));
190 assert!(matches_pattern("exact_match", "exact_match"));
191 }
192
193 #[test]
194 fn test_field_params_merge() {
195 let mut params = CompressionFieldParams::default();
196 assert_eq!(params.rle_threshold, None);
197 assert_eq!(params.compression, None);
198 assert_eq!(params.compression_level, None);
199 assert_eq!(params.bss, None);
200
201 let other = CompressionFieldParams {
202 rle_threshold: Some(0.3),
203 compression: Some("lz4".to_string()),
204 compression_level: None,
205 bss: Some(BssMode::On),
206 minichunk_size: None,
207 };
208
209 params.merge(&other);
210 assert_eq!(params.rle_threshold, Some(0.3));
211 assert_eq!(params.compression, Some("lz4".to_string()));
212 assert_eq!(params.compression_level, None);
213 assert_eq!(params.bss, Some(BssMode::On));
214
215 let another = CompressionFieldParams {
216 rle_threshold: None,
217 compression: Some("zstd".to_string()),
218 compression_level: Some(3),
219 bss: Some(BssMode::Auto),
220 minichunk_size: None,
221 };
222
223 params.merge(&another);
224 assert_eq!(params.rle_threshold, Some(0.3)); assert_eq!(params.compression, Some("zstd".to_string())); assert_eq!(params.compression_level, Some(3)); assert_eq!(params.bss, Some(BssMode::Auto)); }
229
230 #[test]
231 fn test_get_field_params() {
232 let mut params = CompressionParams::new();
233
234 params.types.insert(
236 "Int32".to_string(),
237 CompressionFieldParams {
238 rle_threshold: Some(0.5),
239 compression: Some("lz4".to_string()),
240 ..Default::default()
241 },
242 );
243
244 params.columns.insert(
246 "*_id".to_string(),
247 CompressionFieldParams {
248 rle_threshold: Some(0.3),
249 compression: Some("zstd".to_string()),
250 compression_level: Some(3),
251 bss: None,
252 minichunk_size: None,
253 },
254 );
255
256 let field_params = params.get_field_params("some_field", &DataType::Float32);
258 assert_eq!(field_params.compression, None);
259 assert_eq!(field_params.rle_threshold, None);
260
261 let field_params = params.get_field_params("some_field", &DataType::Int32);
263 assert_eq!(field_params.compression, Some("lz4".to_string())); assert_eq!(field_params.rle_threshold, Some(0.5)); let field_params = params.get_field_params("user_id", &DataType::Int32);
268 assert_eq!(field_params.compression, Some("zstd".to_string())); assert_eq!(field_params.compression_level, Some(3)); assert_eq!(field_params.rle_threshold, Some(0.3)); }
272
273 #[test]
274 fn test_exact_match_priority() {
275 let mut params = CompressionParams::new();
276
277 params.columns.insert(
279 "*_id".to_string(),
280 CompressionFieldParams {
281 compression: Some("lz4".to_string()),
282 ..Default::default()
283 },
284 );
285
286 params.columns.insert(
288 "user_id".to_string(),
289 CompressionFieldParams {
290 compression: Some("zstd".to_string()),
291 ..Default::default()
292 },
293 );
294
295 let field_params = params.get_field_params("user_id", &DataType::Int32);
297 assert_eq!(field_params.compression, Some("zstd".to_string()));
298 }
299}