use alloc::borrow::ToOwned;
use core::fmt::{Display, Formatter};
use core::num::NonZero;
use core::str::FromStr;
pub trait PartitionIterate: Iterator {
unsafe fn partition_with_get<S, G>(
self,
partition_sizer: S,
length_getter: G,
) -> impl Iterator<Item = (PartitionIndex, Self::Item)>
where
S: Fn(PartitionIndex) -> PartitionSize,
G: Fn(&Self::Item) -> u64;
unsafe fn partition_with<S>(
self,
partition_sizer: S,
) -> impl Iterator<Item = (PartitionIndex, Self::Item)>
where
S: Fn(PartitionIndex) -> PartitionSize,
Self::Item: ToOwned<Owned = u64>,
Self: Sized,
{
unsafe { self.partition_with_get(partition_sizer, |i| i.to_owned()) }
}
}
impl<I> PartitionIterate for I
where
I: Iterator,
{
unsafe fn partition_with_get<S, G>(
self,
partition_sizer: S,
length_getter: G,
) -> impl Iterator<Item = (PartitionIndex, Self::Item)>
where
S: Fn(PartitionIndex) -> PartitionSize,
G: Fn(&Self::Item) -> u64,
{
unsafe { PartitionIterator::new(self, partition_sizer, length_getter) }
}
}
#[derive(PartialEq, PartialOrd, Debug, Clone, Copy)]
pub struct PartitionExponentBase(f32);
impl PartitionExponentBase {
pub fn new(base: f32) -> anyhow::Result<Self> {
if base <= 1.0 {
return Err(anyhow::anyhow!("'base' must be greater than 1.0"));
}
Ok(Self(base))
}
}
impl Default for PartitionExponentBase {
fn default() -> Self {
PartitionExponentBase::new(1.1f32).unwrap()
}
}
impl FromStr for PartitionExponentBase {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let base: f32 = s.parse()?;
PartitionExponentBase::new(base)
}
}
impl Display for PartitionExponentBase {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Clone, Copy)]
pub struct PartitionIndex(pub u32);
#[derive(Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Clone, Copy)]
pub struct PartitionSize(pub NonZero<u64>);
impl PartitionSize {
pub fn new(index: u64) -> Option<Self> {
NonZero::new(index).map(Self)
}
pub unsafe fn new_unchecked(index: u64) -> Self {
Self(unsafe { NonZero::new_unchecked(index) })
}
pub fn constant(size: NonZero<u64>) -> Self {
Self(size)
}
pub fn exponential(base: PartitionExponentBase, index: PartitionIndex) -> Self {
unsafe {
Self::new_unchecked(base.0.powi(index.0.cast_signed()).ceil() as u64)
}
}
}
struct PartitionIterator<S, I, G> {
current_index: PartitionIndex,
partition_sizer: S,
cumulated_size: u64,
source_iter: I,
length_getter: G,
}
impl<S, I, G> PartitionIterator<S, I, G> {
pub unsafe fn new(source_iter: I, partition_sizer: S, length_getter: G) -> Self {
Self {
current_index: PartitionIndex(0),
cumulated_size: 0,
source_iter,
partition_sizer,
length_getter,
}
}
}
impl<S, I, G> Iterator for PartitionIterator<S, I, G>
where
I: Iterator,
S: Fn(PartitionIndex) -> PartitionSize,
G: Fn(&I::Item) -> u64,
{
type Item = (PartitionIndex, I::Item);
fn next(&mut self) -> Option<Self::Item> {
let Some(next) = self.source_iter.next() else {
return None;
};
let item_length = (self.length_getter)(&next);
loop {
let partition_size = (self.partition_sizer)(self.current_index);
if item_length < self.cumulated_size + partition_size.0.get() {
return Some((self.current_index, next));
}
self.current_index = PartitionIndex(self.current_index.0 + 1);
self.cumulated_size += partition_size.0.get();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::{vec, vec::Vec};
#[test]
fn simple_test() {
let result = unsafe {
(0..16)
.into_iter()
.partition_with(
|_| PartitionSize(NonZero::new(4).unwrap()),
)
.map(|(i, _)| i.0)
.collect::<Vec<_>>()
};
assert_eq!(vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], result);
}
#[test]
fn exponential_test() {
let result = unsafe {
(0..32)
.into_iter()
.partition_with(|i| {
PartitionSize::exponential(PartitionExponentBase::new(1.3f32).unwrap(), i)
})
.map(|(i, _)| i.0)
.collect::<Vec<_>>()
};
assert_eq!(
vec![
0, 1, 1, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 8,
8, 8, 8, 8
],
result
);
}
}