1use std::collections::{HashMap, HashSet};
11
12use crate::byte_encoder::{decode_byte_level_token, METASPACE};
13use crate::map::TokenizerMap;
14
15#[derive(Debug, Clone, Copy, Default)]
17pub struct DetokenizeOptions {
18 pub partial: bool,
22 pub render_special: bool,
24}
25
26pub struct Detokenizer {
32 special_ids: HashSet<u32>,
33 fallback_start: i64,
34 fallback_end: i64,
35 id_to_bytes: Option<HashMap<u32, Vec<u8>>>,
37 id_to_text: Option<HashMap<u32, String>>,
39 byte_buffer: Vec<u8>,
40}
41
42impl Detokenizer {
43 pub fn new(map: &TokenizerMap) -> Self {
45 let special_ids: HashSet<u32> = map
46 .special_tokens
47 .as_ref()
48 .map(|s| s.values().copied().collect())
49 .unwrap_or_default();
50 let fallback_start = map.byte_fallback_start.unwrap_or(-1);
51 let fallback_end = map.byte_fallback_end.unwrap_or(-2);
52
53 let (id_to_bytes, id_to_text) = if map.encoder.as_deref() == Some("byte_level") {
54 (Some(build_byte_level_table(map)), None)
55 } else {
56 (None, Some(build_text_table(map)))
57 };
58
59 Self {
60 special_ids,
61 fallback_start,
62 fallback_end,
63 id_to_bytes,
64 id_to_text,
65 byte_buffer: Vec::new(),
66 }
67 }
68
69 pub fn render(&mut self, ids: &[u32], options: DetokenizeOptions) -> String {
71 let mut out = String::new();
72 let render_special = options.render_special;
73
74 for &id in ids {
75 let id_i = id as i64;
77 if id_i >= self.fallback_start && id_i <= self.fallback_end {
78 let b = (id_i - self.fallback_start) as u8;
79 self.byte_buffer.push(b);
80 self.flush_all_bytes(&mut out);
81 continue;
82 }
83
84 if let Some(map_bytes) = &self.id_to_bytes {
85 if self.special_ids.contains(&id) && !render_special {
87 if !self.byte_buffer.is_empty() {
88 self.flush_bytes_force(&mut out);
89 }
90 continue;
91 }
92 match map_bytes.get(&id) {
93 None => {
94 if !self.byte_buffer.is_empty() {
95 self.flush_bytes_force(&mut out);
96 }
97 out.push('\u{FFFD}');
98 }
99 Some(bytes) => {
100 self.byte_buffer.extend_from_slice(bytes);
101 self.flush_all_bytes(&mut out);
102 }
103 }
104 continue;
105 }
106
107 if !self.byte_buffer.is_empty() {
109 self.flush_bytes_force(&mut out);
110 }
111 if self.special_ids.contains(&id) && !render_special {
112 continue;
113 }
114 match self.id_to_text.as_ref().and_then(|m| m.get(&id)) {
115 Some(text) => out.push_str(text),
116 None => out.push('\u{FFFD}'),
117 }
118 }
119
120 if !options.partial && !self.byte_buffer.is_empty() {
121 self.flush_bytes_force(&mut out);
122 }
123 out
124 }
125
126 pub fn reset(&mut self) {
128 self.byte_buffer.clear();
129 }
130
131 pub fn detokenize(map: &TokenizerMap, ids: &[u32], render_special: bool) -> String {
134 let mut d = Self::new(map);
135 d.render(ids, DetokenizeOptions { partial: false, render_special })
136 }
137
138 fn flush_all_bytes(&mut self, out: &mut String) {
141 loop {
142 if self.byte_buffer.is_empty() {
143 return;
144 }
145 let needed = utf8_sequence_length(self.byte_buffer[0]);
146 if needed == 0 {
147 self.byte_buffer.remove(0);
148 out.push('\u{FFFD}');
149 continue;
150 }
151 if self.byte_buffer.len() < needed {
152 return;
153 }
154 let slice: Vec<u8> = self.byte_buffer.drain(..needed).collect();
155 match std::str::from_utf8(&slice) {
156 Ok(s) => out.push_str(s),
157 Err(_) => out.push('\u{FFFD}'),
158 }
159 }
160 }
161
162 fn flush_bytes_force(&mut self, out: &mut String) {
163 if self.byte_buffer.is_empty() {
164 return;
165 }
166 let bytes = std::mem::take(&mut self.byte_buffer);
167 out.push_str(&String::from_utf8_lossy(&bytes));
169 }
170}
171
172fn utf8_sequence_length(b: u8) -> usize {
173 if b & 0x80 == 0x00 {
174 1
175 } else if b & 0xE0 == 0xC0 {
176 2
177 } else if b & 0xF0 == 0xE0 {
178 3
179 } else if b & 0xF8 == 0xF0 {
180 4
181 } else {
182 0
183 }
184}
185
186fn build_byte_level_table(map: &TokenizerMap) -> HashMap<u32, Vec<u8>> {
187 let mut result = HashMap::new();
188 if let Some(vocab) = &map.vocab {
189 result.reserve(vocab.len());
190 for (token, &id) in vocab {
191 result.insert(id, decode_byte_level_token(token));
192 }
193 }
194 result
195}
196
197fn build_text_table(map: &TokenizerMap) -> HashMap<u32, String> {
198 let mut result: HashMap<u32, String> = HashMap::new();
199 let is_metaspace = map.encoder.as_deref() == Some("metaspace");
200
201 if let Some(vocab) = &map.vocab {
202 for (token, &id) in vocab {
203 if is_byte_fallback_token(token) {
206 continue;
207 }
208 let text = if is_metaspace {
209 token.replace(METASPACE, " ")
210 } else {
211 token.clone()
212 };
213 result.insert(id, text);
214 }
215 }
216 if let Some(tokens) = &map.tokens {
217 for (id_str, text) in tokens {
218 if let Ok(id) = id_str.parse::<u32>() {
219 result.insert(id, text.clone());
220 }
221 }
222 }
223 result
224}
225
226fn is_byte_fallback_token(s: &str) -> bool {
227 let bytes = s.as_bytes();
228 if bytes.len() != 6 {
229 return false;
230 }
231 if bytes[0] != b'<' || bytes[1] != b'0' || bytes[2] != b'x' || bytes[5] != b'>' {
232 return false;
233 }
234 is_hex_byte(bytes[3]) && is_hex_byte(bytes[4])
235}
236
237fn is_hex_byte(b: u8) -> bool {
238 b.is_ascii_hexdigit()
239}