autoreply 0.3.5

autoreply: Model Context Protocol server for Bluesky profile and post search functionality
Documentation
use std::{
    cell::{RefCell, RefMut},
    collections::HashMap,
};

use unicode_normalization::UnicodeNormalization;

use super::proto::NormalizerSpec;

const DEFAULT_SPACE_CHAR: char = ' ';
const ESCAPED_WHITESPACE: char = '\u{2581}';

#[derive(Debug, Clone)]
struct CharUnit {
    ch: char,
    origin: usize,
    is_dummy_prefix: bool,
}

#[derive(Debug, Default, Clone)]
struct NormalizerWorkspace {
    units: Vec<CharUnit>,
    scratch: Vec<CharUnit>,
    chars: Vec<char>,
    positions: Vec<usize>,
}

impl NormalizerWorkspace {
    fn reserve(&mut self, capacity: usize, add_dummy_prefix: bool) {
        let needed = capacity + if add_dummy_prefix { 1 } else { 0 };
        self.units
            .reserve(needed.saturating_sub(self.units.capacity()));
        self.scratch
            .reserve(needed.saturating_sub(self.scratch.capacity()));
        self.chars
            .reserve(needed.saturating_sub(self.chars.capacity()));
        self.positions
            .reserve(needed.saturating_sub(self.positions.capacity()));
    }

    fn prepare(&mut self, capacity: usize, add_dummy_prefix: bool) {
        self.reserve(capacity, add_dummy_prefix);
        self.units.clear();
        self.scratch.clear();
        self.chars.clear();
        self.positions.clear();
    }

    fn apply_nfkc(&mut self) {
        self.scratch.clear();
        for unit in self.units.drain(..) {
            for normalized in std::iter::once(unit.ch).nfkc() {
                self.scratch.push(CharUnit {
                    ch: normalized,
                    origin: unit.origin,
                    is_dummy_prefix: unit.is_dummy_prefix,
                });
            }
        }
        std::mem::swap(&mut self.units, &mut self.scratch);
    }

    fn collapse_whitespace(&mut self) {
        self.scratch.clear();
        let mut prev_was_space = false;

        for unit in self.units.drain(..) {
            if unit.ch.is_whitespace() && !unit.is_dummy_prefix {
                if self.scratch.is_empty() || prev_was_space {
                    continue;
                }
                prev_was_space = true;
                self.scratch.push(CharUnit {
                    ch: DEFAULT_SPACE_CHAR,
                    origin: unit.origin,
                    is_dummy_prefix: false,
                });
            } else {
                prev_was_space = unit.ch.is_whitespace();
                self.scratch.push(unit);
            }
        }

        while matches!(self.scratch.last(), Some(last) if last.ch.is_whitespace() && !last.is_dummy_prefix)
        {
            self.scratch.pop();
        }

        std::mem::swap(&mut self.units, &mut self.scratch);
    }

    fn rebuild_outputs(&mut self) {
        self.chars.clear();
        self.positions.clear();
        self.chars.extend(self.units.iter().map(|u| u.ch));
        self.positions.extend(self.units.iter().map(|u| u.origin));
    }
}

pub struct NormalizedString<'a> {
    workspace: RefMut<'a, NormalizerWorkspace>,
}

impl<'a> NormalizedString<'a> {
    pub fn chars(&self) -> &[char] {
        &self.workspace.chars
    }

    pub fn positions(&self) -> &[usize] {
        &self.workspace.positions
    }

    pub fn len(&self) -> usize {
        self.workspace.chars.len()
    }

    pub fn is_empty(&self) -> bool {
        self.workspace.chars.is_empty()
    }

    pub fn to_string(&self) -> String {
        self.workspace.chars.iter().collect()
    }
}

#[derive(Debug, Clone)]
pub struct Normalizer {
    add_dummy_prefix: bool,
    remove_extra_whitespaces: bool,
    escape_whitespaces: bool,
    rules: NormalizationTrie,
    workspace: RefCell<NormalizerWorkspace>,
}

impl Normalizer {
    pub fn from_spec(spec: &NormalizerSpec) -> Self {
        let rules = NormalizationTrie::from_spec(spec);
        Self {
            add_dummy_prefix: spec.add_dummy_prefix.unwrap_or(true),
            remove_extra_whitespaces: spec.remove_extra_whitespaces.unwrap_or(true),
            escape_whitespaces: spec.escape_whitespaces.unwrap_or(true),
            rules,
            workspace: RefCell::new(NormalizerWorkspace::default()),
        }
    }

    pub fn reserve(&self, capacity: usize) {
        let mut workspace = self.workspace.borrow_mut();
        workspace.reserve(capacity, self.add_dummy_prefix);
    }

    pub fn prewarm(&self, capacity: usize) {
        self.reserve(capacity);
        let _ = self.normalize(&"x".repeat(capacity));
    }

    pub fn normalize(&self, input: &str) -> NormalizedString<'_> {
        let mut workspace = self.workspace.borrow_mut();
        workspace.prepare(input.len(), self.add_dummy_prefix);

        if self.add_dummy_prefix {
            workspace.units.push(CharUnit {
                ch: DEFAULT_SPACE_CHAR,
                origin: 0,
                is_dummy_prefix: true,
            });
        }

        let mut iter = input.char_indices().peekable();
        while let Some((byte_idx, ch)) = iter.next() {
            if let Some((replacement, consumed_bytes)) = self.rules.apply(&input[byte_idx..]) {
                if !replacement.is_empty() {
                    for &sub_ch in replacement {
                        workspace.units.push(CharUnit {
                            ch: sub_ch,
                            origin: byte_idx,
                            is_dummy_prefix: false,
                        });
                    }
                }

                let target_end = byte_idx + consumed_bytes;
                while let Some(&(next_idx, _)) = iter.peek() {
                    if next_idx < target_end {
                        iter.next();
                    } else {
                        break;
                    }
                }
                continue;
            }

            workspace.units.push(CharUnit {
                ch,
                origin: byte_idx,
                is_dummy_prefix: false,
            });
        }

        workspace.apply_nfkc();
        if self.remove_extra_whitespaces {
            workspace.collapse_whitespace();
        }

        if self.escape_whitespaces {
            for unit in &mut workspace.units {
                if unit.ch == DEFAULT_SPACE_CHAR {
                    unit.ch = ESCAPED_WHITESPACE;
                }
            }
        }

        workspace.rebuild_outputs();

        NormalizedString { workspace }
    }
}

#[derive(Debug, Clone, Default)]
struct NormalizationTrie {
    root: TrieNode,
}

#[derive(Debug, Clone, Default)]
struct TrieNode {
    children: HashMap<char, TrieNode>,
    replacement: Option<Vec<char>>,
}

impl NormalizationTrie {
    fn from_spec(spec: &NormalizerSpec) -> Self {
        let mut trie = Self::default();
        if let Some(tables) = spec.normalization_rule_tsv.as_deref() {
            trie.load_rules(tables);
        }
        trie
    }

    fn load_rules(&mut self, tsv_data: &str) {
        for line in tsv_data.lines() {
            if line.trim().is_empty() || line.trim_start().starts_with('#') {
                continue;
            }

            let mut parts = line.splitn(2, '\t');
            let Some(src_part) = parts.next() else {
                continue;
            };
            let Some(dst_part) = parts.next() else {
                continue;
            };

            let src_chars = parse_hex_chars(src_part);
            let dst_chars = hex_sequence_to_chars(dst_part);

            if src_chars.is_empty() {
                continue;
            }

            let mut node = &mut self.root;
            for &ch in &src_chars {
                node = node.children.entry(ch).or_insert_with(TrieNode::default);
            }
            node.replacement = Some(dst_chars);
        }
    }

    fn apply(&self, text: &str) -> Option<(&[char], usize)> {
        let mut node = &self.root;
        let mut chars = text.chars();
        let mut matched_len = 0;
        let mut last_match: Option<(&[char], usize)> = None;

        while let Some(ch) = chars.next() {
            if let Some(child) = node.children.get(&ch) {
                node = child;
                matched_len += ch.len_utf8();
                if let Some(ref replacement) = node.replacement {
                    last_match = Some((replacement.as_slice(), matched_len));
                }
            } else {
                break;
            }
        }

        last_match
    }
}

fn parse_hex_chars(hex_str: &str) -> Vec<char> {
    hex_str
        .split_whitespace()
        .filter_map(|s| u32::from_str_radix(s, 16).ok())
        .filter_map(char::from_u32)
        .collect()
}

fn hex_sequence_to_chars(hex_str: &str) -> Vec<char> {
    if hex_str.is_empty() {
        return Vec::new();
    }
    parse_hex_chars(hex_str)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::sentencepiece::proto::NormalizerSpec;

    fn reference_normalizer() -> Normalizer {
        // Minimal normalization rules for testing control character removal
        let normalization_rules = "\
# SentencePiece NFKC normalization rules
# Format: source_hex<tab>target_hex
# Control characters to remove (map to empty)
007F\t
008F\t
009F\t
000B\t
";
        let spec = NormalizerSpec {
            add_dummy_prefix: Some(true),
            remove_extra_whitespaces: Some(true),
            escape_whitespaces: Some(true),
            normalization_rule_tsv: Some(normalization_rules.to_string()),
            ..Default::default()
        };
        Normalizer::from_spec(&spec)
    }

    #[test]
    fn normalizes_basic_text() {
        let norm = reference_normalizer();
        let cases = vec![
            ("Hello", "▁Hello"),
            ("  Hello  ", "▁Hello"),
            ("Hello world", "▁Hello▁world"),
        ];

        for (input, expected) in cases {
            let normalized = norm.normalize(input);
            assert_eq!(
                normalized.chars().iter().collect::<String>(),
                expected,
                "input: {input:?}"
            );
        }
    }

    #[test]
    fn drops_control_characters_like_reference() {
        let norm = reference_normalizer();
        for &codepoint in &[0x7F, 0x8F, 0x9F, 0x0B] {
            let input = char::from_u32(codepoint).unwrap().to_string();
            let normalized = norm.normalize(&input);
            let result = normalized.chars().iter().collect::<String>();
            // After removing control char and with add_dummy_prefix, we expect just the space
            assert_eq!(
                result, "",
                "codepoint: U+{codepoint:04X}, got: {:?}",
                result
            );
        }
    }
}