use crate::tokenization::hf_tokenizers::tokenizer::{Decoder, Result};
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Clone, Debug, Serialize)]
#[serde(tag = "type")]
#[non_exhaustive]
pub struct WordPiece {
pub prefix: String,
pub cleanup: bool,
}
impl WordPiece {
pub fn new(prefix: String, cleanup: bool) -> Self {
Self { prefix, cleanup }
}
}
impl Default for WordPiece {
fn default() -> Self {
Self {
prefix: String::from("##"),
cleanup: true,
}
}
}
impl Decoder for WordPiece {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
let mut output = tokens.join(" ").replace(&format!(" {}", self.prefix), "");
if self.cleanup {
output = output
.replace(" .", ".")
.replace(" ?", "?")
.replace(" !", "!")
.replace(" ,", ",")
.replace(" ' ", "'")
.replace(" n't", "n't")
.replace(" 'm", "'m")
.replace(" do not", " don't")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re");
}
Ok(output)
}
}