infotheory 1.1.1

The algorithmic information theory library.
Documentation
use ahash::AHashMap;

type RepeatPos = u32;
const REPEAT_POS_NONE: RepeatPos = u32::MAX;
type RepeatKey = u32;

#[inline]
fn repeat_pos_from_usize(pos: usize) -> RepeatPos {
    if pos >= REPEAT_POS_NONE as usize {
        panic!("text repeat position overflow");
    }
    pos as RepeatPos
}

#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub(crate) struct NeuralContextState {
    pub(crate) prev1: u8,
    pub(crate) prev2: u8,
    pub(crate) prev1_class: u8,
    pub(crate) prev2_class: u8,
    pub(crate) run_len: u16,
    pub(crate) utf8_left: u8,
    pub(crate) in_word: bool,
    pub(crate) word_len_bucket: u8,
    pub(crate) prev_word_class: u8,
    pub(crate) bracket_bucket: u8,
    pub(crate) quote_flags: u8,
    pub(crate) sentence_boundary: bool,
    pub(crate) paragraph_break: bool,
    pub(crate) repeat_len_bucket: u8,
    pub(crate) copied_last_byte: bool,
    pub(crate) has_history: bool,
}

pub(crate) type NeuralHistoryState = NeuralContextState;

#[derive(Clone, Copy, Debug, Default)]
struct WordTracker {
    len: u16,
    saw_lower: bool,
    saw_upper: bool,
    saw_digit: bool,
    saw_other: bool,
}

impl WordTracker {
    fn clear(&mut self) {
        *self = Self::default();
    }

    fn observe(&mut self, byte: u8) {
        self.len = self.len.saturating_add(1);
        match byte {
            b'a'..=b'z' => self.saw_lower = true,
            b'A'..=b'Z' => self.saw_upper = true,
            b'0'..=b'9' => self.saw_digit = true,
            _ => self.saw_other = true,
        }
    }

    fn class(self) -> u8 {
        if self.saw_digit && !self.saw_lower && !self.saw_upper && !self.saw_other {
            3
        } else if self.saw_upper && !self.saw_lower && !self.saw_digit && !self.saw_other {
            2
        } else if self.saw_lower && !self.saw_upper && !self.saw_digit && !self.saw_other {
            1
        } else if self.len > 0 {
            4
        } else {
            0
        }
    }
}

#[derive(Clone, Debug, Default)]
struct LocalRepeatState {
    history: Vec<u8>,
    table: AHashMap<RepeatKey, (RepeatPos, RepeatPos)>,
    predicted: Option<u8>,
    match_len: usize,
    copied_last_byte: bool,
}

impl LocalRepeatState {
    const MIN_LEN: usize = 4;
    const MAX_LEN: usize = 255;

    fn predict_from_history(&self) -> (Option<u8>, usize) {
        if self.history.len() < Self::MIN_LEN {
            return (None, 0);
        }
        let end = self.history.len() - 1;
        let Some(key) = repeat_key(&self.history) else {
            return (None, 0);
        };
        let Some(&(latest, previous)) = self.table.get(&key) else {
            return (None, 0);
        };
        let end_pos = repeat_pos_from_usize(end);
        let candidate_end = if latest == end_pos { previous } else { latest };
        if candidate_end == REPEAT_POS_NONE {
            return (None, 0);
        }
        let candidate_end = candidate_end as usize;
        if candidate_end + 1 >= self.history.len() {
            return (None, 0);
        }
        let mut matched = Self::MIN_LEN;
        while matched < Self::MAX_LEN
            && end >= matched
            && candidate_end >= matched
            && self.history[end - matched] == self.history[candidate_end - matched]
        {
            matched += 1;
        }
        (Some(self.history[candidate_end + 1]), matched)
    }

    fn repeat_len_bucket(&self) -> u8 {
        bucket_repeat_len(self.match_len)
    }

    fn update(&mut self, symbol: u8) {
        self.copied_last_byte = self.predicted == Some(symbol);
        self.history.push(symbol);
        if let Some(key) = repeat_key(&self.history) {
            let end = repeat_pos_from_usize(self.history.len() - 1);
            self.table
                .entry(key)
                .and_modify(|entry| {
                    entry.1 = entry.0;
                    entry.0 = end;
                })
                .or_insert((end, REPEAT_POS_NONE));
        }
        let (predicted, match_len) = self.predict_from_history();
        self.predicted = predicted;
        self.match_len = match_len;
    }
}

#[derive(Clone, Debug, Default)]
pub(crate) struct TextContextAnalyzer {
    state: NeuralContextState,
    word: WordTracker,
    newline_run: u8,
    bracket_stack: [u8; 8],
    bracket_depth: usize,
    repeat: LocalRepeatState,
}

impl TextContextAnalyzer {
    pub(crate) fn new() -> Self {
        Self::default()
    }

    pub(crate) fn state(&self) -> NeuralContextState {
        self.state
    }

    pub(crate) fn update(&mut self, symbol: u8) {
        let was_predicted = self.repeat.predicted;
        self.repeat.update(symbol);

        let byte_class = classify_byte(symbol);
        if self.state.has_history && symbol == self.state.prev1 {
            self.state.run_len = self.state.run_len.saturating_add(1).min(255);
        } else {
            self.state.run_len = 1;
        }

        let prev_in_word = self.state.in_word;
        let is_word_byte = is_word_byte(symbol);
        if is_word_byte {
            if !prev_in_word {
                self.word.clear();
            }
            self.word.observe(symbol);
        } else if prev_in_word {
            self.state.prev_word_class = self.word.class();
            self.word.clear();
        }
        self.state.in_word = is_word_byte;
        self.state.word_len_bucket = bucket_word_len(self.word.len);

        self.update_structure(symbol);

        self.state.prev2 = self.state.prev1;
        self.state.prev2_class = self.state.prev1_class;
        self.state.prev1 = symbol;
        self.state.prev1_class = byte_class;
        self.state.utf8_left = utf8_left_after(symbol, self.state.utf8_left);
        self.state.repeat_len_bucket = self.repeat.repeat_len_bucket();
        self.state.copied_last_byte = was_predicted == Some(symbol);
        self.state.has_history = true;
    }

    fn update_structure(&mut self, symbol: u8) {
        self.state.sentence_boundary = matches!(symbol, b'.' | b'!' | b'?');

        if symbol == b'\n' {
            self.newline_run = self.newline_run.saturating_add(1).min(3);
        } else {
            self.newline_run = 0;
        }
        self.state.paragraph_break = self.newline_run >= 2;

        match symbol {
            b'(' => self.push_bracket(1),
            b'[' => self.push_bracket(2),
            b'{' => self.push_bracket(3),
            b'<' => self.push_bracket(4),
            b')' => self.pop_bracket(1),
            b']' => self.pop_bracket(2),
            b'}' => self.pop_bracket(3),
            b'>' => self.pop_bracket(4),
            b'"' => self.state.quote_flags ^= 0x1,
            b'\'' => self.state.quote_flags ^= 0x2,
            _ => {}
        }
        self.state.bracket_bucket = if self.bracket_depth == 0 {
            0
        } else {
            self.bracket_stack[self.bracket_depth - 1]
        };
    }

    fn push_bracket(&mut self, bracket: u8) {
        if self.bracket_depth < self.bracket_stack.len() {
            self.bracket_stack[self.bracket_depth] = bracket;
            self.bracket_depth += 1;
        } else {
            self.bracket_stack[self.bracket_stack.len() - 1] = bracket;
        }
    }

    fn pop_bracket(&mut self, bracket: u8) {
        if self.bracket_depth == 0 {
            return;
        }
        if self.bracket_stack[self.bracket_depth - 1] == bracket {
            self.bracket_depth -= 1;
            return;
        }
        for idx in (0..self.bracket_depth).rev() {
            if self.bracket_stack[idx] == bracket {
                self.bracket_depth = idx;
                return;
            }
        }
    }
}

#[inline]
pub(crate) fn classify_byte(byte: u8) -> u8 {
    match byte {
        b'a'..=b'z' | b'A'..=b'Z' => 1,
        b'0'..=b'9' => 2,
        b' ' | b'\t' | b'\n' | b'\r' => 3,
        b'!'..=b'/' | b':'..=b'@' | b'['..=b'`' | b'{'..=b'~' => 4,
        0xC0..=0xFF => 5,
        0x80..=0xBF => 6,
        _ => 0,
    }
}

#[inline]
fn is_word_byte(byte: u8) -> bool {
    matches!(byte, b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' | 0x80..=0xFF)
}

#[inline]
fn bucket_word_len(len: u16) -> u8 {
    match len {
        0 => 0,
        1 => 1,
        2 => 2,
        3..=4 => 3,
        5..=8 => 4,
        9..=16 => 5,
        _ => 6,
    }
}

#[inline]
pub(crate) fn bucket_repeat_len(len: usize) -> u8 {
    match len {
        0 => 0,
        1..=3 => 1,
        4..=5 => 2,
        6..=8 => 3,
        9..=12 => 4,
        13..=16 => 5,
        17..=24 => 6,
        _ => 7,
    }
}

#[inline]
fn utf8_left_after(symbol: u8, prev_left: u8) -> u8 {
    if (0x80..=0xBF).contains(&symbol) {
        prev_left.saturating_sub(1)
    } else {
        match symbol {
            0xC0..=0xDF => 1,
            0xE0..=0xEF => 2,
            0xF0..=0xF7 => 3,
            _ => 0,
        }
    }
}

#[inline]
fn repeat_key(history: &[u8]) -> Option<RepeatKey> {
    if history.len() < LocalRepeatState::MIN_LEN {
        return None;
    }
    let n = history.len();
    Some(u32::from_be_bytes([
        history[n - 4],
        history[n - 3],
        history[n - 2],
        history[n - 1],
    ]))
}