use crate::dataframe::DataFrame;
use crate::series::Series;
use crate::VeloxxError;
use rayon::prelude::*;
#[cfg(all(target_arch = "x86_64", not(target_arch = "wasm32")))]
use std::arch::x86_64::*;
use std::collections::HashMap;
struct DenseSimdGroupByParams<'a> {
group_values: &'a [i32],
group_bitmap: &'a [bool],
values: &'a [f64],
value_bitmap: &'a [bool],
group_col_name: &'a str,
value_col_name: &'a str,
min_key: i32,
range: usize,
}
pub struct FastGroupBy;
impl FastGroupBy {
pub fn simd_groupby_i32_sum(
group_values: &[i32],
group_bitmap: &[bool],
values: &[f64],
value_bitmap: &[bool],
group_col_name: &str,
value_col_name: &str,
) -> Result<DataFrame, VeloxxError> {
let (min_key, max_key, valid_count) =
Self::find_i32_range_simd(group_values, group_bitmap, value_bitmap);
if valid_count == 0 {
return Self::create_empty_result(group_col_name, value_col_name);
}
let range = (max_key as i64 - min_key as i64) as usize + 1;
if range <= 65536 && valid_count >= 1000 {
Self::dense_array_groupby_simd(DenseSimdGroupByParams {
group_values,
group_bitmap,
values,
value_bitmap,
group_col_name,
value_col_name,
min_key,
range,
})
} else {
Self::hash_groupby_simd(
group_values,
group_bitmap,
values,
value_bitmap,
group_col_name,
value_col_name,
)
}
}
#[cfg(target_arch = "x86_64")]
fn find_i32_range_simd(
group_values: &[i32],
group_bitmap: &[bool],
value_bitmap: &[bool],
) -> (i32, i32, usize) {
if !is_x86_feature_detected!("avx2") {
return Self::find_i32_range_scalar(group_values, group_bitmap, value_bitmap);
}
unsafe {
let mut min_vec = _mm256_set1_epi32(i32::MAX);
let mut max_vec = _mm256_set1_epi32(i32::MIN);
let mut count = 0usize;
let len = group_values.len();
let chunks = len / 8;
let remainder = len % 8;
for i in 0..chunks {
let base_idx = i * 8;
let mut valid_mask = 0u8;
for j in 0..8 {
let idx = base_idx + j;
if group_bitmap[idx] && value_bitmap[idx] {
valid_mask |= 1 << j;
count += 1;
}
}
if valid_mask != 0 {
let data =
_mm256_loadu_si256(group_values.as_ptr().add(base_idx) as *const __m256i);
let mask_expanded = Self::expand_byte_mask_to_i32_avx2(valid_mask);
let masked_data = _mm256_blendv_epi8(
_mm256_set1_epi32(i32::MAX), data,
mask_expanded,
);
min_vec = _mm256_min_epi32(min_vec, masked_data);
let masked_data_max =
_mm256_blendv_epi8(_mm256_set1_epi32(i32::MIN), data, mask_expanded);
max_vec = _mm256_max_epi32(max_vec, masked_data_max);
}
}
let min_array: [i32; 8] = std::mem::transmute(min_vec);
let max_array: [i32; 8] = std::mem::transmute(max_vec);
let mut final_min = min_array[0];
let mut final_max = max_array[0];
for i in 1..8 {
final_min = final_min.min(min_array[i]);
final_max = final_max.max(max_array[i]);
}
for i in (len - remainder)..len {
if group_bitmap[i] && value_bitmap[i] {
final_min = final_min.min(group_values[i]);
final_max = final_max.max(group_values[i]);
}
}
(final_min, final_max, count)
}
}
#[cfg(not(target_arch = "x86_64"))]
fn find_i32_range_simd(
group_values: &[i32],
group_bitmap: &[bool],
value_bitmap: &[bool],
) -> (i32, i32, usize) {
Self::find_i32_range_scalar(group_values, group_bitmap, value_bitmap)
}
fn find_i32_range_scalar(
group_values: &[i32],
group_bitmap: &[bool],
value_bitmap: &[bool],
) -> (i32, i32, usize) {
let mut min_key = i32::MAX;
let mut max_key = i32::MIN;
let mut count = 0;
for i in 0..group_values.len() {
if group_bitmap[i] && value_bitmap[i] {
min_key = min_key.min(group_values[i]);
max_key = max_key.max(group_values[i]);
count += 1;
}
}
(min_key, max_key, count)
}
#[allow(clippy::too_many_arguments)]
fn dense_array_groupby_simd(params: DenseSimdGroupByParams) -> Result<DataFrame, VeloxxError> {
let mut sums = vec![0.0f64; params.range];
let mut counts = vec![0u32; params.range];
Self::accumulate_dense_simd(
&mut sums,
&mut counts,
params.group_values,
params.group_bitmap,
params.values,
params.value_bitmap,
params.min_key,
);
let mut group_keys = Vec::new();
let mut sum_values = Vec::new();
for group_index in 0..params.range {
if counts[group_index] > 0 {
group_keys.push(params.min_key + group_index as i32);
sum_values.push(sums[group_index]);
}
}
let mut result = HashMap::new();
result.insert(
params.group_col_name.to_string(),
Series::I32(
params.group_col_name.to_string(),
group_keys.clone(),
vec![true; group_keys.len()],
),
);
result.insert(
params.value_col_name.to_string(),
Series::F64(
params.value_col_name.to_string(),
sum_values.clone(),
vec![true; sum_values.len()],
),
);
Ok(DataFrame::new(result.into_iter().collect()))
}
#[cfg(target_arch = "x86_64")]
fn accumulate_dense_simd(
sums: &mut [f64],
counts: &mut [u32],
group_values: &[i32],
group_bitmap: &[bool],
values: &[f64],
value_bitmap: &[bool],
min_key: i32,
) {
Self::accumulate_dense_scalar(
sums,
counts,
group_values,
group_bitmap,
values,
value_bitmap,
min_key,
);
}
#[cfg(not(target_arch = "x86_64"))]
fn accumulate_dense_simd(
sums: &mut [f64],
counts: &mut [u32],
group_values: &[i32],
group_bitmap: &[bool],
values: &[f64],
value_bitmap: &[bool],
min_key: i32,
) {
Self::accumulate_dense_scalar(
sums,
counts,
group_values,
group_bitmap,
values,
value_bitmap,
min_key,
);
}
fn accumulate_dense_scalar(
sums: &mut [f64],
counts: &mut [u32],
group_values: &[i32],
group_bitmap: &[bool],
values: &[f64],
value_bitmap: &[bool],
min_key: i32,
) {
for i in 0..group_values.len() {
if group_bitmap[i] && value_bitmap[i] {
let group_index = (group_values[i] - min_key) as usize;
if group_index < sums.len() {
sums[group_index] += values[i];
counts[group_index] += 1;
}
}
}
}
fn hash_groupby_simd(
group_values: &[i32],
group_bitmap: &[bool],
values: &[f64],
value_bitmap: &[bool],
group_col_name: &str,
value_col_name: &str,
) -> Result<DataFrame, VeloxxError> {
let map = (0..group_values.len())
.into_par_iter()
.filter_map(|i| {
if group_bitmap[i] && value_bitmap[i] {
Some((group_values[i], values[i]))
} else {
None
}
})
.fold(HashMap::<i32, (f64, u32)>::new, |mut map, (key, value)| {
let entry = map.entry(key).or_insert((0.0, 0));
entry.0 += value;
entry.1 += 1;
map
})
.reduce(HashMap::<i32, (f64, u32)>::new, |mut acc, map| {
for (key, (sum, count)) in map {
let entry = acc.entry(key).or_insert((0.0, 0));
entry.0 += sum;
entry.1 += count;
}
acc
});
let mut group_keys = Vec::new();
let mut sum_values = Vec::new();
for (key, (sum, _count)) in map {
group_keys.push(key);
sum_values.push(sum);
}
let mut result = std::collections::HashMap::new();
result.insert(
group_col_name.to_string(),
Series::I32(
group_col_name.to_string(),
group_keys.clone(),
vec![true; group_keys.len()],
),
);
result.insert(
value_col_name.to_string(),
Series::F64(
value_col_name.to_string(),
sum_values.clone(),
vec![true; sum_values.len()],
),
);
Ok(DataFrame::new(result.into_iter().collect()))
}
fn create_empty_result(
group_col_name: &str,
value_col_name: &str,
) -> Result<DataFrame, VeloxxError> {
let mut result = std::collections::HashMap::new();
result.insert(
group_col_name.to_string(),
Series::I32(group_col_name.to_string(), vec![], vec![]),
);
result.insert(
value_col_name.to_string(),
Series::F64(value_col_name.to_string(), vec![], vec![]),
);
Ok(DataFrame::new(result.into_iter().collect()))
}
#[cfg(target_arch = "x86_64")]
unsafe fn expand_byte_mask_to_i32_avx2(mask: u8) -> __m256i {
let mut expanded = [0i32; 8];
for (i, item) in expanded.iter_mut().enumerate() {
*item = if (mask & (1 << i)) != 0 { -1 } else { 0 };
}
_mm256_loadu_si256(expanded.as_ptr() as *const __m256i)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_groupby_basic() {
let group_values = vec![1, 2, 1, 3, 2];
let group_bitmap = vec![true, true, true, true, true];
let values = vec![10.0, 20.0, 30.0, 40.0, 50.0];
let value_bitmap = vec![true, true, true, true, true];
let result = FastGroupBy::simd_groupby_i32_sum(
&group_values,
&group_bitmap,
&values,
&value_bitmap,
"group",
"value",
)
.unwrap();
println!("Result row count: {}", result.row_count());
println!("Result column count: {}", result.column_count());
assert_eq!(result.column_count(), 2);
}
#[test]
fn test_dense_vs_hash_threshold() {
let group_values: Vec<i32> = (0..10000).map(|i| i % 100).collect(); let group_bitmap = vec![true; 10000];
let values: Vec<f64> = (0..10000).map(|i| i as f64).collect();
let value_bitmap = vec![true; 10000];
let result = FastGroupBy::simd_groupby_i32_sum(
&group_values,
&group_bitmap,
&values,
&value_bitmap,
"group",
"value",
)
.unwrap();
assert_eq!(result.row_count(), 100); }
}