use anyhow::{Context, Result};
use mmap_rs::Mmap;
use rayon::prelude::*;
use thiserror::Error;
use super::suffixes::*;
use super::*;
use crate::front_coded_list::FrontCodedList;
use crate::labels::LabelNameId;
use crate::utils::suffix_path;
pub trait MaybeLabelNames {}
pub struct MappedLabelNames {
label_names: FrontCodedList<Mmap, Mmap>,
}
impl<L: LabelNames> MaybeLabelNames for L {}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct NoLabelNames;
impl MaybeLabelNames for NoLabelNames {}
#[diagnostic::on_unimplemented(
label = "does not have label names loaded",
note = "Use `let graph = graph.load_properties(|props| props.load_label_names()).unwrap()` to load them",
note = "Or replace `graph.init_properties()` with `graph.load_all_properties::<DynMphf>().unwrap()` to load all properties"
)]
pub trait LabelNames {
type LabelNames<'a>: GetIndex<Output = Vec<u8>>
where
Self: 'a;
fn label_names(&self) -> Self::LabelNames<'_>;
}
impl LabelNames for MappedLabelNames {
type LabelNames<'a>
= &'a FrontCodedList<Mmap, Mmap>
where
Self: 'a;
#[inline(always)]
fn label_names(&self) -> Self::LabelNames<'_> {
&self.label_names
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct VecLabelNames {
label_names: Vec<Vec<u8>>,
}
impl VecLabelNames {
pub fn new<S: AsRef<[u8]>>(label_names: Vec<S>) -> Result<Self> {
let base64 = base64_simd::STANDARD;
Ok(VecLabelNames {
label_names: label_names
.into_iter()
.map(|s| base64.encode_to_string(s).into_bytes())
.collect(),
})
}
}
impl LabelNames for VecLabelNames {
type LabelNames<'a>
= &'a [Vec<u8>]
where
Self: 'a;
#[inline(always)]
fn label_names(&self) -> Self::LabelNames<'_> {
self.label_names.as_slice()
}
}
impl<
MAPS: MaybeMaps,
TIMESTAMPS: MaybeTimestamps,
PERSONS: MaybePersons,
CONTENTS: MaybeContents,
STRINGS: MaybeStrings,
> SwhGraphProperties<MAPS, TIMESTAMPS, PERSONS, CONTENTS, STRINGS, NoLabelNames>
{
pub fn load_label_names(
self,
) -> Result<SwhGraphProperties<MAPS, TIMESTAMPS, PERSONS, CONTENTS, STRINGS, MappedLabelNames>>
{
let label_names = MappedLabelNames {
label_names: FrontCodedList::load(suffix_path(&self.path, LABEL_NAME))
.context("Could not load label names")?,
};
self.with_label_names(label_names)
}
pub fn with_label_names<LABELNAMES: MaybeLabelNames>(
self,
label_names: LABELNAMES,
) -> Result<SwhGraphProperties<MAPS, TIMESTAMPS, PERSONS, CONTENTS, STRINGS, LABELNAMES>> {
Ok(SwhGraphProperties {
maps: self.maps,
timestamps: self.timestamps,
persons: self.persons,
contents: self.contents,
strings: self.strings,
label_names,
path: self.path,
num_nodes: self.num_nodes,
label_names_are_in_base64_order: Default::default(), })
}
}
impl<
MAPS: MaybeMaps,
TIMESTAMPS: MaybeTimestamps,
PERSONS: MaybePersons,
CONTENTS: MaybeContents,
STRINGS: MaybeStrings,
LABELNAMES: LabelNames,
> SwhGraphProperties<MAPS, TIMESTAMPS, PERSONS, CONTENTS, STRINGS, LABELNAMES>
{
pub fn num_label_names(&self) -> u64 {
self.label_names
.label_names()
.len()
.try_into()
.expect("label_names.len() overflowed u64")
}
pub fn iter_label_name_ids(&self) -> impl Iterator<Item = LabelNameId> {
(0..self.num_label_names()).map(LabelNameId)
}
pub fn par_iter_label_name_ids(&self) -> impl ParallelIterator<Item = LabelNameId> {
(0..self.num_label_names()).into_par_iter().map(LabelNameId)
}
#[inline]
pub fn label_name_base64(&self, label_name_id: LabelNameId) -> Vec<u8> {
self.try_label_name_base64(label_name_id)
.unwrap_or_else(|e| panic!("Cannot get label name: {e}"))
}
#[inline]
pub fn try_label_name_base64(
&self,
label_name_id: LabelNameId,
) -> Result<Vec<u8>, OutOfBoundError> {
let index = label_name_id
.0
.try_into()
.expect("label_name_id overflowed usize");
self.label_names
.label_names()
.get(index)
.ok_or(OutOfBoundError {
index,
len: self.label_names.label_names().len(),
})
}
#[inline]
pub fn label_name(&self, label_name_id: LabelNameId) -> Vec<u8> {
self.try_label_name(label_name_id)
.unwrap_or_else(|e| panic!("Cannot get label name: {e}"))
}
#[inline]
pub fn try_label_name(&self, label_name_id: LabelNameId) -> Result<Vec<u8>, OutOfBoundError> {
let base64 = base64_simd::STANDARD;
self.try_label_name_base64(label_name_id).map(|name| {
base64.decode_to_vec(name).unwrap_or_else(|name| {
panic!(
"Could not decode label_name of id {}: {:?}",
label_name_id.0, name
)
})
})
}
#[inline]
pub fn label_name_id(
&self,
name: impl AsRef<[u8]>,
) -> Result<LabelNameId, LabelIdFromNameError> {
use std::cmp::Ordering::*;
let base64 = base64_simd::STANDARD;
let name = name.as_ref();
let name_base64 = base64.encode_to_string(name.as_ref()).into_bytes();
macro_rules! bisect {
($compare_pivot:expr) => {{
let compare_pivot = $compare_pivot;
let res: Result<LabelNameId, LabelIdFromNameError> = {
let mut min = 0;
let mut max = self
.label_names
.label_names()
.len()
.saturating_sub(1)
.try_into()
.expect("number of labels overflowed u64");
while min <= max {
let pivot = (min + max) / 2;
let pivot_id = LabelNameId(pivot);
let pivot_name = self.label_name_base64(pivot_id);
if min == max {
if pivot_name.as_slice() == name_base64 {
return Ok(pivot_id);
} else {
break;
}
} else {
match compare_pivot(pivot_id) {
Less => min = pivot.saturating_add(1),
Equal => return Ok(pivot_id),
Greater => max = pivot.saturating_sub(1),
}
}
}
Err(LabelIdFromNameError(name_base64.to_vec()))
};
res
}};
}
let bisect_base64 = || {
bisect!(|pivot_id| self
.label_name_base64(pivot_id)
.as_slice()
.cmp(&name_base64))
};
let bisect_nonbase64 =
|| bisect!(|pivot_id| self.label_name(pivot_id).as_slice().cmp(name));
match self.label_names_are_in_base64_order.get() {
Some(true) => {
bisect_base64()
}
Some(false) => {
bisect_nonbase64()
}
None => {
match bisect_nonbase64() {
Ok(label_name_id) => {
if let Err(LabelIdFromNameError(_)) = bisect_base64() {
log::debug!("Labels are not in base64 order");
let _ = self.label_names_are_in_base64_order.set(false);
}
Ok(label_name_id)
}
Err(LabelIdFromNameError(_)) => {
match bisect_base64() {
Ok(label_name_id) => {
log::debug!("Labels are in base64 order");
let _ = self.label_names_are_in_base64_order.set(true);
Ok(label_name_id)
}
Err(LabelIdFromNameError(e)) => {
Err(LabelIdFromNameError(e))
}
}
}
}
}
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Error)]
#[error("Unknown label name: {}", String::from_utf8_lossy(.0))]
pub struct LabelIdFromNameError(pub Vec<u8>);