#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AccumulationStrategy {
Sum,
Max,
Min,
Mean,
Product,
LogSumExp,
}
#[derive(Debug, Clone)]
pub struct PartitionConfig {
pub chunk_size: usize,
pub max_memory_bytes: Option<usize>,
pub accumulation: AccumulationStrategy,
pub parallel: bool,
pub epsilon: f64,
}
impl PartitionConfig {
pub fn new(chunk_size: usize) -> Self {
PartitionConfig {
chunk_size,
..Default::default()
}
}
pub fn memory_bounded(max_bytes: usize, element_size: usize) -> Self {
let chunk_size = max_bytes.checked_div(element_size).unwrap_or(1).max(1);
PartitionConfig {
chunk_size,
max_memory_bytes: Some(max_bytes),
..Default::default()
}
}
pub fn with_strategy(mut self, strategy: AccumulationStrategy) -> Self {
self.accumulation = strategy;
self
}
pub fn with_parallel(mut self, parallel: bool) -> Self {
self.parallel = parallel;
self
}
pub fn chunks_for_size(&self, total_elements: usize) -> usize {
if self.chunk_size == 0 {
return 0;
}
total_elements.div_ceil(self.chunk_size)
}
}
impl Default for PartitionConfig {
fn default() -> Self {
PartitionConfig {
chunk_size: 4096,
max_memory_bytes: None,
accumulation: AccumulationStrategy::Sum,
parallel: false,
epsilon: 1e-12,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_partition_config_new() {
let cfg = PartitionConfig::new(1024);
assert_eq!(cfg.chunk_size, 1024);
assert!(cfg.max_memory_bytes.is_none());
assert_eq!(cfg.accumulation, AccumulationStrategy::Sum);
assert!(!cfg.parallel);
}
#[test]
fn test_partition_config_memory_bounded() {
let cfg = PartitionConfig::memory_bounded(64, 8);
assert_eq!(cfg.chunk_size, 8);
assert_eq!(cfg.max_memory_bytes, Some(64));
}
#[test]
fn test_chunks_for_size() {
let cfg = PartitionConfig::new(10);
assert_eq!(cfg.chunks_for_size(0), 0);
assert_eq!(cfg.chunks_for_size(10), 1);
assert_eq!(cfg.chunks_for_size(11), 2);
assert_eq!(cfg.chunks_for_size(100), 10);
assert_eq!(cfg.chunks_for_size(101), 11);
}
}