use std::path::{Path, PathBuf};
use snafu::{ResultExt, Snafu};
#[derive(prost::Message)]
struct SpModelProto {
#[prost(message, repeated, tag = "1")]
pieces: Vec<SpPiece>,
}
#[derive(prost::Message)]
struct SpPiece {
#[prost(string, optional, tag = "1")]
piece: Option<String>,
#[prost(int32, optional, tag = "3")]
r#type: Option<i32>,
}
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum Error {
#[snafu(display("reading SentencePiece model from {}: {source}", path.display()))]
Io { path: PathBuf, source: std::io::Error },
#[snafu(display("parsing SentencePiece model at {}: {source}", path.display()))]
Decode { path: PathBuf, source: prost::DecodeError },
}
pub type Result<T> = std::result::Result<T, Error>;
pub fn load_vocab(path: &Path) -> Result<Vec<String>> {
use prost::Message;
let bytes = std::fs::read(path).context(IoSnafu { path: path.to_path_buf() })?;
let proto = SpModelProto::decode(&*bytes).context(DecodeSnafu { path: path.to_path_buf() })?;
let mut pieces = Vec::with_capacity(proto.pieces.len());
for p in proto.pieces {
let kind = p.r#type.unwrap_or(1);
let s = if kind == 1 || kind == 4 { p.piece.unwrap_or_default() } else { String::new() };
pieces.push(s);
}
Ok(pieces)
}