use crate::prelude::{
collections::btree_map::{
BTreeMap,
Entry,
},
marker::PhantomData,
num::NonZeroU32,
vec::Vec,
};
#[cfg(feature = "serde")]
use serde::{
Deserialize,
Serialize,
};
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
pub struct UntrackedSymbol<T> {
id: NonZeroU32,
#[cfg_attr(feature = "serde", serde(skip))]
marker: PhantomData<fn() -> T>,
}
impl<T> scale::Encode for UntrackedSymbol<T> {
fn encode_to<W: scale::Output + ?Sized>(&self, dest: &mut W) {
self.id.get().encode_to(dest)
}
}
impl<T> scale::Decode for UntrackedSymbol<T> {
fn decode<I: scale::Input>(value: &mut I) -> Result<Self, scale::Error> {
let id = <u32 as scale::Decode>::decode(value)?;
if id < 1 {
return Err("UntrackedSymbol::id should be a non-zero unsigned integer".into())
}
let id = NonZeroU32::new(id).expect("ID is non zero");
Ok(UntrackedSymbol {
id,
marker: Default::default(),
})
}
}
impl<T> UntrackedSymbol<T> {
pub fn id(&self) -> NonZeroU32 {
self.id
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(Serialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
pub struct Symbol<'a, T> {
id: NonZeroU32,
#[cfg_attr(feature = "serde", serde(skip))]
marker: PhantomData<fn() -> &'a T>,
}
impl<T> Symbol<'_, T> {
pub fn into_untracked(self) -> UntrackedSymbol<T> {
UntrackedSymbol {
id: self.id,
marker: PhantomData,
}
}
}
#[derive(Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
pub struct Interner<T> {
#[cfg_attr(feature = "serde", serde(skip))]
map: BTreeMap<T, usize>,
vec: Vec<T>,
}
impl<T> Interner<T>
where
T: Ord,
{
pub fn new() -> Self {
Self {
map: BTreeMap::new(),
vec: Vec::new(),
}
}
}
impl<T: Ord> Default for Interner<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Interner<T>
where
T: Ord + Clone,
{
pub fn intern_or_get(&mut self, s: T) -> (bool, Symbol<T>) {
let next_id = self.vec.len();
let (inserted, sym_id) = match self.map.entry(s.clone()) {
Entry::Vacant(vacant) => {
vacant.insert(next_id);
self.vec.push(s);
(true, next_id)
}
Entry::Occupied(occupied) => (false, *occupied.get()),
};
(
inserted,
Symbol {
id: NonZeroU32::new((sym_id + 1) as u32).unwrap(),
marker: PhantomData,
},
)
}
pub fn get(&self, s: &T) -> Option<Symbol<T>> {
self.map.get(s).map(|&id| {
Symbol {
id: NonZeroU32::new(id as u32).unwrap(),
marker: PhantomData,
}
})
}
pub fn resolve(&self, sym: Symbol<T>) -> Option<&T> {
let idx = (sym.id.get() - 1) as usize;
if idx >= self.vec.len() {
return None
}
self.vec.get(idx)
}
}
#[cfg(test)]
mod tests {
use super::*;
type StringInterner = Interner<&'static str>;
fn assert_id(
interner: &mut StringInterner,
new_symbol: &'static str,
expected_id: u32,
) {
let actual_id = interner.intern_or_get(new_symbol).1.id.get();
assert_eq!(actual_id, expected_id,);
}
fn assert_resolve<E>(interner: &mut StringInterner, symbol_id: u32, expected_str: E)
where
E: Into<Option<&'static str>>,
{
let actual_str = interner.resolve(Symbol {
id: NonZeroU32::new(symbol_id).unwrap(),
marker: PhantomData,
});
assert_eq!(actual_str.cloned(), expected_str.into(),);
}
#[test]
fn simple() {
let mut interner = StringInterner::new();
assert_id(&mut interner, "Hello", 1);
assert_id(&mut interner, ", World!", 2);
assert_id(&mut interner, "1 2 3", 3);
assert_id(&mut interner, "Hello", 1);
assert_resolve(&mut interner, 1, "Hello");
assert_resolve(&mut interner, 2, ", World!");
assert_resolve(&mut interner, 3, "1 2 3");
assert_resolve(&mut interner, 4, None);
}
}