use std::fs;
use std::path::Path;
use crate::model::{SentencePieceModel, Token, merge_or_byte_fallback_tokens};
use crate::normalizer::Normalizer;
use crate::util::{SPACE_SYMBOL, piece_to_byte, replace_space_symbol};
use crate::{Error, Result};
#[derive(Clone, Debug)]
pub struct SentencePieceProcessor {
model: SentencePieceModel,
normalizer: Normalizer,
denormalizer: Option<Normalizer>,
encode_extra_options: Vec<ExtraOption>,
decode_extra_options: Vec<ExtraOption>,
}
#[derive(Clone, Copy, Debug)]
enum ExtraOption {
Bos,
Eos,
Reverse,
UnkPiece,
}
impl SentencePieceProcessor {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let bytes = fs::read(path)?;
Self::from_serialized_model(&bytes)
}
pub fn from_serialized_model(bytes: &[u8]) -> Result<Self> {
Self::from_model(SentencePieceModel::from_slice(bytes)?)
}
pub fn from_model(model: SentencePieceModel) -> Result<Self> {
let mut normalizer = Normalizer::new(model.normalizer_spec(), model.trainer_spec())?;
let user_symbols = model.user_symbols();
crate::model::validate_no_duplicate_user_symbols(&user_symbols)?;
normalizer.set_user_symbols(user_symbols);
let denormalizer = model
.denormalizer_spec()
.filter(|spec| !spec.precompiled_charsmap.is_empty())
.map(Normalizer::new_denormalizer)
.transpose()?;
Ok(Self {
model,
normalizer,
denormalizer,
encode_extra_options: Vec::new(),
decode_extra_options: Vec::new(),
})
}
pub fn model(&self) -> &SentencePieceModel {
&self.model
}
pub fn normalize(&self, input: &str) -> Result<String> {
self.normalizer.normalize(input)
}
pub fn encode(&self, input: &str) -> Result<Vec<String>> {
Ok(self
.encode_tokens(input)?
.into_iter()
.map(|token| token.piece)
.collect())
}
pub fn encode_to_ids(&self, input: &str) -> Result<Vec<usize>> {
Ok(self
.encode_tokens(input)?
.into_iter()
.map(|token| token.id)
.collect())
}
pub fn decode<T: AsRef<str>>(&self, pieces: &[T]) -> Result<String> {
self.decode_pieces(pieces)
}
pub fn decode_pieces<T: AsRef<str>>(&self, pieces: &[T]) -> Result<String> {
let mut tokens = pieces
.iter()
.map(|piece| {
let piece = piece.as_ref().to_owned();
let id = self.model.piece_to_id(&piece);
Token { piece, id }
})
.collect::<Vec<_>>();
self.apply_extra_options(&self.decode_extra_options, &mut tokens);
self.decode_tokens(&tokens)
}
pub fn decode_ids(&self, ids: &[usize]) -> Result<String> {
let mut tokens = Vec::with_capacity(ids.len());
for id in ids {
let piece = self.model.id_to_piece(*id)?.to_owned();
tokens.push(Token { piece, id: *id });
}
self.apply_extra_options(&self.decode_extra_options, &mut tokens);
self.decode_tokens(&tokens)
}
pub fn set_encode_extra_options(&mut self, options: &str) -> Result<()> {
self.encode_extra_options = self.parse_extra_options(options)?;
Ok(())
}
pub fn set_decode_extra_options(&mut self, options: &str) -> Result<()> {
self.decode_extra_options = self.parse_extra_options(options)?;
Ok(())
}
pub fn unk_id(&self) -> usize {
self.model.unk_id()
}
pub fn bos_id(&self) -> Option<usize> {
self.model.bos_id()
}
pub fn eos_id(&self) -> Option<usize> {
self.model.eos_id()
}
pub fn pad_id(&self) -> Option<usize> {
self.model.pad_id()
}
pub fn piece_to_id(&self, piece: &str) -> usize {
self.model.piece_to_id(piece)
}
pub fn id_to_piece(&self, id: usize) -> Result<&str> {
self.model.id_to_piece(id)
}
fn encode_tokens(&self, input: &str) -> Result<Vec<Token>> {
let normalized = self.normalizer.normalize(input)?;
let tokens = self.model.encode_normalized(&normalized)?;
let mut tokens = merge_or_byte_fallback_tokens(&self.model, tokens)?;
self.apply_extra_options(&self.encode_extra_options, &mut tokens);
Ok(tokens)
}
fn decode_tokens(&self, tokens: &[Token]) -> Result<String> {
let mut text = String::new();
let mut byte_buffer = Vec::new();
let mut is_bos_ws = true;
let mut bos_ws_seen = false;
for token in tokens {
if self.model.is_byte(token.id) {
if let Some(byte) = piece_to_byte(&token.piece) {
byte_buffer.push(byte);
continue;
}
return Err(Error::model_parse(format!(
"invalid byte piece {}",
token.piece
)));
}
flush_byte_buffer(&mut byte_buffer, &mut text);
if bos_ws_seen || !text.is_empty() {
is_bos_ws = false;
}
let (decoded, consumed_bos_ws) = self.decode_sentence_piece(token, is_bos_ws)?;
bos_ws_seen = consumed_bos_ws;
text.push_str(&decoded);
}
flush_byte_buffer(&mut byte_buffer, &mut text);
if let Some(denormalizer) = &self.denormalizer {
denormalizer.normalize(&text)
} else {
Ok(text)
}
}
fn decode_sentence_piece(&self, token: &Token, is_bos_ws: bool) -> Result<(String, bool)> {
if self.model.is_control(token.id) {
return Ok((String::new(), false));
}
if self.model.is_unknown(token.id) {
let unk_piece = self.model.id_to_piece(token.id)?;
if unk_piece == token.piece {
return Ok((self.model.unk_surface().to_owned(), false));
}
return Ok((token.piece.clone(), false));
}
let mut piece = token.piece.as_str();
let mut has_bos_ws = false;
if is_bos_ws
&& (self.normalizer.add_dummy_prefix() || self.normalizer.remove_extra_whitespaces())
&& piece.starts_with(SPACE_SYMBOL)
{
piece = &piece[SPACE_SYMBOL.len()..];
has_bos_ws = true;
if self.normalizer.remove_extra_whitespaces() {
has_bos_ws = false;
}
}
Ok((replace_space_symbol(piece), has_bos_ws))
}
fn parse_extra_options(&self, options: &str) -> Result<Vec<ExtraOption>> {
if options.is_empty() {
return Ok(Vec::new());
}
options
.split(':')
.map(|option| match option {
"bos" => {
self.ensure_special_defined(self.model.bos_piece())?;
Ok(ExtraOption::Bos)
}
"eos" => {
self.ensure_special_defined(self.model.eos_piece())?;
Ok(ExtraOption::Eos)
}
"reverse" => Ok(ExtraOption::Reverse),
"unk" | "unk_piece" => Ok(ExtraOption::UnkPiece),
other => Err(Error::invalid_input(format!(
"extra option {other:?} is not available"
))),
})
.collect()
}
fn ensure_special_defined(&self, piece: &str) -> Result<()> {
let id = self.model.piece_to_id(piece);
if self.model.is_unknown(id) {
Err(Error::invalid_input(format!(
"id for special piece {piece:?} is not defined"
)))
} else {
Ok(())
}
}
fn apply_extra_options(&self, options: &[ExtraOption], tokens: &mut Vec<Token>) {
for option in options {
match option {
ExtraOption::Reverse => tokens.reverse(),
ExtraOption::Eos => {
let id = self.model.piece_to_id(self.model.eos_piece());
tokens.push(Token {
piece: self.model.eos_piece().to_owned(),
id,
});
}
ExtraOption::Bos => {
let id = self.model.piece_to_id(self.model.bos_piece());
tokens.insert(
0,
Token {
piece: self.model.bos_piece().to_owned(),
id,
},
);
}
ExtraOption::UnkPiece => {
for token in tokens.iter_mut() {
if self.model.is_unknown(token.id) {
token.piece = self.model.unk_piece().to_owned();
}
}
}
}
}
}
}
fn flush_byte_buffer(byte_buffer: &mut Vec<u8>, text: &mut String) {
if byte_buffer.is_empty() {
return;
}
match std::str::from_utf8(byte_buffer) {
Ok(valid) => text.push_str(valid),
Err(_) => text.push_str(&String::from_utf8_lossy(byte_buffer)),
}
byte_buffer.clear();
}