ctw 0.1.0

CTW Sequence Predictor
Documentation
use rand::Rng;

#[derive(Debug, Clone)]
pub struct CtwNode {
    pub zeros: u64,
    pub ones: u64,
    pub p_kt: f64,
    pub prob: f64,
    pub zero_child: Option<Box<CtwNode>>,
    pub one_child: Option<Box<CtwNode>>,
}

impl Default for CtwNode {
    fn default() -> Self {
        Self {
            zeros: 0,
            ones: 0,
            p_kt: 1.0,
            prob: 1.0,
            zero_child: None,
            one_child: None,
        }
    }
}

impl CtwNode {
    pub fn update_zero(&mut self, context: &[bool]) {
        match context {
            [] => (),
            [false, rest @ ..] => self.zero_child.get_or_insert_default().update_zero(rest),
            [true, rest @ ..] => self.one_child.get_or_insert_default().update_zero(rest),
        }
        self.p_kt = self.p_kt * (self.zeros as f64 + 0.5) / (self.ones + self.zeros + 1) as f64;
        self.zeros += 1;
        self.prob = self.calc_probablity();
    }

    pub fn update_one(&mut self, context: &[bool]) {
        match context {
            [] => (),
            [false, rest @ ..] => self.zero_child.get_or_insert_default().update_one(rest),
            [true, rest @ ..] => self.one_child.get_or_insert_default().update_one(rest),
        }
        self.p_kt = self.p_kt * (self.ones as f64 + 0.5) / (self.ones + self.zeros + 1) as f64;
        self.ones += 1;
        self.prob = self.calc_probablity();
    }

    pub fn update(&mut self, symbol: bool, context: &[bool]) {
        match symbol {
            true => self.update_one(context),
            false => self.update_zero(context),
        }
    }

    pub fn revert_zero(&mut self, context: &[bool]) {
        match context {
            [] => (),
            [false, rest @ ..] => self.zero_child.get_or_insert_default().revert_zero(rest),
            [true, rest @ ..] => self.one_child.get_or_insert_default().revert_zero(rest),
        }
        if self.zeros > 0 {
            self.zeros -= 1;
        }
        self.p_kt = self.p_kt * (self.zeros + self.ones + 1) as f64 / (self.zeros as f64 + 0.5);
        self.prob = self.calc_probablity();
    }

    pub fn revert_one(&mut self, context: &[bool]) {
        match context {
            [] => (),
            [false, rest @ ..] => self.zero_child.get_or_insert_default().revert_one(rest),
            [true, rest @ ..] => self.one_child.get_or_insert_default().revert_one(rest),
        }
        if self.ones > 0 {
            self.ones -= 1;
        }
        self.p_kt = self.p_kt * (self.zeros + self.ones + 1) as f64 / (self.ones as f64 + 0.5);
        self.prob = self.calc_probablity();
    }

    pub fn revert(&mut self, symbol: bool, context: &[bool]) {
        match symbol {
            true => self.revert_one(context),
            false => self.revert_zero(context),
        }
    }

    pub fn calc_probablity(&self) -> f64 {
        if self.is_end_of_context() {
            self.p_kt
        } else {
            let p_zero = self.zero_child.as_ref().map(|n| n.prob).unwrap_or(1.0);
            let p_one = self.one_child.as_ref().map(|n| n.prob).unwrap_or(1.0);
            (self.p_kt + p_zero * p_one) / 2.0
        }
    }

    pub fn is_end_of_context(&self) -> bool {
        let zero_child_count = self
            .zero_child
            .as_ref()
            .map(|n| n.zeros + n.ones)
            .unwrap_or_default();
        let one_child_count = self
            .one_child
            .as_ref()
            .map(|n| n.zeros + n.ones)
            .unwrap_or_default();

        zero_child_count == 0 && one_child_count == 0
    }

    pub fn sample(&mut self, context: &[bool], mut rng: impl Rng) -> bool {
        let p_before = self.prob;
        self.update_one(context);
        let p_after = self.prob;
        self.revert_one(context);
        let p = p_after / p_before;

        let r = rng.gen_range(0.0..1.0);
        r < p
    }
}