use super::{BitReader, BitWriter};
use crate::traits::SortedIterator;
use crate::utils::{ArcMmapHelper, MmapHelper, Triple};
use crate::{
traits::{BitDeserializer, BitSerializer},
utils::{BatchCodec, humanize},
};
use std::sync::Arc;
use anyhow::{Context, Result};
use dsi_bitstream::prelude::*;
use mmap_rs::MmapFlags;
use rdst::*;
#[derive(Clone, Debug)]
pub struct GroupedGapsCodec<
E: Endianness = NE,
S: BitSerializer<E, BitWriter<E>> = (),
D: BitDeserializer<E, BitReader<E>, DeserType = S::SerType> + Clone = (),
const OUTDEGREE_CODE: usize = { dsi_bitstream::dispatch::code_consts::GAMMA },
const SRC_CODE: usize = { dsi_bitstream::dispatch::code_consts::GAMMA },
const DST_CODE: usize = { dsi_bitstream::dispatch::code_consts::DELTA },
> where
BitReader<E>: BitRead<E>,
BitWriter<E>: BitWrite<E>,
{
pub serializer: S,
pub deserializer: D,
pub _marker: core::marker::PhantomData<E>,
}
impl<E, S, D, const OUTDEGREE_CODE: usize, const SRC_CODE: usize, const DST_CODE: usize>
GroupedGapsCodec<E, S, D, OUTDEGREE_CODE, SRC_CODE, DST_CODE>
where
E: Endianness,
S: BitSerializer<E, BitWriter<E>> + Send + Sync,
D: BitDeserializer<E, BitReader<E>, DeserType = S::SerType> + Send + Sync + Clone,
BitReader<E>: BitRead<E>,
BitWriter<E>: BitWrite<E>,
{
pub fn new(serializer: S, deserializer: D) -> Self {
Self {
serializer,
deserializer,
_marker: core::marker::PhantomData,
}
}
}
impl<
E: Endianness,
S: BitSerializer<E, BitWriter<E>> + Default,
D: BitDeserializer<E, BitReader<E>, DeserType = S::SerType> + Clone + Default,
const OUTDEGREE_CODE: usize,
const SRC_CODE: usize,
const DST_CODE: usize,
> Default for GroupedGapsCodec<E, S, D, OUTDEGREE_CODE, SRC_CODE, DST_CODE>
where
BitReader<E>: BitRead<E>,
BitWriter<E>: BitWrite<E>,
{
fn default() -> Self {
Self {
serializer: S::default(),
deserializer: D::default(),
_marker: core::marker::PhantomData,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct GroupedGapsStats {
pub total_triples: usize,
pub outdegree_bits: usize,
pub src_bits: usize,
pub dst_bits: usize,
pub labels_bits: usize,
}
impl core::fmt::Display for GroupedGapsStats {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"outdegree: {}B ({:.3} bits / arc), src: {}B ({:.3} bits / arc), dst: {}B ({:.3} bits / arc), labels: {}B ({:.3} bits / arc)",
humanize(self.outdegree_bits as f64 / 8.0),
self.outdegree_bits as f64 / self.total_triples as f64,
humanize(self.src_bits as f64 / 8.0),
self.src_bits as f64 / self.total_triples as f64,
humanize(self.dst_bits as f64 / 8.0),
self.dst_bits as f64 / self.total_triples as f64,
humanize(self.labels_bits as f64 / 8.0),
self.labels_bits as f64 / self.total_triples as f64,
)
}
}
impl<E, S, D, const OUTDEGREE_CODE: usize, const SRC_CODE: usize, const DST_CODE: usize> BatchCodec
for GroupedGapsCodec<E, S, D, OUTDEGREE_CODE, SRC_CODE, DST_CODE>
where
E: Endianness,
S: BitSerializer<E, BitWriter<E>> + Send + Sync,
D: BitDeserializer<E, BitReader<E>, DeserType = S::SerType> + Send + Sync + Clone,
S::SerType: Send + Sync + Copy + 'static, BitReader<E>: BitRead<E> + CodesRead<E>,
BitWriter<E>: BitWrite<E> + CodesWrite<E>,
{
type Label = S::SerType;
type DecodedBatch = GroupedGapsIter<E, D, OUTDEGREE_CODE, SRC_CODE, DST_CODE>;
type EncodedBatchStats = GroupedGapsStats;
fn encode_batch(
&self,
path: impl AsRef<std::path::Path>,
batch: &mut [((usize, usize), Self::Label)],
) -> Result<(usize, Self::EncodedBatchStats)> {
let start = std::time::Instant::now();
Triple::cast_batch_mut(batch).radix_sort_unstable();
log::debug!("Sorted {} arcs in {:?}", batch.len(), start.elapsed());
self.encode_sorted_batch(path, batch)
}
fn encode_sorted_batch(
&self,
path: impl AsRef<std::path::Path>,
batch: &[((usize, usize), Self::Label)],
) -> Result<(usize, Self::EncodedBatchStats)> {
debug_assert!(Triple::cast_batch(batch).is_sorted(), "Batch is not sorted");
let file_path = path.as_ref();
let mut stream = buf_bit_writer::from_path::<E, usize>(file_path).with_context(|| {
format!(
"Could not create BatchIterator temporary file {}",
file_path.display()
)
})?;
stream
.write_delta(batch.len() as u64)
.context("Could not write length")?;
let mut stats = GroupedGapsStats {
total_triples: batch.len(),
outdegree_bits: 0,
src_bits: 0,
dst_bits: 0,
labels_bits: 0,
};
let mut prev_src = 0;
let mut i = 0;
while i < batch.len() {
let ((src, _), _) = batch[i];
stats.src_bits += ConstCode::<SRC_CODE>
.write(&mut stream, (src - prev_src) as _)
.with_context(|| format!("Could not write {src} after {prev_src}"))?;
let outdegree = batch[i..].iter().take_while(|t| t.0.0 == src).count();
stats.outdegree_bits += ConstCode::<OUTDEGREE_CODE>
.write(&mut stream, outdegree as _)
.with_context(|| format!("Could not write outdegree {outdegree} for {src}"))?;
let mut prev_dst = 0;
for _ in 0..outdegree {
let ((_, dst), label) = &batch[i];
stats.dst_bits += ConstCode::<DST_CODE>
.write(&mut stream, (dst - prev_dst) as _)
.with_context(|| format!("Could not write {dst} after {prev_dst}"))?;
stats.labels_bits += self
.serializer
.serialize(label, &mut stream)
.context("Could not serialize label")?;
prev_dst = *dst;
i += 1;
}
prev_src = src;
}
stream.flush().context("Could not flush stream")?;
let total_bits = stats.outdegree_bits + stats.src_bits + stats.dst_bits + stats.labels_bits;
Ok((total_bits, stats))
}
fn decode_batch(&self, path: impl AsRef<std::path::Path>) -> Result<Self::DecodedBatch> {
let mut stream = <BufBitReader<E, _>>::new(MemWordReader::new(ArcMmapHelper(Arc::new(
MmapHelper::mmap(
path.as_ref(),
MmapFlags::TRANSPARENT_HUGE_PAGES | MmapFlags::SEQUENTIAL,
)
.with_context(|| format!("Could not mmap {}", path.as_ref().display()))?,
))));
let len = stream.read_delta().context("Could not read length")? as usize;
Ok(GroupedGapsIter {
deserializer: self.deserializer.clone(),
stream,
len,
current: 0,
src: 0,
dst_left: 0,
prev_dst: 0,
})
}
}
#[derive(Clone, Debug)]
pub struct GroupedGapsIter<
E: Endianness = NE,
D: BitDeserializer<E, BitReader<E>> = (),
const OUTDEGREE_CODE: usize = { dsi_bitstream::dispatch::code_consts::GAMMA },
const SRC_CODE: usize = { dsi_bitstream::dispatch::code_consts::GAMMA },
const DST_CODE: usize = { dsi_bitstream::dispatch::code_consts::GAMMA },
> where
BitReader<E>: BitRead<E>,
BitWriter<E>: BitWrite<E>,
{
deserializer: D,
stream: BitReader<E>,
len: usize,
current: usize,
src: usize,
dst_left: usize,
prev_dst: usize,
}
unsafe impl<
E: Endianness,
D: BitDeserializer<E, BitReader<E>>,
const OUTDEGREE_CODE: usize,
const SRC_CODE: usize,
const DST_CODE: usize,
> SortedIterator for GroupedGapsIter<E, D, OUTDEGREE_CODE, SRC_CODE, DST_CODE>
where
BitReader<E>: BitRead<E> + CodesRead<E>,
BitWriter<E>: BitWrite<E> + CodesWrite<E>,
{
}
impl<
E: Endianness,
D: BitDeserializer<E, BitReader<E>>,
const OUTDEGREE_CODE: usize,
const SRC_CODE: usize,
const DST_CODE: usize,
> Iterator for GroupedGapsIter<E, D, OUTDEGREE_CODE, SRC_CODE, DST_CODE>
where
BitReader<E>: BitRead<E> + CodesRead<E>,
BitWriter<E>: BitWrite<E> + CodesWrite<E>,
{
type Item = ((usize, usize), D::DeserType);
fn next(&mut self) -> Option<Self::Item> {
if self.current >= self.len {
return None;
}
if self.dst_left == 0 {
let src_gap = ConstCode::<SRC_CODE>.read(&mut self.stream).ok()?;
self.src += src_gap as usize;
self.dst_left = ConstCode::<OUTDEGREE_CODE>.read(&mut self.stream).ok()? as usize;
self.prev_dst = 0;
}
let dst_gap = ConstCode::<DST_CODE>.read(&mut self.stream).ok()?;
let label = self.deserializer.deserialize(&mut self.stream).ok()?;
self.prev_dst += dst_gap as usize;
self.current += 1;
self.dst_left -= 1;
Some(((self.src, self.prev_dst), label))
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.len(), Some(self.len()))
}
}
impl<
E: Endianness,
D: BitDeserializer<E, BitReader<E>>,
const OUTDEGREE_CODE: usize,
const SRC_CODE: usize,
const DST_CODE: usize,
> ExactSizeIterator for GroupedGapsIter<E, D, OUTDEGREE_CODE, SRC_CODE, DST_CODE>
where
BitReader<E>: BitRead<E> + CodesRead<E>,
BitWriter<E>: BitWrite<E> + CodesWrite<E>,
{
fn len(&self) -> usize {
self.len - self.current
}
}