sentencepiece-rs 0.2.1

Rust runtime reimplementation of SentencePiece model loading, normalization, encoding, and decoding.
Documentation
use sentencepiece_rs::{ModelType, SPACE_SYMBOL, SentencePieceProcessor};

const NORMAL: i32 = 1;
const UNKNOWN: i32 = 2;
const CONTROL: i32 = 3;
const USER_DEFINED: i32 = 4;
const BYTE: i32 = 6;

#[test]
fn loads_model_metadata() {
    let processor = test_processor();

    assert_eq!(processor.model().vocab_size(), 9);
    assert_eq!(processor.model().model_type(), ModelType::Unigram);
    assert_eq!(processor.unk_id(), 0);
    assert_eq!(processor.bos_id(), Some(1));
    assert_eq!(processor.eos_id(), Some(2));
    assert_eq!(processor.piece_to_id("missing"), 0);
    assert_eq!(processor.id_to_piece(4).unwrap(), "▁hello");
}

#[test]
fn normalizes_whitespace_without_charsmap() {
    let processor = test_processor();

    assert_eq!(
        processor.normalize("  hello   world  ").unwrap(),
        "▁hello▁world"
    );
    assert_eq!(processor.normalize("").unwrap(), "");
    assert_eq!(processor.normalize("     ").unwrap(), "");
}

#[test]
fn encodes_and_decodes_unigram() {
    let processor = test_processor();

    let pieces = processor.encode("hello world").unwrap();
    assert_eq!(pieces, ["▁hello", "▁world"]);

    let ids = processor.encode_to_ids("hello world").unwrap();
    assert_eq!(ids, [4, 5]);

    assert_eq!(processor.decode(&pieces).unwrap(), "hello world");
    assert_eq!(processor.decode_ids(&ids).unwrap(), "hello world");
}

#[test]
fn handles_unknown_pieces_and_surfaces() {
    let processor = test_processor();

    let pieces = processor.encode("hello z").unwrap();
    assert_eq!(pieces, ["▁hello", SPACE_SYMBOL, "z"]);
    assert_eq!(processor.encode_to_ids("hello z").unwrap(), [4, 3, 0]);
    assert_eq!(processor.decode(&pieces).unwrap(), "hello z");

    let explicit_unk = ["<unk>"];
    assert_eq!(processor.decode(&explicit_unk).unwrap(), "");
}

#[test]
fn handles_unicode_text() {
    let processor = test_processor();

    let pieces = processor.encode("こんにちは 世界").unwrap();
    assert_eq!(pieces, ["▁こんにちは", "▁世界"]);
    assert_eq!(processor.decode(&pieces).unwrap(), "こんにちは 世界");
}

#[test]
fn keeps_user_defined_symbols_intact() {
    let processor = test_processor();

    let pieces = processor.encode("hello <USER>").unwrap();
    assert_eq!(pieces, ["▁hello", SPACE_SYMBOL, "<USER>"]);
    assert_eq!(processor.decode(&pieces).unwrap(), "hello <USER>");
}

#[test]
fn byte_fallback_round_trips_unicode_unknowns() {
    let processor = byte_fallback_processor();

    let pieces = processor.encode("hi 🚀").unwrap();
    assert_eq!(
        pieces,
        ["▁hi", SPACE_SYMBOL, "<0xF0>", "<0x9F>", "<0x9A>", "<0x80>"]
    );
    assert_eq!(processor.decode(&pieces).unwrap(), "hi 🚀");
}

#[test]
fn bpe_model_merges_best_pairs() {
    let processor = bpe_processor();

    let pieces = processor.encode("abc").unwrap();
    assert_eq!(pieces, [SPACE_SYMBOL, "ab", "c"]);
    assert_eq!(processor.decode(&pieces).unwrap(), "abc");
}

#[test]
fn extra_options_work_in_order() {
    let mut processor = test_processor();
    processor.set_encode_extra_options("bos:eos").unwrap();

    assert_eq!(
        processor.encode("hello world").unwrap(),
        ["<s>", "▁hello", "▁world", "</s>"]
    );

    processor.set_encode_extra_options("reverse").unwrap();
    assert_eq!(
        processor.encode("hello world").unwrap(),
        ["▁world", "▁hello"]
    );
}

fn test_processor() -> SentencePieceProcessor {
    let pieces = vec![
        piece("<unk>", 0.0, UNKNOWN),
        piece("<s>", 0.0, CONTROL),
        piece("</s>", 0.0, CONTROL),
        piece(SPACE_SYMBOL, -5.0, NORMAL),
        piece("▁hello", 0.0, NORMAL),
        piece("▁world", 0.0, NORMAL),
        piece("▁こんにちは", 0.0, NORMAL),
        piece("▁世界", 0.0, NORMAL),
        piece("<USER>", 0.0, USER_DEFINED),
    ];
    SentencePieceProcessor::from_serialized_model(&model(pieces, Vec::new())).unwrap()
}

fn byte_fallback_processor() -> SentencePieceProcessor {
    let mut pieces = vec![
        piece("<unk>", 0.0, UNKNOWN),
        piece("<s>", 0.0, CONTROL),
        piece("</s>", 0.0, CONTROL),
        piece(SPACE_SYMBOL, -5.0, NORMAL),
        piece("▁hi", 0.0, NORMAL),
    ];
    for byte in 0u8..=255 {
        pieces.push(piece(&format!("<0x{byte:02X}>"), -20.0, BYTE));
    }

    let mut trainer = Vec::new();
    bool_field(&mut trainer, 35, true);
    SentencePieceProcessor::from_serialized_model(&model(pieces, trainer)).unwrap()
}

fn bpe_processor() -> SentencePieceProcessor {
    let pieces = vec![
        piece("<unk>", 0.0, UNKNOWN),
        piece("<s>", 0.0, CONTROL),
        piece("</s>", 0.0, CONTROL),
        piece(SPACE_SYMBOL, 0.0, NORMAL),
        piece("a", 0.0, NORMAL),
        piece("b", 0.0, NORMAL),
        piece("c", 0.0, NORMAL),
        piece("ab", 10.0, NORMAL),
    ];
    let mut trainer = Vec::new();
    varint_field(&mut trainer, 3, 2);
    SentencePieceProcessor::from_serialized_model(&model(pieces, trainer)).unwrap()
}

fn model(pieces: Vec<Vec<u8>>, trainer: Vec<u8>) -> Vec<u8> {
    let mut out = Vec::new();
    for piece in pieces {
        message_field(&mut out, 1, &piece);
    }
    if !trainer.is_empty() {
        message_field(&mut out, 2, &trainer);
    }
    out
}

fn piece(text: &str, score: f32, kind: i32) -> Vec<u8> {
    let mut out = Vec::new();
    string_field(&mut out, 1, text);
    fixed32_field(&mut out, 2, score.to_bits());
    varint_field(&mut out, 3, kind as u64);
    out
}

fn message_field(out: &mut Vec<u8>, field: u32, bytes: &[u8]) {
    key(out, field, 2);
    varint(out, bytes.len() as u64);
    out.extend_from_slice(bytes);
}

fn string_field(out: &mut Vec<u8>, field: u32, value: &str) {
    message_field(out, field, value.as_bytes());
}

fn fixed32_field(out: &mut Vec<u8>, field: u32, value: u32) {
    key(out, field, 5);
    out.extend_from_slice(&value.to_le_bytes());
}

fn bool_field(out: &mut Vec<u8>, field: u32, value: bool) {
    varint_field(out, field, u64::from(value));
}

fn varint_field(out: &mut Vec<u8>, field: u32, value: u64) {
    key(out, field, 0);
    varint(out, value);
}

fn key(out: &mut Vec<u8>, field: u32, wire: u8) {
    varint(out, ((field as u64) << 3) | u64::from(wire));
}

fn varint(out: &mut Vec<u8>, mut value: u64) {
    while value >= 0x80 {
        out.push((value as u8) | 0x80);
        value >>= 7;
    }
    out.push(value as u8);
}