use crate::VeloxxError;
use rayon::prelude::*;
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Mutex;
pub struct ParallelGroupByResult<K, V> {
pub groups: HashMap<K, Vec<V>>,
}
pub trait ParallelGroupBy<K, V> {
fn parallel_group_by<F>(&self, key_fn: F) -> Result<ParallelGroupByResult<K, V>, VeloxxError>
where
F: Fn(&V) -> K + Sync + Send,
K: Eq + Hash + Sync + Send,
V: Sync + Send;
}
impl<K, V> ParallelGroupBy<K, V> for Vec<V>
where
K: Eq + Hash + Sync + Send,
V: Clone + Sync + Send,
{
fn parallel_group_by<F>(&self, key_fn: F) -> Result<ParallelGroupByResult<K, V>, VeloxxError>
where
F: Fn(&V) -> K + Sync + Send,
{
let groups: Mutex<HashMap<K, Vec<V>>> = Mutex::new(HashMap::new());
self.par_iter().try_for_each(|item| {
let key = key_fn(item);
let mut groups = groups.lock().map_err(|_| {
VeloxxError::ExecutionError(
"Failed to acquire lock for parallel group_by".to_string(),
)
})?;
groups
.entry(key)
.or_insert_with(Vec::new)
.push(item.clone());
Ok::<(), VeloxxError>(())
})?;
let groups = groups.into_inner().map_err(|_| {
VeloxxError::ExecutionError(
"Failed to extract results from parallel group_by".to_string(),
)
})?;
Ok(ParallelGroupByResult { groups })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::series::Series;
#[test]
fn test_parallel_group_by_series() {
let series1 = Series::new_i32("series1", vec![Some(1), Some(2), Some(3)]);
let series2 = Series::new_i32("series2", vec![Some(4), Some(5), Some(6)]);
let series3 = Series::new_i32("series1", vec![Some(7), Some(8), Some(9)]);
let data = vec![series1, series2, series3];
let result = data.parallel_group_by(|s| s.name().to_string()).unwrap();
assert_eq!(result.groups.len(), 2);
assert_eq!(result.groups.get("series1").unwrap().len(), 2);
assert_eq!(result.groups.get("series2").unwrap().len(), 1);
}
#[test]
fn test_parallel_group_by_integers() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let result = data.parallel_group_by(|&x| x % 2).unwrap();
assert_eq!(result.groups.len(), 2);
assert_eq!(result.groups.get(&0).unwrap().len(), 5); assert_eq!(result.groups.get(&1).unwrap().len(), 5); }
}