use std::{
fs::File,
io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write},
sync::Mutex,
};
use clap::builder::PossibleValue;
use log::{info, trace};
use super::*;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, MemSize)]
#[cfg_attr(feature = "epserde", derive(epserde::prelude::Epserde))]
#[cfg_attr(feature = "epserde", repr(C))]
#[cfg_attr(feature = "epserde", zero_copy)]
pub enum Sharding {
#[default]
None,
Memory,
Disk,
Hybrid(usize),
}
impl clap::ValueEnum for Sharding {
fn value_variants<'a>() -> &'a [Self] {
&[
Sharding::None,
Sharding::Memory,
Sharding::Disk,
Sharding::Hybrid(1 << 37),
]
}
fn to_possible_value<'a>(&self) -> Option<PossibleValue> {
Some(match self {
Sharding::None => PossibleValue::new("none"),
Sharding::Memory => PossibleValue::new("memory"),
Sharding::Disk => PossibleValue::new("disk"),
Sharding::Hybrid(_) => PossibleValue::new("hybrid"),
})
}
}
impl<Key: KeyT, BF: BucketFn, F: Packed, Hx: Hasher<Key>> PtrHash<Key, BF, F, Hx, Vec<u8>> {
pub(crate) fn shards<'a>(
&'a self,
keys: impl ParallelIterator<Item = impl Borrow<Key>> + Clone + 'a,
) -> Box<dyn Iterator<Item = Vec<Hx::H>> + 'a> {
match self.params.sharding {
Sharding::None => self.no_sharding(keys.clone()),
Sharding::Memory => self.shard_keys_in_memory(keys.clone()),
Sharding::Disk => self.shard_keys_hybrid(usize::MAX, keys.clone()),
Sharding::Hybrid(mem) => self.shard_keys_hybrid(mem, keys.clone()),
}
}
fn no_sharding<'a>(
&'a self,
keys: impl ParallelIterator<Item = impl Borrow<Key>> + Clone + 'a,
) -> Box<dyn Iterator<Item = Vec<Hx::H>> + 'a> {
trace!("No sharding: collecting all {} hashes in memory.", self.n);
let start = std::time::Instant::now();
let hashes = keys.map(|key| self.hash_key(key.borrow())).collect();
log_duration("collect hash", start);
Box::new(std::iter::once(hashes))
}
fn shard_keys_in_memory<'a>(
&'a self,
keys: impl ParallelIterator<Item = impl Borrow<Key>> + Clone + 'a,
) -> Box<dyn Iterator<Item = Vec<Hx::H>> + 'a> {
trace!(
"In-memory sharding: iterate keys once for each of {} shards, each of ~{} keys.",
self.shards,
self.n / self.shards
);
let it = (0..self.shards).map(move |shard| {
trace!("Shard {shard:>3}/{:3}\r", self.shards);
let start = std::time::Instant::now();
let hashes: Vec<_> = keys
.clone()
.map(|key| self.hash_key(key.borrow()))
.filter(move |h| self.shard(*h) == shard)
.collect();
trace!("Shard {shard:>3}/{:3}: {} keys", self.shards, hashes.len());
log_duration("collect shrd", start);
hashes
});
Box::new(it)
}
fn shard_keys_hybrid<'a>(
&'a self,
mem: usize,
keys: impl ParallelIterator<Item = impl Borrow<Key>> + Clone + 'a,
) -> Box<dyn Iterator<Item = Vec<Hx::H>> + 'a> {
let total_shards = self.shards;
let keys_per_shard = self.n / total_shards;
let shards_on_disk = mem / std::mem::size_of::<Hx::H>() / keys_per_shard;
assert!(
shards_on_disk > 0,
"Each shard takes more than the provided memory."
);
if mem < usize::MAX {
info!("Hybrid sharding: writing hashes to disk for {shards_on_disk} shards at a time, for total {} shards each of ~{} keys.", self.shards, self.n / self.shards);
} else {
info!(
"On-disk sharding: writing hashes to disk for all {} shards at a time, each of ~{} keys.",
self.shards, self.n / self.shards
);
}
let it = (0..self.shards)
.step_by(shards_on_disk)
.flat_map(move |first_shard| {
let temp_dir = tempfile::TempDir::new().unwrap();
info!("TMP PATH: {:?}", temp_dir.path());
let shard_range = first_shard..(first_shard + shards_on_disk).min(self.shards);
info!("Writing keys for shards {shard_range:?}/{}", self.shards);
let start = std::time::Instant::now();
let writers = shard_range
.clone()
.map(|shard| {
Mutex::new((
BufWriter::new(
File::options()
.read(true)
.write(true)
.create(true)
.open(temp_dir.path().join(format!("{}.tmp", shard)))
.unwrap(),
),
0,
))
})
.collect_vec();
let init = || writers.iter().map(ThreadLocalBuf::new).collect_vec();
keys.clone()
.map(|key| self.hash_key(key.borrow()))
.for_each_init(init, |bufs, h| {
let shard = self.shard(h);
if shard_range.contains(&shard) {
bufs[shard - shard_range.start].push(h);
}
});
let start = log_duration("Writing files", start);
let files = writers
.into_iter()
.map(|w| {
let (mut w, cnt) = w.into_inner().unwrap();
w.flush().unwrap();
let mut file = w.into_inner().unwrap();
file.seek(SeekFrom::Start(0)).unwrap();
(file, cnt)
})
.collect_vec();
log_duration("Flushing writers", start);
files
.into_iter()
.zip(shard_range)
.map(move |((f, cnt), _shard)| {
let start = std::time::Instant::now();
let mut v = vec![Hx::H::default(); cnt];
let mut reader = BufReader::new(f);
let (pre, data, post) = unsafe { v.align_to_mut::<u8>() };
assert!(pre.is_empty());
assert!(post.is_empty());
Read::read_exact(&mut reader, data).unwrap();
log_duration("Read shard", start);
v
})
});
Box::new(it)
}
}
struct ThreadLocalBuf<'a, H> {
buf: Vec<H>,
file: &'a Mutex<(BufWriter<File>, usize)>,
}
impl<'a, H> ThreadLocalBuf<'a, H> {
fn new(file: &'a Mutex<(BufWriter<File>, usize)>) -> Self {
Self {
buf: Vec::with_capacity(1 << 28),
file,
}
}
fn push(&mut self, h: H) {
self.buf.push(h);
if self.buf.len() == self.buf.capacity() {
self.flush();
}
}
fn flush(&mut self) {
let mut file = self.file.lock().unwrap();
let (pre, bytes, post) = unsafe { self.buf.align_to::<u8>() };
assert!(pre.is_empty());
assert!(post.is_empty());
file.0.write_all(bytes).unwrap();
file.1 += self.buf.len();
self.buf.clear();
}
}
impl<'a, H> Drop for ThreadLocalBuf<'a, H> {
fn drop(&mut self) {
self.flush();
}
}