lance_encoding/
compression_config.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Configuration for compression parameters
5//!
6//! This module provides types for configuring compression strategies
7//! on a per-column or per-type basis using a parameter-driven approach.
8
9use std::collections::HashMap;
10
11use arrow::datatypes::DataType;
12
13/// Compression parameter configuration
14#[derive(Debug, Clone, PartialEq)]
15pub struct CompressionParams {
16    /// Column-level parameters: column name/pattern -> parameters
17    pub columns: HashMap<String, CompressionFieldParams>,
18
19    /// Type-level parameters: data type name -> parameters
20    pub types: HashMap<String, CompressionFieldParams>,
21}
22
23/// Field-level compression parameters
24#[derive(Debug, Clone, PartialEq, Default)]
25pub struct CompressionFieldParams {
26    /// RLE threshold (0.0-1.0)
27    /// When run_count < num_values * threshold, RLE will be used
28    pub rle_threshold: Option<f64>,
29
30    /// General compression scheme: "lz4", "zstd", "none"
31    pub compression: Option<String>,
32
33    /// Compression level (only for schemes that support it, e.g., zstd)
34    pub compression_level: Option<i32>,
35}
36
37impl CompressionParams {
38    /// Create empty compression parameters
39    pub fn new() -> Self {
40        Self {
41            columns: HashMap::new(),
42            types: HashMap::new(),
43        }
44    }
45
46    /// Get effective parameters for a field (merging type params and column params)
47    pub fn get_field_params(
48        &self,
49        field_name: &str,
50        data_type: &DataType,
51    ) -> CompressionFieldParams {
52        let mut params = CompressionFieldParams::default();
53
54        // Apply type-level parameters
55        let type_name = data_type.to_string();
56        if let Some(type_params) = self.types.get(&type_name) {
57            params.merge(type_params);
58        }
59
60        // Apply column-level parameters (highest priority)
61        // First check exact match
62        if let Some(col_params) = self.columns.get(field_name) {
63            params.merge(col_params);
64        } else {
65            // Check pattern matching
66            for (pattern, col_params) in &self.columns {
67                if matches_pattern(field_name, pattern) {
68                    params.merge(col_params);
69                    break; // Use first matching pattern
70                }
71            }
72        }
73
74        params
75    }
76}
77
78impl Default for CompressionParams {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl CompressionFieldParams {
85    /// Merge another CompressionFieldParams, non-None values will override
86    pub fn merge(&mut self, other: &Self) {
87        if other.rle_threshold.is_some() {
88            self.rle_threshold = other.rle_threshold;
89        }
90        if other.compression.is_some() {
91            self.compression = other.compression.clone();
92        }
93        if other.compression_level.is_some() {
94            self.compression_level = other.compression_level;
95        }
96    }
97}
98
99/// Check if a name matches a pattern (supports wildcards)
100fn matches_pattern(name: &str, pattern: &str) -> bool {
101    if pattern == "*" {
102        return true;
103    }
104
105    if let Some(prefix) = pattern.strip_suffix('*') {
106        return name.starts_with(prefix);
107    }
108
109    if let Some(suffix) = pattern.strip_prefix('*') {
110        return name.ends_with(suffix);
111    }
112
113    if pattern.contains('*') {
114        // Simple glob pattern matching (only supports single * in middle)
115        if let Some(pos) = pattern.find('*') {
116            let prefix = &pattern[..pos];
117            let suffix = &pattern[pos + 1..];
118            return name.starts_with(prefix)
119                && name.ends_with(suffix)
120                && name.len() >= pattern.len() - 1;
121        }
122    }
123
124    name == pattern
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn test_pattern_matching() {
133        assert!(matches_pattern("user_id", "*_id"));
134        assert!(matches_pattern("product_id", "*_id"));
135        assert!(!matches_pattern("identity", "*_id"));
136
137        assert!(matches_pattern("log_message", "log_*"));
138        assert!(matches_pattern("log_level", "log_*"));
139        assert!(!matches_pattern("message_log", "log_*"));
140
141        assert!(matches_pattern("test_field_name", "test_*_name"));
142        assert!(matches_pattern("test_column_name", "test_*_name"));
143        assert!(!matches_pattern("test_name", "test_*_name"));
144
145        assert!(matches_pattern("anything", "*"));
146        assert!(matches_pattern("exact_match", "exact_match"));
147    }
148
149    #[test]
150    fn test_field_params_merge() {
151        let mut params = CompressionFieldParams::default();
152        assert_eq!(params.rle_threshold, None);
153        assert_eq!(params.compression, None);
154        assert_eq!(params.compression_level, None);
155
156        let other = CompressionFieldParams {
157            rle_threshold: Some(0.3),
158            compression: Some("lz4".to_string()),
159            compression_level: None,
160        };
161
162        params.merge(&other);
163        assert_eq!(params.rle_threshold, Some(0.3));
164        assert_eq!(params.compression, Some("lz4".to_string()));
165        assert_eq!(params.compression_level, None);
166
167        let another = CompressionFieldParams {
168            rle_threshold: None,
169            compression: Some("zstd".to_string()),
170            compression_level: Some(3),
171        };
172
173        params.merge(&another);
174        assert_eq!(params.rle_threshold, Some(0.3)); // Not overridden
175        assert_eq!(params.compression, Some("zstd".to_string())); // Overridden
176        assert_eq!(params.compression_level, Some(3)); // New value
177    }
178
179    #[test]
180    fn test_get_field_params() {
181        let mut params = CompressionParams::new();
182
183        // Set type-level params
184        params.types.insert(
185            "Int32".to_string(),
186            CompressionFieldParams {
187                rle_threshold: Some(0.5),
188                compression: Some("lz4".to_string()),
189                ..Default::default()
190            },
191        );
192
193        // Set column-level params
194        params.columns.insert(
195            "*_id".to_string(),
196            CompressionFieldParams {
197                rle_threshold: Some(0.3),
198                compression: Some("zstd".to_string()),
199                compression_level: Some(3),
200            },
201        );
202
203        // Test no match (should get default)
204        let field_params = params.get_field_params("some_field", &DataType::Float32);
205        assert_eq!(field_params.compression, None);
206        assert_eq!(field_params.rle_threshold, None);
207
208        // Test type match only
209        let field_params = params.get_field_params("some_field", &DataType::Int32);
210        assert_eq!(field_params.compression, Some("lz4".to_string())); // From type
211        assert_eq!(field_params.rle_threshold, Some(0.5)); // From type
212
213        // Test column override (pattern match)
214        let field_params = params.get_field_params("user_id", &DataType::Int32);
215        assert_eq!(field_params.compression, Some("zstd".to_string())); // From column
216        assert_eq!(field_params.compression_level, Some(3)); // From column
217        assert_eq!(field_params.rle_threshold, Some(0.3)); // From column (overrides type)
218    }
219
220    #[test]
221    fn test_exact_match_priority() {
222        let mut params = CompressionParams::new();
223
224        // Add pattern
225        params.columns.insert(
226            "*_id".to_string(),
227            CompressionFieldParams {
228                compression: Some("lz4".to_string()),
229                ..Default::default()
230            },
231        );
232
233        // Add exact match
234        params.columns.insert(
235            "user_id".to_string(),
236            CompressionFieldParams {
237                compression: Some("zstd".to_string()),
238                ..Default::default()
239            },
240        );
241
242        // Exact match should win
243        let field_params = params.get_field_params("user_id", &DataType::Int32);
244        assert_eq!(field_params.compression, Some("zstd".to_string()));
245    }
246}