use bstr::BStr;
use intaglio::Symbol;
use rustc_hash::FxHasher;
use serde::de::Visitor;
use serde::ser::SerializeSeq;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::hash::BuildHasherDefault;
use std::marker::PhantomData;
type HashBuilder = BuildHasherDefault<FxHasher>;
pub(crate) struct StringPool<T>
where
T: From<u32> + Into<u32>,
{
pool: intaglio::SymbolTable<HashBuilder>,
size: usize,
phantom: PhantomData<T>,
}
impl<T> StringPool<T>
where
T: From<u32> + Into<u32>,
{
pub fn new() -> Self {
Self {
pool: intaglio::SymbolTable::with_hasher(HashBuilder::default()),
size: 0,
phantom: Default::default(),
}
}
#[inline]
pub fn get_or_intern(&mut self, s: &str) -> T {
if let Some(s) = self.pool.check_interned(s) {
T::from(s.id())
} else {
self.size += s.len();
T::from(self.pool.intern(s.to_string()).unwrap().id())
}
}
#[inline]
pub fn get(&self, id: T) -> Option<&str> {
self.pool.get(Symbol::from(id.into()))
}
}
impl<T> Serialize for StringPool<T>
where
T: From<u32> + Into<u32>,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(self.pool.len()))?;
for string in self.pool.strings() {
seq.serialize_element(string)?
}
seq.end()
}
}
impl<'de, T> Deserialize<'de> for StringPool<T>
where
T: From<u32> + Into<u32>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_seq(StringPoolVisitor::new())
}
}
struct StringPoolVisitor<T> {
phantom: PhantomData<T>,
}
impl<T> StringPoolVisitor<T> {
fn new() -> Self {
Self { phantom: PhantomData }
}
}
impl<'de, T> Visitor<'de> for StringPoolVisitor<T>
where
T: From<u32> + Into<u32>,
{
type Value = StringPool<T>;
fn expecting(
&self,
formatter: &mut std::fmt::Formatter,
) -> std::fmt::Result {
formatter.write_str("a StringPool")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut pool = StringPool::new();
while let Some(string) = seq.next_element()? {
pool.get_or_intern(string);
}
Ok(pool)
}
}
pub struct BStringPool<T>
where
T: From<u32> + Into<u32>,
{
pool: intaglio::bytes::SymbolTable<HashBuilder>,
size: usize,
phantom: PhantomData<T>,
}
impl<T> BStringPool<T>
where
T: From<u32> + Into<u32>,
{
pub fn new() -> Self {
Self {
pool: intaglio::bytes::SymbolTable::with_hasher(
HashBuilder::default(),
),
size: 0,
phantom: Default::default(),
}
}
#[inline]
pub fn get_or_intern<S>(&mut self, s: S) -> T
where
S: AsRef<[u8]>,
{
let bytes = s.as_ref();
if let Some(s) = self.pool.check_interned(bytes) {
T::from(s.id())
} else {
self.size += bytes.len();
T::from(self.pool.intern(bytes.to_owned()).unwrap().id())
}
}
#[inline]
pub fn get(&self, id: T) -> Option<&BStr> {
self.get_bytes(id).map(BStr::new)
}
#[inline]
pub fn get_bytes(&self, id: T) -> Option<&[u8]> {
self.pool.get(Symbol::from(id.into()))
}
#[inline]
pub fn get_str(&self, id: T) -> Option<&str> {
self.get_bytes(id)
.map(|s| {
std::str::from_utf8(s)
.expect("using BStringPool::get_str with a string that is not valid UTF-8")
})
}
}
impl<T> Serialize for BStringPool<T>
where
T: From<u32> + Into<u32>,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(self.pool.len()))?;
for string in self.pool.bytestrings() {
seq.serialize_element(string)?
}
seq.end()
}
}
impl<'de, T> Deserialize<'de> for BStringPool<T>
where
T: From<u32> + Into<u32>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_seq(BStringPoolVisitor::new())
}
}
struct BStringPoolVisitor<T> {
phantom: PhantomData<T>,
}
impl<T> BStringPoolVisitor<T> {
fn new() -> Self {
Self { phantom: PhantomData }
}
}
impl<'de, T> Visitor<'de> for BStringPoolVisitor<T>
where
T: From<u32> + Into<u32>,
{
type Value = BStringPool<T>;
fn expecting(
&self,
formatter: &mut std::fmt::Formatter,
) -> std::fmt::Result {
formatter.write_str("a BStringPool")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut pool = BStringPool::new();
while let Some(string) = seq.next_element::<&[u8]>()? {
pool.get_or_intern(string);
}
Ok(pool)
}
}
#[cfg(test)]
mod test {
use pretty_assertions::assert_eq;
use super::BStringPool;
use super::StringPool;
use bstr::BStr;
#[test]
fn string_pool_serde() {
let mut pool: StringPool<u32> = StringPool::new();
pool.get_or_intern("foo");
pool.get_or_intern("bar");
let serialized =
bincode::serde::encode_to_vec(&pool, bincode::config::standard())
.unwrap();
let (deserialized, _): (StringPool<u32>, _) =
bincode::serde::decode_from_slice(
&serialized,
bincode::config::standard(),
)
.unwrap();
assert_eq!(deserialized.get(0), Some("foo"));
assert_eq!(deserialized.get(1), Some("bar"));
assert_eq!(deserialized.get(2), None);
}
#[test]
fn bstring_pool_serde() {
let mut pool: BStringPool<u32> = BStringPool::new();
pool.get_or_intern("foo");
pool.get_or_intern("bar");
let serialized =
bincode::serde::encode_to_vec(&pool, bincode::config::standard())
.unwrap();
let (deserialized, _): (BStringPool<u32>, _) =
bincode::serde::decode_from_slice(
&serialized,
bincode::config::standard(),
)
.unwrap();
assert_eq!(deserialized.get(0), Some(BStr::new("foo")));
assert_eq!(deserialized.get(1), Some(BStr::new("bar")));
assert_eq!(deserialized.get(2), None);
}
}