vom_rs 0.3.0

A library for Probabilistic Finite Automata.
Documentation
use std::cmp::min;
use std::collections::{HashMap, VecDeque};
use std::fmt::Write;
use std::hash::Hash;

/**********************************************************************************************/
/* Helper functions to determine empirical probability of tokens and symbols within a sample. */
/**********************************************************************************************/

/// Chi function following Ron, Singer and Tishby (1996).
pub fn chi<T: Eq>(sample: &[T], token: &[T], j: usize) -> u32 {
    // check if token matches the subsequence in the sample
    u32::from(&sample[(j - token.len())..j] == token)
}

/// Chi function, but with token + symbol instead of just the token.
pub fn chi_with_symbol<T: Eq>(sample: &[T], token: &[T], symbol: &T, j: usize) -> u32 {
    u32::from(chi(sample, token, j - 1) == 1 && sample[j - 1] == *symbol)
}

/// determine empirical probablity of token given sample
pub fn empirical_probability_of_token<T: Eq>(sample: &[T], token: &[T], bound: usize) -> f32 {
    let mut sum = 0;
    for j in min(bound, token.len())..sample.len() {
        sum += chi(sample, token, j);
    }
    sum as f32 / (sample.len() - bound) as f32
}

/// determine empirical probablity of symbol given sample
pub fn empirical_probability_of_symbol<T: Eq>(sample: &[T], symbol: &T, bound: usize) -> f32 {
    let mut sum = 0;
    for s in sample {
        if s == symbol {
            sum += 1;
        }
    }
    sum as f32 / (sample.len() - bound) as f32
}

/// determine empirical probablity of symbol, given token and sample
pub fn empirical_probability_of_symbol_given_token<T: Eq>(
    sample: &[T],
    token: &[T],
    symbol: &T,
    bound: usize,
) -> f32 {
    match token.len() {
        // if token is empty, use empirical probablility of symbol instead
        0 => empirical_probability_of_symbol(sample, symbol, bound),
        _ => {
            let mut p_a = 0;
            let mut p_b = 0;

            for j in bound..sample.len() {
                p_a += chi_with_symbol(sample, token, symbol, j + 1);
                p_b += chi(sample, token, j);
            }

            if p_a == 0 && p_b == 0 {
                0.0
            } else {
                p_a as f32 / p_b as f32
            }
        }
    }
}

/****************************************************************/
/* Node structure and operations of Probabilistix Suffix Trees. */
/****************************************************************/

/// Node structure to represent a Probabilistic Suffix Tree.
#[derive(PartialEq, Clone)]
pub struct PstNode<T: Eq + Copy + Hash + std::fmt::Debug> {
    pub label: Vec<T>,
    pub child_probability: HashMap<T, f32>,
    pub children: HashMap<T, PstNode<T>>,
}

impl<T: Eq + Copy + Hash + std::fmt::Debug> PstNode<T> {
    /// Create a probabilistic suffix tree node with empty label.
    pub fn with_empty_label() -> Self {
        PstNode {
            label: Vec::new(),
            child_probability: HashMap::new(),
            children: HashMap::new(),
        }
    }

    /// Create a probabilistic suffix tree node with given label.
    fn with_label(label: &[T]) -> Self {
        PstNode {
            label: label.to_vec(),
            child_probability: HashMap::new(),
            children: HashMap::new(),
        }
    }

    /// Create a probabilistic suffix tree node with given label.
    fn with_label_and_probs(label: &[T], probs: HashMap<T, f32>) -> Self {
        PstNode {
            label: label.to_vec(),
            child_probability: probs,
            children: HashMap::new(),
        }
    }

    /// Add a child node to a PST node.
    #[allow(clippy::or_fun_call)]
    fn get_or_insert_child(
        &mut self,
        key: T,
        label: &[T],
        copy_gamma: bool,
    ) -> (bool, &mut PstNode<T>) {
        // let calling function know whether this note was already present or not ...
        let inserted = !self.children.contains_key(&key);
        if copy_gamma {
            //println!("INSERT CG {:?} {:?}", key, label);
            (
                inserted,
                self.children
                    .entry(key)
                    .or_insert(PstNode::with_label_and_probs(
                        label,
                        self.child_probability.clone(),
                    )),
            )
        } else {
            //println!("INSERT NON-CG {:?} {:?}", key, label);
            (
                inserted,
                self.children
                    .entry(key)
                    .or_insert(PstNode::with_label(label)),
            )
        }
    }
}

/// Add a leaf to a Probabilistic suffix tree, add nodes along the path if necessary.
pub fn add_leaf<T: Eq + Copy + Hash + std::fmt::Debug>(
    node: &mut PstNode<T>,
    label: &[T],
) -> Vec<Vec<T>> {
    //println!("LEAF {:?}", label);
    let mut added_nodes: Vec<Vec<T>> = Vec::new();
    add_leaf_recursion(node, label, label.len() - 1, false, &mut added_nodes);
    added_nodes
}

/// Add a node to a Probabilistic Suffix Tree node, adding nodes on the path if necessary.
/// If specified, copy gamma function of parent node.
fn add_leaf_recursion<T: Eq + Copy + Hash + std::fmt::Debug>(
    node: &mut PstNode<T>,
    label: &[T],
    label_idx: usize,
    copy_gamma: bool,
    added_nodes: &mut Vec<Vec<T>>,
) {
    let path_node = node.get_or_insert_child(label[label_idx], &label[label_idx..], copy_gamma);
    if path_node.0 {
        added_nodes.push(label[label_idx..].to_vec());
    }
    if label_idx != 0 {
        add_leaf_recursion(path_node.1, label, label_idx - 1, copy_gamma, added_nodes);
    }
}

/// recursive function to fill in missing nodes (rarely called)
fn complete_inner_nodes<T: Eq + Copy + Hash + std::fmt::Debug>(
    node: &mut PstNode<T>,
    alphabet: &[T],
) {
    for child in node.children.values_mut() {
        complete_inner_nodes(child, alphabet);
    }
    // fill in if some symbols are missing, but only for inner nodes,
    // not for leaves
    if !node.children.is_empty() && node.children.len() != alphabet.len() {
        for symbol in alphabet {
            if !node.children.contains_key(symbol) {
                let mut label: Vec<T> = vec![*symbol];
                label.extend_from_slice(node.label.as_slice());
                node.children.insert(*symbol, PstNode::with_label(&label));
            }
        }
    }
}

/// Complete the probability function of a probabilistic suffix tree.
fn complete_gamma<T: Eq + Copy + Hash + std::fmt::Debug>(
    node: &mut PstNode<T>,
    parent_label: Option<&[T]>,
    alphabet: &[T],
    sample: &[T],
    gamma_min: f32,
    bound: usize,
) {
    match parent_label {
        Some(p) => {
            for symbol in alphabet {
                node.child_probability.insert(
                    *symbol,
                    empirical_probability_of_symbol_given_token(sample, p, symbol, bound),
                );
            }
        }
        None => {
            for symbol in alphabet {
                let prob = gamma_min
                    + (empirical_probability_of_symbol(sample, symbol, bound)
                        * (1.0 - (alphabet.len() as f32 * gamma_min)));
                node.child_probability.insert(*symbol, prob);
            }
        }
    };
    for child in node.children.values_mut() {
        complete_gamma(child, Some(&node.label), alphabet, sample, gamma_min, bound);
    }
}

/// Learn a PST from a sample string and an alphabet. Returns the root.
pub fn learn_with_alphabet<T: Eq + Copy + Hash + std::fmt::Debug>(
    sample: &[T],
    alphabet: &[T],
    bound: usize,
    epsilon: f32,
    n: usize,
) -> PstNode<T> {
    let epsilon2 = epsilon / (48.0 * bound as f32);
    let gamma_min = epsilon2 / alphabet.len() as f32;
    let epsilon0 = epsilon / (2.0 * n as f32 * bound as f32 * (1.0 / gamma_min).ln()); // is this really ln ?? check paper !!
    let epsilon1 = epsilon2 / (8.0 * n as f32 * epsilon0 * gamma_min);
    let epsilon3 = epsilon0 * (1.0 - epsilon1);

    let mut root = PstNode::with_empty_label();
    let mut tokens: VecDeque<Vec<T>> = VecDeque::new();

    /*
    println!(
        "e {} bound {} n {} e0 {} e1 {} e2 {} e3 {} gmin {}",
        epsilon, bound, n, epsilon0, epsilon1, epsilon2, epsilon3, gamma_min
    );
     */

    // start with single-character tokens
    for symbol in alphabet {
        if empirical_probability_of_symbol(sample, symbol, bound) >= epsilon3 {
            tokens.push_back(vec![*symbol]);
        }
    }

    //println!("initial tokens {:?}", tokens);
    //println!("alphabet {:?}", alphabet);

    let sym_p_thresh = (1.0 + epsilon2) * gamma_min;
    let sym_p_suf_thresh = 1.0 + (3.0 * epsilon2);

    // as long as there's tokens to handle
    while let Some(token) = tokens.pop_front() {
        //println!("CHECK CURRENT TOKEN {:?}", token);
        // check if we need to add this token to the tree ...
        for symbol in alphabet {
            let sym_p = empirical_probability_of_symbol_given_token(sample, &token, symbol, bound);
            let sym_p_suf =
                empirical_probability_of_symbol_given_token(sample, &token[1..], symbol, bound);

            if sym_p >= sym_p_thresh && (sym_p / sym_p_suf) > sym_p_suf_thresh && sym_p_suf > 0.0 {
                //println!("ADD {:?} SYM {:?} symp {} thresh {} sympsuf {} thresh {}", token.as_slice(), symbol, sym_p, sym_p_thresh, (sym_p / sym_p_suf), sym_p_suf_thresh);
                add_leaf(&mut root, token.as_slice());
                break;
            } //else {
              //  println!("DONT ADD {:?}{:?}", token, symbol);
              //}
        }

        // check if we need to add tokens to the list ...
        // this is pretty much always below 0, not sure why ...
        // (same in the lisp impl)
        let token_thresh = f32::max(0.0_f32, (1.0 - epsilon1) * epsilon0);
        //println!("token thresh {:?} {}", token, token_thresh);

        if token.len() < bound {
            for symbol in alphabet {
                // there should be a way to do this more efficiently, i.e. a method
                // that determines the probability without having to construct the vector just yet ...
                let mut potential_new_token = token.clone();
                potential_new_token.push(*symbol);

                let epr = empirical_probability_of_token(sample, &potential_new_token, bound);

                if epr > token_thresh {
                    //println!("POSSIBLE TOK {} {:?}", epr, potential_new_token);
                    tokens.push_back(potential_new_token);
                }
            }
        }
    } // end token loop

    // complete the inner nodes tree if necessary ..
    complete_inner_nodes(&mut root, alphabet);

    // fill in probabilities
    complete_gamma(&mut root, None, alphabet, sample, gamma_min, bound);
    // return root
    root
}

/*************************************************/
/* Helper functions to transform a PST to a PFA. */
/*************************************************/

/// Property as defined in the original paper. TBD
#[allow(dead_code)]
pub fn has_star_property<T: Eq + Copy + Hash + std::fmt::Debug>(
    _root: &mut PstNode<T>,
    _alphabet: &[T],
) -> bool {
    false
}

/// Find the longest suffix for label in the tree.
pub fn find_longest_suffix_state<'a, T: Eq + Copy + Hash + std::fmt::Debug>(
    root: &'a PstNode<T>,
    label: &[T],
) -> &'a PstNode<T> {
    if label.is_empty() {
        root
    } else {
        let last = label.last().unwrap();
        if root.children.contains_key(last) {
            find_longest_suffix_state(
                root.children.get(last).unwrap(),
                &label[..(label.len() - 1)],
            )
        } else {
            root
        }
    }
}

/// Find the longest suffix for label plus a symbol in the tree.
pub fn find_longest_suffix_state_with_symbol<'a, T: Eq + Copy + Hash + std::fmt::Debug>(
    root: &'a PstNode<T>,
    label: &[T],
    symbol: &T,
) -> &'a PstNode<T> {
    if root.children.contains_key(symbol) {
        find_longest_suffix_state(root.children.get(symbol).unwrap(), label)
    } else {
        root
    }
}

/// Collect all child labels of a node.
fn collect_child_labels<T: Eq + Copy + Hash + std::fmt::Debug>(
    root: &PstNode<T>,
    labels: &mut Vec<Vec<T>>,
) {
    labels.push(root.label.clone());
    if !root.children.is_empty() {
        for (_, v) in root.children.iter() {
            collect_child_labels(v, labels);
        }
    }
}

/// Get all child labels of a node.
pub fn get_child_labels<T: Eq + Copy + Hash + std::fmt::Debug>(node: &PstNode<T>) -> Vec<Vec<T>> {
    let mut child_labels = Vec::new();

    collect_child_labels(node, &mut child_labels);

    child_labels
}

/// Get all states that end with symbol.
pub fn get_suffix_symbol_states<T: Eq + Copy + Hash + std::fmt::Debug>(
    root: &PstNode<T>,
    symbol: T,
) -> Vec<Vec<T>> {
    let mut child_labels = Vec::new();

    if root.children.contains_key(&symbol) {
        collect_child_labels(&root.children[&symbol], &mut child_labels);
    }

    child_labels
}

fn get_states_containing_symbol_rec<T: Eq + Copy + Hash + std::fmt::Debug>(
    root: &PstNode<T>,
    symbol: T,
    states: &mut Vec<Vec<T>>,
) {
    if root.label.iter().any(|s| *s == symbol) {
        states.push(root.label.clone());
    }

    if !root.children.is_empty() {
        for (_, v) in root.children.iter() {
            get_states_containing_symbol_rec(v, symbol, states);
        }
    }
}

/// get all states that contain a symbol
pub fn get_states_containing_symbol<T: Eq + Copy + Hash + std::fmt::Debug>(
    root: &PstNode<T>,
    symbol: T,
) -> Vec<Vec<T>> {
    let mut states = Vec::new();

    get_states_containing_symbol_rec(root, symbol, &mut states);

    states
}

/// Debug output recursion.
fn to_dot_recursion<T: Eq + Copy + Hash + std::fmt::Debug>(
    node: &PstNode<T>,
    idx: &mut usize,
    mut w: &mut dyn Write,
) {
    let cur = *idx;
    let mut label = "".to_string();
    for c in &node.label {
        write!(label, "{:?}", c).unwrap();
    }

    label.retain(|c| c != '\"');

    label.push_str(", ");

    for (sym, prob) in node.child_probability.iter() {
        let mut symstring = format!("{:?}", sym);
        symstring.retain(|c| c != '\"');
        write!(label, "{} {}, ", symstring, prob).unwrap();
    }

    writeln!(&mut w, "{}[label=\"{}\"]", idx, label).unwrap();

    for child in node.children.values() {
        *idx += 1;
        writeln!(
            &mut w,
            "{}->{}[weight=1.0, penwidth=1.0, rank=same, arrowsize=1.0]",
            cur, idx
        )
        .unwrap();
        to_dot_recursion(child, idx, w);
    }
}

/// Debug output.
pub fn to_dot<T: Eq + Copy + Hash + std::fmt::Debug>(root: &PstNode<T>) -> String {
    let mut w = String::new();
    writeln!(&mut w, "digraph{{").unwrap();
    let mut idx = 0;
    to_dot_recursion(root, &mut idx, &mut w);
    writeln!(&mut w, "}}").unwrap();
    w
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::fs;

    #[test]
    fn test_chi_char() {
        let sample = vec!['a', 'b', 'c', 'd', 'e', 'f'];
        let token = vec!['c', 'd', 'e'];

        assert_eq! {chi(&sample, &token, 5), 1};
        assert_eq! {chi(&sample, &token, 4), 0};
    }

    #[test]
    fn test_chi_with_symbol_char() {
        let sample = vec!['a', 'b', 'c', 'd', 'e', 'f'];
        let token = vec!['c', 'd'];
        let symbol = 'e';

        assert_eq! {chi_with_symbol(&sample, &token, &symbol, 5), 1};
        assert_eq! {chi_with_symbol(&sample, &token, &symbol, 4), 0};
    }

    #[test]
    fn test_empirical_probability_of_token() {
        let sample = vec!['a', 'b', 'c', 'd', 'e', 'f'];
        let token = vec!['c', 'd'];

        assert_eq! {empirical_probability_of_token(&sample, &token, 2), 0.25};
    }

    #[test]
    fn test_empirical_probability_of_symbol() {
        let sample = vec!['a', 'b', 'c', 'd', 'e', 'f'];
        let symbol = 'c';

        assert_eq! {empirical_probability_of_symbol(&sample, &symbol, 2), 0.25};
    }

    #[test]
    fn test_empirical_probability_of_symbol_given_token() {
        let sample = vec!['a', 'b', 'c', 'd', 'e', 'f'];
        let token = vec!['c', 'd'];
        let symbol = 'e';

        assert_eq! {empirical_probability_of_symbol_given_token(&sample, &token, &symbol, 2), 1.0};
    }

    #[test]
    fn test_print_dot() {
        let sample = vec![
            "x", "p", "x", "p", "x", "p", "x", "p", "x", "p", "~", "~", "~", "~", "~", "x", "g",
            "x", "o", "g", "x", "o", "g", "o", "x", "o", "g", "o", "x", "o", "g", "~", "o", "~",
            "o", "o", "~", "~", "~", "~", "x", "p", "x", "p", "x", "p", "o", "x", "p", "o", "x",
            "o", "~", "x", "o", "o", "o", "o", "o", "~", "x", "~", "x", "~",
        ];
        let alphabet = vec!["g", "p", "o", "x", "~"];

        let pst = learn_with_alphabet(&sample, &alphabet, 3, 0.01, 40);
        let dotstring = to_dot(&pst);
        fs::write("testpst", dotstring).expect("Unable to write file");
    }
}