lance_encoding/
compression_config.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Configuration for compression parameters
5//!
6//! This module provides types for configuring compression strategies
7//! on a per-column or per-type basis using a parameter-driven approach.
8
9use std::collections::HashMap;
10
11use arrow_schema::DataType;
12
13/// Byte stream split encoding mode
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum BssMode {
16    /// Never use BSS
17    Off,
18    /// Always use BSS for floating point data
19    On,
20    /// Automatically decide based on data characteristics
21    Auto,
22}
23
24impl BssMode {
25    /// Convert to internal sensitivity value
26    pub fn to_sensitivity(&self) -> f32 {
27        match self {
28            Self::Off => 0.0,
29            Self::On => 1.0,
30            Self::Auto => 0.5, // Default sensitivity for auto mode
31        }
32    }
33
34    /// Parse from string
35    pub fn parse(s: &str) -> Option<Self> {
36        match s.to_lowercase().as_str() {
37            "off" => Some(Self::Off),
38            "on" => Some(Self::On),
39            "auto" => Some(Self::Auto),
40            _ => None,
41        }
42    }
43}
44
45/// Compression parameter configuration
46#[derive(Debug, Clone, PartialEq)]
47pub struct CompressionParams {
48    /// Column-level parameters: column name/pattern -> parameters
49    pub columns: HashMap<String, CompressionFieldParams>,
50
51    /// Type-level parameters: data type name -> parameters
52    pub types: HashMap<String, CompressionFieldParams>,
53}
54
55/// Field-level compression parameters
56#[derive(Debug, Clone, PartialEq, Default)]
57pub struct CompressionFieldParams {
58    /// RLE threshold (0.0-1.0)
59    /// When run_count < num_values * threshold, RLE will be used
60    pub rle_threshold: Option<f64>,
61
62    /// General compression scheme: "lz4", "zstd", "none"
63    pub compression: Option<String>,
64
65    /// Compression level (only for schemes that support it, e.g., zstd)
66    pub compression_level: Option<i32>,
67
68    /// Byte stream split mode for floating point data
69    pub bss: Option<BssMode>,
70}
71
72impl CompressionParams {
73    /// Create empty compression parameters
74    pub fn new() -> Self {
75        Self {
76            columns: HashMap::new(),
77            types: HashMap::new(),
78        }
79    }
80
81    /// Get effective parameters for a field (merging type params and column params)
82    pub fn get_field_params(
83        &self,
84        field_name: &str,
85        data_type: &DataType,
86    ) -> CompressionFieldParams {
87        let mut params = CompressionFieldParams::default();
88
89        // Apply type-level parameters
90        let type_name = data_type.to_string();
91        if let Some(type_params) = self.types.get(&type_name) {
92            params.merge(type_params);
93        }
94
95        // Apply column-level parameters (highest priority)
96        // First check exact match
97        if let Some(col_params) = self.columns.get(field_name) {
98            params.merge(col_params);
99        } else {
100            // Check pattern matching
101            for (pattern, col_params) in &self.columns {
102                if matches_pattern(field_name, pattern) {
103                    params.merge(col_params);
104                    break; // Use first matching pattern
105                }
106            }
107        }
108
109        params
110    }
111}
112
113impl Default for CompressionParams {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl CompressionFieldParams {
120    /// Merge another CompressionFieldParams, non-None values will override
121    pub fn merge(&mut self, other: &Self) {
122        if other.rle_threshold.is_some() {
123            self.rle_threshold = other.rle_threshold;
124        }
125        if other.compression.is_some() {
126            self.compression = other.compression.clone();
127        }
128        if other.compression_level.is_some() {
129            self.compression_level = other.compression_level;
130        }
131        if other.bss.is_some() {
132            self.bss = other.bss;
133        }
134    }
135}
136
137/// Check if a name matches a pattern (supports wildcards)
138fn matches_pattern(name: &str, pattern: &str) -> bool {
139    if pattern == "*" {
140        return true;
141    }
142
143    if let Some(prefix) = pattern.strip_suffix('*') {
144        return name.starts_with(prefix);
145    }
146
147    if let Some(suffix) = pattern.strip_prefix('*') {
148        return name.ends_with(suffix);
149    }
150
151    if pattern.contains('*') {
152        // Simple glob pattern matching (only supports single * in middle)
153        if let Some(pos) = pattern.find('*') {
154            let prefix = &pattern[..pos];
155            let suffix = &pattern[pos + 1..];
156            return name.starts_with(prefix)
157                && name.ends_with(suffix)
158                && name.len() >= pattern.len() - 1;
159        }
160    }
161
162    name == pattern
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_pattern_matching() {
171        assert!(matches_pattern("user_id", "*_id"));
172        assert!(matches_pattern("product_id", "*_id"));
173        assert!(!matches_pattern("identity", "*_id"));
174
175        assert!(matches_pattern("log_message", "log_*"));
176        assert!(matches_pattern("log_level", "log_*"));
177        assert!(!matches_pattern("message_log", "log_*"));
178
179        assert!(matches_pattern("test_field_name", "test_*_name"));
180        assert!(matches_pattern("test_column_name", "test_*_name"));
181        assert!(!matches_pattern("test_name", "test_*_name"));
182
183        assert!(matches_pattern("anything", "*"));
184        assert!(matches_pattern("exact_match", "exact_match"));
185    }
186
187    #[test]
188    fn test_field_params_merge() {
189        let mut params = CompressionFieldParams::default();
190        assert_eq!(params.rle_threshold, None);
191        assert_eq!(params.compression, None);
192        assert_eq!(params.compression_level, None);
193        assert_eq!(params.bss, None);
194
195        let other = CompressionFieldParams {
196            rle_threshold: Some(0.3),
197            compression: Some("lz4".to_string()),
198            compression_level: None,
199            bss: Some(BssMode::On),
200        };
201
202        params.merge(&other);
203        assert_eq!(params.rle_threshold, Some(0.3));
204        assert_eq!(params.compression, Some("lz4".to_string()));
205        assert_eq!(params.compression_level, None);
206        assert_eq!(params.bss, Some(BssMode::On));
207
208        let another = CompressionFieldParams {
209            rle_threshold: None,
210            compression: Some("zstd".to_string()),
211            compression_level: Some(3),
212            bss: Some(BssMode::Auto),
213        };
214
215        params.merge(&another);
216        assert_eq!(params.rle_threshold, Some(0.3)); // Not overridden
217        assert_eq!(params.compression, Some("zstd".to_string())); // Overridden
218        assert_eq!(params.compression_level, Some(3)); // New value
219        assert_eq!(params.bss, Some(BssMode::Auto)); // Overridden
220    }
221
222    #[test]
223    fn test_get_field_params() {
224        let mut params = CompressionParams::new();
225
226        // Set type-level params
227        params.types.insert(
228            "Int32".to_string(),
229            CompressionFieldParams {
230                rle_threshold: Some(0.5),
231                compression: Some("lz4".to_string()),
232                ..Default::default()
233            },
234        );
235
236        // Set column-level params
237        params.columns.insert(
238            "*_id".to_string(),
239            CompressionFieldParams {
240                rle_threshold: Some(0.3),
241                compression: Some("zstd".to_string()),
242                compression_level: Some(3),
243                bss: None,
244            },
245        );
246
247        // Test no match (should get default)
248        let field_params = params.get_field_params("some_field", &DataType::Float32);
249        assert_eq!(field_params.compression, None);
250        assert_eq!(field_params.rle_threshold, None);
251
252        // Test type match only
253        let field_params = params.get_field_params("some_field", &DataType::Int32);
254        assert_eq!(field_params.compression, Some("lz4".to_string())); // From type
255        assert_eq!(field_params.rle_threshold, Some(0.5)); // From type
256
257        // Test column override (pattern match)
258        let field_params = params.get_field_params("user_id", &DataType::Int32);
259        assert_eq!(field_params.compression, Some("zstd".to_string())); // From column
260        assert_eq!(field_params.compression_level, Some(3)); // From column
261        assert_eq!(field_params.rle_threshold, Some(0.3)); // From column (overrides type)
262    }
263
264    #[test]
265    fn test_exact_match_priority() {
266        let mut params = CompressionParams::new();
267
268        // Add pattern
269        params.columns.insert(
270            "*_id".to_string(),
271            CompressionFieldParams {
272                compression: Some("lz4".to_string()),
273                ..Default::default()
274            },
275        );
276
277        // Add exact match
278        params.columns.insert(
279            "user_id".to_string(),
280            CompressionFieldParams {
281                compression: Some("zstd".to_string()),
282                ..Default::default()
283            },
284        );
285
286        // Exact match should win
287        let field_params = params.get_field_params("user_id", &DataType::Int32);
288        assert_eq!(field_params.compression, Some("zstd".to_string()));
289    }
290}