lance_encoding/
compression_config.rs1use std::collections::HashMap;
10
11use arrow::datatypes::DataType;
12
13#[derive(Debug, Clone, PartialEq)]
15pub struct CompressionParams {
16 pub columns: HashMap<String, CompressionFieldParams>,
18
19 pub types: HashMap<String, CompressionFieldParams>,
21}
22
23#[derive(Debug, Clone, PartialEq, Default)]
25pub struct CompressionFieldParams {
26 pub rle_threshold: Option<f64>,
29
30 pub compression: Option<String>,
32
33 pub compression_level: Option<i32>,
35}
36
37impl CompressionParams {
38 pub fn new() -> Self {
40 Self {
41 columns: HashMap::new(),
42 types: HashMap::new(),
43 }
44 }
45
46 pub fn get_field_params(
48 &self,
49 field_name: &str,
50 data_type: &DataType,
51 ) -> CompressionFieldParams {
52 let mut params = CompressionFieldParams::default();
53
54 let type_name = data_type.to_string();
56 if let Some(type_params) = self.types.get(&type_name) {
57 params.merge(type_params);
58 }
59
60 if let Some(col_params) = self.columns.get(field_name) {
63 params.merge(col_params);
64 } else {
65 for (pattern, col_params) in &self.columns {
67 if matches_pattern(field_name, pattern) {
68 params.merge(col_params);
69 break; }
71 }
72 }
73
74 params
75 }
76}
77
78impl Default for CompressionParams {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl CompressionFieldParams {
85 pub fn merge(&mut self, other: &Self) {
87 if other.rle_threshold.is_some() {
88 self.rle_threshold = other.rle_threshold;
89 }
90 if other.compression.is_some() {
91 self.compression = other.compression.clone();
92 }
93 if other.compression_level.is_some() {
94 self.compression_level = other.compression_level;
95 }
96 }
97}
98
99fn matches_pattern(name: &str, pattern: &str) -> bool {
101 if pattern == "*" {
102 return true;
103 }
104
105 if let Some(prefix) = pattern.strip_suffix('*') {
106 return name.starts_with(prefix);
107 }
108
109 if let Some(suffix) = pattern.strip_prefix('*') {
110 return name.ends_with(suffix);
111 }
112
113 if pattern.contains('*') {
114 if let Some(pos) = pattern.find('*') {
116 let prefix = &pattern[..pos];
117 let suffix = &pattern[pos + 1..];
118 return name.starts_with(prefix)
119 && name.ends_with(suffix)
120 && name.len() >= pattern.len() - 1;
121 }
122 }
123
124 name == pattern
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[test]
132 fn test_pattern_matching() {
133 assert!(matches_pattern("user_id", "*_id"));
134 assert!(matches_pattern("product_id", "*_id"));
135 assert!(!matches_pattern("identity", "*_id"));
136
137 assert!(matches_pattern("log_message", "log_*"));
138 assert!(matches_pattern("log_level", "log_*"));
139 assert!(!matches_pattern("message_log", "log_*"));
140
141 assert!(matches_pattern("test_field_name", "test_*_name"));
142 assert!(matches_pattern("test_column_name", "test_*_name"));
143 assert!(!matches_pattern("test_name", "test_*_name"));
144
145 assert!(matches_pattern("anything", "*"));
146 assert!(matches_pattern("exact_match", "exact_match"));
147 }
148
149 #[test]
150 fn test_field_params_merge() {
151 let mut params = CompressionFieldParams::default();
152 assert_eq!(params.rle_threshold, None);
153 assert_eq!(params.compression, None);
154 assert_eq!(params.compression_level, None);
155
156 let other = CompressionFieldParams {
157 rle_threshold: Some(0.3),
158 compression: Some("lz4".to_string()),
159 compression_level: None,
160 };
161
162 params.merge(&other);
163 assert_eq!(params.rle_threshold, Some(0.3));
164 assert_eq!(params.compression, Some("lz4".to_string()));
165 assert_eq!(params.compression_level, None);
166
167 let another = CompressionFieldParams {
168 rle_threshold: None,
169 compression: Some("zstd".to_string()),
170 compression_level: Some(3),
171 };
172
173 params.merge(&another);
174 assert_eq!(params.rle_threshold, Some(0.3)); assert_eq!(params.compression, Some("zstd".to_string())); assert_eq!(params.compression_level, Some(3)); }
178
179 #[test]
180 fn test_get_field_params() {
181 let mut params = CompressionParams::new();
182
183 params.types.insert(
185 "Int32".to_string(),
186 CompressionFieldParams {
187 rle_threshold: Some(0.5),
188 compression: Some("lz4".to_string()),
189 ..Default::default()
190 },
191 );
192
193 params.columns.insert(
195 "*_id".to_string(),
196 CompressionFieldParams {
197 rle_threshold: Some(0.3),
198 compression: Some("zstd".to_string()),
199 compression_level: Some(3),
200 },
201 );
202
203 let field_params = params.get_field_params("some_field", &DataType::Float32);
205 assert_eq!(field_params.compression, None);
206 assert_eq!(field_params.rle_threshold, None);
207
208 let field_params = params.get_field_params("some_field", &DataType::Int32);
210 assert_eq!(field_params.compression, Some("lz4".to_string())); assert_eq!(field_params.rle_threshold, Some(0.5)); let field_params = params.get_field_params("user_id", &DataType::Int32);
215 assert_eq!(field_params.compression, Some("zstd".to_string())); assert_eq!(field_params.compression_level, Some(3)); assert_eq!(field_params.rle_threshold, Some(0.3)); }
219
220 #[test]
221 fn test_exact_match_priority() {
222 let mut params = CompressionParams::new();
223
224 params.columns.insert(
226 "*_id".to_string(),
227 CompressionFieldParams {
228 compression: Some("lz4".to_string()),
229 ..Default::default()
230 },
231 );
232
233 params.columns.insert(
235 "user_id".to_string(),
236 CompressionFieldParams {
237 compression: Some("zstd".to_string()),
238 ..Default::default()
239 },
240 );
241
242 let field_params = params.get_field_params("user_id", &DataType::Int32);
244 assert_eq!(field_params.compression, Some("zstd".to_string()));
245 }
246}