use std::cmp::min;
use std::collections::{HashMap, VecDeque};
use std::fmt::Write;
use std::hash::Hash;
pub fn chi<T: Eq>(sample: &[T], token: &[T], j: usize) -> u32 {
u32::from(&sample[(j - token.len())..j] == token)
}
pub fn chi_with_symbol<T: Eq>(sample: &[T], token: &[T], symbol: &T, j: usize) -> u32 {
u32::from(chi(sample, token, j - 1) == 1 && sample[j - 1] == *symbol)
}
pub fn empirical_probability_of_token<T: Eq>(sample: &[T], token: &[T], bound: usize) -> f32 {
let mut sum = 0;
for j in min(bound, token.len())..sample.len() {
sum += chi(sample, token, j);
}
sum as f32 / (sample.len() - bound) as f32
}
pub fn empirical_probability_of_symbol<T: Eq>(sample: &[T], symbol: &T, bound: usize) -> f32 {
let mut sum = 0;
for s in sample {
if s == symbol {
sum += 1;
}
}
sum as f32 / (sample.len() - bound) as f32
}
pub fn empirical_probability_of_symbol_given_token<T: Eq>(
sample: &[T],
token: &[T],
symbol: &T,
bound: usize,
) -> f32 {
match token.len() {
0 => empirical_probability_of_symbol(sample, symbol, bound),
_ => {
let mut p_a = 0;
let mut p_b = 0;
for j in bound..sample.len() {
p_a += chi_with_symbol(sample, token, symbol, j + 1);
p_b += chi(sample, token, j);
}
if p_a == 0 && p_b == 0 {
0.0
} else {
p_a as f32 / p_b as f32
}
}
}
}
#[derive(PartialEq, Clone)]
pub struct PstNode<T: Eq + Copy + Hash + std::fmt::Debug> {
pub label: Vec<T>,
pub child_probability: HashMap<T, f32>,
pub children: HashMap<T, PstNode<T>>,
}
impl<T: Eq + Copy + Hash + std::fmt::Debug> PstNode<T> {
pub fn with_empty_label() -> Self {
PstNode {
label: Vec::new(),
child_probability: HashMap::new(),
children: HashMap::new(),
}
}
fn with_label(label: &[T]) -> Self {
PstNode {
label: label.to_vec(),
child_probability: HashMap::new(),
children: HashMap::new(),
}
}
fn with_label_and_probs(label: &[T], probs: HashMap<T, f32>) -> Self {
PstNode {
label: label.to_vec(),
child_probability: probs,
children: HashMap::new(),
}
}
#[allow(clippy::or_fun_call)]
fn get_or_insert_child(
&mut self,
key: T,
label: &[T],
copy_gamma: bool,
) -> (bool, &mut PstNode<T>) {
let inserted = !self.children.contains_key(&key);
if copy_gamma {
(
inserted,
self.children
.entry(key)
.or_insert(PstNode::with_label_and_probs(
label,
self.child_probability.clone(),
)),
)
} else {
(
inserted,
self.children
.entry(key)
.or_insert(PstNode::with_label(label)),
)
}
}
}
pub fn add_leaf<T: Eq + Copy + Hash + std::fmt::Debug>(
node: &mut PstNode<T>,
label: &[T],
) -> Vec<Vec<T>> {
let mut added_nodes: Vec<Vec<T>> = Vec::new();
add_leaf_recursion(node, label, label.len() - 1, false, &mut added_nodes);
added_nodes
}
fn add_leaf_recursion<T: Eq + Copy + Hash + std::fmt::Debug>(
node: &mut PstNode<T>,
label: &[T],
label_idx: usize,
copy_gamma: bool,
added_nodes: &mut Vec<Vec<T>>,
) {
let path_node = node.get_or_insert_child(label[label_idx], &label[label_idx..], copy_gamma);
if path_node.0 {
added_nodes.push(label[label_idx..].to_vec());
}
if label_idx != 0 {
add_leaf_recursion(path_node.1, label, label_idx - 1, copy_gamma, added_nodes);
}
}
fn complete_inner_nodes<T: Eq + Copy + Hash + std::fmt::Debug>(
node: &mut PstNode<T>,
alphabet: &[T],
) {
for child in node.children.values_mut() {
complete_inner_nodes(child, alphabet);
}
if !node.children.is_empty() && node.children.len() != alphabet.len() {
for symbol in alphabet {
if !node.children.contains_key(symbol) {
let mut label: Vec<T> = vec![*symbol];
label.extend_from_slice(node.label.as_slice());
node.children.insert(*symbol, PstNode::with_label(&label));
}
}
}
}
fn complete_gamma<T: Eq + Copy + Hash + std::fmt::Debug>(
node: &mut PstNode<T>,
parent_label: Option<&[T]>,
alphabet: &[T],
sample: &[T],
gamma_min: f32,
bound: usize,
) {
match parent_label {
Some(p) => {
for symbol in alphabet {
node.child_probability.insert(
*symbol,
empirical_probability_of_symbol_given_token(sample, p, symbol, bound),
);
}
}
None => {
for symbol in alphabet {
let prob = gamma_min
+ (empirical_probability_of_symbol(sample, symbol, bound)
* (1.0 - (alphabet.len() as f32 * gamma_min)));
node.child_probability.insert(*symbol, prob);
}
}
};
for child in node.children.values_mut() {
complete_gamma(child, Some(&node.label), alphabet, sample, gamma_min, bound);
}
}
pub fn learn_with_alphabet<T: Eq + Copy + Hash + std::fmt::Debug>(
sample: &[T],
alphabet: &[T],
bound: usize,
epsilon: f32,
n: usize,
) -> PstNode<T> {
let epsilon2 = epsilon / (48.0 * bound as f32);
let gamma_min = epsilon2 / alphabet.len() as f32;
let epsilon0 = epsilon / (2.0 * n as f32 * bound as f32 * (1.0 / gamma_min).ln()); let epsilon1 = epsilon2 / (8.0 * n as f32 * epsilon0 * gamma_min);
let epsilon3 = epsilon0 * (1.0 - epsilon1);
let mut root = PstNode::with_empty_label();
let mut tokens: VecDeque<Vec<T>> = VecDeque::new();
for symbol in alphabet {
if empirical_probability_of_symbol(sample, symbol, bound) >= epsilon3 {
tokens.push_back(vec![*symbol]);
}
}
let sym_p_thresh = (1.0 + epsilon2) * gamma_min;
let sym_p_suf_thresh = 1.0 + (3.0 * epsilon2);
while let Some(token) = tokens.pop_front() {
for symbol in alphabet {
let sym_p = empirical_probability_of_symbol_given_token(sample, &token, symbol, bound);
let sym_p_suf =
empirical_probability_of_symbol_given_token(sample, &token[1..], symbol, bound);
if sym_p >= sym_p_thresh && (sym_p / sym_p_suf) > sym_p_suf_thresh && sym_p_suf > 0.0 {
add_leaf(&mut root, token.as_slice());
break;
} }
let token_thresh = f32::max(0.0_f32, (1.0 - epsilon1) * epsilon0);
if token.len() < bound {
for symbol in alphabet {
let mut potential_new_token = token.clone();
potential_new_token.push(*symbol);
let epr = empirical_probability_of_token(sample, &potential_new_token, bound);
if epr > token_thresh {
tokens.push_back(potential_new_token);
}
}
}
}
complete_inner_nodes(&mut root, alphabet);
complete_gamma(&mut root, None, alphabet, sample, gamma_min, bound);
root
}
#[allow(dead_code)]
pub fn has_star_property<T: Eq + Copy + Hash + std::fmt::Debug>(
_root: &mut PstNode<T>,
_alphabet: &[T],
) -> bool {
false
}
pub fn find_longest_suffix_state<'a, T: Eq + Copy + Hash + std::fmt::Debug>(
root: &'a PstNode<T>,
label: &[T],
) -> &'a PstNode<T> {
if label.is_empty() {
root
} else {
let last = label.last().unwrap();
if root.children.contains_key(last) {
find_longest_suffix_state(
root.children.get(last).unwrap(),
&label[..(label.len() - 1)],
)
} else {
root
}
}
}
pub fn find_longest_suffix_state_with_symbol<'a, T: Eq + Copy + Hash + std::fmt::Debug>(
root: &'a PstNode<T>,
label: &[T],
symbol: &T,
) -> &'a PstNode<T> {
if root.children.contains_key(symbol) {
find_longest_suffix_state(root.children.get(symbol).unwrap(), label)
} else {
root
}
}
fn collect_child_labels<T: Eq + Copy + Hash + std::fmt::Debug>(
root: &PstNode<T>,
labels: &mut Vec<Vec<T>>,
) {
labels.push(root.label.clone());
if !root.children.is_empty() {
for (_, v) in root.children.iter() {
collect_child_labels(v, labels);
}
}
}
pub fn get_child_labels<T: Eq + Copy + Hash + std::fmt::Debug>(node: &PstNode<T>) -> Vec<Vec<T>> {
let mut child_labels = Vec::new();
collect_child_labels(node, &mut child_labels);
child_labels
}
pub fn get_suffix_symbol_states<T: Eq + Copy + Hash + std::fmt::Debug>(
root: &PstNode<T>,
symbol: T,
) -> Vec<Vec<T>> {
let mut child_labels = Vec::new();
if root.children.contains_key(&symbol) {
collect_child_labels(&root.children[&symbol], &mut child_labels);
}
child_labels
}
fn get_states_containing_symbol_rec<T: Eq + Copy + Hash + std::fmt::Debug>(
root: &PstNode<T>,
symbol: T,
states: &mut Vec<Vec<T>>,
) {
if root.label.iter().any(|s| *s == symbol) {
states.push(root.label.clone());
}
if !root.children.is_empty() {
for (_, v) in root.children.iter() {
get_states_containing_symbol_rec(v, symbol, states);
}
}
}
pub fn get_states_containing_symbol<T: Eq + Copy + Hash + std::fmt::Debug>(
root: &PstNode<T>,
symbol: T,
) -> Vec<Vec<T>> {
let mut states = Vec::new();
get_states_containing_symbol_rec(root, symbol, &mut states);
states
}
fn to_dot_recursion<T: Eq + Copy + Hash + std::fmt::Debug>(
node: &PstNode<T>,
idx: &mut usize,
mut w: &mut dyn Write,
) {
let cur = *idx;
let mut label = "".to_string();
for c in &node.label {
write!(label, "{:?}", c).unwrap();
}
label.retain(|c| c != '\"');
label.push_str(", ");
for (sym, prob) in node.child_probability.iter() {
let mut symstring = format!("{:?}", sym);
symstring.retain(|c| c != '\"');
write!(label, "{} {}, ", symstring, prob).unwrap();
}
writeln!(&mut w, "{}[label=\"{}\"]", idx, label).unwrap();
for child in node.children.values() {
*idx += 1;
writeln!(
&mut w,
"{}->{}[weight=1.0, penwidth=1.0, rank=same, arrowsize=1.0]",
cur, idx
)
.unwrap();
to_dot_recursion(child, idx, w);
}
}
pub fn to_dot<T: Eq + Copy + Hash + std::fmt::Debug>(root: &PstNode<T>) -> String {
let mut w = String::new();
writeln!(&mut w, "digraph{{").unwrap();
let mut idx = 0;
to_dot_recursion(root, &mut idx, &mut w);
writeln!(&mut w, "}}").unwrap();
w
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
#[test]
fn test_chi_char() {
let sample = vec!['a', 'b', 'c', 'd', 'e', 'f'];
let token = vec!['c', 'd', 'e'];
assert_eq! {chi(&sample, &token, 5), 1};
assert_eq! {chi(&sample, &token, 4), 0};
}
#[test]
fn test_chi_with_symbol_char() {
let sample = vec!['a', 'b', 'c', 'd', 'e', 'f'];
let token = vec!['c', 'd'];
let symbol = 'e';
assert_eq! {chi_with_symbol(&sample, &token, &symbol, 5), 1};
assert_eq! {chi_with_symbol(&sample, &token, &symbol, 4), 0};
}
#[test]
fn test_empirical_probability_of_token() {
let sample = vec!['a', 'b', 'c', 'd', 'e', 'f'];
let token = vec!['c', 'd'];
assert_eq! {empirical_probability_of_token(&sample, &token, 2), 0.25};
}
#[test]
fn test_empirical_probability_of_symbol() {
let sample = vec!['a', 'b', 'c', 'd', 'e', 'f'];
let symbol = 'c';
assert_eq! {empirical_probability_of_symbol(&sample, &symbol, 2), 0.25};
}
#[test]
fn test_empirical_probability_of_symbol_given_token() {
let sample = vec!['a', 'b', 'c', 'd', 'e', 'f'];
let token = vec!['c', 'd'];
let symbol = 'e';
assert_eq! {empirical_probability_of_symbol_given_token(&sample, &token, &symbol, 2), 1.0};
}
#[test]
fn test_print_dot() {
let sample = vec![
"x", "p", "x", "p", "x", "p", "x", "p", "x", "p", "~", "~", "~", "~", "~", "x", "g",
"x", "o", "g", "x", "o", "g", "o", "x", "o", "g", "o", "x", "o", "g", "~", "o", "~",
"o", "o", "~", "~", "~", "~", "x", "p", "x", "p", "x", "p", "o", "x", "p", "o", "x",
"o", "~", "x", "o", "o", "o", "o", "o", "~", "x", "~", "x", "~",
];
let alphabet = vec!["g", "p", "o", "x", "~"];
let pst = learn_with_alphabet(&sample, &alphabet, 3, 0.01, 40);
let dotstring = to_dot(&pst);
fs::write("testpst", dotstring).expect("Unable to write file");
}
}