use core::num::NonZeroUsize;
use sync_cell_slice::SyncSlice;
use anyhow::{Context, Result};
use dsi_progress_logger::{ProgressLog, concurrent_progress_logger};
use rayon::prelude::*;
use super::MemoryUsage;
use super::sort_pairs::KMergeIters;
use crate::utils::SplitIters;
use crate::utils::{BatchCodec, CodecIter, DefaultBatchCodec};
pub struct ParSortIters {
num_nodes: usize,
expected_num_pairs: Option<usize>,
num_partitions: NonZeroUsize,
memory_usage: MemoryUsage,
}
impl ParSortIters {
pub fn sort(
&self,
pairs: impl IntoIterator<
Item: IntoIterator<Item = (usize, usize), IntoIter: Send + Sync> + Send + Sync,
IntoIter: ExactSizeIterator + Send + Sync,
>,
) -> Result<SplitIters<impl IntoIterator<Item = (usize, usize), IntoIter: Send + Sync>>> {
self.try_sort::<std::convert::Infallible>(pairs)
}
pub fn try_sort<E: Into<anyhow::Error>>(
&self,
pairs: impl IntoIterator<
Item: IntoIterator<Item = (usize, usize), IntoIter: Send + Sync> + Send + Sync,
IntoIter: ExactSizeIterator + Send + Sync,
>,
) -> Result<SplitIters<impl IntoIterator<Item = (usize, usize), IntoIter: Send + Sync>>> {
let split = <ParSortIters>::try_sort_labeled::<DefaultBatchCodec, E, _>(
self,
DefaultBatchCodec::default(),
pairs
.into_iter()
.map(|iter| iter.into_iter().map(|pair| (pair, ()))),
)?;
let iters_without_labels: Vec<_> = split
.iters
.into_vec()
.into_iter()
.map(|iter| iter.into_iter().map(|(pair, _)| pair))
.collect();
Ok(SplitIters::new(
split.boundaries,
iters_without_labels.into_boxed_slice(),
))
}
}
impl ParSortIters {
pub fn new(num_nodes: usize) -> Result<Self> {
Ok(Self {
num_nodes,
expected_num_pairs: None,
num_partitions: NonZeroUsize::new(num_cpus::get()).context("zero CPUs")?,
memory_usage: MemoryUsage::default(),
})
}
pub fn expected_num_pairs(self, expected_num_pairs: usize) -> Self {
Self {
expected_num_pairs: Some(expected_num_pairs),
..self
}
}
pub fn num_partitions(self, num_partitions: NonZeroUsize) -> Self {
Self {
num_partitions,
..self
}
}
pub fn memory_usage(self, memory_usage: MemoryUsage) -> Self {
Self {
memory_usage,
..self
}
}
pub fn sort_labeled<
C: BatchCodec,
P: IntoIterator<
Item: IntoIterator<Item = ((usize, usize), C::Label), IntoIter: Send> + Send,
IntoIter: ExactSizeIterator,
>,
>(
&self,
batch_codec: C,
pairs: P,
) -> Result<
SplitIters<
impl IntoIterator<Item = ((usize, usize), C::Label), IntoIter: Send + Sync> + use<C, P>,
>,
> {
self.try_sort_labeled::<C, std::convert::Infallible, P>(batch_codec, pairs)
}
pub fn try_sort_labeled<
C: BatchCodec,
E: Into<anyhow::Error>,
P: IntoIterator<
Item: IntoIterator<Item = ((usize, usize), C::Label), IntoIter: Send> + Send,
IntoIter: ExactSizeIterator,
>,
>(
&self,
batch_codec: C,
pairs: P,
) -> Result<
SplitIters<
impl IntoIterator<Item = ((usize, usize), C::Label), IntoIter: Send + Sync> + use<C, E, P>,
>,
> {
let unsorted_pairs = pairs;
let num_partitions = self.num_partitions.into();
let num_buffers = rayon::current_num_threads() * num_partitions;
let batch_size = self
.memory_usage
.batch_size::<((usize, usize), C::Label)>()
.div_ceil(num_buffers);
let num_nodes_per_partition = self.num_nodes.div_ceil(num_partitions);
let mut pl = concurrent_progress_logger!(
display_memory = true,
item_name = "pair",
local_speed = true,
expected_updates = self.expected_num_pairs,
);
pl.start("Reading and sorting pairs");
pl.info(format_args!("Per-processor batch size: {}", batch_size));
let presort_tmp_dir =
tempfile::tempdir().context("Could not create temporary directory")?;
let unsorted_pairs = unsorted_pairs.into_iter();
let num_blocks = unsorted_pairs.len();
let mut partitioned_presorted_pairs = vec![vec![]; num_blocks];
let result = partitioned_presorted_pairs.as_sync_slice();
std::thread::scope(|s| {
let presort_tmp_dir = &presort_tmp_dir;
for (block_id, pair) in unsorted_pairs.enumerate() {
let mut pl = pl.clone();
let batch_codec = &batch_codec;
s.spawn(move || {
let mut unsorted_buffers = (0..num_partitions)
.map(|_| Vec::with_capacity(batch_size))
.collect::<Vec<_>>();
let mut sorted_pairs =
(0..num_partitions).map(|_| Vec::new()).collect::<Vec<_>>();
for ((src, dst), label) in pair {
let partition_id = src / num_nodes_per_partition;
let sorted_pairs = &mut sorted_pairs[partition_id];
let buf = &mut unsorted_buffers[partition_id];
if buf.len() >= buf.capacity() {
let buf_len = buf.len();
super::par_sort_pairs::flush_buffer(
presort_tmp_dir.path(),
batch_codec,
block_id,
partition_id,
sorted_pairs,
buf,
)
.context("Could not flush buffer")
.unwrap();
assert!(buf.is_empty(), "flush_buffer did not empty the buffer");
pl.update_with_count(buf_len);
}
buf.push(((src, dst), label));
}
for (partition_id, (pairs, mut buf)) in sorted_pairs
.iter_mut()
.zip(unsorted_buffers.into_iter())
.enumerate()
{
let buf_len = buf.len();
super::par_sort_pairs::flush_buffer(
presort_tmp_dir.path(),
batch_codec,
block_id,
partition_id,
pairs,
&mut buf,
)
.context("Could not flush buffer at the end")
.unwrap();
assert!(buf.is_empty(), "flush_buffer did not empty the buffer");
pl.update_with_count(buf_len);
}
unsafe {
result[block_id].set(sorted_pairs);
}
});
}
});
let partitioned_presorted_pairs = partitioned_presorted_pairs.into_par_iter().reduce(
|| (0..num_partitions).map(|_| Vec::new()).collect(),
|mut pair_partitions1: Vec<Vec<CodecIter<C>>>,
pair_partitions2: Vec<Vec<CodecIter<C>>>|
-> Vec<Vec<CodecIter<C>>> {
assert_eq!(pair_partitions1.len(), num_partitions);
assert_eq!(pair_partitions2.len(), num_partitions);
for (partition1, partition2) in pair_partitions1
.iter_mut()
.zip(pair_partitions2.into_iter())
{
partition1.extend(partition2.into_iter());
}
pair_partitions1
},
);
pl.done();
let boundaries: Vec<usize> = (0..=num_partitions)
.map(|i| (i * num_nodes_per_partition).min(self.num_nodes))
.collect();
let iters: Vec<_> = partitioned_presorted_pairs
.into_iter()
.map(|partition| {
KMergeIters::new(partition)
})
.collect();
Ok(SplitIters::new(
boundaries.into_boxed_slice(),
iters.into_boxed_slice(),
))
}
}