pub trait StreamingDetokenizer {
fn reset(&mut self);
fn add_token(&mut self, token: u32);
fn finalize(&mut self);
fn text(&self) -> std::borrow::Cow<'_, str>;
fn tokens(&self) -> &[u32];
fn offset(&self) -> usize;
fn set_offset(&mut self, offset: usize);
fn last_segment(&mut self) -> String {
let text = self.text();
let s: &str = text.as_ref();
let len = s.len();
let off = self.offset().min(len);
let mut start = off;
while start < len && !s.is_char_boundary(start) {
start += 1;
}
let end = len;
let seg = s[start..].to_owned();
drop(text);
self.set_offset(end);
seg
}
}
pub struct NaiveStreamingDetokenizer<F> {
decode: F,
clean_up_spaces: bool,
tokens: Vec<u32>,
offset: usize,
text: String,
current_tokens: Vec<u32>,
current_text: String,
}
impl<F> NaiveStreamingDetokenizer<F>
where
F: Fn(&[u32]) -> String,
{
pub fn new(decode: F, clean_up_spaces: bool) -> Self {
let mut s = Self {
decode,
clean_up_spaces,
tokens: Vec::new(),
offset: 0,
text: String::new(),
current_tokens: Vec::new(),
current_text: String::new(),
};
s.reset();
s
}
fn recompute_text(&mut self) {
if !self.current_tokens.is_empty() {
let mut ct = (self.decode)(&self.current_tokens);
let ends_replacement = ct.ends_with('\u{fffd}');
let trailing_space = self.clean_up_spaces && !ct.is_empty() && ct.ends_with(' ');
if ends_replacement || trailing_space {
ct.pop();
}
self.current_text = ct;
}
if self.current_text.ends_with('\n') {
self.text.push_str(&self.current_text);
self.current_tokens.clear();
self.current_text.clear();
}
}
}
impl<F> StreamingDetokenizer for NaiveStreamingDetokenizer<F>
where
F: Fn(&[u32]) -> String,
{
fn reset(&mut self) {
self.offset = 0;
self.tokens.clear();
self.text.clear();
self.current_tokens.clear();
self.current_text.clear();
}
fn add_token(&mut self, token: u32) {
self.current_tokens.push(token);
self.tokens.push(token);
self.recompute_text();
}
fn finalize(&mut self) {
let decoded = (self.decode)(&self.current_tokens);
self.text.push_str(&decoded);
self.current_tokens.clear();
self.current_text.clear();
}
fn text(&self) -> std::borrow::Cow<'_, str> {
if self.current_text.is_empty() {
std::borrow::Cow::Borrowed(&self.text)
} else {
let mut s = String::with_capacity(self.text.len() + self.current_text.len());
s.push_str(&self.text);
s.push_str(&self.current_text);
std::borrow::Cow::Owned(s)
}
}
fn tokens(&self) -> &[u32] {
&self.tokens
}
fn offset(&self) -> usize {
self.offset
}
fn set_offset(&mut self, offset: usize) {
self.offset = offset;
}
}
impl<F> NaiveStreamingDetokenizer<F>
where
F: Fn(&[u32]) -> String,
{
pub fn combined_text(&self) -> String {
use crate::tokenizer::stream::StreamingDetokenizer;
self.text().into_owned()
}
}
#[cfg(feature = "tokenizer-gpt2")]
mod byte_decoder {
use crate::tokenizer::generated::BYTE_DECODER;
#[cfg(test)]
pub(super) const TABLE: &[(char, u8)] = BYTE_DECODER;
#[inline]
pub(super) fn decode_char(c: char) -> Option<u8> {
BYTE_DECODER
.binary_search_by(|&(k, _)| k.cmp(&c))
.ok()
.map(|i| BYTE_DECODER[i].1)
}
}
#[cfg(feature = "tokenizer-spm")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-spm")))]
pub struct SpmStreamingDetokenizer {
trim_space: bool,
tokenmap: std::collections::HashMap<u32, Vec<u8>>,
tokens: Vec<u32>,
offset: usize,
text: String,
unflushed: Vec<u8>,
}
#[cfg(feature = "tokenizer-spm")]
impl SpmStreamingDetokenizer {
pub fn new<I, S>(vocab: I, trim_space: bool) -> Self
where
I: IntoIterator<Item = (S, u32)>,
S: AsRef<str>,
{
let iter = vocab.into_iter();
let mut tokenmap: std::collections::HashMap<u32, Vec<u8>> =
std::collections::HashMap::with_capacity(iter.size_hint().0);
for (value, id) in iter {
let value = value.as_ref();
let vb = value.as_bytes();
let bytes = if vb.len() >= 5 && &vb[..3] == b"<0x" {
std::str::from_utf8(&vb[3..5])
.ok()
.and_then(|h| u8::from_str_radix(h, 16).ok())
.map(|b| vec![b])
.unwrap_or_else(|| vb.to_vec())
} else {
vb.to_vec()
};
tokenmap.insert(id, bytes);
}
let mut s = Self {
trim_space,
tokenmap,
tokens: Vec::new(),
offset: 0,
text: String::new(),
unflushed: Vec::new(),
};
s.reset();
s
}
fn try_flush(&mut self, force: bool) {
let mut replaced: Vec<u8> = Vec::with_capacity(self.unflushed.len());
let sep = "\u{2581}".as_bytes();
let mut i = 0;
while i < self.unflushed.len() {
if self.unflushed[i..].starts_with(sep) {
replaced.push(b' ');
i += sep.len();
} else {
replaced.push(self.unflushed[i]);
i += 1;
}
}
let text = String::from_utf8_lossy(&replaced).into_owned();
if !force && text.ends_with('\u{fffd}') {
return;
}
let text = if self.text.is_empty() && self.trim_space && text.starts_with(' ') {
text[1..].to_owned()
} else {
text
};
self.text.push_str(&text);
self.unflushed.clear();
}
}
#[cfg(feature = "tokenizer-spm")]
impl StreamingDetokenizer for SpmStreamingDetokenizer {
fn reset(&mut self) {
self.offset = 0;
self.unflushed.clear();
self.text.clear();
self.tokens.clear();
}
fn add_token(&mut self, token: u32) {
self.tokens.push(token);
if let Some(v) = self.tokenmap.get(&token) {
self.unflushed.extend_from_slice(v);
}
self.try_flush(false);
}
fn finalize(&mut self) {
self.try_flush(true);
self.unflushed.clear();
}
fn text(&self) -> std::borrow::Cow<'_, str> {
std::borrow::Cow::Borrowed(&self.text)
}
fn tokens(&self) -> &[u32] {
&self.tokens
}
fn offset(&self) -> usize {
self.offset
}
fn set_offset(&mut self, offset: usize) {
self.offset = offset;
}
}
#[cfg(feature = "tokenizer-bpe")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-bpe")))]
pub struct BpeStreamingDetokenizer {
clean_spaces: bool,
tokenmap: std::collections::HashMap<u32, String>,
max_id: u32,
tokens: Vec<u32>,
offset: usize,
text: String,
unflushed: String,
}
#[cfg(feature = "tokenizer-bpe")]
use crate::tokenizer::generated::BPE_SPACE_MATCHES;
#[cfg(feature = "tokenizer-bpe")]
impl BpeStreamingDetokenizer {
pub fn new<I, S>(vocab: I, clean_spaces: bool) -> Self
where
I: IntoIterator<Item = (S, u32)>,
S: AsRef<str>,
{
let iter = vocab.into_iter();
let mut tokenmap: std::collections::HashMap<u32, String> =
std::collections::HashMap::with_capacity(iter.size_hint().0);
let mut max_id: u32 = 0;
for (value, id) in iter {
max_id = max_id.max(id);
tokenmap.insert(id, value.as_ref().to_owned());
}
let mut s = Self {
clean_spaces,
tokenmap,
max_id,
tokens: Vec::new(),
offset: 0,
text: String::new(),
unflushed: String::new(),
};
s.reset();
s
}
fn decode_bytes(&self, seq: &str) -> String {
let mut barr: Vec<u8> = Vec::with_capacity(seq.len());
for c in seq.chars() {
match byte_decoder::decode_char(c) {
Some(b) => barr.push(b),
None => {
let mut buf = [0u8; 4];
barr.extend_from_slice(c.encode_utf8(&mut buf).as_bytes());
}
}
}
String::from_utf8_lossy(&barr).into_owned()
}
fn maybe_trim_space(&self, current_text: &str) -> String {
if current_text.is_empty() {
return current_text.to_owned();
}
if !current_text.starts_with(' ') {
return current_text.to_owned();
}
if self.text.is_empty() {
return current_text[1..].to_owned();
}
if self.clean_spaces {
let rest = ¤t_text[1..];
if BPE_SPACE_MATCHES.iter().any(|m| rest.starts_with(m)) {
return rest.to_owned();
}
}
current_text.to_owned()
}
}
#[cfg(feature = "tokenizer-bpe")]
impl StreamingDetokenizer for BpeStreamingDetokenizer {
fn reset(&mut self) {
self.offset = 0;
self.unflushed.clear();
self.text.clear();
self.tokens.clear();
}
fn add_token(&mut self, token: u32) {
self.tokens.push(token);
let v: &str = match self.tokenmap.get(&token) {
Some(s) => s.as_str(),
None if token <= self.max_id => "",
None => "!",
};
self.unflushed.push_str(v);
let text = self.decode_bytes(&self.unflushed);
let single_space =
v.chars().count() == 1 && v.chars().next().and_then(byte_decoder::decode_char) == Some(32);
if !text.ends_with('\u{fffd}') && !single_space {
let trimmed = self.maybe_trim_space(&text);
self.text.push_str(&trimmed);
self.unflushed.clear();
}
}
fn finalize(&mut self) {
let mut barr: Vec<u8> = Vec::new();
for c in self.unflushed.chars() {
if let Some(b) = byte_decoder::decode_char(c) {
barr.push(b);
}
}
let current_text = String::from_utf8_lossy(&barr).into_owned();
let trimmed = self.maybe_trim_space(¤t_text);
self.text.push_str(&trimmed);
self.unflushed.clear();
}
fn text(&self) -> std::borrow::Cow<'_, str> {
std::borrow::Cow::Borrowed(&self.text)
}
fn tokens(&self) -> &[u32] {
&self.tokens
}
fn offset(&self) -> usize {
self.offset
}
fn set_offset(&mut self, offset: usize) {
self.offset = offset;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::Display, derive_more::IsVariant)]
#[display("{}", self.as_str())]
pub enum DetokenizerClass {
Naive,
Spm,
SpmNoSpace,
Bpe,
}
impl DetokenizerClass {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Naive => "naive",
Self::Spm => "spm",
Self::SpmNoSpace => "spm_no_space",
Self::Bpe => "bpe",
}
}
}
#[cfg(any(feature = "tokenizer-spm", feature = "tokenizer-bpe"))]
fn json_eq(a: &serde_json::Value, b: &serde_json::Value) -> bool {
use serde_json::Value;
match (a, b) {
(Value::Object(x), Value::Object(y)) => {
x.len() == y.len()
&& x
.iter()
.all(|(k, v)| y.get(k).is_some_and(|w| json_eq(v, w)))
}
(Value::Array(x), Value::Array(y)) => {
x.len() == y.len() && x.iter().zip(y.iter()).all(|(p, q)| json_eq(p, q))
}
_ => a == b,
}
}
#[cfg(any(feature = "tokenizer-spm", feature = "tokenizer-bpe"))]
fn spm_decoder_target(with_strip: bool) -> serde_json::Value {
let mut decoders = vec![
serde_json::json!({"type": "Replace", "pattern": {"String": "▁"}, "content": " "}),
serde_json::json!({"type": "ByteFallback"}),
serde_json::json!({"type": "Fuse"}),
];
if with_strip {
decoders.push(serde_json::json!({"type": "Strip", "content": " ", "start": 1, "stop": 0}));
}
serde_json::json!({"type": "Sequence", "decoders": decoders})
}
#[cfg(any(feature = "tokenizer-spm", feature = "tokenizer-bpe"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "tokenizer-spm", feature = "tokenizer-bpe")))
)]
pub fn infer_detokenizer_class(decoder: Option<&serde_json::Value>) -> DetokenizerClass {
let Some(decoder) = decoder else {
return DetokenizerClass::Naive;
};
if json_eq(&spm_decoder_target(true), decoder) {
DetokenizerClass::Spm
} else if json_eq(&spm_decoder_target(false), decoder) {
DetokenizerClass::SpmNoSpace
} else if decoder.get("type").and_then(|t| t.as_str()) == Some("ByteLevel") {
DetokenizerClass::Bpe
} else {
DetokenizerClass::Naive
}
}
#[cfg(feature = "tokenizer-stream")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-stream")))]
pub struct NaiveHfDetokenizer {
hf: tokenizers::Tokenizer,
clean_up_spaces: bool,
tokens: Vec<u32>,
offset: usize,
text: String,
current_tokens: Vec<u32>,
current_text: String,
}
#[cfg(feature = "tokenizer-stream")]
impl NaiveHfDetokenizer {
pub fn new(hf: tokenizers::Tokenizer, clean_up_spaces: bool) -> Self {
Self {
hf,
clean_up_spaces,
tokens: Vec::new(),
offset: 0,
text: String::new(),
current_tokens: Vec::new(),
current_text: String::new(),
}
}
fn decode(&self, ids: &[u32]) -> String {
self.hf.decode(ids, false).unwrap_or_default()
}
fn recompute_text(&mut self) {
if !self.current_tokens.is_empty() {
let mut ct = self.decode(&self.current_tokens);
let ends_replacement = ct.ends_with('\u{fffd}');
let trailing_space = self.clean_up_spaces && !ct.is_empty() && ct.ends_with(' ');
if ends_replacement || trailing_space {
ct.pop();
}
self.current_text = ct;
}
if self.current_text.ends_with('\n') {
self.text.push_str(&self.current_text);
self.current_tokens.clear();
self.current_text.clear();
}
}
}
#[cfg(feature = "tokenizer-stream")]
impl StreamingDetokenizer for NaiveHfDetokenizer {
fn reset(&mut self) {
self.offset = 0;
self.tokens.clear();
self.text.clear();
self.current_tokens.clear();
self.current_text.clear();
}
fn add_token(&mut self, token: u32) {
self.current_tokens.push(token);
self.tokens.push(token);
self.recompute_text();
}
fn finalize(&mut self) {
let decoded = self.decode(&self.current_tokens);
self.text.push_str(&decoded);
self.current_tokens.clear();
self.current_text.clear();
}
fn text(&self) -> std::borrow::Cow<'_, str> {
if self.current_text.is_empty() {
std::borrow::Cow::Borrowed(&self.text)
} else {
let mut s = String::with_capacity(self.text.len() + self.current_text.len());
s.push_str(&self.text);
s.push_str(&self.current_text);
std::borrow::Cow::Owned(s)
}
}
fn tokens(&self) -> &[u32] {
&self.tokens
}
fn offset(&self) -> usize {
self.offset
}
fn set_offset(&mut self, offset: usize) {
self.offset = offset;
}
}
#[cfg(feature = "tokenizer-stream")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokenizer-stream")))]
#[derive(derive_more::IsVariant, derive_more::Unwrap, derive_more::TryUnwrap)]
#[unwrap(ref, ref_mut)]
#[try_unwrap(ref, ref_mut)]
pub enum Detokenizer {
Naive(Box<NaiveHfDetokenizer>),
#[cfg(feature = "tokenizer-spm")]
Spm(SpmStreamingDetokenizer),
#[cfg(feature = "tokenizer-bpe")]
Bpe(BpeStreamingDetokenizer),
Custom(Box<dyn StreamingDetokenizer>),
}
#[cfg(feature = "tokenizer-stream")]
impl StreamingDetokenizer for Detokenizer {
fn reset(&mut self) {
match self {
Self::Naive(d) => d.reset(),
#[cfg(feature = "tokenizer-spm")]
Self::Spm(d) => d.reset(),
#[cfg(feature = "tokenizer-bpe")]
Self::Bpe(d) => d.reset(),
Self::Custom(d) => d.reset(),
}
}
fn add_token(&mut self, token: u32) {
match self {
Self::Naive(d) => d.add_token(token),
#[cfg(feature = "tokenizer-spm")]
Self::Spm(d) => d.add_token(token),
#[cfg(feature = "tokenizer-bpe")]
Self::Bpe(d) => d.add_token(token),
Self::Custom(d) => d.add_token(token),
}
}
fn finalize(&mut self) {
match self {
Self::Naive(d) => d.finalize(),
#[cfg(feature = "tokenizer-spm")]
Self::Spm(d) => d.finalize(),
#[cfg(feature = "tokenizer-bpe")]
Self::Bpe(d) => d.finalize(),
Self::Custom(d) => d.finalize(),
}
}
fn text(&self) -> std::borrow::Cow<'_, str> {
match self {
Self::Naive(d) => d.text(),
#[cfg(feature = "tokenizer-spm")]
Self::Spm(d) => d.text(),
#[cfg(feature = "tokenizer-bpe")]
Self::Bpe(d) => d.text(),
Self::Custom(d) => d.text(),
}
}
fn tokens(&self) -> &[u32] {
match self {
Self::Naive(d) => d.tokens(),
#[cfg(feature = "tokenizer-spm")]
Self::Spm(d) => d.tokens(),
#[cfg(feature = "tokenizer-bpe")]
Self::Bpe(d) => d.tokens(),
Self::Custom(d) => d.tokens(),
}
}
fn offset(&self) -> usize {
match self {
Self::Naive(d) => d.offset(),
#[cfg(feature = "tokenizer-spm")]
Self::Spm(d) => d.offset(),
#[cfg(feature = "tokenizer-bpe")]
Self::Bpe(d) => d.offset(),
Self::Custom(d) => d.offset(),
}
}
fn set_offset(&mut self, offset: usize) {
match self {
Self::Naive(d) => d.set_offset(offset),
#[cfg(feature = "tokenizer-spm")]
Self::Spm(d) => d.set_offset(offset),
#[cfg(feature = "tokenizer-bpe")]
Self::Bpe(d) => d.set_offset(offset),
Self::Custom(d) => d.set_offset(offset),
}
}
}
#[cfg(all(test, feature = "tokenizer-gpt2"))]
mod byte_decoder_tests {
fn legacy_make_byte_decoder() -> std::collections::HashMap<char, u8> {
let limits: [u32; 7] = [
0,
'!' as u32,
'~' as u32 + 1,
'¡' as u32,
'¬' as u32 + 1,
'®' as u32,
'ÿ' as u32 + 1,
];
let mut map = std::collections::HashMap::new();
let mut n: u32 = 0;
for (i, w) in limits.windows(2).enumerate() {
let (start, stop) = (w[0], w[1]);
if i % 2 == 0 {
for b in start..stop {
let c = char::from_u32(256 + n).unwrap();
map.insert(c, b as u8);
n += 1;
}
} else {
for b in start..stop {
let c = char::from_u32(b).unwrap();
map.insert(c, b as u8);
}
}
}
map
}
#[test]
fn generated_byte_decoder_matches_legacy_algorithm() {
let legacy = legacy_make_byte_decoder();
assert_eq!(super::byte_decoder::TABLE.len(), legacy.len());
for (&c, &b) in &legacy {
assert_eq!(
super::byte_decoder::decode_char(c),
Some(b),
"mismatch for char {c:?}"
);
}
assert!(
super::byte_decoder::TABLE
.windows(2)
.all(|w| w[0].0 < w[1].0)
);
for &(c, b) in super::byte_decoder::TABLE {
assert_eq!(legacy.get(&c), Some(&b));
}
}
}