autoreply 0.3.4

autoreply: Model Context Protocol server for Bluesky profile and post search functionality
Documentation
use std::collections::HashMap;
use std::ops::Range;
use std::path::Path;
use std::{fs, io};

use prost::Message;
use thiserror::Error;

use super::proto::{self, model_proto, ModelProto};

#[derive(Debug, Error)]
pub enum SentencePieceError {
    #[error("failed to read model file: {0}")]
    Io(#[from] io::Error),
    #[error("failed to decode model protobuf: {0}")]
    Decode(#[from] prost::DecodeError),
    #[error("model missing trainer spec")]
    MissingTrainerSpec,
    #[error("model missing normalizer spec")]
    MissingNormalizerSpec,
    #[error("model has empty vocabulary")]
    EmptyVocabulary,
    #[error("piece '{0}' is empty")]
    EmptyPiece(String),
}

#[derive(Debug, Clone)]
pub struct VocabularyPiece {
    pub id: u32,
    pub score: f32,
    pub kind: SentencePieceType,
    text_range: Range<usize>,
    char_range: Range<usize>,
}

impl VocabularyPiece {
    pub fn text<'a>(&'a self, storage: &'a VocabularyStorage) -> &'a str {
        &storage.text[self.text_range.clone()]
    }

    pub fn chars<'a>(&'a self, storage: &'a VocabularyStorage) -> &'a [char] {
        &storage.chars[self.char_range.clone()]
    }
}

#[derive(Debug, Clone, Default)]
pub struct VocabularyStorage {
    text: String,
    chars: Vec<char>,
}

impl VocabularyStorage {
    pub fn text(&self) -> &str {
        &self.text
    }

    pub fn piece_text<'a>(&'a self, piece: &'a VocabularyPiece) -> &'a str {
        piece.text(self)
    }

    pub fn piece_chars<'a>(&'a self, piece: &'a VocabularyPiece) -> &'a [char] {
        piece.chars(self)
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SentencePieceType {
    Normal,
    Unknown,
    Control,
    UserDefined,
    Byte,
    Unused,
}

impl From<model_proto::sentence_piece::Type> for SentencePieceType {
    fn from(value: model_proto::sentence_piece::Type) -> Self {
        use model_proto::sentence_piece::Type;
        match value {
            Type::Normal => SentencePieceType::Normal,
            Type::Unknown => SentencePieceType::Unknown,
            Type::Control => SentencePieceType::Control,
            Type::UserDefined => SentencePieceType::UserDefined,
            Type::Byte => SentencePieceType::Byte,
            Type::Unused => SentencePieceType::Unused,
        }
    }
}

#[derive(Debug, Clone)]
pub struct SentencePieceModel {
    pub proto: ModelProto,
    pub vocab: Vec<VocabularyPiece>,
    pub storage: VocabularyStorage,
    pub piece_index: HashMap<String, u32>,
    pub unk_id: u32,
    pub bos_id: Option<u32>,
    pub eos_id: Option<u32>,
    pub pad_id: Option<u32>,
}

impl SentencePieceModel {
    pub fn load_from_file(path: impl AsRef<Path>) -> Result<Self, SentencePieceError> {
        let bytes = fs::read(path)?;
        Self::load_from_bytes(&bytes)
    }

    pub fn load_from_bytes(bytes: &[u8]) -> Result<Self, SentencePieceError> {
        let proto = ModelProto::decode(bytes)?;
        Self::from_proto(proto)
    }

    pub fn from_proto(proto: ModelProto) -> Result<Self, SentencePieceError> {
        if proto.trainer_spec.is_none() {
            return Err(SentencePieceError::MissingTrainerSpec);
        }
        if proto.normalizer_spec.is_none() {
            return Err(SentencePieceError::MissingNormalizerSpec);
        }

        let (vocab, storage) = build_vocab(&proto)?;
        let piece_index = build_piece_index(&vocab, &storage);
        let trainer = proto.trainer_spec.as_ref().unwrap();

        let unk_id = trainer.unk_id.unwrap_or(0) as u32;
        let bos_id = option_id(trainer.bos_id);
        let eos_id = option_id(trainer.eos_id);
        let pad_id = option_id(trainer.pad_id);

        Ok(Self {
            proto,
            vocab,
            storage,
            piece_index,
            unk_id,
            bos_id,
            eos_id,
            pad_id,
        })
    }

    pub fn vocab(&self) -> &[VocabularyPiece] {
        &self.vocab
    }

    pub fn storage(&self) -> &VocabularyStorage {
        &self.storage
    }

    pub fn piece(&self, id: u32) -> Option<&VocabularyPiece> {
        self.vocab.get(id as usize)
    }

    pub fn piece_chars(&self, id: u32) -> Option<&[char]> {
        self.piece(id).map(|piece| piece.chars(&self.storage))
    }

    pub fn piece_text(&self, id: u32) -> Option<&str> {
        self.piece(id).map(|piece| piece.text(&self.storage))
    }

    pub fn trainer_spec(&self) -> &proto::TrainerSpec {
        self.proto
            .trainer_spec
            .as_ref()
            .expect("trainer_spec validated during construction")
    }

    pub fn normalizer_spec(&self) -> &proto::NormalizerSpec {
        self.proto
            .normalizer_spec
            .as_ref()
            .expect("normalizer_spec validated during construction")
    }

    pub fn self_test_data(&self) -> Option<&proto::SelfTestData> {
        self.proto.self_test_data.as_ref()
    }
}

fn build_vocab(
    proto: &ModelProto,
) -> Result<(Vec<VocabularyPiece>, VocabularyStorage), SentencePieceError> {
    if proto.pieces.is_empty() {
        return Err(SentencePieceError::EmptyVocabulary);
    }

    let mut storage = VocabularyStorage::default();
    storage.text = String::new();
    storage.text.reserve(proto.pieces.len() * 4);
    storage.chars.reserve(proto.pieces.len() * 4);

    let mut vocab = Vec::with_capacity(proto.pieces.len());
    for (idx, piece) in proto.pieces.iter().enumerate() {
        let piece_text = piece.piece.clone().unwrap_or_default();
        if piece_text.is_empty() {
            return Err(SentencePieceError::EmptyPiece(format!("id {}", idx)));
        }

        let text_start = storage.text.len();
        storage.text.push_str(&piece_text);
        let text_end = storage.text.len();

        let chars_start = storage.chars.len();
        storage.chars.extend(piece_text.chars());
        let chars_end = storage.chars.len();

        vocab.push(VocabularyPiece {
            id: idx as u32,
            score: piece.score.unwrap_or(0.0),
            kind: piece_kind(piece),
            text_range: text_start..text_end,
            char_range: chars_start..chars_end,
        });
    }

    Ok((vocab, storage))
}

fn piece_kind(piece: &model_proto::SentencePiece) -> SentencePieceType {
    piece
        .r#type
        .and_then(|t| model_proto::sentence_piece::Type::try_from(t).ok())
        .map(SentencePieceType::from)
        .unwrap_or(SentencePieceType::Normal)
}

fn build_piece_index(
    vocab: &[VocabularyPiece],
    storage: &VocabularyStorage,
) -> HashMap<String, u32> {
    let mut index = HashMap::with_capacity(vocab.len());
    for item in vocab {
        index.insert(storage.piece_text(item).to_string(), item.id);
    }
    index
}

fn option_id(raw: Option<i32>) -> Option<u32> {
    raw.and_then(|id| if id >= 0 { Some(id as u32) } else { None })
}

#[cfg(test)]
mod tests {
    use super::*;
    use proto::{NormalizerSpec, TrainerSpec};

    fn dummy_proto() -> ModelProto {
        ModelProto {
            pieces: vec![
                model_proto::SentencePiece {
                    piece: Some("<unk>".to_string()),
                    score: Some(0.0),
                    r#type: Some(model_proto::sentence_piece::Type::Unknown as i32),
                    ..Default::default()
                },
                model_proto::SentencePiece {
                    piece: Some("hello".to_string()),
                    score: Some(-1.0),
                    r#type: Some(model_proto::sentence_piece::Type::Normal as i32),
                    ..Default::default()
                },
            ],
            trainer_spec: Some(TrainerSpec {
                unk_id: Some(0),
                bos_id: Some(1),
                eos_id: Some(2),
                pad_id: Some(-1),
                ..Default::default()
            }),
            normalizer_spec: Some(NormalizerSpec::default()),
            ..Default::default()
        }
    }

    #[test]
    fn builds_vocab_index_and_special_ids() {
        let proto = dummy_proto();
        let model = SentencePieceModel::from_proto(proto).expect("model");
        assert_eq!(model.vocab.len(), 2);
        assert_eq!(model.unk_id, 0);
        assert_eq!(model.bos_id, Some(1));
        assert_eq!(model.eos_id, Some(2));
        assert_eq!(model.pad_id, None);
        assert_eq!(model.piece_index.get("hello"), Some(&1));
    }

    #[test]
    fn rejects_missing_trainer_spec() {
        let mut proto = dummy_proto();
        proto.trainer_spec = None;
        let err = SentencePieceModel::from_proto(proto).unwrap_err();
        assert!(matches!(err, SentencePieceError::MissingTrainerSpec));
    }

    #[test]
    fn rejects_missing_normalizer_spec() {
        let mut proto = dummy_proto();
        proto.normalizer_spec = None;
        let err = SentencePieceModel::from_proto(proto).unwrap_err();
        assert!(matches!(err, SentencePieceError::MissingNormalizerSpec));
    }

    #[test]
    fn rejects_empty_vocab() {
        let mut proto = dummy_proto();
        proto.pieces.clear();
        let err = SentencePieceModel::from_proto(proto).unwrap_err();
        assert!(matches!(err, SentencePieceError::EmptyVocabulary));
    }

    #[test]
    fn rejects_empty_piece_text() {
        let mut proto = dummy_proto();
        proto.pieces[1].piece = Some(String::new());
        let err = SentencePieceModel::from_proto(proto).unwrap_err();
        assert!(matches!(err, SentencePieceError::EmptyPiece(_)));
    }
}