use super::PrefixMap;
use fst::{Automaton, IntoStreamer, Map, MapBuilder, Streamer};
use std::borrow::Borrow;
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::hash::Hash;
#[derive(Debug, Clone)]
struct Prefixes<'a>(&'a [u8]);
#[derive(Debug, Clone, Copy)]
enum PrefixesState {
Pos(usize),
Fail,
Success,
}
impl Automaton for Prefixes<'_> {
type State = PrefixesState;
fn start(&self) -> Self::State {
PrefixesState::Pos(0)
}
fn is_match(&self, state: &Self::State) -> bool {
matches!(state, PrefixesState::Success)
}
fn accept(&self, state: &Self::State, byte: u8) -> Self::State {
match state {
&PrefixesState::Pos(ind) if self.0.get(ind) == Some(&byte) => {
PrefixesState::Pos(ind + 1)
}
_ => PrefixesState::Fail,
}
}
fn accept_eof(&self, state: &Self::State) -> Option<Self::State> {
match state {
PrefixesState::Pos(_) => Some(PrefixesState::Success),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct FstPrefixMap<V> {
map: Map<Vec<u8>>,
values: Box<[V]>,
}
impl<V> FstPrefixMap<V>
where
V: Hash + Eq,
{
pub fn from_vec<K, B>(inp: B) -> Self
where
K: Borrow<str> + Eq,
B: Into<Vec<(K, V)>>,
{
let mut ordered = inp.into();
ordered.sort_by(|(left, _), (right, _)| {
left.borrow().as_bytes().cmp(right.borrow().as_bytes())
});
super::remove_ordered_dups(&mut ordered);
let mut build = MapBuilder::memory();
let mut inds = HashMap::new();
for (key, val) in ordered {
let next = inds.len();
let val_ind = *match inds.entry(val) {
Entry::Occupied(ent) => ent.into_mut(),
Entry::Vacant(ent) => ent.insert(next),
};
let val = val_ind
.try_into()
.unwrap_or_else(|_| unreachable!("casting a usize into a u64"));
build
.insert(key.borrow().as_bytes(), val)
.unwrap_or_else(|_| unreachable!("memory builder"));
}
let mut ents: Vec<_> = inds.into_iter().collect();
ents.sort_unstable_by_key(|&(_, k)| k);
FstPrefixMap {
map: build.into_map(),
values: ents.into_iter().map(|(v, _)| v).collect(),
}
}
}
impl<K, V> FromIterator<(K, V)> for FstPrefixMap<V>
where
K: Borrow<str> + Eq,
V: Hash + Eq,
{
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
{
Self::from_vec(iter.into_iter().collect::<Vec<_>>())
}
}
impl<V> PrefixMap<V> for FstPrefixMap<V> {
fn get_longest_prefix<K>(&self, inp: K) -> Option<(usize, &V)>
where
K: AsRef<str>,
{
let mut res = self
.map
.get([])
.map(|ind| (0, &self.values[usize::try_from(ind).unwrap()]));
let matcher = Prefixes(inp.as_ref().as_bytes());
let mut stream = self.map.search(&matcher).into_stream();
while let Some((k, v)) = stream.next() {
match res {
Some((len, _)) if k.len() <= len => {}
_ => {
let ind: usize = v.try_into().unwrap();
res = Some((k.len(), &self.values[ind]));
}
}
}
res
}
}
#[cfg(test)]
mod tests {
use super::{FstPrefixMap, PrefixMap};
#[test]
fn correct_prefixes() {
let map = FstPrefixMap::from_vec([("a", 0), ("abc", 1), ("bc", 2), ("bc", 3)]);
assert_eq!(map.get_longest_prefix("abcd"), Some((3, &1)));
assert_eq!(map.get_longest_prefix("ab"), Some((1, &0)));
assert_eq!(map.get_longest_prefix("bca"), Some((2, &3)));
assert_eq!(map.get_longest_prefix("bd"), None);
assert_eq!(map.get_longest_prefix("💖"), None);
}
#[test]
fn works_for_perverse() {
let map = FstPrefixMap::from_vec([("", 0), (" 3", 1)]);
assert_eq!(map.get_longest_prefix(" 3 "), Some((2, &1)));
assert_eq!(map.get_longest_prefix("ab"), Some((0, &0)));
}
}