use std::collections::{HashMap, HashSet};
use crate::byte_encoder::{decode_byte_level_token, METASPACE};
use crate::map::TokenizerMap;
#[derive(Debug, Clone, Copy, Default)]
pub struct DetokenizeOptions {
pub partial: bool,
pub render_special: bool,
}
pub struct Detokenizer {
special_ids: HashSet<u32>,
fallback_start: i64,
fallback_end: i64,
id_to_bytes: Option<HashMap<u32, Vec<u8>>>,
id_to_text: Option<HashMap<u32, String>>,
byte_buffer: Vec<u8>,
}
impl Detokenizer {
pub fn new(map: &TokenizerMap) -> Self {
let special_ids: HashSet<u32> = map
.special_tokens
.as_ref()
.map(|s| s.values().copied().collect())
.unwrap_or_default();
let fallback_start = map.byte_fallback_start.unwrap_or(-1);
let fallback_end = map.byte_fallback_end.unwrap_or(-2);
let (id_to_bytes, id_to_text) = if map.encoder.as_deref() == Some("byte_level") {
(Some(build_byte_level_table(map)), None)
} else {
(None, Some(build_text_table(map)))
};
Self {
special_ids,
fallback_start,
fallback_end,
id_to_bytes,
id_to_text,
byte_buffer: Vec::new(),
}
}
pub fn render(&mut self, ids: &[u32], options: DetokenizeOptions) -> String {
let mut out = String::new();
let render_special = options.render_special;
for &id in ids {
let id_i = id as i64;
if id_i >= self.fallback_start && id_i <= self.fallback_end {
let b = (id_i - self.fallback_start) as u8;
self.byte_buffer.push(b);
self.flush_all_bytes(&mut out);
continue;
}
if let Some(map_bytes) = &self.id_to_bytes {
if self.special_ids.contains(&id) && !render_special {
if !self.byte_buffer.is_empty() {
self.flush_bytes_force(&mut out);
}
continue;
}
match map_bytes.get(&id) {
None => {
if !self.byte_buffer.is_empty() {
self.flush_bytes_force(&mut out);
}
out.push('\u{FFFD}');
}
Some(bytes) => {
self.byte_buffer.extend_from_slice(bytes);
self.flush_all_bytes(&mut out);
}
}
continue;
}
if !self.byte_buffer.is_empty() {
self.flush_bytes_force(&mut out);
}
if self.special_ids.contains(&id) && !render_special {
continue;
}
match self.id_to_text.as_ref().and_then(|m| m.get(&id)) {
Some(text) => out.push_str(text),
None => out.push('\u{FFFD}'),
}
}
if !options.partial && !self.byte_buffer.is_empty() {
self.flush_bytes_force(&mut out);
}
out
}
pub fn reset(&mut self) {
self.byte_buffer.clear();
}
pub fn detokenize(map: &TokenizerMap, ids: &[u32], render_special: bool) -> String {
let mut d = Self::new(map);
d.render(ids, DetokenizeOptions { partial: false, render_special })
}
fn flush_all_bytes(&mut self, out: &mut String) {
loop {
if self.byte_buffer.is_empty() {
return;
}
let needed = utf8_sequence_length(self.byte_buffer[0]);
if needed == 0 {
self.byte_buffer.remove(0);
out.push('\u{FFFD}');
continue;
}
if self.byte_buffer.len() < needed {
return;
}
let slice: Vec<u8> = self.byte_buffer.drain(..needed).collect();
match std::str::from_utf8(&slice) {
Ok(s) => out.push_str(s),
Err(_) => out.push('\u{FFFD}'),
}
}
}
fn flush_bytes_force(&mut self, out: &mut String) {
if self.byte_buffer.is_empty() {
return;
}
let bytes = std::mem::take(&mut self.byte_buffer);
out.push_str(&String::from_utf8_lossy(&bytes));
}
}
fn utf8_sequence_length(b: u8) -> usize {
if b & 0x80 == 0x00 {
1
} else if b & 0xE0 == 0xC0 {
2
} else if b & 0xF0 == 0xE0 {
3
} else if b & 0xF8 == 0xF0 {
4
} else {
0
}
}
fn build_byte_level_table(map: &TokenizerMap) -> HashMap<u32, Vec<u8>> {
let mut result = HashMap::new();
if let Some(vocab) = &map.vocab {
result.reserve(vocab.len());
for (token, &id) in vocab {
result.insert(id, decode_byte_level_token(token));
}
}
result
}
fn build_text_table(map: &TokenizerMap) -> HashMap<u32, String> {
let mut result: HashMap<u32, String> = HashMap::new();
let is_metaspace = map.encoder.as_deref() == Some("metaspace");
if let Some(vocab) = &map.vocab {
for (token, &id) in vocab {
if is_byte_fallback_token(token) {
continue;
}
let text = if is_metaspace {
token.replace(METASPACE, " ")
} else {
token.clone()
};
result.insert(id, text);
}
}
if let Some(tokens) = &map.tokens {
for (id_str, text) in tokens {
if let Ok(id) = id_str.parse::<u32>() {
result.insert(id, text.clone());
}
}
}
result
}
fn is_byte_fallback_token(s: &str) -> bool {
let bytes = s.as_bytes();
if bytes.len() != 6 {
return false;
}
if bytes[0] != b'<' || bytes[1] != b'0' || bytes[2] != b'x' || bytes[5] != b'>' {
return false;
}
is_hex_byte(bytes[3]) && is_hex_byte(bytes[4])
}
fn is_hex_byte(b: u8) -> bool {
b.is_ascii_hexdigit()
}