use crate::operations;
use crate::pst;
use rand::seq::SliceRandom;
use rand::Rng;
use std::cmp::Ordering;
use std::collections::hash_map::DefaultHasher;
use std::collections::{BTreeSet, HashMap};
use std::fmt::Write;
use std::hash::{Hash, Hasher};
pub type Label<T> = Vec<T>;
pub type LabelHash = u64;
#[derive(Clone)]
pub struct PfaQueryResult<T: Eq + Copy + Hash + std::fmt::Debug> {
pub last_state: Label<T>,
pub current_state: Label<T>,
pub last_symbol: T,
pub next_symbol: T,
}
pub struct PfaInsertionResult<T: Eq + Copy + Hash + std::fmt::Debug> {
pub source: Label<T>,
pub destination: Label<T>,
pub symbol: T,
pub prob: f32,
}
pub struct PfaRemovalResult<T: Eq + Copy + Hash + std::fmt::Debug> {
pub source: Label<T>,
pub destination: Label<T>,
pub prob: f32,
}
pub struct PfaOperationResult<T: Eq + Copy + Hash + std::fmt::Debug> {
pub added_transitions: Vec<PfaInsertionResult<T>>,
pub removed_transitions: Vec<PfaRemovalResult<T>>,
pub added_states: Vec<Label<T>>,
pub removed_states: Vec<Label<T>>,
pub added_symbol: Option<T>,
pub template_symbol: Option<T>,
}
#[derive(Clone, Debug)]
pub struct PfaChild<T: Eq + Copy + Hash + std::fmt::Debug> {
pub prob: f32,
pub child: Label<T>,
pub child_hash: LabelHash,
}
impl<T: Eq + Copy + Hash + std::fmt::Debug> PartialEq for PfaChild<T> {
fn eq(&self, other: &Self) -> bool {
self.child_hash == other.child_hash
}
}
impl<T: Eq + Copy + Hash + std::fmt::Debug> Eq for PfaChild<T> {}
impl<T: Eq + Copy + Hash + std::fmt::Debug> PartialOrd for PfaChild<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T: Eq + Copy + Hash + std::fmt::Debug> Ord for PfaChild<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.child_hash.cmp(&other.child_hash)
}
}
pub fn calculate_hash<T: Hash>(t: &T) -> u64 {
let mut s = DefaultHasher::new();
t.hash(&mut s);
s.finish()
}
#[derive(Debug, Clone)]
pub struct Rule<T: Eq + Copy + Hash + std::fmt::Debug + Ord> {
pub source: Label<T>,
pub symbol: T,
pub probability: f32,
}
#[derive(Clone)]
pub struct Pfa<T: Eq + Copy + Hash + std::fmt::Debug + Ord> {
pub pst_root: Option<pst::PstNode<T>>,
pub alphabet: Vec<T>,
pub current_state: Option<LabelHash>,
pub current_symbol: Option<T>,
pub labels: HashMap<LabelHash, Label<T>>,
pub children: HashMap<LabelHash, Vec<PfaChild<T>>>,
pub parents: HashMap<LabelHash, Vec<LabelHash>>,
pub history: Vec<T>,
pub state_history: Vec<LabelHash>,
pub history_length: usize,
}
impl<T: Eq + Copy + Hash + std::fmt::Debug + Ord> Default for Pfa<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Eq + Copy + Hash + std::fmt::Debug + Ord> PartialEq for Pfa<T> {
fn eq(&self, other: &Self) -> bool {
let label_keys1: BTreeSet<LabelHash> = self.labels.keys().cloned().collect();
let label_keys2: BTreeSet<LabelHash> = other.labels.keys().cloned().collect();
if label_keys1 .difference(&label_keys2)
.next()
.is_some()
{
return false;
}
let child_keys1: BTreeSet<LabelHash> = self.children.keys().cloned().collect();
let child_keys2: BTreeSet<LabelHash> = other.children.keys().cloned().collect();
if child_keys1 .difference(&child_keys2)
.next()
.is_some()
{
return false;
}
let par_keys1: BTreeSet<LabelHash> = self.parents.keys().cloned().collect();
let par_keys2: BTreeSet<LabelHash> = other.parents.keys().cloned().collect();
if par_keys1 .difference(&par_keys2)
.next()
.is_some()
{
return false;
}
for key in child_keys1.iter() {
let ch1 = &self.children[key];
let mut ch1set: BTreeSet<PfaChild<T>> = BTreeSet::new();
let mut ch2set: BTreeSet<PfaChild<T>> = BTreeSet::new();
for c in ch1.iter() {
ch1set.insert(c.clone());
}
let ch2 = &self.children[key];
for c in ch2.iter() {
ch2set.insert(c.clone());
}
if ch1set .difference(&ch2set)
.next()
.is_some()
{
return false;
}
}
for key in par_keys1.iter() {
let ch1 = &self.parents[key];
let mut ch1set: BTreeSet<LabelHash> = BTreeSet::new();
let mut ch2set: BTreeSet<LabelHash> = BTreeSet::new();
for c in ch1.iter() {
ch1set.insert(*c);
}
let ch2 = &self.parents[key];
for c in ch2.iter() {
ch2set.insert(*c);
}
if ch1set .difference(&ch2set)
.next()
.is_some()
{
return false;
}
}
true
}
}
impl<T: Eq + Copy + Hash + std::fmt::Debug + Ord> Pfa<T> {
pub fn new() -> Self {
Pfa {
pst_root: Some(pst::PstNode::with_empty_label()),
alphabet: Vec::new(),
current_symbol: None,
current_state: None,
labels: HashMap::new(),
children: HashMap::new(),
parents: HashMap::new(),
history: Vec::new(),
state_history: Vec::new(),
history_length: 9,
}
}
#[allow(dead_code)]
pub fn transfer_state(&mut self, other: &Pfa<T>) {
if let Some(state_label) = &other.current_state {
if self.labels.contains_key(state_label) {
self.current_state = Some(*state_label);
}
}
if let Some(symbol) = other.current_symbol {
if self.alphabet.iter().any(|&i| i == symbol) {
self.current_symbol = Some(symbol);
}
}
}
pub fn add_child(&mut self, src: &Label<T>, dest: &Label<T>, prob: f32) {
let src_hash = calculate_hash(src);
let child_hash = calculate_hash(dest);
if let Some(c) = self.children.get_mut(&src_hash) {
c.push(PfaChild {
prob,
child: dest.to_vec(),
child_hash,
});
c.sort(); c.dedup();
}
}
pub fn add_parent(&mut self, dest: &Label<T>, src: &Label<T>) {
let src_hash = calculate_hash(src);
let child_hash = calculate_hash(dest);
if let Some(p) = self.parents.get_mut(&child_hash) {
p.push(src_hash);
p.sort_unstable(); p.dedup();
}
}
#[allow(dead_code)]
pub fn check_consistency(&self) -> bool {
let label_keys: BTreeSet<LabelHash> = self.labels.keys().cloned().collect();
let parent_keys: BTreeSet<LabelHash> = self.parents.keys().cloned().collect();
let child_keys: BTreeSet<LabelHash> = self.children.keys().cloned().collect();
if label_keys .difference(&parent_keys)
.next()
.is_some()
{
println!("INCONSISTENCY - label and parents");
return false;
}
if parent_keys .difference(&label_keys)
.next()
.is_some()
{
println!("INCONSISTENCY - parents and label");
return false;
}
if label_keys .difference(&child_keys)
.next()
.is_some()
{
println!("INCONSISTENCY - label and children");
return false;
}
if child_keys .difference(&label_keys)
.next()
.is_some()
{
println!("INCONSISTENCY - children and label");
return false;
}
if parent_keys .difference(&child_keys)
.next()
.is_some()
{
println!("INCONSISTENCY - parent and children");
return false;
}
if child_keys .difference(&parent_keys)
.next()
.is_some()
{
println!("INCONSISTENCY - children and parent");
return false;
}
for (_, chs) in self.children.iter() {
for ch in chs.iter() {
if !label_keys.contains(&ch.child_hash) {
println!(
"INCONSISTENCY - child {:?} {} doesn't exist",
ch.child, ch.child_hash
);
return false;
}
}
}
for (_, pars) in self.parents.iter() {
for par in pars.iter() {
if !label_keys.contains(par) {
println!("INCONSISTENCY - parent {} doesn't exist", par);
return false;
}
}
}
true
}
pub fn add_state(&mut self, label: &Label<T>) {
self.alphabet.extend_from_slice(label);
self.alphabet.sort();
self.alphabet.dedup();
let label_hash = calculate_hash(label);
self.children.insert(label_hash, Vec::new());
self.parents.insert(label_hash, Vec::new());
self.labels.insert(label_hash, label.to_vec());
if self.current_state.is_none() {
self.current_state = Some(label_hash);
if !label.is_empty() {
self.current_symbol = Some(*label.first().unwrap());
}
}
}
#[allow(dead_code)]
pub fn state_orphaned(&self, label: &Label<T>) -> bool {
let label_hash = calculate_hash(label);
self.state_orphaned_hash(label_hash)
}
#[allow(dead_code)]
pub fn state_orphaned_hash(&self, label_hash: LabelHash) -> bool {
if let Some(par) = self.parents.get(&label_hash) {
par.is_empty()
} else {
true
}
}
pub fn has_state(&self, label: &Label<T>) -> bool {
let label_hash = calculate_hash(label);
self.labels.contains_key(&label_hash)
}
pub fn has_state_hash(&self, label_hash: LabelHash) -> bool {
self.labels.contains_key(&label_hash)
}
pub fn has_transition(&self, src: &Label<T>, dest: &Label<T>) -> bool {
let src_hash = calculate_hash(src);
let dest_hash = calculate_hash(dest);
self.has_transition_hash(src_hash, dest_hash)
}
fn has_transition_hash(&self, src_hash: LabelHash, dest_hash: LabelHash) -> bool {
if self.labels.contains_key(&src_hash) && self.labels.contains_key(&dest_hash) {
for ch in self.children[&src_hash].iter() {
if ch.child_hash == dest_hash {
return true;
}
}
}
false
}
pub fn get_emission(&self, src: &Label<T>, sym: T) -> Option<(Label<T>, Label<T>, f32)> {
let src_hash = calculate_hash(src);
self.get_emission_hash(src_hash, sym)
}
fn get_emission_hash(&self, src_hash: LabelHash, sym: T) -> Option<(Label<T>, Label<T>, f32)> {
if let Some(children) = self.children.get(&src_hash) {
for ch in children.iter() {
if let Some(s) = ch.child.last() {
if *s == sym {
return Some((self.labels[&src_hash].clone(), ch.child.clone(), ch.prob));
}
}
}
}
None
}
pub fn rebalance_state(&mut self, state: &Label<T>) {
let state_hash = calculate_hash(state);
self.rebalance_state_hash(&state_hash);
}
fn rebalance_state_hash(&mut self, state_hash: &LabelHash) {
let mut probs = Vec::new();
for ch in self.children[state_hash].iter() {
probs.push(ch.prob);
}
let probs_rebalanced = operations::rebalance_float(probs, 1.0, 0.35);
if let Some(chs) = self.children.get_mut(state_hash) {
for i in 0..probs_rebalanced.len() {
chs[i].prob = probs_rebalanced[i];
}
}
}
fn free_probability_state_hash(&mut self, state_hash: &LabelHash, to_free: f32) {
let mut probs = Vec::new();
for ch in self.children[state_hash].iter() {
probs.push(ch.prob);
}
let probs_freed = operations::free_probability_float(probs, to_free);
if let Some(chs) = self.children.get_mut(state_hash) {
for i in 0..probs_freed.len() {
chs[i].prob = probs_freed[i];
}
}
}
#[allow(dead_code)]
pub fn free_probability_state(&mut self, state: &Label<T>, to_free: f32) {
let state_hash = calculate_hash(state);
self.free_probability_state_hash(&state_hash, to_free);
}
fn blur_state_hash(&mut self, state_hash: &LabelHash, blur: f32) {
let mut probs = Vec::new();
for ch in self.children[state_hash].iter() {
probs.push(ch.prob);
}
let probs_blurred = operations::blur_float(probs, blur);
if let Some(chs) = self.children.get_mut(state_hash) {
for i in 0..probs_blurred.len() {
chs[i].prob = probs_blurred[i];
}
}
}
fn sharpen_state_hash(&mut self, state_hash: &LabelHash, sharpen: f32) {
let mut probs = Vec::new();
for ch in self.children[state_hash].iter() {
probs.push(ch.prob);
}
let probs_sharpened = operations::sharpen_float(probs, sharpen);
if let Some(chs) = self.children.get_mut(state_hash) {
for i in 0..probs_sharpened.len() {
chs[i].prob = probs_sharpened[i];
}
}
}
#[allow(dead_code)]
pub fn blur(&mut self, blurriness: f32) {
let keys: Vec<u64> = self.labels.keys().cloned().collect();
for hash in keys.iter() {
self.blur_state_hash(hash, blurriness);
}
}
#[allow(dead_code)]
pub fn sharpen(&mut self, sharpness: f32) {
let keys: Vec<u64> = self.labels.keys().cloned().collect();
for hash in keys.iter() {
self.sharpen_state_hash(hash, sharpness);
}
}
#[allow(dead_code)]
pub fn rebalance(&mut self) {
let keys: Vec<u64> = self.labels.keys().cloned().collect();
for hash in keys.iter() {
if !self.children[hash].is_empty() {
self.rebalance_state_hash(hash);
}
}
}
pub fn add_state_transition(
&mut self,
src: &Label<T>,
dest: &Label<T>,
prob: f32,
rebalance: bool,
) -> PfaInsertionResult<T> {
self.add_child(src, dest, prob);
self.add_parent(dest, src);
if rebalance {
self.rebalance_state(src);
}
PfaInsertionResult {
source: src.clone(),
destination: dest.clone(),
symbol: *dest.last().unwrap(),
prob,
}
}
#[allow(dead_code)]
pub fn add_symbol_transition(
&mut self,
suffix: T,
dest: &Label<T>,
prob: f32,
rebalance: bool,
) -> Vec<PfaInsertionResult<T>> {
let mut insertions = Vec::new();
if self.has_state(dest) {
let suffix_states =
pst::get_suffix_symbol_states(self.pst_root.as_ref().unwrap(), suffix);
for src in suffix_states.iter() {
if self.has_state(src) {
insertions.push(self.add_state_transition(src, dest, prob, rebalance));
}
}
}
insertions
}
fn remove_state_transition_hash(
&mut self,
src_hash: LabelHash,
dest_hash: LabelHash,
rebalance: bool,
) -> PfaRemovalResult<T> {
let mut idx = 0;
let mut found = false;
let mut prob = 0.0;
for ch in self.children[&src_hash].iter() {
if ch.child_hash == dest_hash {
found = true;
prob = ch.prob;
break;
} else {
idx += 1;
}
}
if found {
self.children.get_mut(&src_hash).unwrap().remove(idx);
}
found = false;
idx = 0;
for par in self.parents[&dest_hash].iter() {
if *par == src_hash {
found = true;
break;
} else {
idx += 1;
}
}
if found {
self.parents.get_mut(&dest_hash).unwrap().remove(idx);
}
if rebalance {
self.rebalance_state_hash(&src_hash);
}
PfaRemovalResult {
source: self.labels[&src_hash].clone(),
destination: self.labels[&dest_hash].clone(),
prob,
}
}
pub fn remove_state_transition(
&mut self,
src: &Label<T>,
dest: &Label<T>,
rebalance: bool,
) -> PfaRemovalResult<T> {
let src_hash = calculate_hash(src);
let dest_hash = calculate_hash(dest);
self.remove_state_transition_hash(src_hash, dest_hash, rebalance)
}
fn modify_transition_probability(
&mut self,
src_hash: LabelHash,
dest_hash: LabelHash,
prob_mod: f32,
) {
if let Some(children) = self.children.get_mut(&src_hash) {
for ch in children.iter_mut() {
if ch.child_hash == dest_hash {
ch.prob += prob_mod;
}
}
}
}
#[allow(dead_code)]
pub fn remove_symbol_transition(
&mut self,
suffix: T,
dest: T,
rebalance: bool,
) -> Vec<PfaRemovalResult<T>> {
let mut removals = Vec::new();
let suffix_states = pst::get_suffix_symbol_states(self.pst_root.as_ref().unwrap(), suffix);
let dest_states = pst::get_suffix_symbol_states(self.pst_root.as_ref().unwrap(), dest);
for src in suffix_states.iter() {
for dest in dest_states.iter() {
if self.has_transition(src, dest) {
removals.push(self.remove_state_transition(src, dest, rebalance));
}
}
}
removals
}
fn remove_outgoing_transitions(&mut self, state: &Label<T>) -> Vec<PfaRemovalResult<T>> {
let hash = calculate_hash(state);
let mut rem_hash = Vec::new();
let mut removals = Vec::new();
for ch in self.children[&hash].iter() {
rem_hash.push(ch.child_hash);
}
for r in rem_hash.iter() {
removals.push(self.remove_state_transition_hash(hash, *r, false));
}
removals
}
fn remove_incoming_transitions(&mut self, state: &Label<T>) -> Vec<PfaRemovalResult<T>> {
let hash = calculate_hash(state);
let mut rem_hash = Vec::new();
let mut removals = Vec::new();
for par in self.parents[&hash].iter() {
rem_hash.push(*par);
}
for r in rem_hash.iter() {
removals.push(self.remove_state_transition_hash(*r, hash, false));
}
removals
}
#[allow(dead_code)]
fn purge_state(&mut self, state: &Label<T>) -> PfaOperationResult<T> {
let mut removed_transitions = Vec::new();
removed_transitions.append(&mut self.remove_incoming_transitions(state));
removed_transitions.append(&mut self.remove_outgoing_transitions(state));
let hash = calculate_hash(state);
self.state_history.retain(|x| *x != hash);
if let Some(cur) = self.current_state {
if cur == hash {
self.reset_current_state();
}
}
self.labels.remove(&hash);
self.parents.remove(&hash);
self.children.remove(&hash);
PfaOperationResult {
added_states: Vec::new(),
removed_states: Vec::new(),
added_transitions: Vec::new(),
removed_transitions,
template_symbol: None,
added_symbol: None,
}
}
pub fn remove_orphaned_states(&mut self) {
let mut orphans = Vec::new();
for (k, v) in self.labels.iter() {
if self.state_orphaned_hash(*k) {
orphans.push(v.clone());
}
}
for o in orphans {
self.purge_state(&o);
}
}
pub fn rebuild_pst(&mut self) {
let mut new_root = pst::PstNode::<T>::with_empty_label();
for (_, label) in self.labels.iter() {
pst::add_leaf(&mut new_root, label);
}
self.pst_root = Some(new_root);
}
pub fn remove_symbol(&mut self, symbol: T, rebalance: bool) -> PfaOperationResult<T> {
let mut states_to_remove =
pst::get_states_containing_symbol(self.pst_root.as_ref().unwrap(), symbol);
states_to_remove.retain(|x| self.has_state(x));
self.alphabet.retain(|x| *x != symbol);
self.history.retain(|x| *x != symbol);
let mut insertions = Vec::new();
let mut removals = Vec::new();
for state in states_to_remove.iter() {
let state_hash = calculate_hash(state);
self.state_history.retain(|x| *x != state_hash);
let mut removals_out = self.remove_outgoing_transitions(state);
let mut removals_in = self.remove_incoming_transitions(state);
for r_in in removals_in.iter() {
for r_out in removals_out.iter() {
let r_in_src_hash = calculate_hash(&r_in.source);
let r_out_dest_hash = calculate_hash(&r_out.destination);
if r_in_src_hash != state_hash && r_out_dest_hash != state_hash {
if !self.has_transition_hash(r_in_src_hash, r_out_dest_hash) {
insertions.push(self.add_state_transition(
&r_in.source,
&r_out.destination,
r_in.prob,
rebalance,
));
} else {
self.modify_transition_probability(
r_in_src_hash,
r_out_dest_hash,
r_in.prob,
);
}
}
}
}
if let Some(cur) = self.current_state {
if cur == state_hash {
self.reset_current_state();
}
}
self.children.remove(&state_hash);
self.parents.remove(&state_hash);
self.labels.remove(&state_hash);
removals.append(&mut removals_in);
removals.append(&mut removals_out);
}
self.rebuild_pst();
for rem in removals.iter() {
insertions.retain(|x| x.source != rem.source || x.destination != rem.destination);
}
PfaOperationResult {
added_states: Vec::new(),
removed_states: states_to_remove,
added_transitions: insertions,
removed_transitions: removals,
template_symbol: None,
added_symbol: None,
}
}
pub fn reset_current_state(&mut self) {
if let Some(s) = self.state_history.last() {
let stitch_state = self.labels[s].clone();
self.current_state = Some(*s);
self.current_symbol = Some(*stitch_state.last().unwrap()); println!(
"reset cur state (from state history) because removal {:?}",
stitch_state
);
return;
}
let mut stitch_state = Label::new();
while let Some(s) = self.history.iter().rev().next() {
stitch_state.insert(0, *s);
if self.has_state(&stitch_state) {
self.current_state = Some(calculate_hash(&stitch_state));
self.current_symbol = Some(*s);
println!(
"reset cur state (from symbol history) because removal {:?}",
stitch_state
);
return;
}
}
for s in self.alphabet.iter() {
stitch_state.clear();
stitch_state.push(*s);
if self.has_state(&stitch_state) {
self.current_state = Some(calculate_hash(&stitch_state));
self.current_symbol = Some(*s);
println!(
"reset cur state (from alphabet) because removal {:?}",
stitch_state
);
return;
}
}
println!("can't find valid state in this pfa ...");
self.current_state = None;
self.current_symbol = None;
}
#[allow(dead_code)]
pub fn pad_history(&mut self) {
if self.history.len() < self.history_length {
let new_elem = match self.history.len() {
0 => self.alphabet[0],
_ => *self.history.last().unwrap(),
};
let cur_len = self.history.len();
for _ in cur_len..self.history_length {
self.history.push(new_elem);
}
}
}
#[allow(dead_code)]
pub fn sim_steps(&mut self, steps: usize) {
let state_backup = self.current_state;
let symbol_backup = self.current_symbol;
for _ in 0..steps {
self.next_transition();
}
self.current_state = state_backup;
self.current_symbol = symbol_backup;
}
pub fn next_transition(&mut self) -> Option<PfaQueryResult<T>> {
let mut choice_list = Vec::<LabelHash>::new();
if let Some(cur) = &self.current_state {
self.state_history.push(*cur);
for c in &self.children[cur] {
let prob = (100.0 * c.prob) as i32;
for _ in 0..prob {
if self.has_state_hash(c.child_hash) {
choice_list.push(c.child_hash);
} else {
panic!(
"WARNING - found non-existing state {:?} -> {:?} {}",
self.labels[cur], c.child, c.child_hash
);
}
}
}
}
if let Some(sym) = self.current_symbol {
self.history.push(sym);
}
if let (Some(cur_state), Some(res)) = (
self.current_state,
choice_list.choose(&mut rand::thread_rng()),
) {
self.current_state = Some(*res);
if let Some(sym) = self.labels[res].last() {
let last_symbol = self.current_symbol.unwrap();
self.current_symbol = Some(*sym);
if self.history.len() > self.history_length {
self.history.drain(0..1);
}
if self.state_history.len() > self.history_length {
self.state_history.drain(0..1);
}
Some(PfaQueryResult {
last_state: self.labels[&cur_state].clone(),
current_state: self.labels[res].clone(),
last_symbol,
next_symbol: *sym,
})
} else {
None
}
} else {
None
}
}
#[allow(dead_code)]
pub fn next_symbol(&mut self) -> Option<T> {
if let Some(t) = self.next_transition() {
Some(t.last_symbol)
} else {
None
}
}
fn copy_states_from_pst(&mut self, root: &pst::PstNode<T>) {
self.add_state(&root.label);
for child in root.children.values() {
self.copy_states_from_pst(child);
}
}
fn from_pst_nostar(root: &pst::PstNode<T>, alphabet: &[T]) -> Self {
let mut new_pfa = Pfa {
pst_root: Some(pst::PstNode::with_empty_label()),
alphabet: Vec::new(),
current_symbol: None,
current_state: None,
labels: HashMap::new(),
children: HashMap::new(),
parents: HashMap::new(),
history: Vec::new(),
state_history: Vec::new(),
history_length: 9,
};
new_pfa.copy_states_from_pst(root);
let labels = new_pfa.labels.clone();
for (_, label) in labels.iter() {
let pst_state = pst::find_longest_suffix_state(root, label);
for symbol in alphabet.iter() {
let longest_suffix_state =
pst::find_longest_suffix_state_with_symbol(root, label, symbol);
let tprob = *pst_state.child_probability.get(symbol).unwrap();
if tprob > 0.0 {
new_pfa.add_state_transition(label, &longest_suffix_state.label, tprob, false);
}
}
}
new_pfa
}
fn learn_with_alphabet(
sample: &[T],
alphabet: &[T],
bound: usize,
epsilon: f32,
n: usize,
) -> Self {
let pst_root = crate::pst::learn_with_alphabet(sample, alphabet, bound, epsilon, n);
let mut pfa = Pfa::from_pst_nostar(&pst_root, alphabet);
if let Some(s) = sample.first() {
pfa.current_symbol = Some(*s);
}
pfa
}
pub fn learn(sample: &[T], bound: usize, epsilon: f32, n: usize) -> Self {
let mut alphabet = sample.to_vec();
alphabet.sort();
alphabet.dedup();
Pfa::learn_with_alphabet(sample, &alphabet, bound, epsilon, n)
}
pub fn add_rule(&mut self, rule: &Rule<T>) {
let mut tmp_self = self.clone();
let mut prefix = Label::<T>::new();
let mut last_prefix = Label::<T>::new();
let mut suffix = rule.source.clone();
for sym in suffix.drain(..) {
prefix.push(sym);
if !tmp_self.has_state(&prefix) {
tmp_self.add_state(&prefix);
}
if !last_prefix.is_empty() {
let longest = &pst::find_longest_suffix_state(
tmp_self.pst_root.as_ref().unwrap(),
&last_prefix,
);
let longest_suf_states = pst::get_child_labels(longest);
if let Some(transition) = tmp_self.get_emission(&last_prefix, sym) {
for l in longest_suf_states.iter() {
if tmp_self.has_state(l) {
tmp_self.remove_state_transition(l, &transition.1, false);
tmp_self.add_state_transition(l, &prefix, transition.2, false);
}
}
if prefix != rule.source {
let longest2 = &pst::find_longest_suffix_state(
tmp_self.pst_root.as_ref().unwrap(),
&prefix,
)
.label
.clone();
for asym in tmp_self.alphabet.clone().iter() {
if let Some(transition2) = tmp_self.get_emission(longest2, *asym) {
tmp_self.add_state_transition(
&prefix,
&transition2.1,
transition2.2,
false,
);
}
}
}
} else {
return; }
}
last_prefix.push(sym);
if let Some(root) = tmp_self.pst_root.as_mut() {
pst::add_leaf(root, &prefix);
}
}
prefix.push(rule.symbol);
let mut longest =
pst::find_longest_suffix_state(tmp_self.pst_root.as_ref().unwrap(), &prefix)
.label
.clone();
if !tmp_self.has_state(&longest) {
longest.clear();
}
if longest.is_empty() {
longest.push(rule.symbol);
if !tmp_self.has_state(&longest) {
tmp_self.add_state(&longest);
if let Some(root) = tmp_self.pst_root.as_mut() {
pst::add_leaf(root, &longest);
}
}
}
let suf_node =
pst::find_longest_suffix_state(tmp_self.pst_root.as_ref().unwrap(), &rule.source);
let suf_states = pst::get_child_labels(suf_node);
for state in suf_states.iter() {
tmp_self.add_state_transition(state, &longest, rule.probability, false);
}
*self = tmp_self;
}
#[allow(dead_code)]
pub fn randomize_edges(&mut self, chance: f32, prob: f32) {
let mut new_edges = Vec::new();
let chance_int = (100.0 * chance) as i32;
let mut rng = rand::thread_rng();
for (k1, v1) in self.labels.iter() {
for (k2, v2) in self.labels.iter() {
let c: i32 = rng.gen::<i32>() % 100;
if c < chance_int && !self.has_transition_hash(*k1, *k2) {
new_edges.push((v1.clone(), v2.clone()));
}
}
}
for e in new_edges.iter() {
self.add_state_transition(&e.0, &e.1, prob, false);
}
}
#[allow(dead_code)]
pub fn repeat(&mut self, chance: f32, max_rep: usize) {
for sym in self.alphabet.clone().iter() {
self.repeat_symbol(*sym, chance, max_rep);
}
}
#[allow(dead_code)]
pub fn solidify(&mut self, ctx_len: usize) {
if self.history.len() >= (ctx_len + 1) {
let src =
self.history[self.history.len() - (ctx_len + 1)..self.history.len() - 1].to_vec();
self.add_rule(&Rule {
source: src,
symbol: *self.history.last().unwrap(),
probability: 1.0,
});
self.remove_orphaned_states();
}
}
#[allow(dead_code)]
pub fn rewind(&mut self, states: usize) {
if self.state_history.len() >= states {
self.current_state = Some(
*self
.state_history
.get(self.state_history.len() - states)
.unwrap(),
);
}
}
#[allow(dead_code)]
pub fn repeat_symbol(&mut self, sym: T, chance: f32, max_rep: usize) {
let mut states_to_add = Vec::new();
let mut transitions_to_add = Vec::new();
let suffix_states: Vec<Label<T>> =
pst::get_suffix_symbol_states(self.pst_root.as_ref().unwrap(), sym);
for state in suffix_states.iter() {
let state_hash = calculate_hash(state);
if !self.children.contains_key(&state_hash) {
continue;
}
self.free_probability_state_hash(&state_hash, chance);
let mut lab = state.clone();
let mut last_lab = state.clone();
for i in 0..(max_rep - 1) {
lab.push(sym);
states_to_add.push(lab.clone());
transitions_to_add.push((last_lab.clone(), lab.clone(), chance));
if i < max_rep - 2 {
for ch in self.children[&state_hash].iter() {
transitions_to_add.push((lab.clone(), ch.child.clone(), 1.0 - chance));
}
} else {
for ch in self.children[&state_hash].iter() {
transitions_to_add.push((lab.clone(), ch.child.clone(), 1.0));
}
}
last_lab = lab.clone();
}
}
for state in states_to_add.iter() {
self.add_state(state);
}
for trans in transitions_to_add.iter() {
self.add_state_transition(&trans.0, &trans.1, trans.2, false);
}
}
pub fn infer_from_rules(rules: &mut [Rule<T>], remove_orphans: bool) -> Self {
let mut pfa = Pfa::new();
rules.sort_by(|a, b| a.source.len().partial_cmp(&b.source.len()).unwrap());
for rule in rules.iter() {
pfa.add_rule(rule);
}
if remove_orphans {
pfa.remove_orphaned_states();
}
pfa
}
#[allow(dead_code)]
pub fn get_state_history_string(&self) -> String {
let mut readable_history = Vec::new();
for hash in self.state_history.iter() {
let mut label = if let Some(l) = self.labels.get(hash) {
format!("{:?}", l)
} else {
"UNKNOWN".to_string()
};
label.retain(|c| {
c != '\"'
&& c != '\''
&& c != '['
&& c != ']'
&& c != '{'
&& c != '}'
&& c != ','
&& c != ' '
&& c != '\\'
});
readable_history.push(label);
}
format!("{:?}", readable_history)
}
#[allow(dead_code)]
pub fn get_symbol_history_string(&self) -> String {
let mut readable_history = Vec::new();
for sym in self.history.iter() {
let mut label = format!("{:?}", sym);
label.retain(|c| {
c != '\"'
&& c != '\''
&& c != '['
&& c != ']'
&& c != '{'
&& c != '}'
&& c != ','
&& c != ' '
&& c != '\\'
});
readable_history.push(label);
}
format!("{:?}", readable_history)
}
}
pub fn to_dot<T: Eq + Copy + Hash + Ord + std::fmt::Debug>(pfa: &Pfa<T>) -> String {
let mut w = String::new();
writeln!(&mut w, "digraph{{").unwrap();
for (k, v) in pfa.labels.iter() {
let mut lab = format!("{:?}", v);
lab.retain(|c| {
c != '\"'
&& c != '\''
&& c != '['
&& c != ']'
&& c != '{'
&& c != '}'
&& c != ','
&& c != ' '
&& c != '\\'
});
writeln!(&mut w, "{}[label=\"{}\"]", k, lab).unwrap();
}
for (k, v) in pfa.children.iter() {
for ch in v.iter() {
writeln!(
&mut w,
"{}->{}[label=\"{}\" weight=\"{}\", penwidth=\"{}\", rank=same, arrowsize=1.0]",
k, ch.child_hash, ch.prob, ch.prob, ch.prob
)
.unwrap();
}
}
writeln!(&mut w, "}}").unwrap();
w
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
#[test]
fn test_remove_symbol_transition() {
let mut rules = Vec::new();
rules.push(Rule {
source: "abab".chars().collect(),
symbol: 'd',
probability: 1.0,
});
rules.push(Rule {
source: "aaaaa".chars().collect(),
symbol: 'd',
probability: 1.0,
});
rules.push(Rule {
source: "a".chars().collect(),
symbol: 'a',
probability: 0.6,
});
rules.push(Rule {
source: "a".chars().collect(),
symbol: 'b',
probability: 0.4,
});
rules.push(Rule {
source: "b".chars().collect(),
symbol: 'c',
probability: 0.5,
});
rules.push(Rule {
source: "b".chars().collect(),
symbol: 'a',
probability: 0.5,
});
rules.push(Rule {
source: "c".chars().collect(),
symbol: 'd',
probability: 1.0,
});
rules.push(Rule {
source: "d".chars().collect(),
symbol: 'a',
probability: 1.0,
});
let mut pfa = Pfa::<char>::infer_from_rules(&mut rules, true);
let dot_string_before = to_dot::<char>(&pfa);
let dot_string_before_pst = pst::to_dot::<char>(&pfa.pst_root.clone().unwrap());
fs::write("before_removal.dot", dot_string_before).expect("Unable to write file");
fs::write("before_removal_pst.dot", dot_string_before_pst).expect("Unable to write file");
let longest =
pst::find_longest_suffix_state(&pfa.pst_root.clone().unwrap(), &['a', 'b']).clone();
println!("longest {:?}", longest.label);
pfa.add_rule(&Rule {
source: "ab".chars().collect(),
symbol: 'e',
probability: 0.4,
});
pfa.add_rule(&Rule {
source: "e".chars().collect(),
symbol: 'a',
probability: 1.0,
});
pfa.rebalance();
let dot_string_intermediate = to_dot::<char>(&pfa);
fs::write("intermediate.dot", dot_string_intermediate).expect("Unable to write file");
pfa.remove_symbol_transition('a', 'b', false);
let dot_string_after = to_dot::<char>(&pfa);
fs::write("after_removal.dot", dot_string_after).expect("Unable to write file");
}
#[test]
fn test_rule_addition_order_equivalence() {
let mut rules = Vec::new();
rules.push(Rule {
source: "a".chars().collect(),
symbol: 'a',
probability: 0.1,
});
rules.push(Rule {
source: "a".chars().collect(),
symbol: 'b',
probability: 0.9,
});
rules.push(Rule {
source: "b".chars().collect(),
symbol: 'a',
probability: 0.8,
});
rules.push(Rule {
source: "b".chars().collect(),
symbol: 'c',
probability: 0.2,
});
rules.push(Rule {
source: "c".chars().collect(),
symbol: 'd',
probability: 1.0,
});
rules.push(Rule {
source: "d".chars().collect(),
symbol: 'a',
probability: 1.0,
});
let mut rules1 = rules.clone();
rules1.push(Rule {
source: "baba".chars().collect(),
symbol: 'b',
probability: 1.0,
});
rules1.push(Rule {
source: "bcda".chars().collect(),
symbol: 'b',
probability: 1.0,
});
let mut rules2 = rules.clone();
rules2.push(Rule {
source: "bcda".chars().collect(),
symbol: 'b',
probability: 1.0,
});
rules2.push(Rule {
source: "baba".chars().collect(),
symbol: 'b',
probability: 1.0,
});
let pfa1 = Pfa::<char>::infer_from_rules(&mut rules1, true);
let pfa2 = Pfa::<char>::infer_from_rules(&mut rules2, true);
assert!(pfa1 == pfa2);
}
}