dynamo_llm/
tokenizers.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4pub mod hf;
5
6// TODO: Add tokenizer benchmarks
7// TODO: Enable README.md as a module doc
8// #[doc = include_str!("../README.md")]
9
10use std::hash::{DefaultHasher, Hash, Hasher};
11use std::sync::Arc;
12use std::{ops::Deref, path::Path};
13
14use crate::protocols::TokenIdType;
15pub use anyhow::{Error, Result};
16
17pub use hf::HuggingFaceTokenizer;
18
19/// Represents the type of tokenizer being used
20#[derive(Debug)]
21pub enum TokenizerType {
22    HuggingFace(String),
23}
24
25/// character offsets in the original text
26pub type Offsets = (usize, usize);
27
28/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
29#[derive(Debug, Clone)]
30pub enum Encoding {
31    /// Hugging Face
32    Hf(Box<tokenizers::tokenizer::Encoding>),
33    /// Sentence Piece
34    Sp(Vec<TokenIdType>),
35}
36
37impl Encoding {
38    pub fn token_ids(&self) -> &[u32] {
39        match self {
40            Encoding::Hf(inner) => inner.get_ids(),
41            Encoding::Sp(inner) => inner,
42        }
43    }
44}
45
46impl Hash for Encoding {
47    fn hash<H: Hasher>(&self, state: &mut H) {
48        self.token_ids().hash(state);
49    }
50}
51
52pub mod traits {
53    use super::*;
54
55    pub trait Encoder: Send + Sync {
56        fn encode(&self, input: &str) -> Result<Encoding>;
57        fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>;
58    }
59
60    pub trait Decoder: Send + Sync {
61        fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
62    }
63
64    pub trait Tokenizer: Encoder + Decoder {
65        // fn get_vocab_size(&self) -> usize;
66        // fn make_unique_clone(&self) -> Box<dyn Tokenizer>;
67    }
68}
69
70impl Encoding {
71    pub fn get_hash(&self) -> u64 {
72        let mut hasher = DefaultHasher::new();
73        self.hash(&mut hasher);
74        hasher.finish()
75    }
76}
77
78/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
79#[derive(Clone)]
80pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
81
82impl Tokenizer {
83    pub fn from_file(file_path: &str) -> Result<Tokenizer> {
84        Ok(Tokenizer(create_tokenizer_from_file(file_path)?))
85    }
86
87    /// Create a stateful sequence object for decoding token_ids into text
88    pub fn decode_stream(
89        &self,
90        prompt_token_ids: &[TokenIdType],
91        skip_special_tokens: bool,
92    ) -> DecodeStream {
93        DecodeStream::new(self.0.clone(), prompt_token_ids, skip_special_tokens)
94    }
95}
96
97impl Deref for Tokenizer {
98    type Target = Arc<dyn traits::Tokenizer>;
99
100    fn deref(&self) -> &Self::Target {
101        &self.0
102    }
103}
104
105impl From<Arc<dyn traits::Tokenizer>> for Tokenizer {
106    fn from(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
107        Tokenizer(tokenizer)
108    }
109}
110
111impl<T> From<Arc<T>> for Tokenizer
112where
113    T: traits::Tokenizer + 'static, // 'static is required to ensure T can be safely put into an Arc
114{
115    fn from(tokenizer: Arc<T>) -> Self {
116        Tokenizer(tokenizer)
117    }
118}
119
120/// Create a tokenizer from a file path to a tokenizer file.
121/// The file extension is used to determine the tokenizer type.
122/// Supported file types are:
123/// - json: HuggingFace tokenizer
124pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
125    let path = Path::new(file_path);
126    let extension = path
127        .extension()
128        .and_then(std::ffi::OsStr::to_str)
129        .ok_or_else(|| Error::msg("Failed to read file extension".to_string()))?;
130
131    match extension {
132        "json" => {
133            let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
134            Ok(Arc::new(tokenizer))
135        }
136        _ => Err(Error::msg("Unsupported file type".to_string())),
137    }
138}
139
140// With incremental detokenization, we need to consider the final context tokens when handling the initial decode tokens.
141// This is the initial offset from the end of the context that we start decoding from.
142// Both Huggingface TGI and vLLM use this same value.
143// See: https://github.com/huggingface/text-generation-inference/blob/24c2bff65924801ddf90fa24fcc72752d4f45538/server/text_generation_server/models/mamba.py#L169
144// and https://github.com/vllm-project/vllm/blob/da2705198fa19030a25d0bea437f7be6547d47d4/vllm/transformers_utils/detokenizer_utils.py#L51
145const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
146
147/// DecodeStream will keep the state necessary to produce individual chunks of
148/// strings given an input stream of token_ids.
149///
150/// This is necessary because decoding in general cannot achieve that since strings
151/// depend on surrounding ids to provide a valid string. Typically stripping extra spaces.
152pub struct DecodeStream {
153    /// The tokenizer used to decode token_ids
154    tokenizer: Arc<dyn traits::Tokenizer>,
155
156    skip_special_tokens: bool,
157    /// A temporary buffer of the necessary token_ids needed
158    /// to produce valid string chunks.
159    /// This typically contains 3 parts:
160    ///  - read
161    ///  - prefix
162    ///  - rest
163    ///
164    /// Read is the bit necessary to surround the prefix
165    /// so decoding the whole ids produces a valid prefix.
166    /// Prefix is the previously produced string, kept around to trim off of
167    /// the next valid chunk
168    all_token_ids: Vec<u32>,
169
170    prefix_offset: usize,
171
172    read_offset: usize,
173}
174
175impl DecodeStream {
176    pub fn new(
177        tokenizer: Arc<dyn traits::Tokenizer>,
178        prompt_token_ids: &[TokenIdType],
179        skip_special_tokens: bool,
180    ) -> Self {
181        let num_input_tokens = prompt_token_ids.len();
182        let prompt_token_ids = prompt_token_ids.to_vec();
183        Self {
184            tokenizer,
185            skip_special_tokens,
186            all_token_ids: prompt_token_ids,
187            prefix_offset: num_input_tokens
188                .saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET),
189            read_offset: num_input_tokens,
190        }
191    }
192
193    /// Step appends a token_id to the internal state and tries to produce a text chunk.
194    ///
195    /// Implementation directly copied from Huggingface's TGI:
196    /// https://github.com/huggingface/text-generation-inference/blob/24c2bff65924801ddf90fa24fcc72752d4f45538/server/text_generation_server/models/model.py#L144
197    ///
198    /// Returning `None` means the given id is not enough to produce a chunk.
199    /// This typically happens with `byte_fallback` options where some tokens do not
200    /// represent valid UTF-8, and only follow-up token_ids will help produce
201    /// a valid chunk.
202    pub fn step(&mut self, id: u32) -> Result<Option<String>> {
203        self.all_token_ids.push(id);
204
205        let prefix_text = self.tokenizer.decode(
206            &self.all_token_ids[self.prefix_offset..self.read_offset],
207            self.skip_special_tokens,
208        )?;
209
210        let new_text = self.tokenizer.decode(
211            &self.all_token_ids[self.prefix_offset..],
212            self.skip_special_tokens,
213        )?;
214
215        if new_text.len() > prefix_text.len() && !new_text.ends_with("�") {
216            let new_text = new_text[prefix_text.len()..].to_string();
217
218            self.prefix_offset = self.read_offset;
219            self.read_offset = self.all_token_ids.len();
220
221            Ok(Some(new_text))
222        } else {
223            Ok(None)
224        }
225    }
226}
227
228/// Maintains state for an ongoing sequence of tokens and their decoded text
229pub struct Sequence {
230    /// Encodes text -> token_ids
231    tokenizer: Tokenizer,
232
233    /// The current sequence of token ids
234    token_ids: Vec<TokenIdType>,
235
236    /// The position in the current sequence the last decoded token completed
237    prefix_offset: usize,
238
239    /// Current position in the sequence
240    read_offset: usize,
241}
242
243impl std::fmt::Debug for Sequence {
244    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245        f.debug_struct("Sequence")
246            .field("tokenizer", &"Arc<dyn Tokenizer>")
247            .field(
248                "token_ids",
249                &format_args!("{}", {
250                    let token_ids = self.token_ids();
251                    if token_ids.len() <= 20 {
252                        format!("{:?}", token_ids)
253                    } else {
254                        let first_ten = &token_ids[..10];
255                        let last_ten = &token_ids[token_ids.len() - 10..];
256                        format!("{:?} ... {:?}", first_ten, last_ten)
257                    }
258                }),
259            )
260            .field("prefix_offset", &self.prefix_offset)
261            .field("read_offset", &self.read_offset)
262            .field("token count", &self.token_ids.len())
263            .finish()
264    }
265}
266
267impl Sequence {
268    pub fn new(tokenizer: Tokenizer) -> Self {
269        Self {
270            tokenizer,
271            token_ids: Vec::new(),
272            prefix_offset: 0,
273            read_offset: 0,
274        }
275    }
276
277    pub fn is_empty(&self) -> bool {
278        self.token_ids.is_empty()
279    }
280
281    pub fn len(&self) -> usize {
282        self.token_ids.len()
283    }
284
285    pub fn clear(&mut self) {
286        self.token_ids.clear();
287        self.prefix_offset = 0;
288        self.read_offset = 0;
289    }
290
291    pub fn append_text(&mut self, input: &str) -> Result<()> {
292        // let tokenizer = self.tokenizer.read().map_err(|err| {
293        //     Error::msg(format!("Failed to acquire read lock on tokenizer: {}", err))
294        // })?;
295
296        let encoding = self.tokenizer.encode(input)?;
297        self.token_ids.extend(encoding.token_ids());
298        Ok(())
299    }
300
301    // Based on
302    // https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
303    // under Apache 2.0 license
304    pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<String> {
305        self.token_ids.push(token_id);
306        // log::trace!("pushed token_id: {}", token_id);
307
308        let prefix_text = self
309            .tokenizer
310            .decode(&self.token_ids[self.prefix_offset..self.read_offset], false)?;
311
312        let new_text = self
313            .tokenizer
314            .decode(&self.token_ids[self.prefix_offset..], false)?;
315
316        // if the end character of the previous returned sequence is a multi-byte character
317        // then we can not split the text on that byte offset, so we roll back to the byte offset
318        // of the start of that character
319        let mut prefix_text_len = prefix_text.len();
320        while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
321            prefix_text_len -= 1;
322        }
323        let prefix_text_len = prefix_text_len;
324
325        if new_text.len() > prefix_text.len() {
326            if new_text.ends_with("�") {
327                return Ok("".to_string());
328            } else {
329                // shift and update the state
330                let new_text = new_text[prefix_text_len..].to_string().replace("�", "");
331                self.prefix_offset = self.read_offset;
332                self.read_offset = self.token_ids.len();
333                return Ok(new_text);
334            }
335        }
336
337        Ok("".to_string())
338    }
339
340    pub fn tokenizer(&self) -> Tokenizer {
341        self.tokenizer.clone()
342    }
343
344    pub fn token_ids(&self) -> &[TokenIdType] {
345        &self.token_ids
346    }
347
348    pub fn text(&self) -> Result<String> {
349        // let tokenizer = self.tokenizer.read().map_err(|err| {
350        //     Error::msg(format!("Failed to acquire read lock on tokenizer: {}", err))
351        // })?;
352        self.tokenizer.decode(&self.token_ids, false)
353    }
354}
355
356/// The output conditions/values of a SequenceDecoder::add_token_id operation.
357/// Result of decoding a token, indicating whether text was produced or a stop condition was met
358pub enum SequenceDecoderOutput {
359    /// The text for the appended token_id
360    Text(String),
361
362    /// A sequence of token_ids has been partially matched a stop sequence, so the text is held
363    /// until either a match or a divergence
364    Held,
365
366    /// Indicates that a stop sequence has been matched and the decoder is stopped.
367    /// Subsequent calls to append_token_id will return an error
368    Stopped,
369
370    /// Indicates that a stop token_id has been matched and the decoder is stopped.
371    /// Subsequent calls to append_token_id will return an error
372    /// The text for the stop token_id is returned
373    StoppedWithText(String),
374}
375
376/// A Sequence for decoding a stream of token ids into text and detecting stop sequences.
377/// A stop sequence is either a matching token_id or a sequence of texts/strings which match.
378/// Matches happen first at the token-level, then at the sequence-level. Hidden takes precedence
379/// over visible. For example, if you put the same token_id in both `stop_token_ids_visible` and
380/// `stop_token_ids_hidden`, the token_id will be treated as hidden.
381#[derive(Debug)]
382pub struct StopSequenceDecoder {
383    // The current sequence of token ids
384    sequence: Sequence,
385
386    // Stop Tokens - the presence of any one of these should trigger a stop
387    // If found, the text for the matched token will be returned
388    stop_token_ids_visible: Vec<TokenIdType>,
389
390    // Stop Tokens - the presence of any one of these should trigger a stop
391    // If found, the text for the matched token will NOT be returned
392    stop_token_ids_hidden: Vec<TokenIdType>,
393
394    // Stop Words - the presence of any one of these should trigger a stop
395    // If found, the text for the matched token will be returned
396    #[allow(dead_code)]
397    stop_sequences_visible: Vec<String>,
398
399    // Stop Words - the presence of any one of these should trigger a stop
400    // If found, the text for the matched token will NOT be returned
401    stop_sequences_hidden: Vec<String>,
402
403    // If the decoder has observed and returned a stop SequenceDecoderOutput,
404    // futhur calls to append_token_id will return an error
405    stopped: bool,
406
407    // text jail - if a partial stop sequence is being observed, we hold/jail the text
408    // until either the stop sequence is matched or the sequence is reset by a divergence
409    state: String,
410}
411
412impl StopSequenceDecoder {
413    /// Builder object for configurating a StopSequenceDecoder
414    pub fn builder(tokenizer: Tokenizer) -> StopSequenceDecoderBuilder {
415        StopSequenceDecoderBuilder::new(tokenizer)
416    }
417
418    /// Add a token_id to the sequence and return the SequenceDecoderOutput
419    pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
420        if self.stopped {
421            return Err(Error::msg("Decoder is stopped"));
422        }
423
424        // update the sequence
425        let text = self.sequence.append_token_id(token_id)?;
426
427        // append the text to the state
428        self.state.push_str(text.as_str());
429
430        let mut stop: bool = false;
431        let mut visible: bool = false;
432
433        if self.stop_token_ids_visible.contains(&token_id) {
434            stop = true;
435            visible = true;
436        }
437
438        if self.stop_token_ids_hidden.contains(&token_id) {
439            stop = true;
440            visible = false;
441        }
442
443        if stop {
444            self.stopped = true;
445            let state = std::mem::take(&mut self.state);
446            if visible {
447                return Ok(SequenceDecoderOutput::StoppedWithText(state));
448            }
449            return Ok(SequenceDecoderOutput::Stopped);
450        }
451
452        // determine if state matches any of the stop sequences
453        for stop_sequence in self.stop_sequences_hidden.iter() {
454            if stop_sequence.starts_with(&self.state) {
455                if stop_sequence == &self.state {
456                    // on matched stop sequence, we do NOT return the jailed stop sequence
457                    self.stopped = true;
458                    return Ok(SequenceDecoderOutput::Stopped);
459                } else {
460                    return Ok(SequenceDecoderOutput::Held);
461                }
462            }
463        }
464
465        let state = std::mem::take(&mut self.state);
466        Ok(SequenceDecoderOutput::Text(state))
467    }
468
469    pub fn is_empty(&self) -> bool {
470        self.sequence.token_ids.is_empty()
471    }
472
473    pub fn len(&self) -> usize {
474        self.sequence.token_ids.len()
475    }
476
477    pub fn is_complete(&self) -> bool {
478        self.stopped
479    }
480
481    pub fn close(&mut self) {
482        self.stopped = true;
483    }
484}
485
486pub struct StopSequenceDecoderBuilder {
487    tokenizer: Tokenizer,
488    stop_token_ids_visible: Vec<TokenIdType>,
489    stop_token_ids_hidden: Vec<TokenIdType>,
490    stop_sequences_visible: Vec<String>,
491    stop_sequences_hidden: Vec<String>,
492}
493
494impl StopSequenceDecoderBuilder {
495    pub fn new(tokenizer: Tokenizer) -> Self {
496        Self {
497            tokenizer,
498            stop_token_ids_visible: Vec::new(),
499            stop_token_ids_hidden: Vec::new(),
500            stop_sequences_visible: Vec::new(),
501            stop_sequences_hidden: Vec::new(),
502        }
503    }
504
505    /// Adds a visible stop token id to the StopSequenceDecoder
506    pub fn add_stop_token_id_visible(mut self, token_id: TokenIdType) -> Self {
507        self.stop_token_ids_visible.push(token_id);
508        self
509    }
510
511    /// Adds a list of visible stop token ids to the StopSequenceDecoder
512    /// Each token_id is added as for an individual match
513    pub fn add_stop_token_ids_visible(mut self, token_ids: &[TokenIdType]) -> Self {
514        self.stop_token_ids_visible.extend(token_ids);
515        self
516    }
517
518    /// Adds a hidden stop token id to the StopSequenceDecoder
519    pub fn add_stop_token_id_hidden(mut self, token_id: TokenIdType) -> Self {
520        self.stop_token_ids_hidden.push(token_id);
521        self
522    }
523
524    /// Adds a list of hidden stop token ids to the StopSequenceDecoder
525    /// Each token_id is added as for an individual match
526    pub fn add_stop_token_ids_hidden(mut self, token_ids: &[TokenIdType]) -> Self {
527        self.stop_token_ids_hidden.extend(token_ids);
528        self
529    }
530
531    pub fn add_stop_sequence_visible(mut self, text: &str) -> Self {
532        self.stop_sequences_visible.push(text.to_string());
533        self
534    }
535
536    pub fn add_stop_sequences_visible(mut self, strings: &[&str]) -> Self {
537        self.stop_sequences_visible
538            .extend(strings.iter().map(|text| text.to_string()));
539        self
540    }
541
542    pub fn add_stop_sequence_hidden(mut self, text: &str) -> Self {
543        self.stop_sequences_hidden.push(text.to_string());
544        self
545    }
546
547    pub fn add_stop_sequences_hidden(mut self, strings: &[&str]) -> Self {
548        self.stop_sequences_hidden
549            .extend(strings.iter().map(|text| text.to_string()));
550        self
551    }
552
553    pub fn build(self) -> Result<StopSequenceDecoder> {
554        Ok(StopSequenceDecoder {
555            sequence: Sequence::new(self.tokenizer.clone()),
556            stop_token_ids_visible: self.stop_token_ids_visible,
557            stop_token_ids_hidden: self.stop_token_ids_hidden,
558            stop_sequences_visible: self.stop_sequences_visible,
559            stop_sequences_hidden: self.stop_sequences_hidden,
560            stopped: false,
561            state: String::new(),
562        })
563    }
564}