Skip to main content

codec_rs/
detokenize.rs

1// SPDX-License-Identifier: MIT
2//! Stateful detokenizer: token IDs → text.
3//!
4//! Three correctness concerns it handles:
5//!
6//! 1. Per-token decoding via the map's encoder (byte_level / metaspace / identity).
7//! 2. Byte-fallback range — IDs in `[byte_fallback_start, byte_fallback_end]` are decoded as raw bytes.
8//! 3. Partial multi-byte sequences across frame boundaries — buffered between calls when `partial: true`.
9
10use std::collections::{HashMap, HashSet};
11
12use crate::byte_encoder::{decode_byte_level_token, METASPACE};
13use crate::map::TokenizerMap;
14
15/// Options for [`Detokenizer::render`].
16#[derive(Debug, Clone, Copy, Default)]
17pub struct DetokenizeOptions {
18    /// If `true`, this is not the final chunk — buffer any trailing
19    /// partial UTF-8 sequence rather than emitting replacement
20    /// characters. Set to `false` on the last chunk so the buffer flushes.
21    pub partial: bool,
22    /// If `true`, render special tokens (e.g. `<|eos|>`) as text. Default: false.
23    pub render_special: bool,
24}
25
26/// Stateful detokenizer.
27///
28/// Construct with a [`TokenizerMap`]; call [`Detokenizer::render`]
29/// repeatedly with chunks of IDs. State (the partial UTF-8 byte buffer)
30/// persists across calls; [`Detokenizer::reset`] clears it.
31pub struct Detokenizer {
32    special_ids: HashSet<u32>,
33    fallback_start: i64,
34    fallback_end: i64,
35    /// `Some` when encoder == "byte_level".
36    id_to_bytes: Option<HashMap<u32, Vec<u8>>>,
37    /// `Some` for metaspace + identity.
38    id_to_text: Option<HashMap<u32, String>>,
39    byte_buffer: Vec<u8>,
40}
41
42impl Detokenizer {
43    /// Build a detokenizer from a map.
44    pub fn new(map: &TokenizerMap) -> Self {
45        let special_ids: HashSet<u32> = map
46            .special_tokens
47            .as_ref()
48            .map(|s| s.values().copied().collect())
49            .unwrap_or_default();
50        let fallback_start = map.byte_fallback_start.unwrap_or(-1);
51        let fallback_end = map.byte_fallback_end.unwrap_or(-2);
52
53        let (id_to_bytes, id_to_text) = if map.encoder.as_deref() == Some("byte_level") {
54            (Some(build_byte_level_table(map)), None)
55        } else {
56            (None, Some(build_text_table(map)))
57        };
58
59        Self {
60            special_ids,
61            fallback_start,
62            fallback_end,
63            id_to_bytes,
64            id_to_text,
65            byte_buffer: Vec::new(),
66        }
67    }
68
69    /// Render a chunk of IDs to text. Stateful across calls.
70    pub fn render(&mut self, ids: &[u32], options: DetokenizeOptions) -> String {
71        let mut out = String::new();
72        let render_special = options.render_special;
73
74        for &id in ids {
75            // Byte-fallback range: SentencePiece reserves IDs for raw bytes 0x00-0xFF.
76            let id_i = id as i64;
77            if id_i >= self.fallback_start && id_i <= self.fallback_end {
78                let b = (id_i - self.fallback_start) as u8;
79                self.byte_buffer.push(b);
80                self.flush_all_bytes(&mut out);
81                continue;
82            }
83
84            if let Some(map_bytes) = &self.id_to_bytes {
85                // byte_level: every vocab token IS a byte sequence.
86                if self.special_ids.contains(&id) && !render_special {
87                    if !self.byte_buffer.is_empty() {
88                        self.flush_bytes_force(&mut out);
89                    }
90                    continue;
91                }
92                match map_bytes.get(&id) {
93                    None => {
94                        if !self.byte_buffer.is_empty() {
95                            self.flush_bytes_force(&mut out);
96                        }
97                        out.push('\u{FFFD}');
98                    }
99                    Some(bytes) => {
100                        self.byte_buffer.extend_from_slice(bytes);
101                        self.flush_all_bytes(&mut out);
102                    }
103                }
104                continue;
105            }
106
107            // metaspace / identity: token text is rendered directly.
108            if !self.byte_buffer.is_empty() {
109                self.flush_bytes_force(&mut out);
110            }
111            if self.special_ids.contains(&id) && !render_special {
112                continue;
113            }
114            match self.id_to_text.as_ref().and_then(|m| m.get(&id)) {
115                Some(text) => out.push_str(text),
116                None => out.push('\u{FFFD}'),
117            }
118        }
119
120        if !options.partial && !self.byte_buffer.is_empty() {
121            self.flush_bytes_force(&mut out);
122        }
123        out
124    }
125
126    /// Reset internal state — call between conversations / requests.
127    pub fn reset(&mut self) {
128        self.byte_buffer.clear();
129    }
130
131    /// Convenience: detokenize a complete sequence in one shot. Uses a
132    /// fresh detokenizer; partial buffering not exposed.
133    pub fn detokenize(map: &TokenizerMap, ids: &[u32], render_special: bool) -> String {
134        let mut d = Self::new(map);
135        d.render(ids, DetokenizeOptions { partial: false, render_special })
136    }
137
138    // ── Internals ──────────────────────────────────────────────────────────
139
140    fn flush_all_bytes(&mut self, out: &mut String) {
141        loop {
142            if self.byte_buffer.is_empty() {
143                return;
144            }
145            let needed = utf8_sequence_length(self.byte_buffer[0]);
146            if needed == 0 {
147                self.byte_buffer.remove(0);
148                out.push('\u{FFFD}');
149                continue;
150            }
151            if self.byte_buffer.len() < needed {
152                return;
153            }
154            let slice: Vec<u8> = self.byte_buffer.drain(..needed).collect();
155            match std::str::from_utf8(&slice) {
156                Ok(s) => out.push_str(s),
157                Err(_) => out.push('\u{FFFD}'),
158            }
159        }
160    }
161
162    fn flush_bytes_force(&mut self, out: &mut String) {
163        if self.byte_buffer.is_empty() {
164            return;
165        }
166        let bytes = std::mem::take(&mut self.byte_buffer);
167        // Lossy decode matches .NET's `Encoding.UTF8.GetString` (replacement char on invalid).
168        out.push_str(&String::from_utf8_lossy(&bytes));
169    }
170}
171
172fn utf8_sequence_length(b: u8) -> usize {
173    if b & 0x80 == 0x00 {
174        1
175    } else if b & 0xE0 == 0xC0 {
176        2
177    } else if b & 0xF0 == 0xE0 {
178        3
179    } else if b & 0xF8 == 0xF0 {
180        4
181    } else {
182        0
183    }
184}
185
186fn build_byte_level_table(map: &TokenizerMap) -> HashMap<u32, Vec<u8>> {
187    let mut result = HashMap::new();
188    if let Some(vocab) = &map.vocab {
189        result.reserve(vocab.len());
190        for (token, &id) in vocab {
191            result.insert(id, decode_byte_level_token(token));
192        }
193    }
194    result
195}
196
197fn build_text_table(map: &TokenizerMap) -> HashMap<u32, String> {
198    let mut result: HashMap<u32, String> = HashMap::new();
199    let is_metaspace = map.encoder.as_deref() == Some("metaspace");
200
201    if let Some(vocab) = &map.vocab {
202        for (token, &id) in vocab {
203            // SentencePiece byte-fallback tokens (<0xHH>) live in vocab
204            // but are handled by the byte_fallback range path.
205            if is_byte_fallback_token(token) {
206                continue;
207            }
208            let text = if is_metaspace {
209                token.replace(METASPACE, " ")
210            } else {
211                token.clone()
212            };
213            result.insert(id, text);
214        }
215    }
216    if let Some(tokens) = &map.tokens {
217        for (id_str, text) in tokens {
218            if let Ok(id) = id_str.parse::<u32>() {
219                result.insert(id, text.clone());
220            }
221        }
222    }
223    result
224}
225
226fn is_byte_fallback_token(s: &str) -> bool {
227    let bytes = s.as_bytes();
228    if bytes.len() != 6 {
229        return false;
230    }
231    if bytes[0] != b'<' || bytes[1] != b'0' || bytes[2] != b'x' || bytes[5] != b'>' {
232        return false;
233    }
234    is_hex_byte(bytes[3]) && is_hex_byte(bytes[4])
235}
236
237fn is_hex_byte(b: u8) -> bool {
238    b.is_ascii_hexdigit()
239}