use crate::{
traits::SortedIterator,
utils::{BatchCodec, CodecIter, DefaultBatchCodec, MemoryUsage},
};
use anyhow::{Context, anyhow};
use dary_heap::PeekMut;
use std::path::{Path, PathBuf};
pub struct SortPairs<C: BatchCodec = DefaultBatchCodec> {
batch_size: usize,
tmp_dir: PathBuf,
batch_codec: C,
num_batches: usize,
last_batch_len: usize,
batch: Vec<((usize, usize), C::Label)>,
}
impl SortPairs {
pub fn new<P: AsRef<Path>>(memory_usage: MemoryUsage, tmp_dir: P) -> anyhow::Result<Self> {
Self::new_labeled(memory_usage, tmp_dir, DefaultBatchCodec::default())
}
pub fn push(&mut self, x: usize, y: usize) -> anyhow::Result<()> {
self.push_labeled(x, y, ())
}
pub fn sort(
&mut self,
pairs: impl IntoIterator<Item = (usize, usize)>,
) -> anyhow::Result<KMergeIters<CodecIter<DefaultBatchCodec>>> {
self.try_sort::<std::convert::Infallible>(pairs.into_iter().map(Ok))
}
pub fn try_sort<E: Into<anyhow::Error>>(
&mut self,
pairs: impl IntoIterator<Item = Result<(usize, usize), E>>,
) -> anyhow::Result<KMergeIters<CodecIter<DefaultBatchCodec>, ()>> {
for pair in pairs {
let (x, y) = pair.map_err(Into::into)?;
self.push(x, y)?;
}
self.iter()
}
}
impl<C: BatchCodec> SortPairs<C> {
pub fn new_labeled<P: AsRef<Path>>(
memory_usage: MemoryUsage,
dir: P,
batch_codec: C,
) -> anyhow::Result<Self> {
let dir = dir.as_ref();
let mut dir_entries =
std::fs::read_dir(dir).with_context(|| format!("Could not list {}", dir.display()))?;
if dir_entries.next().is_some() {
Err(anyhow!("{} is not empty", dir.display()))
} else {
let batch_size = memory_usage.batch_size::<(usize, usize, C::Label)>();
Ok(SortPairs {
batch_size,
batch_codec,
tmp_dir: dir.to_owned(),
num_batches: 0,
last_batch_len: 0,
batch: Vec::with_capacity(batch_size),
})
}
}
pub fn push_labeled(&mut self, x: usize, y: usize, t: C::Label) -> anyhow::Result<()> {
self.batch.push(((x, y), t));
if self.batch.len() >= self.batch_size {
self.dump()?;
}
Ok(())
}
fn dump(&mut self) -> anyhow::Result<()> {
if self.batch.is_empty() {
return Ok(());
}
let batch_path = self.tmp_dir.join(format!("{:06x}", self.num_batches));
let start = std::time::Instant::now();
let (bit_size, stats) = self.batch_codec.encode_batch(batch_path, &mut self.batch)?;
log::info!(
"Dumped batch {} with {} arcs ({} bits, {:.2} bits / arc) in {:.3} seconds, stats: {}",
self.num_batches,
self.batch.len(),
bit_size,
bit_size as f64 / self.batch.len() as f64,
start.elapsed().as_secs_f64(),
stats
);
self.last_batch_len = self.batch.len();
self.batch.clear();
self.num_batches += 1;
Ok(())
}
pub fn iter(&mut self) -> anyhow::Result<KMergeIters<CodecIter<C>, C::Label>> {
self.dump()?;
Ok(KMergeIters::new((0..self.num_batches).map(|batch_idx| {
let batch_path = self.tmp_dir.join(format!("{batch_idx:06x}"));
self.batch_codec
.decode_batch(batch_path)
.unwrap()
.into_iter()
})))
}
pub fn sort_labeled(
&mut self,
pairs: impl IntoIterator<Item = ((usize, usize), C::Label)>,
) -> anyhow::Result<KMergeIters<CodecIter<C>, C::Label>> {
self.try_sort_labeled::<std::convert::Infallible>(pairs.into_iter().map(Ok))
}
pub fn try_sort_labeled<E: Into<anyhow::Error>>(
&mut self,
pairs: impl IntoIterator<Item = Result<((usize, usize), C::Label), E>>,
) -> anyhow::Result<KMergeIters<CodecIter<C>, C::Label>> {
for pair in pairs {
let ((x, y), label) = pair.map_err(Into::into)?;
self.push_labeled(x, y, label)?;
}
self.iter()
}
}
#[derive(Clone, Debug)]
struct HeadTail<T, I: Iterator<Item = ((usize, usize), T)>> {
head: ((usize, usize), T),
tail: I,
}
impl<T, I: Iterator<Item = ((usize, usize), T)>> PartialEq for HeadTail<T, I> {
#[inline(always)]
fn eq(&self, other: &Self) -> bool {
self.head.0 == other.head.0
}
}
impl<T, I: Iterator<Item = ((usize, usize), T)>> Eq for HeadTail<T, I> {}
impl<T, I: Iterator<Item = ((usize, usize), T)>> PartialOrd for HeadTail<T, I> {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<T, I: Iterator<Item = ((usize, usize), T)>> Ord for HeadTail<T, I> {
#[inline(always)]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.head.0.cmp(&self.head.0)
}
}
#[derive(Clone, Debug)]
pub struct KMergeIters<I: Iterator<Item = ((usize, usize), T)>, T = ()> {
heap: dary_heap::QuaternaryHeap<HeadTail<T, I>>,
}
impl<T, I: Iterator<Item = ((usize, usize), T)>> KMergeIters<I, T> {
pub fn new(iters: impl IntoIterator<Item = I>) -> Self {
let iters = iters.into_iter();
let mut heap = dary_heap::QuaternaryHeap::with_capacity(iters.size_hint().1.unwrap_or(10));
for mut iter in iters {
if let Some((pair, label)) = iter.next() {
heap.push(HeadTail {
head: (pair, label),
tail: iter,
});
}
}
KMergeIters { heap }
}
}
unsafe impl<T, I: Iterator<Item = ((usize, usize), T)> + SortedIterator> SortedIterator
for KMergeIters<I, T>
{
}
impl<T, I: Iterator<Item = ((usize, usize), T)>> Iterator for KMergeIters<I, T> {
type Item = ((usize, usize), T);
fn next(&mut self) -> Option<Self::Item> {
let mut head_tail = self.heap.peek_mut()?;
match head_tail.tail.next() {
None => Some(PeekMut::pop(head_tail).head),
Some((pair, label)) => Some(std::mem::replace(&mut head_tail.head, (pair, label))),
}
}
}
impl<T, I: Iterator<Item = ((usize, usize), T)> + ExactSizeIterator> ExactSizeIterator
for KMergeIters<I, T>
{
fn len(&self) -> usize {
self.heap
.iter()
.map(|head_tail| {
1 + head_tail.tail.len()
})
.sum()
}
}
impl<T, I: Iterator<Item = ((usize, usize), T)>> core::default::Default for KMergeIters<I, T> {
fn default() -> Self {
KMergeIters {
heap: dary_heap::QuaternaryHeap::default(),
}
}
}
impl<T, I: Iterator<Item = ((usize, usize), T)>> core::iter::Sum for KMergeIters<I, T> {
fn sum<J: Iterator<Item = Self>>(iter: J) -> Self {
let mut heap = dary_heap::QuaternaryHeap::default();
for mut kmerge in iter {
heap.extend(kmerge.heap.drain());
}
KMergeIters { heap }
}
}
impl<T, I: IntoIterator<Item = ((usize, usize), T)>> core::iter::Sum<I>
for KMergeIters<I::IntoIter, T>
{
fn sum<J: Iterator<Item = I>>(iter: J) -> Self {
KMergeIters::new(iter.map(IntoIterator::into_iter))
}
}
impl<T, I: Iterator<Item = ((usize, usize), T)>> core::iter::FromIterator<Self>
for KMergeIters<I, T>
{
fn from_iter<J: IntoIterator<Item = Self>>(iter: J) -> Self {
iter.into_iter().sum()
}
}
impl<T, I: IntoIterator<Item = ((usize, usize), T)>> core::iter::FromIterator<I>
for KMergeIters<I::IntoIter, T>
{
fn from_iter<J: IntoIterator<Item = I>>(iter: J) -> Self {
KMergeIters::new(iter.into_iter().map(IntoIterator::into_iter))
}
}
impl<T, I: IntoIterator<Item = ((usize, usize), T)>> core::ops::AddAssign<I>
for KMergeIters<I::IntoIter, T>
{
fn add_assign(&mut self, rhs: I) {
let mut rhs = rhs.into_iter();
if let Some((pair, label)) = rhs.next() {
self.heap.push(HeadTail {
head: (pair, label),
tail: rhs,
});
}
}
}
impl<T, I: Iterator<Item = ((usize, usize), T)>> core::ops::AddAssign for KMergeIters<I, T> {
fn add_assign(&mut self, mut rhs: Self) {
self.heap.extend(rhs.heap.drain());
}
}
impl<T, I: IntoIterator<Item = ((usize, usize), T)>> Extend<I> for KMergeIters<I::IntoIter, T> {
fn extend<J: IntoIterator<Item = I>>(&mut self, iter: J) {
self.heap.extend(iter.into_iter().filter_map(|iter| {
let mut iter = iter.into_iter();
let (pair, label) = iter.next()?;
Some(HeadTail {
head: (pair, label),
tail: iter,
})
}));
}
}
impl<T, I: Iterator<Item = ((usize, usize), T)>> Extend<KMergeIters<I, T>> for KMergeIters<I, T> {
fn extend<J: IntoIterator<Item = KMergeIters<I, T>>>(&mut self, iter: J) {
for mut kmerge in iter {
self.heap.extend(kmerge.heap.drain());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
traits::{BitDeserializer, BitSerializer},
utils::{BitReader, BitWriter, gaps::GapsCodec},
};
use dsi_bitstream::prelude::*;
#[derive(Clone, Debug)]
struct MyDessert<E: Endianness> {
_marker: std::marker::PhantomData<E>,
}
impl<E: Endianness> Default for MyDessert<E> {
fn default() -> Self {
MyDessert {
_marker: std::marker::PhantomData,
}
}
}
impl<E: Endianness> BitDeserializer<E, BitReader<E>> for MyDessert<E>
where
BitReader<E>: BitRead<E> + CodesRead<E>,
{
type DeserType = usize;
fn deserialize(
&self,
bitstream: &mut BitReader<E>,
) -> Result<Self::DeserType, <BitReader<E> as BitRead<E>>::Error> {
bitstream.read_delta().map(|x| x as usize)
}
}
impl<E: Endianness> BitSerializer<E, BitWriter<E>> for MyDessert<E>
where
BitWriter<E>: BitWrite<E> + CodesWrite<E>,
{
type SerType = usize;
fn serialize(
&self,
value: &Self::SerType,
bitstream: &mut BitWriter<E>,
) -> Result<usize, <BitWriter<E> as BitWrite<E>>::Error> {
bitstream.write_delta(*value as u64)
}
}
#[test]
fn test_sort_pairs() -> anyhow::Result<()> {
use tempfile::Builder;
let dir = Builder::new().prefix("test_sort_pairs_").tempdir()?;
let mut sp = SortPairs::new_labeled(
MemoryUsage::BatchSize(10),
dir.path(),
GapsCodec::<BE, MyDessert<BE>, MyDessert<BE>>::default(),
)?;
let n = 25;
for i in 0..n {
sp.push_labeled(i, i + 1, i + 2)?;
}
let mut iter = sp.iter()?;
let mut cloned = iter.clone();
for _ in 0..n {
let ((x, y), p) = iter.next().unwrap();
println!("{} {} {}", x, y, p);
assert_eq!(x + 1, y);
assert_eq!(x + 2, p);
}
for _ in 0..n {
let ((x, y), p) = cloned.next().unwrap();
println!("{} {} {}", x, y, p);
assert_eq!(x + 1, y);
assert_eq!(x + 2, p);
}
Ok(())
}
#[test]
fn test_sort_and_sort_labeled() -> anyhow::Result<()> {
use tempfile::Builder;
let dir = Builder::new().prefix("test_sort_").tempdir()?;
let mut sp = SortPairs::new(MemoryUsage::BatchSize(10), dir.path())?;
let pairs = vec![(3, 4), (1, 2), (5, 6), (0, 1), (2, 3)];
let iter = sp.sort(pairs)?;
let mut sorted_pairs = Vec::new();
for ((x, y), _) in iter {
sorted_pairs.push((x, y));
}
assert_eq!(sorted_pairs, vec![(0, 1), (1, 2), (2, 3), (3, 4), (5, 6)]);
let dir2 = Builder::new().prefix("test_sort_labeled_").tempdir()?;
let mut sp2 = SortPairs::new_labeled(
MemoryUsage::BatchSize(5),
dir2.path(),
GapsCodec::<BE, MyDessert<BE>, MyDessert<BE>>::default(),
)?;
let labeled_pairs = vec![
((3, 4), 7),
((1, 2), 5),
((5, 6), 9),
((0, 1), 4),
((2, 3), 6),
];
let iter2 = sp2.sort_labeled(labeled_pairs)?;
let mut sorted_labeled = Vec::new();
for ((x, y), label) in iter2 {
sorted_labeled.push((x, y, label));
}
assert_eq!(
sorted_labeled,
vec![(0, 1, 4), (1, 2, 5), (2, 3, 6), (3, 4, 7), (5, 6, 9)]
);
Ok(())
}
}