use std::io::{Read, Write};
use std::path::PathBuf;
use anyhow::{ensure, Context, Result};
use dsi_progress_logger::ProgressLog;
use rayon::prelude::*;
use super::ParallelDeduplicatingExternalSorter;
const AVERAGE_STRING_LENGTH: usize = 64;
type Bytestring = Box<[u8]>;
#[derive(Copy, Clone)]
struct BytestringExternalSorter {
buffer_size: usize,
}
impl ParallelDeduplicatingExternalSorter<Bytestring> for BytestringExternalSorter {
#[inline(always)]
fn buffer_capacity(&self) -> usize {
self.buffer_size
.div_ceil(AVERAGE_STRING_LENGTH)
.next_power_of_two()
}
#[allow(clippy::get_first)]
fn sort_vec(&self, vec: &mut Vec<Bytestring>) -> Result<()> {
let mut partitions: Vec<_> = (0..65536)
.map(|_| Vec::with_capacity(vec.len().div_ceil(65536)))
.collect();
for string in vec.drain(0..) {
let partition_id = ((string.get(0).copied().unwrap_or(0u8) as usize) << 8)
| string.get(1).copied().unwrap_or(0u8) as usize;
partitions[partition_id].push(string);
}
partitions
.par_iter_mut()
.for_each(|partition| partition.sort_unstable());
for partition in partitions {
vec.extend(partition);
}
Ok(())
}
fn serialize(path: PathBuf, strings: impl Iterator<Item = Bytestring>) -> Result<()> {
let file = std::fs::File::create_new(&path)
.with_context(|| format!("Could not create {}", path.display()))?;
let compression_level = 3;
let mut encoder = zstd::stream::write::Encoder::new(file, compression_level)
.with_context(|| format!("Could not create ZSTD encoder for {}", path.display()))?;
for string in strings {
let len: u32 = string
.len()
.try_into()
.context("String is 2^32 bytes or longer")?;
ensure!(len != u32::MAX, "String is 2^32 -1 bytes long");
encoder
.write_all(&len.to_ne_bytes())
.with_context(|| format!("Could not write string to {}", path.display()))?;
encoder
.write_all(&string)
.with_context(|| format!("Could not write string to {}", path.display()))?;
}
encoder
.write_all(&u32::MAX.to_ne_bytes())
.with_context(|| format!("Could not write string to {}", path.display()))?;
encoder
.finish()
.with_context(|| format!("Could not flush to {}", path.display()))?;
Ok(())
}
fn deserialize(path: PathBuf) -> Result<impl Iterator<Item = Bytestring>> {
let file = std::fs::File::open(&path)
.with_context(|| format!("Could not open {}", path.display()))?;
let mut decoder =
zstd::stream::read::Decoder::new(file).context("Could not decompress sorted file")?;
Ok(std::iter::repeat(()).map_while(move |()| {
let mut buf = [0u8; 4];
decoder
.read_exact(&mut buf)
.expect("Could not read string size");
let size = u32::from_ne_bytes(buf);
if size == u32::MAX {
return None;
}
let mut line = vec![0; size.try_into().unwrap()].into_boxed_slice();
decoder
.read_exact(&mut line)
.expect("Could not read string");
Some(line)
}))
}
}
pub fn par_sort_strings<Iter: ParallelIterator<Item = Bytestring>>(
iter: Iter,
pl: impl ProgressLog + Send,
buffer_size: usize,
) -> Result<impl Iterator<Item = Bytestring>> {
BytestringExternalSorter { buffer_size }.par_sort_dedup(iter, pl)
}