vom_rs 0.3.0

A library for Probabilistic Finite Automata.
Documentation
use crate::operations;
use crate::pst;
use rand::seq::SliceRandom;
use rand::Rng;
use std::cmp::Ordering;
use std::collections::hash_map::DefaultHasher;
use std::collections::{BTreeSet, HashMap};
use std::fmt::Write;
use std::hash::{Hash, Hasher};

pub type Label<T> = Vec<T>;
pub type LabelHash = u64;

#[derive(Clone)]
/// Represents the result of a generic query.
pub struct PfaQueryResult<T: Eq + Copy + Hash + std::fmt::Debug> {
    pub last_state: Label<T>,
    pub current_state: Label<T>,
    pub last_symbol: T,
    pub next_symbol: T,
}

/// Represents the results of operations that insert information.
pub struct PfaInsertionResult<T: Eq + Copy + Hash + std::fmt::Debug> {
    pub source: Label<T>,
    pub destination: Label<T>,
    pub symbol: T,
    pub prob: f32,
}

/// Represents the results of operations that remove information.
pub struct PfaRemovalResult<T: Eq + Copy + Hash + std::fmt::Debug> {
    pub source: Label<T>,
    pub destination: Label<T>,
    pub prob: f32,
}

/// Represents the results of operations that modify the PFA.
pub struct PfaOperationResult<T: Eq + Copy + Hash + std::fmt::Debug> {
    pub added_transitions: Vec<PfaInsertionResult<T>>,
    pub removed_transitions: Vec<PfaRemovalResult<T>>,
    pub added_states: Vec<Label<T>>,
    pub removed_states: Vec<Label<T>>,
    pub added_symbol: Option<T>,
    pub template_symbol: Option<T>,
}

#[derive(Clone, Debug)]
/// A child node within the PFA structure.
pub struct PfaChild<T: Eq + Copy + Hash + std::fmt::Debug> {
    pub prob: f32,
    pub child: Label<T>,
    pub child_hash: LabelHash,
}

impl<T: Eq + Copy + Hash + std::fmt::Debug> PartialEq for PfaChild<T> {
    fn eq(&self, other: &Self) -> bool {
        // check if all elements are the same
        // println!("{:?} ({:?}) {:?} ({:?}) {}", self.child, self.child_hash, other.child, other.child_hash, self.child_hash == other.child_hash);
        self.child_hash == other.child_hash
    }
}

impl<T: Eq + Copy + Hash + std::fmt::Debug> Eq for PfaChild<T> {}

impl<T: Eq + Copy + Hash + std::fmt::Debug> PartialOrd for PfaChild<T> {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

impl<T: Eq + Copy + Hash + std::fmt::Debug> Ord for PfaChild<T> {
    fn cmp(&self, other: &Self) -> Ordering {
        // check if all elements are the same
        self.child_hash.cmp(&other.child_hash)
    }
}

/// Just a hasher for the labels.
pub fn calculate_hash<T: Hash>(t: &T) -> u64 {
    let mut s = DefaultHasher::new();
    t.hash(&mut s);
    s.finish()
}

/// An explicit rule a PFA can be inferred from.
#[derive(Debug, Clone)]
pub struct Rule<T: Eq + Copy + Hash + std::fmt::Debug + Ord> {
    pub source: Label<T>,
    pub symbol: T,
    pub probability: f32,
}

/// The main PFA data structure.
#[derive(Clone)]
pub struct Pfa<T: Eq + Copy + Hash + std::fmt::Debug + Ord> {
    pub pst_root: Option<pst::PstNode<T>>,
    pub alphabet: Vec<T>,
    pub current_state: Option<LabelHash>,
    pub current_symbol: Option<T>,
    pub labels: HashMap<LabelHash, Label<T>>,
    pub children: HashMap<LabelHash, Vec<PfaChild<T>>>,
    pub parents: HashMap<LabelHash, Vec<LabelHash>>,
    pub history: Vec<T>,
    /// history of emitted symbols ...
    pub state_history: Vec<LabelHash>,
    /// history of emitted symbols ...
    pub history_length: usize,
}

impl<T: Eq + Copy + Hash + std::fmt::Debug + Ord> Default for Pfa<T> {
    fn default() -> Self {
        Self::new()
    }
}

impl<T: Eq + Copy + Hash + std::fmt::Debug + Ord> PartialEq for Pfa<T> {
    fn eq(&self, other: &Self) -> bool {
        // first, check if all key sets are the same
        let label_keys1: BTreeSet<LabelHash> = self.labels.keys().cloned().collect();
        let label_keys2: BTreeSet<LabelHash> = other.labels.keys().cloned().collect();
        if label_keys1 // difference must be empty
            .difference(&label_keys2)
            .next()
            .is_some()
        {
            return false;
        }

        let child_keys1: BTreeSet<LabelHash> = self.children.keys().cloned().collect();
        let child_keys2: BTreeSet<LabelHash> = other.children.keys().cloned().collect();
        if child_keys1 // difference must be empty
            .difference(&child_keys2)
            .next()
            .is_some()
        {
            return false;
        }

        let par_keys1: BTreeSet<LabelHash> = self.parents.keys().cloned().collect();
        let par_keys2: BTreeSet<LabelHash> = other.parents.keys().cloned().collect();
        if par_keys1 // difference must be empty
            .difference(&par_keys2)
            .next()
            .is_some()
        {
            return false;
        }

        // now see if connections are the same (children)
        for key in child_keys1.iter() {
            let ch1 = &self.children[key];
            let mut ch1set: BTreeSet<PfaChild<T>> = BTreeSet::new();
            let mut ch2set: BTreeSet<PfaChild<T>> = BTreeSet::new();
            for c in ch1.iter() {
                ch1set.insert(c.clone());
            }
            let ch2 = &self.children[key];
            for c in ch2.iter() {
                ch2set.insert(c.clone());
            }
            if ch1set // difference must be empty
                .difference(&ch2set)
                .next()
                .is_some()
            {
                return false;
            }
        }

        // now see if connections are the same (parents)
        for key in par_keys1.iter() {
            let ch1 = &self.parents[key];
            let mut ch1set: BTreeSet<LabelHash> = BTreeSet::new();
            let mut ch2set: BTreeSet<LabelHash> = BTreeSet::new();
            for c in ch1.iter() {
                ch1set.insert(*c);
            }
            let ch2 = &self.parents[key];
            for c in ch2.iter() {
                ch2set.insert(*c);
            }
            if ch1set // difference must be empty
                .difference(&ch2set)
                .next()
                .is_some()
            {
                return false;
            }
        }

        true
    }
}

impl<T: Eq + Copy + Hash + std::fmt::Debug + Ord> Pfa<T> {
    // empty pfa
    pub fn new() -> Self {
        Pfa {
            pst_root: Some(pst::PstNode::with_empty_label()),
            alphabet: Vec::new(),
            current_symbol: None,
            current_state: None,
            labels: HashMap::new(),
            children: HashMap::new(),
            parents: HashMap::new(),
            history: Vec::new(),
            state_history: Vec::new(),
            history_length: 9,
        }
    }

    #[allow(dead_code)]
    /// Transfer the state from another PFA to this one, of possible.
    pub fn transfer_state(&mut self, other: &Pfa<T>) {
        // this needs to be secured
        if let Some(state_label) = &other.current_state {
            if self.labels.contains_key(state_label) {
                self.current_state = Some(*state_label);
            }
        }
        if let Some(symbol) = other.current_symbol {
            if self.alphabet.iter().any(|&i| i == symbol) {
                self.current_symbol = Some(symbol);
            }
        }
        // history and symbol ages not transfered ...
        // might experiement with that later though ...
    }

    //
    // BASIC OPERATIONS
    //

    /// Add a child to a given node within the PFA.
    pub fn add_child(&mut self, src: &Label<T>, dest: &Label<T>, prob: f32) {
        let src_hash = calculate_hash(src);
        let child_hash = calculate_hash(dest);
        //println!("{:?} {}", dest, child_hash);
        if let Some(c) = self.children.get_mut(&src_hash) {
            c.push(PfaChild {
                prob,
                child: dest.to_vec(),
                child_hash,
            });
            c.sort(); // sort to let dedup work ..
            c.dedup();
        }
    }

    /// Add a parent to a given node within the PFA.
    pub fn add_parent(&mut self, dest: &Label<T>, src: &Label<T>) {
        let src_hash = calculate_hash(src);
        let child_hash = calculate_hash(dest);
        if let Some(p) = self.parents.get_mut(&child_hash) {
            p.push(src_hash);
            p.sort_unstable(); // sort to let dedup work ..
            p.dedup();
        }
    }

    /// Check the inner consistency of a PFA.
    #[allow(dead_code)]
    pub fn check_consistency(&self) -> bool {
        let label_keys: BTreeSet<LabelHash> = self.labels.keys().cloned().collect();
        let parent_keys: BTreeSet<LabelHash> = self.parents.keys().cloned().collect();
        let child_keys: BTreeSet<LabelHash> = self.children.keys().cloned().collect();

        if label_keys // difference must be empty
            .difference(&parent_keys)
            .next()
            .is_some()
        {
            println!("INCONSISTENCY - label and parents");
            return false;
        }

        if parent_keys // difference must be empty
            .difference(&label_keys)
            .next()
            .is_some()
        {
            println!("INCONSISTENCY - parents and label");
            return false;
        }

        if label_keys // difference must be empty
            .difference(&child_keys)
            .next()
            .is_some()
        {
            println!("INCONSISTENCY - label and children");
            return false;
        }

        if child_keys // difference must be empty
            .difference(&label_keys)
            .next()
            .is_some()
        {
            println!("INCONSISTENCY - children and label");
            return false;
        }

        if parent_keys // difference must be empty
            .difference(&child_keys)
            .next()
            .is_some()
        {
            println!("INCONSISTENCY - parent and children");
            return false;
        }

        if child_keys // difference must be empty
            .difference(&parent_keys)
            .next()
            .is_some()
        {
            println!("INCONSISTENCY - children and parent");
            return false;
        }

        for (_, chs) in self.children.iter() {
            for ch in chs.iter() {
                if !label_keys.contains(&ch.child_hash) {
                    println!(
                        "INCONSISTENCY - child {:?} {} doesn't exist",
                        ch.child, ch.child_hash
                    );
                    return false;
                }
            }
        }

        for (_, pars) in self.parents.iter() {
            for par in pars.iter() {
                if !label_keys.contains(par) {
                    println!("INCONSISTENCY - parent {} doesn't exist", par);
                    return false;
                }
            }
        }

        true
    }

    /// Add state to PFA without updating inner PST.
    pub fn add_state(&mut self, label: &Label<T>) {
        self.alphabet.extend_from_slice(label);
        self.alphabet.sort();
        self.alphabet.dedup();
        let label_hash = calculate_hash(label);
        self.children.insert(label_hash, Vec::new());
        self.parents.insert(label_hash, Vec::new());
        self.labels.insert(label_hash, label.to_vec());
        if self.current_state.is_none() {
            self.current_state = Some(label_hash);
            if !label.is_empty() {
                self.current_symbol = Some(*label.first().unwrap());
            }
        }
    }

    //
    // PFA INFO
    //

    /// Check if a PFA node is reachable (has parents).
    #[allow(dead_code)]
    pub fn state_orphaned(&self, label: &Label<T>) -> bool {
        let label_hash = calculate_hash(label);
        self.state_orphaned_hash(label_hash)
    }

    #[allow(dead_code)]
    pub fn state_orphaned_hash(&self, label_hash: LabelHash) -> bool {
        if let Some(par) = self.parents.get(&label_hash) {
            par.is_empty()
        } else {
            true
        }
    }

    /// Check if a PFA has a specified state.
    pub fn has_state(&self, label: &Label<T>) -> bool {
        let label_hash = calculate_hash(label);
        self.labels.contains_key(&label_hash)
    }

    pub fn has_state_hash(&self, label_hash: LabelHash) -> bool {
        self.labels.contains_key(&label_hash)
    }

    // could be done more efficiently i suppose
    /// Check if this PFA has a certain transition.
    pub fn has_transition(&self, src: &Label<T>, dest: &Label<T>) -> bool {
        let src_hash = calculate_hash(src);
        let dest_hash = calculate_hash(dest);
        self.has_transition_hash(src_hash, dest_hash)
    }

    // could be done more efficiently i suppose
    fn has_transition_hash(&self, src_hash: LabelHash, dest_hash: LabelHash) -> bool {
        if self.labels.contains_key(&src_hash) && self.labels.contains_key(&dest_hash) {
            for ch in self.children[&src_hash].iter() {
                if ch.child_hash == dest_hash {
                    return true;
                }
            }
        }
        false
    }

    // could be done more efficiently i suppose
    pub fn get_emission(&self, src: &Label<T>, sym: T) -> Option<(Label<T>, Label<T>, f32)> {
        //println!("get emission: {:?} {:?}", src, sym);
        let src_hash = calculate_hash(src);
        self.get_emission_hash(src_hash, sym)
    }

    // could be done more efficiently i suppose
    /// Check if this PFA has a certain possible emission.
    fn get_emission_hash(&self, src_hash: LabelHash, sym: T) -> Option<(Label<T>, Label<T>, f32)> {
        //println!("fail_point {:?} {:?}", src_hash, self.labels[&src_hash]);
        if let Some(children) = self.children.get(&src_hash) {
            for ch in children.iter() {
                if let Some(s) = ch.child.last() {
                    if *s == sym {
                        return Some((self.labels[&src_hash].clone(), ch.child.clone(), ch.prob));
                    }
                }
            }
        }

        None
    }

    /// Make sure the overall probablity of exiting connections doesn't exceed 1.0
    pub fn rebalance_state(&mut self, state: &Label<T>) {
        let state_hash = calculate_hash(state);
        self.rebalance_state_hash(&state_hash);
    }

    fn rebalance_state_hash(&mut self, state_hash: &LabelHash) {
        let mut probs = Vec::new();
        for ch in self.children[state_hash].iter() {
            probs.push(ch.prob);
        }

        let probs_rebalanced = operations::rebalance_float(probs, 1.0, 0.35);
        if let Some(chs) = self.children.get_mut(state_hash) {
            for i in 0..probs_rebalanced.len() {
                chs[i].prob = probs_rebalanced[i];
            }
        }
    }

    fn free_probability_state_hash(&mut self, state_hash: &LabelHash, to_free: f32) {
        let mut probs = Vec::new();
        for ch in self.children[state_hash].iter() {
            probs.push(ch.prob);
        }
        let probs_freed = operations::free_probability_float(probs, to_free);
        if let Some(chs) = self.children.get_mut(state_hash) {
            for i in 0..probs_freed.len() {
                chs[i].prob = probs_freed[i];
            }
        }
    }

    #[allow(dead_code)]
    pub fn free_probability_state(&mut self, state: &Label<T>, to_free: f32) {
        let state_hash = calculate_hash(state);
        self.free_probability_state_hash(&state_hash, to_free);
    }

    fn blur_state_hash(&mut self, state_hash: &LabelHash, blur: f32) {
        let mut probs = Vec::new();
        for ch in self.children[state_hash].iter() {
            probs.push(ch.prob);
        }
        let probs_blurred = operations::blur_float(probs, blur);
        if let Some(chs) = self.children.get_mut(state_hash) {
            for i in 0..probs_blurred.len() {
                chs[i].prob = probs_blurred[i];
            }
        }
    }

    fn sharpen_state_hash(&mut self, state_hash: &LabelHash, sharpen: f32) {
        let mut probs = Vec::new();
        for ch in self.children[state_hash].iter() {
            probs.push(ch.prob);
        }
        let probs_sharpened = operations::sharpen_float(probs, sharpen);
        if let Some(chs) = self.children.get_mut(state_hash) {
            for i in 0..probs_sharpened.len() {
                chs[i].prob = probs_sharpened[i];
            }
        }
    }

    #[allow(dead_code)]
    pub fn blur(&mut self, blurriness: f32) {
        let keys: Vec<u64> = self.labels.keys().cloned().collect();
        for hash in keys.iter() {
            self.blur_state_hash(hash, blurriness);
        }
    }

    #[allow(dead_code)]
    pub fn sharpen(&mut self, sharpness: f32) {
        // get copy of keys
        let keys: Vec<u64> = self.labels.keys().cloned().collect();
        for hash in keys.iter() {
            self.sharpen_state_hash(hash, sharpness);
        }
    }

    #[allow(dead_code)]
    pub fn rebalance(&mut self) {
        // get copy of keys
        let keys: Vec<u64> = self.labels.keys().cloned().collect();
        for hash in keys.iter() {
            if !self.children[hash].is_empty() {
                self.rebalance_state_hash(hash);
            }
        }
    }

    pub fn add_state_transition(
        &mut self,
        src: &Label<T>,
        dest: &Label<T>,
        prob: f32,
        rebalance: bool,
    ) -> PfaInsertionResult<T> {
        //println!("add trans {:?} -> {:?} {}", src, dest, prob);

        self.add_child(src, dest, prob);
        self.add_parent(dest, src);

        if rebalance {
            self.rebalance_state(src);
        }

        //if !self.check_consistency() {
        //    panic!("PANIC! add state transition {:?} {:?} broke consistency", src, dest);
        //}

        PfaInsertionResult {
            source: src.clone(),
            destination: dest.clone(),
            symbol: *dest.last().unwrap(),
            prob,
        }
    }

    #[allow(dead_code)]
    pub fn add_symbol_transition(
        &mut self,
        suffix: T,
        dest: &Label<T>,
        prob: f32,
        rebalance: bool,
    ) -> Vec<PfaInsertionResult<T>> {
        let mut insertions = Vec::new();
        if self.has_state(dest) {
            let suffix_states =
                pst::get_suffix_symbol_states(self.pst_root.as_ref().unwrap(), suffix);
            for src in suffix_states.iter() {
                if self.has_state(src) {
                    insertions.push(self.add_state_transition(src, dest, prob, rebalance));
                }
            }
        }
        //else {
        //  println!("WARNING - can't add symbol transition, dest state {:?} doesn't exist!", dest);
        //}

        insertions
    }

    fn remove_state_transition_hash(
        &mut self,
        src_hash: LabelHash,
        dest_hash: LabelHash,
        rebalance: bool,
    ) -> PfaRemovalResult<T> {
        let mut idx = 0;
        let mut found = false;
        let mut prob = 0.0;

        //println!("REMOVE (hash): src: {} dest: {}", src_hash, dest_hash);
        for ch in self.children[&src_hash].iter() {
            if ch.child_hash == dest_hash {
                found = true;
                prob = ch.prob;
                break;
            } else {
                idx += 1;
            }
        }

        if found {
            self.children.get_mut(&src_hash).unwrap().remove(idx);
        }

        found = false;
        idx = 0;
        for par in self.parents[&dest_hash].iter() {
            if *par == src_hash {
                found = true;
                break;
            } else {
                idx += 1;
            }
        }

        if found {
            self.parents.get_mut(&dest_hash).unwrap().remove(idx);
        }

        if rebalance {
            self.rebalance_state_hash(&src_hash);
        }

        PfaRemovalResult {
            source: self.labels[&src_hash].clone(),
            destination: self.labels[&dest_hash].clone(),
            prob,
        }
    }

    /// This removes the transition between two labeled states.
    pub fn remove_state_transition(
        &mut self,
        src: &Label<T>,
        dest: &Label<T>,
        rebalance: bool,
    ) -> PfaRemovalResult<T> {
        let src_hash = calculate_hash(src);
        let dest_hash = calculate_hash(dest);
        //println!("REMOVE: src: {:?} dest: {:?}", src, dest);
        self.remove_state_transition_hash(src_hash, dest_hash, rebalance)
    }

    fn modify_transition_probability(
        &mut self,
        src_hash: LabelHash,
        dest_hash: LabelHash,
        prob_mod: f32,
    ) {
        if let Some(children) = self.children.get_mut(&src_hash) {
            for ch in children.iter_mut() {
                if ch.child_hash == dest_hash {
                    ch.prob += prob_mod;
                }
            }
        }
    }

    #[allow(dead_code)]
    /// This removes the transition between two symbols, that
    /// is, it remove the possibility of the two symbols being
    /// emitted in succession.
    /// That includes the possibility of multiple states being
    /// removed. I.e, if you have a transition rule 'abab -> c',
    /// and remove 'a -> b', than the first one will be removed.
    pub fn remove_symbol_transition(
        &mut self,
        suffix: T,
        dest: T,
        rebalance: bool,
    ) -> Vec<PfaRemovalResult<T>> {
        let mut removals = Vec::new();
        let suffix_states = pst::get_suffix_symbol_states(self.pst_root.as_ref().unwrap(), suffix);
        let dest_states = pst::get_suffix_symbol_states(self.pst_root.as_ref().unwrap(), dest);

        for src in suffix_states.iter() {
            for dest in dest_states.iter() {
                if self.has_transition(src, dest) {
                    removals.push(self.remove_state_transition(src, dest, rebalance));
                }
            }
        }

        //if !self.check_consistency() {
        //    panic!("PANIC! remove sym trans broke consistency");
        //}

        removals
    }

    /// Removes all transitions leaving a state.
    fn remove_outgoing_transitions(&mut self, state: &Label<T>) -> Vec<PfaRemovalResult<T>> {
        let hash = calculate_hash(state);
        let mut rem_hash = Vec::new();
        let mut removals = Vec::new();
        for ch in self.children[&hash].iter() {
            rem_hash.push(ch.child_hash);
            //println!("REMOVE OUTGOING - {:?} -> {:?}", state, ch.child);
        }
        for r in rem_hash.iter() {
            removals.push(self.remove_state_transition_hash(hash, *r, false));
        }

        //if !self.check_consistency() {
        //    panic!("PANIC! remove outgoing broke consistency");
        //}

        removals
    }

    /// Removes all transitions entering a state.
    fn remove_incoming_transitions(&mut self, state: &Label<T>) -> Vec<PfaRemovalResult<T>> {
        let hash = calculate_hash(state);
        let mut rem_hash = Vec::new();
        let mut removals = Vec::new();
        for par in self.parents[&hash].iter() {
            rem_hash.push(*par);
            //if let Some(l) = self.labels.get(&par) {
            //println!("REMOVE INCOMING - {:?} -> {:?}", l, state);
            //}else {
            //println!("careful, state with hash {} not existing in parents ! ({} -> {:?})", par, par,state);
            //}
        }
        for r in rem_hash.iter() {
            removals.push(self.remove_state_transition_hash(*r, hash, false));
        }

        //if !self.check_consistency() {
        //    panic!("PANIC! remove incoming broke consistency");
        //}

        removals
    }

    /// Removes all traces of a state from this PFA.
    #[allow(dead_code)]
    fn purge_state(&mut self, state: &Label<T>) -> PfaOperationResult<T> {
        let mut removed_transitions = Vec::new();
        removed_transitions.append(&mut self.remove_incoming_transitions(state));
        removed_transitions.append(&mut self.remove_outgoing_transitions(state));
        let hash = calculate_hash(state);

        // remove symbol from history, which might be empty now ...
        self.state_history.retain(|x| *x != hash);

        if let Some(cur) = self.current_state {
            if cur == hash {
                self.reset_current_state();
            }
        }
        self.labels.remove(&hash);
        self.parents.remove(&hash);
        self.children.remove(&hash);

        //if !self.check_consistency() {
        //    panic!("PANIC! purge state broke consistency");
        //}

        PfaOperationResult {
            added_states: Vec::new(),
            removed_states: Vec::new(),
            added_transitions: Vec::new(),
            removed_transitions,
            template_symbol: None,
            added_symbol: None,
        }
    }

    /// This operation removes all states that have
    /// no incoming connection (no parent state)
    /// BE CAREFUL and only use it once all the operations
    /// you want to perform are done, otherwise you might
    /// run into trouble!
    pub fn remove_orphaned_states(&mut self) {
        let mut orphans = Vec::new();
        for (k, v) in self.labels.iter() {
            if self.state_orphaned_hash(*k) {
                //println!("found orphan {:?}", v);
                orphans.push(v.clone());
            }
        }

        for o in orphans {
            self.purge_state(&o);
        }

        //if !self.check_consistency() {
        //    panic!("PANIC! remove orphans broke consistency");
        //}
    }

    /// Re-build inner PST from labels.
    pub fn rebuild_pst(&mut self) {
        let mut new_root = pst::PstNode::<T>::with_empty_label();
        for (_, label) in self.labels.iter() {
            pst::add_leaf(&mut new_root, label);
        }
        self.pst_root = Some(new_root);
    }

    /// This method allows to prune the alphabet of a pfa
    /// It removes and tries to "bridge" all connections, which, assuming that
    /// every state is reachable, should be always possible (which would need a proof, but anyway)
    pub fn remove_symbol(&mut self, symbol: T, rebalance: bool) -> PfaOperationResult<T> {
        //println!("remove symbol {:?}", symbol);
        // remove all states that contain the symbol
        let mut states_to_remove =
            pst::get_states_containing_symbol(self.pst_root.as_ref().unwrap(), symbol);
        // pst might contain states that the pfa hasn't, so make sure we filter these out ...
        states_to_remove.retain(|x| self.has_state(x));

        // remove symbol from alphabet
        self.alphabet.retain(|x| *x != symbol);
        // remove symbol from history, which might be empty now ...
        self.history.retain(|x| *x != symbol);

        // keep track of insertions
        let mut insertions = Vec::new();
        let mut removals = Vec::new();

        for state in states_to_remove.iter() {
            let state_hash = calculate_hash(state);

            // remove symbol from history, which might be empty now ...
            self.state_history.retain(|x| *x != state_hash);
            //println!("remove state {:?}", state);
            // mutable because we need to merge them later
            let mut removals_out = self.remove_outgoing_transitions(state);
            let mut removals_in = self.remove_incoming_transitions(state);

            for r_in in removals_in.iter() {
                for r_out in removals_out.iter() {
                    let r_in_src_hash = calculate_hash(&r_in.source);
                    let r_out_dest_hash = calculate_hash(&r_out.destination);

                    //println!("check TRANS: {:?} {:?} {:?} {:?}",  r_in.source, r_in.destination, r_out.source, r_out.destination);
                    if r_in_src_hash != state_hash && r_out_dest_hash != state_hash {
                        //println!("ADD TRANS: {:?} {:?}",  r_in.source, r_out.destination);
                        if !self.has_transition_hash(r_in_src_hash, r_out_dest_hash) {
                            insertions.push(self.add_state_transition(
                                &r_in.source,
                                &r_out.destination,
                                r_in.prob,
                                rebalance,
                            ));
                        } else {
                            // add the probability to an existing edge ...
                            self.modify_transition_probability(
                                r_in_src_hash,
                                r_out_dest_hash,
                                r_in.prob,
                            );
                        }
                    }
                }
            }

            // make sure we have a valid current state (if possible)
            if let Some(cur) = self.current_state {
                if cur == state_hash {
                    self.reset_current_state();
                }
            }

            self.children.remove(&state_hash);
            self.parents.remove(&state_hash);
            self.labels.remove(&state_hash);

            removals.append(&mut removals_in);
            removals.append(&mut removals_out);
        }

        // rebuild pst from remaining states !!
        self.rebuild_pst();

        // make sure to remove subsequently removed transitions form insertions
        for rem in removals.iter() {
            insertions.retain(|x| x.source != rem.source || x.destination != rem.destination);
        }

        //if !self.check_consistency() {
        //    panic!("PANIC! remove symbol broke consistency");
        //}

        // let the outside world know what's happening ...
        PfaOperationResult {
            added_states: Vec::new(),
            removed_states: states_to_remove,
            added_transitions: insertions,
            removed_transitions: removals,
            template_symbol: None,
            added_symbol: None,
        }
    }

    /// Reset current state to an existing one, if possible,
    /// to ensure as much as possible continuity in case the pfa
    /// is changed.
    pub fn reset_current_state(&mut self) {
        // if current state doesn't exist,
        // choose previous one ...
        if let Some(s) = self.state_history.last() {
            let stitch_state = self.labels[s].clone();
            self.current_state = Some(*s);
            self.current_symbol = Some(*stitch_state.last().unwrap()); // there shouldn't be empty states
            println!(
                "reset cur state (from state history) because removal {:?}",
                stitch_state
            );
            return;
        }

        // go through history to find last available state ...
        let mut stitch_state = Label::new();
        while let Some(s) = self.history.iter().rev().next() {
            stitch_state.insert(0, *s);
            if self.has_state(&stitch_state) {
                self.current_state = Some(calculate_hash(&stitch_state));
                self.current_symbol = Some(*s);
                println!(
                    "reset cur state (from symbol history) because removal {:?}",
                    stitch_state
                );
                return;
            }
        }

        // try single-symbol states if everything fails ...
        for s in self.alphabet.iter() {
            stitch_state.clear();
            stitch_state.push(*s);
            if self.has_state(&stitch_state) {
                self.current_state = Some(calculate_hash(&stitch_state));
                self.current_symbol = Some(*s);
                println!(
                    "reset cur state (from alphabet) because removal {:?}",
                    stitch_state
                );
                return;
            }
        }

        println!("can't find valid state in this pfa ...");
        self.current_state = None;
        self.current_symbol = None;
    }

    /// Pad the history of this PFA. Just in case the history is too short after a symbol removal.
    #[allow(dead_code)]
    pub fn pad_history(&mut self) {
        if self.history.len() < self.history_length {
            let new_elem = match self.history.len() {
                0 => self.alphabet[0],
                _ => *self.history.last().unwrap(),
            };
            let cur_len = self.history.len();
            for _ in cur_len..self.history_length {
                //println!("push pad: {:?}", new_elem);
                self.history.push(new_elem);
            }
        }
    }

    /// simulate an amount of steps, but reset to current state afterwards ...
    #[allow(dead_code)]
    pub fn sim_steps(&mut self, steps: usize) {
        let state_backup = self.current_state;
        let symbol_backup = self.current_symbol;
        for _ in 0..steps {
            self.next_transition();
        }
        self.current_state = state_backup;
        self.current_symbol = symbol_backup;
    }

    /// Query the next transition of this PFA.
    pub fn next_transition(&mut self) -> Option<PfaQueryResult<T>> {
        let mut choice_list = Vec::<LabelHash>::new();
        if let Some(cur) = &self.current_state {
            //println!("current state before trans {:?} {:?}", cur, self.labels[cur]);
            self.state_history.push(*cur);
            for c in &self.children[cur] {
                let prob = (100.0 * c.prob) as i32;
                for _ in 0..prob {
                    if self.has_state_hash(c.child_hash) {
                        choice_list.push(c.child_hash);
                    } else {
                        panic!(
                            "WARNING - found non-existing state {:?} -> {:?} {}",
                            self.labels[cur], c.child, c.child_hash
                        );
                    }
                }
            }
        }

        // push before updating
        if let Some(sym) = self.current_symbol {
            self.history.push(sym);
        }

        if let (Some(cur_state), Some(res)) = (
            self.current_state,
            choice_list.choose(&mut rand::thread_rng()),
        ) {
            self.current_state = Some(*res);

            if let Some(sym) = self.labels[res].last() {
                let last_symbol = self.current_symbol.unwrap();
                self.current_symbol = Some(*sym);

                // truncate history
                if self.history.len() > self.history_length {
                    self.history.drain(0..1);
                }
                if self.state_history.len() > self.history_length {
                    self.state_history.drain(0..1);
                }
                //println!("hist {:?}", self.history);
                //println!("state hist {:?}", self.state_history);
                //println!("state hist read {:?}", self.get_state_history_string());
                Some(PfaQueryResult {
                    last_state: self.labels[&cur_state].clone(),
                    current_state: self.labels[res].clone(),
                    last_symbol,
                    next_symbol: *sym,
                })
            } else {
                None
            }
        } else {
            None
        }
    }

    /// Only retrieve the next symbol of this PFA.
    #[allow(dead_code)]
    pub fn next_symbol(&mut self) -> Option<T> {
        if let Some(t) = self.next_transition() {
            Some(t.last_symbol)
        } else {
            None
        }
    }

    /// Helper function to generate PFA from PST.
    fn copy_states_from_pst(&mut self, root: &pst::PstNode<T>) {
        self.add_state(&root.label);

        for child in root.children.values() {
            self.copy_states_from_pst(child);
        }
    }

    // so far i don't have any check for the 'star' property that
    // is defined in the paper ...
    fn from_pst_nostar(root: &pst::PstNode<T>, alphabet: &[T]) -> Self {
        let mut new_pfa = Pfa {
            pst_root: Some(pst::PstNode::with_empty_label()),
            alphabet: Vec::new(),
            current_symbol: None,
            current_state: None,
            labels: HashMap::new(),
            children: HashMap::new(),
            parents: HashMap::new(),
            history: Vec::new(),
            state_history: Vec::new(),
            history_length: 9,
        };

        new_pfa.copy_states_from_pst(root);

        // is there a way to do this more efficiently ??
        let labels = new_pfa.labels.clone();

        for (_, label) in labels.iter() {
            // need to determine probability later on
            let pst_state = pst::find_longest_suffix_state(root, label);

            for symbol in alphabet.iter() {
                let longest_suffix_state =
                    pst::find_longest_suffix_state_with_symbol(root, label, symbol);
                let tprob = *pst_state.child_probability.get(symbol).unwrap();
                // transitions with 0 prob don't make sense ...
                if tprob > 0.0 {
                    /*
                    println!("add trans {:?} {:?} {} ", label,
                    &longest_suffix_state.label,
                    tprob);
                     */
                    new_pfa.add_state_transition(label, &longest_suffix_state.label, tprob, false);
                }
            }
        }

        new_pfa
    }

    fn learn_with_alphabet(
        sample: &[T],
        alphabet: &[T],
        bound: usize,
        epsilon: f32,
        n: usize,
    ) -> Self {
        let pst_root = crate::pst::learn_with_alphabet(sample, alphabet, bound, epsilon, n);
        let mut pfa = Pfa::from_pst_nostar(&pst_root, alphabet);
        if let Some(s) = sample.first() {
            pfa.current_symbol = Some(*s);
        }
        pfa
    }

    /// Learn pfa using sample only, alphabet will be inferred from sample.
    pub fn learn(sample: &[T], bound: usize, epsilon: f32, n: usize) -> Self {
        let mut alphabet = sample.to_vec();
        alphabet.sort();
        alphabet.dedup();

        Pfa::learn_with_alphabet(sample, &alphabet, bound, epsilon, n)
    }

    /// Add a rule to this PFA if possible.
    pub fn add_rule(&mut self, rule: &Rule<T>) {
        let mut tmp_self = self.clone();

        //println!("add rule {:?}", rule);
        let mut prefix = Label::<T>::new();
        let mut last_prefix = Label::<T>::new();
        let mut suffix = rule.source.clone();

        // this doesn't make sense when adding single symbol transitions
        // as we need to check whether it exists as suffix ...
        for sym in suffix.drain(..) {
            prefix.push(sym);
            if !tmp_self.has_state(&prefix) {
                //println!("add state: {:?}", prefix);
                tmp_self.add_state(&prefix);
            }

            // empty state not considered ...
            if !last_prefix.is_empty() {
                let longest = &pst::find_longest_suffix_state(
                    tmp_self.pst_root.as_ref().unwrap(),
                    &last_prefix,
                );
                //println!("longest {:?}", longest.label);
                // find all states that end in suffix
                let longest_suf_states = pst::get_child_labels(longest);
                if let Some(transition) = tmp_self.get_emission(&last_prefix, sym) {
                    for l in longest_suf_states.iter() {
                        if tmp_self.has_state(l) {
                            //println!("ch {:?}",l);
                            // remove original
                            tmp_self.remove_state_transition(l, &transition.1, false);
                            tmp_self.add_state_transition(l, &prefix, transition.2, false);
                        }
                    }
                    // new information is unique ...
                    if prefix != rule.source {
                        let longest2 = &pst::find_longest_suffix_state(
                            tmp_self.pst_root.as_ref().unwrap(),
                            &prefix,
                        )
                        .label
                        .clone();
                        //println!("longest2 {:?}", longest2);

                        // copy over information
                        for asym in tmp_self.alphabet.clone().iter() {
                            if let Some(transition2) = tmp_self.get_emission(longest2, *asym) {
                                //println!("copy transition last {:?} cur {:?} {:?} {:?} {:?}", last_prefix, prefix, transition2.1, asym, transition2.2);
                                tmp_self.add_state_transition(
                                    &prefix,
                                    &transition2.1,
                                    transition2.2,
                                    false,
                                );
                            }
                        }
                    }
                } else {
                    //println!("impossible or existing rule {:?} {:?}", rule.source, rule.symbol);
                    return; // this rule is impossible and can't be added at this point of time
                }
            }

            last_prefix.push(sym);
            // update pst
            if let Some(root) = tmp_self.pst_root.as_mut() {
                //println!("add leaf {:?}", prefix);
                pst::add_leaf(root, &prefix);
            }
        }

        // add new information as specified by rule (that is omitted in the loop above ...)
        prefix.push(rule.symbol);
        let mut longest =
            pst::find_longest_suffix_state(tmp_self.pst_root.as_ref().unwrap(), &prefix)
                .label
                .clone();
        //println!("longest 3 {:?}", longest);
        if !tmp_self.has_state(&longest) {
            //println!("clear not found");
            longest.clear();
            // force the subsequent:
        }
        if longest.is_empty() {
            longest.push(rule.symbol);
            if !tmp_self.has_state(&longest) {
                tmp_self.add_state(&longest);
                if let Some(root) = tmp_self.pst_root.as_mut() {
                    //println!("aaadd leaf {:?}", longest);
                    pst::add_leaf(root, &longest);
                }
            }
        }
        //println!("add final transition {:?} {:?}", rule.source, longest);

        // this needs a method to find the suffix states for a longer suffix, in case the rule source is longer
        let suf_node =
            pst::find_longest_suffix_state(tmp_self.pst_root.as_ref().unwrap(), &rule.source);
        let suf_states = pst::get_child_labels(suf_node);
        for state in suf_states.iter() {
            //println!("add {:?} {:?}", state, longest);
            //if self.has_state(state) {
            tmp_self.add_state_transition(state, &longest, rule.probability, false);
            //}
        }
        // if everything goes well, continue with the new version
        // that way, if an impossible rule is found, we don't end up
        // with all the garbage.

        //if !self.check_consistency() {
        //  panic!("PANIC! add rule {:?} broke consistency", rule);
        //}

        *self = tmp_self;
    }

    /// Add random connections between states.
    #[allow(dead_code)]
    pub fn randomize_edges(&mut self, chance: f32, prob: f32) {
        // once more, not the most efficient algorithm ...
        let mut new_edges = Vec::new();
        let chance_int = (100.0 * chance) as i32;
        let mut rng = rand::thread_rng();

        for (k1, v1) in self.labels.iter() {
            for (k2, v2) in self.labels.iter() {
                let c: i32 = rng.gen::<i32>() % 100;
                if c < chance_int && !self.has_transition_hash(*k1, *k2) {
                    new_edges.push((v1.clone(), v2.clone()));
                }
            }
        }

        for e in new_edges.iter() {
            self.add_state_transition(&e.0, &e.1, prob, false);
        }
    }

    /// Add (possible) repetitions for each symbol.
    #[allow(dead_code)]
    pub fn repeat(&mut self, chance: f32, max_rep: usize) {
        for sym in self.alphabet.clone().iter() {
            self.repeat_symbol(*sym, chance, max_rep);
        }
    }

    /// Solidify this PFA, to make once-generated sequences more likely.
    #[allow(dead_code)]
    pub fn solidify(&mut self, ctx_len: usize) {
        if self.history.len() >= (ctx_len + 1) {
            let src =
                self.history[self.history.len() - (ctx_len + 1)..self.history.len() - 1].to_vec();
            //println!("solidify (rule {:?} {:?}) len: {} hist: {:?}", src, *self.history.last().unwrap(), ctx_len, self.history);
            self.add_rule(&Rule {
                source: src,
                symbol: *self.history.last().unwrap(),
                probability: 1.0,
            });
            self.remove_orphaned_states();
        }
        //else {
        //    println!("can't solidify");
        //}

        //if !self.check_consistency() {
        //    panic!("PANIC! solidify broke consistency");
        //}
    }

    /// Jump to an earlier state.
    #[allow(dead_code)]
    pub fn rewind(&mut self, states: usize) {
        if self.state_history.len() >= states {
            self.current_state = Some(
                *self
                    .state_history
                    .get(self.state_history.len() - states)
                    .unwrap(),
            );
        }
    }

    /// Add possible repetitions for a single symbol.
    #[allow(dead_code)]
    pub fn repeat_symbol(&mut self, sym: T, chance: f32, max_rep: usize) {
        let mut states_to_add = Vec::new();

        let mut transitions_to_add = Vec::new();

        let suffix_states: Vec<Label<T>> =
            pst::get_suffix_symbol_states(self.pst_root.as_ref().unwrap(), sym);

        // collect states we need to modify (that end in our symbol)
        for state in suffix_states.iter() {
            let state_hash = calculate_hash(state);

            // not all pst states are necessarily present in pfa ...
            // ignore those that aren't
            if !self.children.contains_key(&state_hash) {
                continue;
            }
            //states_to_free.push(state_hash);
            self.free_probability_state_hash(&state_hash, chance);

            let mut lab = state.clone();
            let mut last_lab = state.clone();

            for i in 0..(max_rep - 1) {
                lab.push(sym);
                states_to_add.push(lab.clone());
                transitions_to_add.push((last_lab.clone(), lab.clone(), chance));

                if i < max_rep - 2 {
                    for ch in self.children[&state_hash].iter() {
                        transitions_to_add.push((lab.clone(), ch.child.clone(), 1.0 - chance));
                    }
                } else {
                    for ch in self.children[&state_hash].iter() {
                        transitions_to_add.push((lab.clone(), ch.child.clone(), 1.0));
                    }
                }

                last_lab = lab.clone();
            }
        }

        for state in states_to_add.iter() {
            self.add_state(state);
        }

        for trans in transitions_to_add.iter() {
            self.add_state_transition(&trans.0, &trans.1, trans.2, false);
        }
    }

    /// Infer PFA from given rules.
    /// If remove_orphans is specified, remove orphaned (unreachable) states
    /// that might occur during successive rule inference.
    pub fn infer_from_rules(rules: &mut [Rule<T>], remove_orphans: bool) -> Self {
        let mut pfa = Pfa::new();

        // make sure the rules are sorted !!
        rules.sort_by(|a, b| a.source.len().partial_cmp(&b.source.len()).unwrap());
        //println!("{:?}", rules);

        // add the necessary states, implicit states that model repetition, etc
        for rule in rules.iter() {
            pfa.add_rule(rule);
        }

        // non-reachable states might occur in the process of creating more
        // complex pfas, so let's remove them
        if remove_orphans {
            pfa.remove_orphaned_states();
        }

        // return assembled PFA
        pfa
    }

    /// get the state history in a somewhat readable manner
    #[allow(dead_code)]
    pub fn get_state_history_string(&self) -> String {
        let mut readable_history = Vec::new();
        for hash in self.state_history.iter() {
            let mut label = if let Some(l) = self.labels.get(hash) {
                format!("{:?}", l)
            } else {
                "UNKNOWN".to_string()
            };

            label.retain(|c| {
                c != '\"'
                    && c != '\''
                    && c != '['
                    && c != ']'
                    && c != '{'
                    && c != '}'
                    && c != ','
                    && c != ' '
                    && c != '\\'
            });
            readable_history.push(label);
        }
        format!("{:?}", readable_history)
    }

    /// get the state history in a somewhat readable manner
    #[allow(dead_code)]
    pub fn get_symbol_history_string(&self) -> String {
        let mut readable_history = Vec::new();
        for sym in self.history.iter() {
            let mut label = format!("{:?}", sym);
            label.retain(|c| {
                c != '\"'
                    && c != '\''
                    && c != '['
                    && c != ']'
                    && c != '{'
                    && c != '}'
                    && c != ','
                    && c != ' '
                    && c != '\\'
            });
            readable_history.push(label);
        }
        format!("{:?}", readable_history)
    }
}

/// Format as dot for easy debugging output ...
pub fn to_dot<T: Eq + Copy + Hash + Ord + std::fmt::Debug>(pfa: &Pfa<T>) -> String {
    let mut w = String::new();
    writeln!(&mut w, "digraph{{").unwrap();

    for (k, v) in pfa.labels.iter() {
        let mut lab = format!("{:?}", v);
        lab.retain(|c| {
            c != '\"'
                && c != '\''
                && c != '['
                && c != ']'
                && c != '{'
                && c != '}'
                && c != ','
                && c != ' '
                && c != '\\'
        });
        //println!("{}", lab);
        writeln!(&mut w, "{}[label=\"{}\"]", k, lab).unwrap();
    }

    for (k, v) in pfa.children.iter() {
        for ch in v.iter() {
            writeln!(
                &mut w,
                "{}->{}[label=\"{}\" weight=\"{}\", penwidth=\"{}\", rank=same, arrowsize=1.0]",
                k, ch.child_hash, ch.prob, ch.prob, ch.prob
            )
            .unwrap();
        }
    }

    writeln!(&mut w, "}}").unwrap();

    // return assembled string ...
    w
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::fs;

    #[test]
    fn test_remove_symbol_transition() {
        let mut rules = Vec::new();

        rules.push(Rule {
            source: "abab".chars().collect(),
            symbol: 'd',
            probability: 1.0,
        });

        rules.push(Rule {
            source: "aaaaa".chars().collect(),
            symbol: 'd',
            probability: 1.0,
        });

        rules.push(Rule {
            source: "a".chars().collect(),
            symbol: 'a',
            probability: 0.6,
        });

        rules.push(Rule {
            source: "a".chars().collect(),
            symbol: 'b',
            probability: 0.4,
        });

        rules.push(Rule {
            source: "b".chars().collect(),
            symbol: 'c',
            probability: 0.5,
        });

        rules.push(Rule {
            source: "b".chars().collect(),
            symbol: 'a',
            probability: 0.5,
        });

        rules.push(Rule {
            source: "c".chars().collect(),
            symbol: 'd',
            probability: 1.0,
        });

        rules.push(Rule {
            source: "d".chars().collect(),
            symbol: 'a',
            probability: 1.0,
        });

        let mut pfa = Pfa::<char>::infer_from_rules(&mut rules, true);

        let dot_string_before = to_dot::<char>(&pfa);
        let dot_string_before_pst = pst::to_dot::<char>(&pfa.pst_root.clone().unwrap());
        fs::write("before_removal.dot", dot_string_before).expect("Unable to write file");
        fs::write("before_removal_pst.dot", dot_string_before_pst).expect("Unable to write file");

        let longest =
            pst::find_longest_suffix_state(&pfa.pst_root.clone().unwrap(), &['a', 'b']).clone();
        println!("longest {:?}", longest.label);

        pfa.add_rule(&Rule {
            source: "ab".chars().collect(),
            symbol: 'e',
            probability: 0.4,
        });

        pfa.add_rule(&Rule {
            source: "e".chars().collect(),
            symbol: 'a',
            probability: 1.0,
        });

        pfa.rebalance();
        let dot_string_intermediate = to_dot::<char>(&pfa);
        fs::write("intermediate.dot", dot_string_intermediate).expect("Unable to write file");

        pfa.remove_symbol_transition('a', 'b', false);
        //pfa.remove_orphaned_states();
        //pfa.remove_orphaned_states();

        let dot_string_after = to_dot::<char>(&pfa);
        fs::write("after_removal.dot", dot_string_after).expect("Unable to write file");
    }

    #[test]
    fn test_rule_addition_order_equivalence() {
        let mut rules = Vec::new();

        //(rule 'baba 'b 100 100)
        //(rule 'bcda 'b 100 100)
        rules.push(Rule {
            source: "a".chars().collect(),
            symbol: 'a',
            probability: 0.1,
        });

        rules.push(Rule {
            source: "a".chars().collect(),
            symbol: 'b',
            probability: 0.9,
        });

        rules.push(Rule {
            source: "b".chars().collect(),
            symbol: 'a',
            probability: 0.8,
        });

        rules.push(Rule {
            source: "b".chars().collect(),
            symbol: 'c',
            probability: 0.2,
        });

        rules.push(Rule {
            source: "c".chars().collect(),
            symbol: 'd',
            probability: 1.0,
        });

        rules.push(Rule {
            source: "d".chars().collect(),
            symbol: 'a',
            probability: 1.0,
        });

        let mut rules1 = rules.clone();
        rules1.push(Rule {
            source: "baba".chars().collect(),
            symbol: 'b',
            probability: 1.0,
        });
        rules1.push(Rule {
            source: "bcda".chars().collect(),
            symbol: 'b',
            probability: 1.0,
        });

        let mut rules2 = rules.clone();
        rules2.push(Rule {
            source: "bcda".chars().collect(),
            symbol: 'b',
            probability: 1.0,
        });
        rules2.push(Rule {
            source: "baba".chars().collect(),
            symbol: 'b',
            probability: 1.0,
        });

        let pfa1 = Pfa::<char>::infer_from_rules(&mut rules1, true);
        let pfa2 = Pfa::<char>::infer_from_rules(&mut rules2, true);

        assert!(pfa1 == pfa2);
    }
}