parattice 0.2.2

Recursive paraphrase lattice generator
Documentation
use crate::lattice::Lattice;
use crate::lattice::LatticeNode;

use std::collections::HashMap;
use std::collections::HashSet;
use std::collections::VecDeque;

pub struct PMANode<'a> {
    edges: HashMap<&'a str, usize>,
    fail: usize,
    matched: Vec<usize>,
}

pub struct LatticeSearcher<'a> {
    pma: Vec<PMANode<'a>>,
    patterns: Vec<Vec<&'a str>>,
}

impl<'a> LatticeSearcher<'a> {
    pub fn new(patterns: Vec<Vec<&'a str>>) -> LatticeSearcher<'a> {
        let mut pma = vec![PMANode {
            edges: HashMap::new(),
            fail: 0,
            matched: vec![],
        }];
        for (i, phrase) in patterns.iter().enumerate() {
            let mut node_id = 0;
            for word in phrase {
                node_id = if let Some(&node_id_next) = pma[node_id].edges.get(word) {
                    node_id_next
                } else {
                    let node_id_next = pma.len();
                    pma[node_id].edges.insert(word, node_id_next);
                    pma.push(PMANode {
                        edges: HashMap::new(),
                        fail: 0,
                        matched: vec![],
                    });
                    node_id_next
                }
            }
            pma[node_id].matched.push(i);
        }
        let mut queue = VecDeque::new();
        for &node_id in pma[0].edges.values() {
            queue.push_back(node_id);
        }
        while let Some(node_id) = queue.pop_front() {
            for (c, node_id_next) in pma[node_id].edges.clone().iter() {
                queue.push_back(*node_id_next);
                let mut node_id_fail = node_id;
                while node_id_fail != 0 {
                    node_id_fail = pma[node_id_fail].fail;
                    if let Some(&edge_target) = pma[node_id_fail].edges.get(c) {
                        node_id_fail = edge_target;
                        break;
                    }
                }
                let matched = pma[node_id_fail].matched.clone();
                let mut node = &mut pma[*node_id_next];
                node.fail = node_id_fail;
                node.matched.extend(matched);
            }
        }
        LatticeSearcher {
            pma: pma,
            patterns: patterns,
        }
    }

    fn next_pma_state_id(pma: &Vec<PMANode>, state_id: usize, edge_str: &str) -> usize {
        let mut next_state_id = state_id;
        loop {
            if let Some(&x) = pma[next_state_id].edges.get(edge_str) {
                return x;
            }
            if next_state_id == 0 {
                return 0;
            }
            next_state_id = pma[next_state_id].fail;
        }
    }

    fn backward_match(phrase: &Vec<&'a str>, lattice: &'a Vec<LatticeNode>, pos: usize) -> Vec<VecDeque<(&'a str, usize)>> {
        let mut result = vec![];
        let mut backward_queue = VecDeque::new();
        let mut init_path = VecDeque::new();
        init_path.push_front(("", pos));
        backward_queue.push_back((phrase.len(), pos, init_path));
        while let Some((phrase_pos, lattice_node_id, path)) = backward_queue.pop_front() {
            if phrase_pos == 0 {
                result.push(path);
            } else {
                for &(edge_str, edge_target) in lattice[lattice_node_id].backwards.keys() {
                    if edge_str == phrase[phrase_pos - 1] {
                        let mut new_path = path.clone();
                        new_path.front_mut().unwrap().0 = edge_str;
                        new_path.push_front(("", edge_target));
                        backward_queue.push_back((phrase_pos - 1, edge_target, new_path));
                    }
                }
            }
        }
        result
    }

    pub fn search(&self, lattice: &'a Lattice) -> Vec<Vec<(&'a str, usize)>> {
        let lattice = &lattice.lattice;
        let mut queue = VecDeque::new();
        queue.push_back((0, 0));
        let mut state_id_cache = HashSet::new();
        let mut result = vec![];
        while let Some((lattice_node_id, pma_state_id)) = queue.pop_front() {
            for (lattice_edge_str, lattice_egde_target) in lattice[lattice_node_id].forwards.keys() {
                let pma_state_id_new = Self::next_pma_state_id(&self.pma, pma_state_id, lattice_edge_str);
                if !state_id_cache.contains(&(*lattice_egde_target, pma_state_id_new)) {
                    queue.push_back((*lattice_egde_target, pma_state_id_new));
                    state_id_cache.insert((*lattice_egde_target, pma_state_id_new));
                }
            }
            for &phrase_id in &self.pma[pma_state_id].matched {
                let phrase = &self.patterns[phrase_id];
                for m in Self::backward_match(phrase, &lattice, lattice_node_id) {
                    result.push(m.into_iter().collect());
                }
            }
        }
        result
    }
}