#![deny(missing_docs)]
use std::{
cell::RefCell, collections::HashMap, error::Error, hash::Hasher, sync::Arc,
};
use rustc_hash::FxHasher;
use serde::{de::Visitor, Deserializer};
thread_local! {
static INTERNED_U8S: RefCell<HashMap<u64, Arc<[u8]>>>
= RefCell::new(HashMap::new());
static INTERNED_STRINGS: RefCell<HashMap<u64, Arc<str>>>
= RefCell::new(HashMap::new());
}
struct ArcU8sVisitor {}
impl<'de> Visitor<'de> for ArcU8sVisitor {
type Value = Arc<[u8]>;
fn expecting(
&self,
formatter: &mut std::fmt::Formatter,
) -> std::fmt::Result {
write!(formatter, "Expected a slice of bytes")
}
fn visit_bytes<E>(self, buffer: &[u8]) -> Result<Self::Value, E>
where
E: Error,
{
let hash = quick_hash(buffer);
INTERNED_U8S.with_borrow_mut(
|lookup_table: &mut HashMap<_, Arc<[u8]>>| {
lookup_table
.entry(hash)
.or_insert_with(|| Arc::from(buffer));
match lookup_table.get(&hash) {
Some(arc) if arc.as_ref() == buffer => Ok(arc.clone()),
_ => Ok(Arc::from(buffer)),
}
},
)
}
}
pub fn intern_arc_u8s<'de, D>(deserializer: D) -> Result<Arc<[u8]>, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(ArcU8sVisitor {})
}
struct ArcStrVisitor {}
impl<'de> Visitor<'de> for ArcStrVisitor {
type Value = Arc<str>;
fn expecting(
&self,
formatter: &mut std::fmt::Formatter,
) -> std::fmt::Result {
write!(formatter, "Expected a string")
}
fn visit_str<E>(self, buffer: &str) -> Result<Self::Value, E>
where
E: Error,
{
let hash = quick_hash(buffer.as_bytes());
INTERNED_STRINGS.with_borrow_mut(
|lookup_table: &mut HashMap<_, Arc<str>>| {
lookup_table
.entry(hash)
.or_insert_with(|| Arc::from(buffer));
match lookup_table.get(&hash) {
Some(arc) if arc.as_ref() == buffer => Ok(arc.clone()),
_ => Ok(Arc::from(buffer)),
}
},
)
}
}
pub fn intern_arc_str<'de, D>(deserializer: D) -> Result<Arc<str>, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(ArcStrVisitor {})
}
pub fn clear_arc_cache() {
INTERNED_U8S.with_borrow_mut(|lookup_table: &mut HashMap<_, _>| {
lookup_table.clear()
});
INTERNED_STRINGS.with_borrow_mut(|lookup_table: &mut HashMap<_, _>| {
lookup_table.clear()
});
}
fn quick_hash(data: &[u8]) -> u64 {
let mut hasher = FxHasher::default();
hasher.write(data);
hasher.finish()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_strings() {
#[derive(serde_derive::Deserialize)]
struct Person {
#[serde(deserialize_with = "intern_arc_str")]
name: Arc<str>,
}
let json = r#"
[
{ "name": "Yenna" },
{ "name": "Yenna" },
{ "name": "Yenna" }
]
"#;
let people: Vec<Person> = serde_json::from_str(json).unwrap();
let first = &people[0];
assert_eq!(Arc::strong_count(&first.name), 4);
clear_arc_cache();
assert_eq!(Arc::strong_count(&first.name), 3);
}
}