mod vocab;
pub use vocab::VocabDecoder;
use crate::encoder::EncoderType;
use crate::postprocessor::PostProcessor;
use crate::types::TokenId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(u32)]
pub enum DecoderType {
#[default]
ByteLevel = 0,
WordPiece = 1,
Metaspace = 2,
}
impl DecoderType {
pub fn from_encoder_type(encoder_type: EncoderType) -> Self {
match encoder_type {
EncoderType::WordPiece => DecoderType::WordPiece,
EncoderType::SentencePiece | EncoderType::Unigram => DecoderType::Metaspace,
_ => DecoderType::ByteLevel,
}
}
pub fn from_u32(v: u32) -> Option<Self> {
match v {
0 => Some(Self::ByteLevel),
1 => Some(Self::WordPiece),
2 => Some(Self::Metaspace),
_ => None,
}
}
}
#[derive(Clone)]
pub struct Decoder {
vocab: VocabDecoder,
decoder_type: DecoderType,
}
impl Decoder {
pub fn new(token_bytes: Vec<Vec<u8>>) -> Self {
Self {
vocab: VocabDecoder::new(token_bytes),
decoder_type: DecoderType::ByteLevel,
}
}
pub fn for_encoder(token_bytes: Vec<Vec<u8>>, encoder_type: EncoderType) -> Self {
Self {
vocab: VocabDecoder::new(token_bytes),
decoder_type: DecoderType::from_encoder_type(encoder_type),
}
}
pub fn with_type(vocab: VocabDecoder, decoder_type: DecoderType) -> Self {
Self { vocab, decoder_type }
}
pub fn decoder_type(&self) -> DecoderType {
self.decoder_type
}
pub fn vocab(&self) -> &VocabDecoder {
&self.vocab
}
pub fn into_vocab(self) -> VocabDecoder {
self.vocab
}
#[inline]
pub fn vocab_size(&self) -> usize {
self.vocab.vocab_size()
}
#[inline]
pub fn token_to_bytes(&self, token: TokenId) -> &[u8] {
self.vocab.token_to_bytes(token)
}
#[inline]
pub fn token_len(&self, token: TokenId) -> usize {
self.vocab.token_len(token)
}
pub fn decode_bytes(&self, tokens: &[TokenId]) -> Vec<u8> {
self.vocab.decode(tokens)
}
pub fn decode(&self, tokens: &[TokenId], post_processor: &PostProcessor) -> Option<String> {
match self.decoder_type {
DecoderType::ByteLevel => self.vocab.decode_to_string(tokens),
DecoderType::WordPiece => decode_wordpiece(tokens, &self.vocab, post_processor),
DecoderType::Metaspace => decode_metaspace(tokens, &self.vocab, post_processor),
}
}
pub fn decode_to_string(&self, tokens: &[TokenId]) -> Option<String> {
self.vocab.decode_to_string(tokens)
}
pub fn from_parts(data: Vec<u8>, offsets: Vec<u32>, decoder_type: DecoderType) -> Self {
Self {
vocab: VocabDecoder::from_parts(data, offsets),
decoder_type,
}
}
pub fn as_parts(&self) -> (&[u8], &[u32]) {
self.vocab.as_parts()
}
pub fn token_bytes(&self) -> Vec<Vec<u8>> {
self.vocab.token_bytes()
}
}
fn decode_wordpiece(
tokens: &[TokenId],
vocab: &VocabDecoder,
post_processor: &PostProcessor,
) -> Option<String> {
let mut result = String::new();
for &id in tokens {
if post_processor.is_special_token(id) {
continue;
}
let bytes = vocab.token_to_bytes(id);
let token_str = std::str::from_utf8(bytes).ok()?;
if let Some(stripped) = token_str.strip_prefix("##") {
result.push_str(stripped);
} else {
if !result.is_empty() {
result.push(' ');
}
result.push_str(token_str);
}
}
Some(result)
}
fn decode_metaspace(
tokens: &[TokenId],
vocab: &VocabDecoder,
post_processor: &PostProcessor,
) -> Option<String> {
let mut result = String::new();
for &id in tokens {
if post_processor.is_special_token(id) {
continue;
}
let bytes = vocab.token_to_bytes(id);
let token_str = std::str::from_utf8(bytes).ok()?;
result.push_str(&token_str.replace('\u{2581}', " "));
}
if result.starts_with(' ') {
result.remove(0);
}
Some(result)
}