use super::{BitReader, BitWriter};
use crate::traits::SortedIterator;
use crate::utils::{ArcMmapHelper, MmapHelper, Triple};
use crate::{
traits::{BitDeserializer, BitSerializer},
utils::{BatchCodec, humanize},
};
use std::sync::Arc;
use anyhow::{Context, Result};
use dsi_bitstream::prelude::*;
use mmap_rs::MmapFlags;
use rdst::*;
#[derive(Clone, Debug)]
pub struct GapsCodec<
E: Endianness = NE,
S: BitSerializer<E, BitWriter<E>> = (),
D: BitDeserializer<E, BitReader<E>, DeserType = S::SerType> + Clone = (),
const SRC_CODE: usize = { dsi_bitstream::dispatch::code_consts::GAMMA },
const DST_CODE: usize = { dsi_bitstream::dispatch::code_consts::DELTA },
> where
BitReader<E>: BitRead<E> + CodesRead<E>,
BitWriter<E>: BitWrite<E> + CodesWrite<E>,
{
pub serializer: S,
pub deserializer: D,
pub _marker: std::marker::PhantomData<E>,
}
impl<E, S, D, const SRC_CODE: usize, const DST_CODE: usize> GapsCodec<E, S, D, SRC_CODE, DST_CODE>
where
E: Endianness,
S: BitSerializer<E, BitWriter<E>> + Send + Sync,
D: BitDeserializer<E, BitReader<E>, DeserType = S::SerType> + Send + Sync + Clone,
BitReader<E>: BitRead<E> + CodesRead<E>,
BitWriter<E>: BitWrite<E> + CodesWrite<E>,
{
pub fn new(serializer: S, deserializer: D) -> Self {
Self {
serializer,
deserializer,
_marker: std::marker::PhantomData,
}
}
}
impl<E, S: Default, D: Default, const SRC_CODE: usize, const DST_CODE: usize> core::default::Default
for GapsCodec<E, S, D, SRC_CODE, DST_CODE>
where
E: Endianness,
S: BitSerializer<E, BitWriter<E>> + Send + Sync,
D: BitDeserializer<E, BitReader<E>, DeserType = S::SerType> + Send + Sync + Clone,
BitReader<E>: BitRead<E> + CodesRead<E>,
BitWriter<E>: BitWrite<E> + CodesWrite<E>,
{
fn default() -> Self {
Self::new(Default::default(), Default::default())
}
}
#[derive(Debug, Clone, Copy)]
pub struct GapsStats {
pub total_triples: usize,
pub src_bits: usize,
pub dst_bits: usize,
pub labels_bits: usize,
}
impl core::fmt::Display for GapsStats {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let total_bits = self.src_bits + self.dst_bits + self.labels_bits;
write!(
f,
"src: {}B ({:.3} bits / arc), dst: {}B ({:.3} bits / arc), labels: {}B ({:.3} bits / arc), total: {}B ({:.3} bits / arc)",
humanize(self.src_bits as f64 / 8.0),
self.src_bits as f64 / self.total_triples as f64,
humanize(self.dst_bits as f64 / 8.0),
self.dst_bits as f64 / self.total_triples as f64,
humanize(self.labels_bits as f64 / 8.0),
self.labels_bits as f64 / self.total_triples as f64,
humanize(total_bits as f64 / 8.0),
total_bits as f64 / self.total_triples as f64,
)
}
}
impl<E, S, D, const SRC_CODE: usize, const DST_CODE: usize> BatchCodec
for GapsCodec<E, S, D, SRC_CODE, DST_CODE>
where
E: Endianness,
S: BitSerializer<E, BitWriter<E>> + Send + Sync,
D: BitDeserializer<E, BitReader<E>, DeserType = S::SerType> + Send + Sync + Clone,
S::SerType: Send + Sync + Copy + 'static + core::fmt::Debug, BitReader<E>: BitRead<E> + CodesRead<E>,
BitWriter<E>: BitWrite<E> + CodesWrite<E>,
{
type Label = S::SerType;
type DecodedBatch = GapsIter<E, D, SRC_CODE, DST_CODE>;
type EncodedBatchStats = GapsStats;
fn encode_batch(
&self,
path: impl AsRef<std::path::Path>,
batch: &mut [((usize, usize), Self::Label)],
) -> Result<(usize, Self::EncodedBatchStats)> {
let start = std::time::Instant::now();
Triple::cast_batch_mut(batch).radix_sort_unstable();
log::debug!("Sorted {} arcs in {:?}", batch.len(), start.elapsed());
self.encode_sorted_batch(path, batch)
}
fn encode_sorted_batch(
&self,
path: impl AsRef<std::path::Path>,
batch: &[((usize, usize), Self::Label)],
) -> Result<(usize, Self::EncodedBatchStats)> {
debug_assert!(Triple::cast_batch(batch).is_sorted());
let file_path = path.as_ref();
let mut stream = buf_bit_writer::from_path::<E, usize>(file_path).with_context(|| {
format!(
"Could not create BatchIterator temporary file {}",
file_path.display()
)
})?;
stream
.write_delta(batch.len() as u64)
.context("Could not write length")?;
let mut stats = GapsStats {
total_triples: batch.len(),
src_bits: 0,
dst_bits: 0,
labels_bits: 0,
};
let (mut prev_src, mut prev_dst) = (0, 0);
for ((src, dst), label) in batch.iter() {
stats.src_bits += ConstCode::<SRC_CODE>
.write(&mut stream, (src - prev_src) as u64)
.with_context(|| format!("Could not write {src} after {prev_src}"))?;
if *src != prev_src {
prev_dst = 0;
}
stats.dst_bits += ConstCode::<DST_CODE>
.write(&mut stream, (dst - prev_dst) as u64)
.with_context(|| format!("Could not write {dst} after {prev_dst}"))?;
stats.labels_bits += self
.serializer
.serialize(label, &mut stream)
.context("Could not serialize label")?;
(prev_src, prev_dst) = (*src, *dst);
}
stream.flush().context("Could not flush stream")?;
let total_bits = stats.src_bits + stats.dst_bits + stats.labels_bits;
Ok((total_bits, stats))
}
fn decode_batch(&self, path: impl AsRef<std::path::Path>) -> Result<Self::DecodedBatch> {
let mut stream = <BufBitReader<E, _>>::new(MemWordReader::new(ArcMmapHelper(Arc::new(
MmapHelper::mmap(
path.as_ref(),
MmapFlags::TRANSPARENT_HUGE_PAGES | MmapFlags::SEQUENTIAL,
)
.with_context(|| format!("Could not mmap {}", path.as_ref().display()))?,
))));
let len = stream.read_delta().context("Could not read length")? as usize;
Ok(GapsIter {
deserializer: self.deserializer.clone(),
stream,
len,
current: 0,
prev_src: 0,
prev_dst: 0,
})
}
}
#[derive(Clone, Debug)]
pub struct GapsIter<
E: Endianness = NE,
D: BitDeserializer<E, BitReader<E>> = (),
const SRC_CODE: usize = { dsi_bitstream::dispatch::code_consts::GAMMA },
const DST_CODE: usize = { dsi_bitstream::dispatch::code_consts::GAMMA },
> where
BitReader<E>: BitRead<E> + CodesRead<E>,
BitWriter<E>: BitWrite<E> + CodesWrite<E>,
{
deserializer: D,
stream: BitReader<E>,
len: usize,
current: usize,
prev_src: usize,
prev_dst: usize,
}
unsafe impl<
E: Endianness,
D: BitDeserializer<E, BitReader<E>>,
const SRC_CODE: usize,
const DST_CODE: usize,
> SortedIterator for GapsIter<E, D, SRC_CODE, DST_CODE>
where
BitReader<E>: BitRead<E> + CodesRead<E>,
BitWriter<E>: BitWrite<E> + CodesWrite<E>,
{
}
impl<
E: Endianness,
D: BitDeserializer<E, BitReader<E>>,
const SRC_CODE: usize,
const DST_CODE: usize,
> Iterator for GapsIter<E, D, SRC_CODE, DST_CODE>
where
BitReader<E>: BitRead<E> + CodesRead<E>,
BitWriter<E>: BitWrite<E> + CodesWrite<E>,
{
type Item = ((usize, usize), D::DeserType);
fn next(&mut self) -> Option<Self::Item> {
if self.current >= self.len {
return None;
}
let src_gap = ConstCode::<SRC_CODE>.read(&mut self.stream).ok()?;
let dst_gap = ConstCode::<DST_CODE>.read(&mut self.stream).ok()?;
let label = self.deserializer.deserialize(&mut self.stream).ok()?;
self.prev_src += src_gap as usize;
if src_gap != 0 {
self.prev_dst = 0;
}
self.prev_dst += dst_gap as usize;
self.current += 1;
Some(((self.prev_src, self.prev_dst), label))
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.len(), Some(self.len()))
}
}
impl<
E: Endianness,
D: BitDeserializer<E, BitReader<E>>,
const SRC_CODE: usize,
const DST_CODE: usize,
> ExactSizeIterator for GapsIter<E, D, SRC_CODE, DST_CODE>
where
BitReader<E>: BitRead<E> + CodesRead<E>,
BitWriter<E>: BitWrite<E> + CodesWrite<E>,
{
fn len(&self) -> usize {
self.len - self.current
}
}