weighted_trie 0.2.0

Weighted prefix trie for autosuggestions
Documentation
use compact_str::CompactString;
use hashbrown::HashMap;
use smallvec::SmallVec;
use std::cmp::Reverse;
use std::mem::size_of;

const MAX_SUGGESTIONS_PER_NODE: usize = 10;
const SMALL_CHILDREN_CAPACITY: usize = 4;
const MAX_WORD_LENGTH: usize = 100;

type NodeIndex = u32;
type WordIndex = u32;
type PackedSuggestion = u64;

#[inline(always)]
const fn pack_suggestion(weight: u32, word_idx: WordIndex) -> PackedSuggestion {
    ((weight as u64) << 32) | (word_idx as u64)
}

#[inline(always)]
const fn get_weight(packed: PackedSuggestion) -> u32 {
    (packed >> 32) as u32
}

#[inline(always)]
const fn get_word_idx(packed: PackedSuggestion) -> WordIndex {
    packed as u32
}

#[derive(Clone)]
enum Children {
    Small(SmallVec<[(char, NodeIndex); SMALL_CHILDREN_CAPACITY]>),
    Large(HashMap<char, NodeIndex>),
}

impl Children {
    #[inline]
    fn new() -> Self {
        Self::Small(SmallVec::new())
    }

    #[inline]
    fn get(&self, c: char) -> Option<NodeIndex> {
        match self {
            Self::Small(vec) => vec.iter().find_map(|&(ch, idx)| (ch == c).then_some(idx)),
            Self::Large(map) => map.get(&c).copied(),
        }
    }

    #[inline]
    fn insert(&mut self, c: char, idx: NodeIndex) {
        match self {
            Self::Small(vec) if vec.len() < SMALL_CHILDREN_CAPACITY => {
                if let Some(entry) = vec.iter_mut().find(|(ch, _)| *ch == c) {
                    entry.1 = idx;
                } else {
                    vec.push((c, idx));
                }
            }
            Self::Small(vec) => {
                #[cold]
                fn transition_to_large(
                    vec: &mut SmallVec<[(char, NodeIndex); SMALL_CHILDREN_CAPACITY]>,
                    c: char,
                    idx: NodeIndex,
                ) -> Children {
                    let mut map: HashMap<_, _> = vec.drain(..).collect();
                    map.insert(c, idx);
                    Children::Large(map)
                }
                *self = transition_to_large(vec, c, idx);
            }
            Self::Large(map) => {
                map.insert(c, idx);
            }
        }
    }
}

#[derive(Default)]
pub struct TrieNode {
    children: Children,
    suggestions: SmallVec<[PackedSuggestion; 2]>,
}

impl Default for Children {
    fn default() -> Self {
        Self::new()
    }
}

pub struct WeightedTrie {
    nodes: Vec<TrieNode>,
    root: NodeIndex,
    words: Vec<CompactString>,
    word_map: HashMap<CompactString, WordIndex>,
    max_word_length: usize,
    max_suggestions: usize,
}

#[derive(Clone)]
pub struct WeightedString {
    pub word: String,
    pub weight: u32,
}

impl WeightedString {
    pub fn new(word: impl Into<String>, weight: u32) -> Self {
        Self {
            word: word.into(),
            weight,
        }
    }
}

pub struct MemoryStats {
    pub nodes_count: usize,
    pub nodes_vec_capacity: usize,
    pub nodes_struct_size: usize,
    pub words_count: usize,
    pub words_storage_bytes: usize,
    pub words_capacity_bytes: usize,
    pub word_map_capacity: usize,
    pub suggestions_total: usize,
    pub suggestions_heap_bytes: usize,
    pub children_small_count: usize,
    pub children_large_count: usize,
    pub children_heap_bytes: usize,
    pub total_bytes: usize,
}

impl WeightedTrie {
    pub fn new() -> Self {
        Self::with_config(MAX_WORD_LENGTH, MAX_SUGGESTIONS_PER_NODE)
    }

    pub fn with_max_word_length(max_word_length: usize) -> Self {
        Self::with_config(max_word_length, MAX_SUGGESTIONS_PER_NODE)
    }

    pub fn with_max_suggestions(max_suggestions: usize) -> Self {
        Self::with_config(MAX_WORD_LENGTH, max_suggestions)
    }

    pub fn with_config(max_word_length: usize, max_suggestions: usize) -> Self {
        Self {
            nodes: vec![TrieNode::default()],
            root: 0,
            words: Vec::new(),
            word_map: HashMap::new(),
            max_word_length,
            max_suggestions,
        }
    }

    pub fn memory_stats(&self) -> MemoryStats {
        let nodes_count = self.nodes.len();
        let nodes_vec_capacity = self.nodes.capacity();
        let nodes_struct_size = nodes_count * size_of::<TrieNode>();

        let words_count = self.words.len();
        let words_storage_bytes: usize = self.words.iter().map(|s| s.len()).sum();
        let words_capacity_bytes: usize = self.words.iter().map(|s| s.capacity()).sum();
        let word_map_capacity = self.word_map.capacity();

        let (
            suggestions_total,
            suggestions_heap_bytes,
            children_small_count,
            children_large_count,
            children_heap_bytes,
        ) = self.nodes.iter().fold(
            (0, 0, 0, 0, 0),
            |(sugg_total, sugg_heap, small, large, child_heap), node| {
                let sugg_heap_add = if node.suggestions.spilled() {
                    node.suggestions.capacity() * size_of::<PackedSuggestion>()
                } else {
                    0
                };

                let (small_add, large_add, child_heap_add) = match &node.children {
                    Children::Small(_) => (1, 0, 0),
                    Children::Large(map) => (
                        0,
                        1,
                        map.capacity() * (size_of::<char>() + size_of::<u32>() + 8),
                    ),
                };

                (
                    sugg_total + node.suggestions.len(),
                    sugg_heap + sugg_heap_add,
                    small + small_add,
                    large + large_add,
                    child_heap + child_heap_add,
                )
            },
        );

        let total_bytes = nodes_struct_size
            + nodes_vec_capacity * size_of::<TrieNode>()
            + words_capacity_bytes
            + word_map_capacity * (size_of::<CompactString>() + size_of::<u32>() + 8)
            + suggestions_heap_bytes
            + children_heap_bytes;

        MemoryStats {
            nodes_count,
            nodes_vec_capacity,
            nodes_struct_size,
            words_count,
            words_storage_bytes,
            words_capacity_bytes,
            word_map_capacity,
            suggestions_total,
            suggestions_heap_bytes,
            children_small_count,
            children_large_count,
            children_heap_bytes,
            total_bytes,
        }
    }

    pub fn build(weighted_strings: Vec<WeightedString>) -> Self {
        Self::build_with_config(weighted_strings, MAX_WORD_LENGTH, MAX_SUGGESTIONS_PER_NODE)
    }

    pub fn build_with_max_word_length(
        weighted_strings: Vec<WeightedString>,
        max_word_length: usize,
    ) -> Self {
        Self::build_with_config(weighted_strings, max_word_length, MAX_SUGGESTIONS_PER_NODE)
    }

    pub fn build_with_max_suggestions(
        weighted_strings: Vec<WeightedString>,
        max_suggestions: usize,
    ) -> Self {
        Self::build_with_config(weighted_strings, MAX_WORD_LENGTH, max_suggestions)
    }

    pub fn build_with_config(
        weighted_strings: Vec<WeightedString>,
        max_word_length: usize,
        max_suggestions: usize,
    ) -> Self {
        let count = weighted_strings.len();
        let mut trie = Self {
            nodes: Vec::with_capacity((count * 2).max(1000)),
            root: 0,
            words: Vec::with_capacity(count),
            word_map: HashMap::with_capacity(count),
            max_word_length,
            max_suggestions,
        };
        trie.nodes.push(TrieNode::default());

        for WeightedString { word, weight } in weighted_strings {
            trie.insert(word, weight);
        }

        trie.words.shrink_to_fit();
        trie.word_map.shrink_to_fit();
        trie.nodes.shrink_to_fit();

        trie
    }

    pub fn insert(&mut self, word: impl Into<String>, weight: u32) -> bool {
        let word = word.into();

        if word.len() > self.max_word_length {
            return false;
        }

        let word_compact = CompactString::from(&word);
        let word_idx = *self
            .word_map
            .entry(word_compact.clone())
            .or_insert_with(|| {
                self.words.push(word_compact);
                (self.words.len() - 1) as WordIndex
            });

        let packed = pack_suggestion(weight, word_idx);
        let mut node_idx = self.root;

        for c in word.chars() {
            node_idx = self.get_or_create_child(node_idx, c);
            self.insert_suggestion(node_idx, word_idx, packed, weight);
        }

        true
    }

    #[inline]
    fn get_or_create_child(&mut self, node_idx: NodeIndex, c: char) -> NodeIndex {
        if let Some(idx) = self.nodes[node_idx as usize].children.get(c) {
            return idx;
        }

        let new_idx = self.nodes.len() as NodeIndex;
        self.nodes.push(TrieNode::default());
        self.nodes[node_idx as usize].children.insert(c, new_idx);
        new_idx
    }

    #[inline]
    fn insert_suggestion(
        &mut self,
        node_idx: NodeIndex,
        word_idx: WordIndex,
        packed: PackedSuggestion,
        weight: u32,
    ) {
        let node = &mut self.nodes[node_idx as usize];

        if let Some(pos) = node
            .suggestions
            .iter()
            .position(|&p| get_word_idx(p) == word_idx)
        {
            if weight > get_weight(node.suggestions[pos]) {
                node.suggestions.remove(pos);
            } else {
                return;
            }
        }

        let pos = node
            .suggestions
            .binary_search_by_key(&Reverse(weight), |&p| Reverse(get_weight(p)))
            .unwrap_or_else(|x| x);

        node.suggestions.insert(pos, packed);

        if node.suggestions.len() > self.max_suggestions {
            node.suggestions.truncate(self.max_suggestions);
        }
    }

    pub fn search(&self, prefix: &str) -> Vec<String> {
        let mut node_idx = self.root;

        for c in prefix.chars() {
            node_idx = match self.nodes[node_idx as usize].children.get(c) {
                Some(idx) => idx,
                None => return Vec::new(),
            };
        }

        self.nodes[node_idx as usize]
            .suggestions
            .iter()
            .map(|&packed| self.words[get_word_idx(packed) as usize].to_string())
            .collect()
    }

    pub fn max_word_length(&self) -> usize {
        self.max_word_length
    }

    pub fn max_suggestions(&self) -> usize {
        self.max_suggestions
    }
}

impl Default for WeightedTrie {
    fn default() -> Self {
        Self::new()
    }
}