use super::strategy::GpuHashStrategy;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GpuGroupByStrategy {
Hash(GpuHashStrategy),
RadixSort,
}
impl GpuGroupByStrategy {
pub fn choose_for_cardinality(estimated_groups: usize) -> Self {
const RADIX_SORT_THRESHOLD: usize = 100_000;
if estimated_groups >= RADIX_SORT_THRESHOLD {
GpuGroupByStrategy::RadixSort
} else {
let chosen_hash = if estimated_groups < 100 {
GpuHashStrategy::Linear
} else if estimated_groups < 10_000 {
GpuHashStrategy::Cuckoo
} else {
GpuHashStrategy::RobinHood
};
GpuGroupByStrategy::Hash(chosen_hash)
}
}
pub fn choose_with_override(estimated_groups: usize, force_radix: bool) -> Self {
if force_radix {
GpuGroupByStrategy::RadixSort
} else {
Self::choose_for_cardinality(estimated_groups)
}
}
}
#[cfg(test)]
pub fn estimate_cardinality_i32(keys: &[i32]) -> usize {
use std::collections::HashSet;
let n = keys.len();
if n <= 10_000 {
let unique: HashSet<i32> = keys.iter().copied().collect();
return unique.len();
}
let sample_size = (n / 10).min(100_000);
let step = n / sample_size;
let mut sample_unique = HashSet::new();
for i in (0..n).step_by(step) {
sample_unique.insert(keys[i]);
}
let estimated = (sample_unique.len() * n / sample_size) * 12 / 10;
estimated.min(n) }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cardinality_estimation_exact() {
let keys: Vec<i32> = (0..1000).collect();
let est = estimate_cardinality_i32(&keys);
assert!((900..=1100).contains(&est), "Estimated: {}", est);
let keys = vec![42; 1000];
let est = estimate_cardinality_i32(&keys);
assert!(est <= 10, "Estimated: {}", est);
}
#[test]
fn test_strategy_selection() {
let strategy = GpuGroupByStrategy::choose_for_cardinality(50);
assert_eq!(strategy, GpuGroupByStrategy::Hash(GpuHashStrategy::Linear));
let strategy = GpuGroupByStrategy::choose_for_cardinality(500);
assert_eq!(strategy, GpuGroupByStrategy::Hash(GpuHashStrategy::Cuckoo));
let strategy = GpuGroupByStrategy::choose_for_cardinality(50_000);
assert_eq!(
strategy,
GpuGroupByStrategy::Hash(GpuHashStrategy::RobinHood)
);
let strategy = GpuGroupByStrategy::choose_for_cardinality(200_000);
assert_eq!(strategy, GpuGroupByStrategy::RadixSort);
}
}