use std::num::NonZeroUsize;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use anyhow::{Context, Result, ensure};
use crossbeam_queue::SegQueue;
use dsi_progress_logger::{ProgressLog, concurrent_progress_logger};
use rayon::prelude::*;
use crate::utils::DefaultBatchCodec;
use super::MemoryUsage;
use super::sort_pairs::KMergeIters;
use super::{BatchCodec, CodecIter};
use crate::utils::SplitIters;
pub struct ParSortPairs {
num_nodes: usize,
expected_num_pairs: Option<usize>,
num_partitions: NonZeroUsize,
memory_usage: MemoryUsage,
}
impl ParSortPairs {
pub fn sort(
&self,
pairs: impl ParallelIterator<Item = (usize, usize)>,
) -> Result<SplitIters<impl IntoIterator<Item = (usize, usize), IntoIter: Clone + Send + Sync>>>
{
self.try_sort::<std::convert::Infallible>(pairs.map(Ok))
}
pub fn try_sort<E: Into<anyhow::Error>>(
&self,
pairs: impl ParallelIterator<Item = Result<(usize, usize), E>>,
) -> Result<SplitIters<impl IntoIterator<Item = (usize, usize), IntoIter: Clone + Send + Sync>>>
{
let split = self.try_sort_labeled(
&DefaultBatchCodec::default(),
pairs.map(|pair| -> Result<_> {
let (src, dst) = pair.map_err(Into::into)?;
Ok(((src, dst), ()))
}),
)?;
let iters_without_labels: Vec<_> = split
.iters
.into_vec()
.into_iter()
.map(|into_iter| into_iter.into_iter().map(|(pair, _)| pair))
.collect();
Ok(SplitIters::new(
split.boundaries,
iters_without_labels.into_boxed_slice(),
))
}
}
impl ParSortPairs {
pub fn new(num_nodes: usize) -> Result<Self> {
Ok(Self {
num_nodes,
expected_num_pairs: None,
num_partitions: NonZeroUsize::new(num_cpus::get()).context("zero CPUs")?,
memory_usage: MemoryUsage::default(),
})
}
pub fn expected_num_pairs(self, expected_num_pairs: usize) -> Self {
Self {
expected_num_pairs: Some(expected_num_pairs),
..self
}
}
pub fn num_partitions(self, num_partitions: NonZeroUsize) -> Self {
Self {
num_partitions,
..self
}
}
pub fn memory_usage(self, memory_usage: MemoryUsage) -> Self {
Self {
memory_usage,
..self
}
}
pub fn sort_labeled<C: BatchCodec, P: ParallelIterator<Item = ((usize, usize), C::Label)>>(
&self,
batch_codec: &C,
pairs: P,
) -> Result<
SplitIters<
impl IntoIterator<Item = ((usize, usize), C::Label), IntoIter: Clone + Send + Sync>
+ use<C, P>,
>,
> {
self.try_sort_labeled::<C, std::convert::Infallible, _>(batch_codec, pairs.map(Ok))
}
pub fn try_sort_labeled<
C: BatchCodec,
E: Into<anyhow::Error>,
P: ParallelIterator<Item = Result<((usize, usize), C::Label), E>>,
>(
&self,
batch_codec: &C,
pairs: P,
) -> Result<
SplitIters<
impl IntoIterator<Item = ((usize, usize), C::Label), IntoIter: Clone + Send + Sync>
+ use<C, E, P>,
>,
> {
let unsorted_pairs = pairs;
let num_partitions = self.num_partitions.into();
let num_buffers = rayon::current_num_threads() * num_partitions;
let batch_size = self
.memory_usage
.batch_size::<((usize, usize), C::Label)>()
.div_ceil(num_buffers);
let num_nodes_per_partition = self.num_nodes.div_ceil(num_partitions);
let mut pl = concurrent_progress_logger!(
display_memory = true,
item_name = "pair",
local_speed = true,
expected_updates = self.expected_num_pairs,
);
pl.start("Reading and sorting pairs");
let worker_id = AtomicUsize::new(0);
let presort_tmp_dir =
tempfile::tempdir().context("Could not create temporary directory")?;
let sorter_thread_states = Arc::new(SegQueue::<SorterThreadState<C>>::new());
unsorted_pairs.try_for_each_init(
|| {
let mut state = sorter_thread_states
.pop()
.unwrap_or_else(|| SorterThreadState {
worker_id: worker_id.fetch_add(1, Ordering::Relaxed),
unsorted_buffers: (0..num_partitions)
.map(|_| Vec::with_capacity(batch_size))
.collect(),
sorted_pairs: (0..num_partitions).map(|_| Vec::new()).collect(),
queue: None,
});
state.queue = Some(Arc::clone(&sorter_thread_states));
(pl.clone(), state)
},
|(pl, thread_state), pair| -> Result<_> {
let ((src, dst), label) = pair.map_err(Into::into)?;
ensure!(
src < self.num_nodes,
"Expected {} nodes, but got node id {src}",
self.num_nodes
);
let partition_id = src / num_nodes_per_partition;
let SorterThreadState {
worker_id,
sorted_pairs,
unsorted_buffers,
queue: _,
} = thread_state;
let sorted_pairs = &mut sorted_pairs[partition_id];
let buf = &mut unsorted_buffers[partition_id];
if buf.len() >= buf.capacity() {
let buf_len = buf.len();
flush_buffer(
presort_tmp_dir.path(),
batch_codec,
*worker_id,
partition_id,
sorted_pairs,
buf,
)
.context("Could not flush buffer")?;
assert!(buf.is_empty(), "flush_buffer did not empty the buffer");
pl.update_with_count(buf_len);
}
buf.push(((src, dst), label));
Ok(())
},
)?;
let sorter_thread_states: Vec<_> = std::iter::repeat(())
.map_while(|()| sorter_thread_states.pop())
.collect();
let partitioned_presorted_pairs: Vec<Vec<CodecIter<C>>> = sorter_thread_states
.into_par_iter()
.map_with(pl.clone(), |pl, mut thread_state: SorterThreadState<C>| {
let mut sorted_pairs = Vec::new();
std::mem::swap(&mut sorted_pairs, &mut thread_state.sorted_pairs);
let mut unsorted_buffers = Vec::new();
std::mem::swap(&mut unsorted_buffers, &mut thread_state.unsorted_buffers);
let mut partitioned_sorted_pairs = Vec::with_capacity(num_partitions);
assert_eq!(sorted_pairs.len(), num_partitions);
assert_eq!(unsorted_buffers.len(), num_partitions);
for (partition_id, (mut sorted_pairs, mut buf)) in sorted_pairs.into_iter().zip(unsorted_buffers.into_iter()).enumerate() {
let buf_len = buf.len();
flush_buffer(presort_tmp_dir.path(), batch_codec, thread_state.worker_id, partition_id, &mut sorted_pairs, &mut buf).context("Could not flush buffer at the end")?;
assert!(buf.is_empty(), "flush_buffer did not empty the buffer");
pl.update_with_count(buf_len);
partitioned_sorted_pairs.push(sorted_pairs);
}
Ok(partitioned_sorted_pairs)
})
.try_reduce(
|| (0..num_partitions).map(|_| Vec::new()).collect(),
|mut pair_partitions1: Vec<Vec<CodecIter<C>>>, pair_partitions2: Vec<Vec<CodecIter<C>>>| -> Result<Vec<Vec<CodecIter<C>>>> {
assert_eq!(pair_partitions1.len(), num_partitions);
assert_eq!(pair_partitions2.len(), num_partitions);
for (partition1, partition2) in pair_partitions1.iter_mut().zip(pair_partitions2.into_iter()) {
partition1.extend(partition2.into_iter());
}
Ok(pair_partitions1)
})?
;
pl.done();
let boundaries: Vec<usize> = (0..=num_partitions)
.map(|i| (i * num_nodes_per_partition).min(self.num_nodes))
.collect();
let iters: Vec<_> = partitioned_presorted_pairs
.into_iter()
.map(|partition| {
KMergeIters::new(partition)
})
.collect();
Ok(SplitIters::new(
boundaries.into_boxed_slice(),
iters.into_boxed_slice(),
))
}
}
struct SorterThreadState<C: BatchCodec> {
worker_id: usize,
sorted_pairs: Vec<Vec<CodecIter<C>>>,
unsorted_buffers: Vec<Vec<((usize, usize), C::Label)>>,
queue: Option<Arc<SegQueue<Self>>>,
}
impl<C: BatchCodec> SorterThreadState<C> {
fn new_empty() -> Self {
SorterThreadState {
worker_id: usize::MAX,
sorted_pairs: Vec::new(),
unsorted_buffers: Vec::new(),
queue: None,
}
}
}
impl<C: BatchCodec> Drop for SorterThreadState<C> {
fn drop(&mut self) {
match self.queue.take() {
Some(queue) => {
let mut other_self = Self::new_empty();
std::mem::swap(&mut other_self, self);
queue.push(other_self);
}
None => {
assert!(
self.sorted_pairs.iter().all(|vec| vec.is_empty()),
"Dropped SorterThreadState without consuming sorted_pairs"
);
assert!(
self.unsorted_buffers.iter().all(|vec| vec.is_empty()),
"Dropped SorterThreadState without consuming unsorted_buffers"
);
}
}
}
}
pub(crate) fn flush_buffer<C: BatchCodec>(
tmp_dir: &Path,
batch_codec: &C,
worker_id: usize,
partition_id: usize,
sorted_pairs: &mut Vec<CodecIter<C>>,
buf: &mut Vec<((usize, usize), C::Label)>,
) -> Result<()> {
let path = tmp_dir.join(format!(
"sorted_batch_{worker_id}_{partition_id}_{}",
sorted_pairs.len()
));
ensure!(
!path.exists(),
"Can't create temporary file {}, it already exists",
path.display()
);
batch_codec
.encode_batch(&path, buf)
.with_context(|| format!("Could not write sorted batch to {}", path.display()))?;
sorted_pairs.push(
batch_codec
.decode_batch(&path)
.with_context(|| format!("Could not read sorted batch from {}", path.display()))?
.into_iter(),
);
buf.clear();
Ok(())
}