use crate::dataframe::DataFrame;
use crate::series::Series;
use crate::VeloxxError;
use rayon::prelude::*;
#[cfg(all(target_arch = "x86_64", not(target_arch = "wasm32")))]
use std::arch::x86_64::*;
use std::collections::HashMap;
pub struct UltraFastGroupBy;
impl UltraFastGroupBy {
pub fn ultra_simd_groupby_i32_sum(
group_values: &[i32],
group_bitmap: &[bool],
values: &[f64],
value_bitmap: &[bool],
group_col_name: &str,
value_col_name: &str,
) -> Result<DataFrame, VeloxxError> {
let len = group_values.len();
if len < 1000 {
return Self::simple_groupby(
group_values,
group_bitmap,
values,
value_bitmap,
group_col_name,
value_col_name,
);
}
Self::parallel_simd_groupby(
group_values,
group_bitmap,
values,
value_bitmap,
group_col_name,
value_col_name,
)
}
fn parallel_simd_groupby(
group_values: &[i32],
group_bitmap: &[bool],
values: &[f64],
value_bitmap: &[bool],
group_col_name: &str,
value_col_name: &str,
) -> Result<DataFrame, VeloxxError> {
let chunk_size = 8192; let num_chunks = group_values.len().div_ceil(chunk_size);
let partial_results: Vec<HashMap<i32, (f64, u32)>> = (0..num_chunks)
.into_par_iter()
.map(|chunk_idx| {
let start = chunk_idx * chunk_size;
let end = (start + chunk_size).min(group_values.len());
Self::process_chunk_simd(
&group_values[start..end],
&group_bitmap[start..end],
&values[start..end],
&value_bitmap[start..end],
)
})
.collect();
let mut final_map = HashMap::new();
for partial in partial_results {
for (key, (sum, count)) in partial {
let entry = final_map.entry(key).or_insert((0.0, 0u32));
entry.0 += sum;
entry.1 += count;
}
}
Self::map_to_dataframe(final_map, group_col_name, value_col_name)
}
fn process_chunk_simd(
group_values: &[i32],
group_bitmap: &[bool],
values: &[f64],
value_bitmap: &[bool],
) -> std::collections::HashMap<i32, (f64, u32)> {
let mut map = HashMap::with_capacity(group_values.len() / 4);
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return Self::process_chunk_avx2(group_values, group_bitmap, values, value_bitmap);
}
}
for i in 0..group_values.len() {
if group_bitmap[i] && value_bitmap[i] {
let entry = map.entry(group_values[i]).or_insert((0.0, 0));
entry.0 += values[i];
entry.1 += 1;
}
}
map
}
#[cfg(target_arch = "x86_64")]
fn process_chunk_avx2(
group_values: &[i32],
group_bitmap: &[bool],
values: &[f64],
value_bitmap: &[bool],
) -> std::collections::HashMap<i32, (f64, u32)> {
let mut map = HashMap::with_capacity(group_values.len() / 4);
let len = group_values.len();
let simd_len = len / 8;
let remainder = len % 8;
unsafe {
for i in 0..simd_len {
let base_idx = i * 8;
let _group_vec =
_mm256_loadu_si256(group_values.as_ptr().add(base_idx) as *const __m256i);
let _values_vec1 = _mm256_loadu_pd(values.as_ptr().add(base_idx));
let _values_vec2 = _mm256_loadu_pd(values.as_ptr().add(base_idx + 4));
let mut valid_elements = Vec::new();
for j in 0..8 {
let idx = base_idx + j;
if group_bitmap[idx] && value_bitmap[idx] {
valid_elements.push((group_values[idx], values[idx]));
}
}
for (group_key, value) in valid_elements {
let entry = map.entry(group_key).or_insert((0.0, 0));
entry.0 += value;
entry.1 += 1;
}
}
}
for i in (len - remainder)..len {
if group_bitmap[i] && value_bitmap[i] {
let entry = map.entry(group_values[i]).or_insert((0.0, 0));
entry.0 += values[i];
entry.1 += 1;
}
}
map
}
fn simple_groupby(
group_values: &[i32],
group_bitmap: &[bool],
values: &[f64],
value_bitmap: &[bool],
group_col_name: &str,
value_col_name: &str,
) -> Result<DataFrame, VeloxxError> {
let mut map = HashMap::new();
for i in 0..group_values.len() {
if group_bitmap[i] && value_bitmap[i] {
let entry = map.entry(group_values[i]).or_insert((0.0, 0u32));
entry.0 += values[i];
entry.1 += 1;
}
}
Self::map_to_dataframe(map, group_col_name, value_col_name)
}
fn map_to_dataframe(
map: std::collections::HashMap<i32, (f64, u32)>,
group_col_name: &str,
value_col_name: &str,
) -> Result<DataFrame, VeloxxError> {
if map.is_empty() {
let mut result = std::collections::HashMap::new();
result.insert(
group_col_name.to_string(),
Series::I32(group_col_name.to_string(), vec![], vec![]),
);
result.insert(
value_col_name.to_string(),
Series::F64(value_col_name.to_string(), vec![], vec![]),
);
return Ok(DataFrame::new(result.into_iter().collect()));
}
let mut keys: Vec<i32> = map.keys().copied().collect();
keys.sort_unstable();
let mut group_keys = Vec::with_capacity(keys.len());
let mut sum_values = Vec::with_capacity(keys.len());
for key in keys {
if let Some((sum, _count)) = map.get(&key) {
group_keys.push(key);
sum_values.push(*sum);
}
}
let mut result = HashMap::new();
result.insert(
group_col_name.to_string(),
Series::I32(
group_col_name.to_string(),
group_keys.clone(),
vec![true; group_keys.len()],
),
);
result.insert(
value_col_name.to_string(),
Series::F64(
value_col_name.to_string(),
sum_values.clone(),
vec![true; sum_values.len()],
),
);
Ok(DataFrame::new(result.into_iter().collect()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ultra_simd_groupby_basic() {
let group_values = vec![1, 2, 1, 3, 2];
let group_bitmap = vec![true, true, true, true, true];
let values = vec![10.0, 20.0, 30.0, 40.0, 50.0];
let value_bitmap = vec![true, true, true, true, true];
let result = UltraFastGroupBy::ultra_simd_groupby_i32_sum(
&group_values,
&group_bitmap,
&values,
&value_bitmap,
"group",
"value",
)
.unwrap();
assert_eq!(result.row_count(), 3); assert_eq!(result.column_count(), 2);
}
#[test]
fn test_ultra_simd_large_dataset() {
let group_values: Vec<i32> = (0..50000).map(|i| i % 1000).collect();
let group_bitmap = vec![true; 50000];
let values: Vec<f64> = (0..50000).map(|i| i as f64).collect();
let value_bitmap = vec![true; 50000];
let result = UltraFastGroupBy::ultra_simd_groupby_i32_sum(
&group_values,
&group_bitmap,
&values,
&value_bitmap,
"group",
"value",
)
.unwrap();
assert_eq!(result.row_count(), 1000); }
}