use rustc_hash::{FxHashMap, FxHashSet};
use super::node::{ScdawgNode, NIL};
use crate::value::DictionaryValue;
use crate::CharUnit;
#[derive(Debug)]
pub struct ScdawgCoreInner<U: CharUnit, V: DictionaryValue> {
pub nodes: Vec<ScdawgNode<U, V>>,
pub last: usize,
pub term_count: usize,
pub terms: Vec<String>,
pub term_set: FxHashSet<String>,
pub term_values: FxHashMap<String, V>,
pub left_edges_computed: bool,
}
impl<U: CharUnit, V: DictionaryValue> ScdawgCoreInner<U, V> {
pub fn new() -> Self {
Self {
nodes: vec![ScdawgNode::root()],
last: 0,
term_count: 0,
terms: Vec::new(),
term_set: FxHashSet::default(),
term_values: FxHashMap::default(),
left_edges_computed: false,
}
}
pub fn with_capacity(term_count: usize, total_chars: usize) -> Self {
let estimated_nodes = total_chars.saturating_mul(2);
let mut nodes = Vec::with_capacity(estimated_nodes);
nodes.push(ScdawgNode::root());
Self {
nodes,
last: 0,
term_count: 0,
terms: Vec::with_capacity(term_count),
term_set: FxHashSet::with_capacity_and_hasher(term_count, Default::default()),
term_values: FxHashMap::with_capacity_and_hasher(term_count, Default::default()),
left_edges_computed: false,
}
}
pub fn alloc_node(&mut self, length: usize, suffix_link: usize, first_char: U) -> usize {
let idx = self.nodes.len();
self.nodes
.push(ScdawgNode::new(length, suffix_link, first_char));
idx
}
pub fn clone_node(&mut self, src: usize) -> usize {
let idx = self.nodes.len();
self.nodes.push(self.nodes[src].clone());
idx
}
pub fn sa_extend(&mut self, c: U, term_idx: usize, pos: usize) {
let first_char = if self.nodes[self.last].length == 0 {
c
} else {
self.nodes[self.last].first_char
};
let cur = self.alloc_node(self.nodes[self.last].length + 1, 0, first_char);
self.nodes[cur].parent = self.last;
self.nodes[cur].parent_label = c;
self.nodes[cur].depth = self.nodes[self.last].depth + 1;
let mut p = self.last;
while p != NIL && self.nodes[p].get_edge(c).is_none() {
self.nodes[p].set_edge(c, cur);
p = self.nodes[p].suffix_link;
}
if p == NIL {
self.nodes[cur].suffix_link = 0;
} else {
let q = self.nodes[p]
.get_edge(c)
.expect("invariant: p has edge c by Phase 1 break condition");
if self.nodes[p].length + 1 == self.nodes[q].length {
self.nodes[cur].suffix_link = q;
} else {
let clone = self.clone_node(q);
self.nodes[clone].length = self.nodes[p].length + 1;
self.nodes[clone].first_char = if self.nodes[p].length == 0 {
c
} else {
self.nodes[p].first_char
};
self.nodes[cur].suffix_link = clone;
self.nodes[q].suffix_link = clone;
self.nodes[clone].parent = p;
self.nodes[clone].parent_label = c;
self.nodes[clone].depth = self.nodes[p].depth + 1;
self.nodes[clone].term_ends.clear();
self.nodes[clone].is_final = false;
self.nodes[clone].value = None;
while p != NIL && self.nodes[p].get_edge(c) == Some(q) {
self.nodes[p].set_edge(c, clone);
p = self.nodes[p].suffix_link;
}
}
}
self.nodes[cur].term_ends.push((term_idx, pos));
self.last = cur;
self.left_edges_computed = false;
}
pub fn insert(&mut self, term: &str) -> bool {
if self.term_set.contains(term) {
return false;
}
let term_idx = self.term_count;
self.last = 0;
for (pos, unit) in U::iter_str(term).enumerate() {
self.sa_extend(unit, term_idx, pos);
}
self.nodes[self.last].is_final = true;
let term_string = term.to_string();
self.term_set.insert(term_string.clone());
self.terms.push(term_string);
self.term_count += 1;
true
}
pub fn insert_with_value(&mut self, term: &str, value: V) -> bool {
if self.term_set.contains(term) {
self.term_values.insert(term.to_string(), value.clone());
if let Some(node) = self.find_substring_fast(term) {
if self.nodes[node].is_final {
self.nodes[node].value = Some(value);
}
}
return false;
}
if self.insert(term) {
self.nodes[self.last].value = Some(value.clone());
self.term_values.insert(term.to_string(), value);
true
} else {
false
}
}
pub fn compute_left_edges(&mut self) {
if self.left_edges_computed {
return;
}
for node in &mut self.nodes {
node.left_edges.clear();
}
for node_idx in 1..self.nodes.len() {
let suffix_target = self.nodes[node_idx].suffix_link;
if suffix_target != NIL {
let label = self.nodes[node_idx].first_char;
self.nodes[suffix_target].left_edges.push((label, node_idx));
}
}
self.left_edges_computed = true;
}
pub fn find_substring_fast(&self, pattern: &str) -> Option<usize> {
if pattern.is_empty() {
return Some(0);
}
let mut current = 0;
for unit in U::iter_str(pattern) {
match self.nodes[current].get_edge(unit) {
Some(next) => current = next,
None => return None,
}
}
Some(current)
}
pub fn contains_substring(&self, pattern: &str) -> bool {
self.find_substring_fast(pattern).is_some()
}
pub fn find_exact_substring(&self, pattern: &str) -> Vec<(String, usize)> {
if pattern.is_empty() {
return self.terms.iter().map(|t| (t.clone(), 0)).collect();
}
let end_node = match self.find_substring_fast(pattern) {
Some(node) => node,
None => return Vec::new(),
};
let pattern_len = U::from_str(pattern).len();
let mut results = Vec::new();
self.collect_term_positions(end_node, pattern_len, &mut results);
results
}
pub fn collect_term_positions(
&self,
node: usize,
pattern_len: usize,
results: &mut Vec<(String, usize)>,
) {
for &(term_idx, end_pos) in &self.nodes[node].term_ends {
if end_pos + 1 >= pattern_len {
let start_pos = end_pos + 1 - pattern_len;
if term_idx < self.terms.len() {
results.push((self.terms[term_idx].clone(), start_pos));
}
}
}
for &(_, target) in &self.nodes[node].left_edges {
self.collect_term_positions(target, pattern_len, results);
}
}
pub fn contains(&self, term: &str) -> bool {
self.term_set.contains(term)
}
pub fn term_count(&self) -> usize {
self.term_count
}
pub fn iter_terms(&self) -> impl Iterator<Item = &String> {
self.terms.iter()
}
pub fn frequency(&self, pattern: &str) -> usize {
if pattern.is_empty() {
return self.terms.iter().map(|t| U::from_str(t).len() + 1).sum();
}
match self.find_substring_fast(pattern) {
Some(node) => {
let mut count = 0;
self.count_occurrences(node, &mut count);
count
}
None => 0,
}
}
pub fn count_occurrences(&self, node: usize, count: &mut usize) {
*count += self.nodes[node].term_ends.len();
for &(_, target) in &self.nodes[node].left_edges {
self.count_occurrences(target, count);
}
}
}
impl<U: CharUnit, V: DictionaryValue> Default for ScdawgCoreInner<U, V> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scdawg_inner_byte_smoke() {
let mut inner: ScdawgCoreInner<u8, ()> = ScdawgCoreInner::new();
assert!(inner.insert("cat"));
assert!(inner.insert("car"));
assert!(!inner.insert("cat")); assert_eq!(inner.term_count(), 2);
assert!(inner.contains_substring("ca"));
assert!(inner.contains_substring("at"));
assert!(!inner.contains_substring("zz"));
inner.compute_left_edges();
assert_eq!(inner.frequency("ca"), 2);
assert_eq!(inner.frequency("at"), 1);
}
#[test]
fn scdawg_inner_char_smoke() {
let mut inner: ScdawgCoreInner<char, ()> = ScdawgCoreInner::new();
assert!(inner.insert("café"));
assert!(!inner.insert("café")); assert_eq!(inner.term_count(), 1);
assert!(inner.contains_substring("café"));
assert!(inner.contains_substring("afé"));
inner.compute_left_edges();
assert_eq!(inner.frequency("café"), 1);
}
}