use std::collections::HashMap;
use arrow_schema::DataType;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum BssMode {
Off,
On,
Auto,
}
impl BssMode {
pub fn to_sensitivity(&self) -> f32 {
match self {
Self::Off => 0.0,
Self::On => 1.0,
Self::Auto => 0.5, }
}
pub fn parse(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"off" => Some(Self::Off),
"on" => Some(Self::On),
"auto" => Some(Self::Auto),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CompressionParams {
pub columns: HashMap<String, CompressionFieldParams>,
pub types: HashMap<String, CompressionFieldParams>,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct CompressionFieldParams {
pub rle_threshold: Option<f64>,
pub compression: Option<String>,
pub compression_level: Option<i32>,
pub bss: Option<BssMode>,
pub minichunk_size: Option<i64>,
}
impl CompressionParams {
pub fn new() -> Self {
Self {
columns: HashMap::new(),
types: HashMap::new(),
}
}
pub fn get_field_params(
&self,
field_name: &str,
data_type: &DataType,
) -> CompressionFieldParams {
let mut params = CompressionFieldParams::default();
let type_name = data_type.to_string();
if let Some(type_params) = self.types.get(&type_name) {
params.merge(type_params);
}
if let Some(col_params) = self.columns.get(field_name) {
params.merge(col_params);
} else {
for (pattern, col_params) in &self.columns {
if matches_pattern(field_name, pattern) {
params.merge(col_params);
break; }
}
}
params
}
}
impl Default for CompressionParams {
fn default() -> Self {
Self::new()
}
}
impl CompressionFieldParams {
pub fn merge(&mut self, other: &Self) {
if other.rle_threshold.is_some() {
self.rle_threshold = other.rle_threshold;
}
if other.compression.is_some() {
self.compression = other.compression.clone();
}
if other.compression_level.is_some() {
self.compression_level = other.compression_level;
}
if other.bss.is_some() {
self.bss = other.bss;
}
if other.minichunk_size.is_some() {
self.minichunk_size = other.minichunk_size;
}
}
}
fn matches_pattern(name: &str, pattern: &str) -> bool {
if pattern == "*" {
return true;
}
if let Some(prefix) = pattern.strip_suffix('*') {
return name.starts_with(prefix);
}
if let Some(suffix) = pattern.strip_prefix('*') {
return name.ends_with(suffix);
}
if pattern.contains('*') {
if let Some(pos) = pattern.find('*') {
let prefix = &pattern[..pos];
let suffix = &pattern[pos + 1..];
return name.starts_with(prefix)
&& name.ends_with(suffix)
&& name.len() >= pattern.len() - 1;
}
}
name == pattern
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_matching() {
assert!(matches_pattern("user_id", "*_id"));
assert!(matches_pattern("product_id", "*_id"));
assert!(!matches_pattern("identity", "*_id"));
assert!(matches_pattern("log_message", "log_*"));
assert!(matches_pattern("log_level", "log_*"));
assert!(!matches_pattern("message_log", "log_*"));
assert!(matches_pattern("test_field_name", "test_*_name"));
assert!(matches_pattern("test_column_name", "test_*_name"));
assert!(!matches_pattern("test_name", "test_*_name"));
assert!(matches_pattern("anything", "*"));
assert!(matches_pattern("exact_match", "exact_match"));
}
#[test]
fn test_field_params_merge() {
let mut params = CompressionFieldParams::default();
assert_eq!(params.rle_threshold, None);
assert_eq!(params.compression, None);
assert_eq!(params.compression_level, None);
assert_eq!(params.bss, None);
let other = CompressionFieldParams {
rle_threshold: Some(0.3),
compression: Some("lz4".to_string()),
compression_level: None,
bss: Some(BssMode::On),
minichunk_size: None,
};
params.merge(&other);
assert_eq!(params.rle_threshold, Some(0.3));
assert_eq!(params.compression, Some("lz4".to_string()));
assert_eq!(params.compression_level, None);
assert_eq!(params.bss, Some(BssMode::On));
let another = CompressionFieldParams {
rle_threshold: None,
compression: Some("zstd".to_string()),
compression_level: Some(3),
bss: Some(BssMode::Auto),
minichunk_size: None,
};
params.merge(&another);
assert_eq!(params.rle_threshold, Some(0.3)); assert_eq!(params.compression, Some("zstd".to_string())); assert_eq!(params.compression_level, Some(3)); assert_eq!(params.bss, Some(BssMode::Auto)); }
#[test]
fn test_get_field_params() {
let mut params = CompressionParams::new();
params.types.insert(
"Int32".to_string(),
CompressionFieldParams {
rle_threshold: Some(0.5),
compression: Some("lz4".to_string()),
..Default::default()
},
);
params.columns.insert(
"*_id".to_string(),
CompressionFieldParams {
rle_threshold: Some(0.3),
compression: Some("zstd".to_string()),
compression_level: Some(3),
bss: None,
minichunk_size: None,
},
);
let field_params = params.get_field_params("some_field", &DataType::Float32);
assert_eq!(field_params.compression, None);
assert_eq!(field_params.rle_threshold, None);
let field_params = params.get_field_params("some_field", &DataType::Int32);
assert_eq!(field_params.compression, Some("lz4".to_string())); assert_eq!(field_params.rle_threshold, Some(0.5));
let field_params = params.get_field_params("user_id", &DataType::Int32);
assert_eq!(field_params.compression, Some("zstd".to_string())); assert_eq!(field_params.compression_level, Some(3)); assert_eq!(field_params.rle_threshold, Some(0.3)); }
#[test]
fn test_exact_match_priority() {
let mut params = CompressionParams::new();
params.columns.insert(
"*_id".to_string(),
CompressionFieldParams {
compression: Some("lz4".to_string()),
..Default::default()
},
);
params.columns.insert(
"user_id".to_string(),
CompressionFieldParams {
compression: Some("zstd".to_string()),
..Default::default()
},
);
let field_params = params.get_field_params("user_id", &DataType::Int32);
assert_eq!(field_params.compression, Some("zstd".to_string()));
}
}