dynamo_llm/
tokenizers.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16pub mod hf;
17
18#[cfg(feature = "sentencepiece")]
19pub mod sp;
20
21// TODO: Add tokenizer benchmarks
22// TODO: Enable README.md as a module doc
23// #[doc = include_str!("../README.md")]
24
25use std::hash::{DefaultHasher, Hash, Hasher};
26use std::sync::Arc;
27use std::{ops::Deref, path::Path};
28
29use crate::protocols::TokenIdType;
30pub use anyhow::{Error, Result};
31
32pub use hf::HuggingFaceTokenizer;
33
34#[cfg(feature = "sentencepiece")]
35pub use sp::SentencePieceTokenizer;
36
37/// Represents the type of tokenizer being used
38#[derive(Debug)]
39pub enum TokenizerType {
40    HuggingFace(String),
41    #[cfg(feature = "sentencepiece")]
42    SentencePiece(String),
43}
44
45/// character offsets in the original text
46pub type Offsets = (usize, usize);
47
48/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
49#[derive(Debug, Hash)]
50pub struct Encoding {
51    pub token_ids: Vec<TokenIdType>,
52    pub tokens: Vec<String>,
53    pub spans: Vec<Offsets>,
54}
55
56pub mod traits {
57    use super::*;
58
59    pub trait Encoder: Send + Sync {
60        fn encode(&self, input: &str) -> Result<Encoding>;
61    }
62
63    pub trait Decoder: Send + Sync {
64        fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
65    }
66
67    pub trait Tokenizer: Encoder + Decoder {
68        // fn get_vocab_size(&self) -> usize;
69        // fn make_unique_clone(&self) -> Box<dyn Tokenizer>;
70    }
71}
72
73impl Encoding {
74    pub fn get_hash(&self) -> u64 {
75        let mut hasher = DefaultHasher::new();
76        self.hash(&mut hasher);
77        hasher.finish()
78    }
79}
80
81/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
82#[derive(Clone)]
83pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
84
85impl Tokenizer {
86    pub fn from_file(file_path: &str) -> Result<Tokenizer> {
87        Ok(Tokenizer(create_tokenizer_from_file(file_path)?))
88    }
89
90    /// Create a stateful sequence object for decoding token_ids into text
91    pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream {
92        DecodeStream::new(self.0.clone(), skip_special_tokens)
93    }
94}
95
96impl Deref for Tokenizer {
97    type Target = Arc<dyn traits::Tokenizer>;
98
99    fn deref(&self) -> &Self::Target {
100        &self.0
101    }
102}
103
104impl From<Arc<dyn traits::Tokenizer>> for Tokenizer {
105    fn from(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
106        Tokenizer(tokenizer)
107    }
108}
109
110impl<T> From<Arc<T>> for Tokenizer
111where
112    T: traits::Tokenizer + 'static, // 'static is required to ensure T can be safely put into an Arc
113{
114    fn from(tokenizer: Arc<T>) -> Self {
115        Tokenizer(tokenizer)
116    }
117}
118
119/// Create a tokenizer from a file path to a tokenizer file.
120/// The file extension is used to determine the tokenizer type.
121/// Supported file types are:
122/// - json: HuggingFace tokenizer
123/// - model: SentencePiece tokenizer
124pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
125    let path = Path::new(file_path);
126    let extension = path
127        .extension()
128        .and_then(std::ffi::OsStr::to_str)
129        .ok_or_else(|| Error::msg("Failed to read file extension".to_string()))?;
130
131    match extension {
132        "json" => {
133            let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
134            Ok(Arc::new(tokenizer))
135        }
136        "model" => {
137            #[cfg(feature = "sentencepiece")]
138            {
139                let tokenizer = SentencePieceTokenizer::from_file(file_path)?;
140                Ok(Arc::new(tokenizer))
141            }
142            #[cfg(not(feature = "sentencepiece"))]
143            {
144                Err(Error::msg(
145                    "SentencePiece tokenizer not supported".to_string(),
146                ))
147            }
148        }
149        _ => Err(Error::msg("Unsupported file type".to_string())),
150    }
151}
152
153/// DecodeStream will keep the state necessary to produce individual chunks of
154/// strings given an input stream of token_ids.
155///
156/// This is necessary because decoding in general cannot achieve that since strings
157/// depend on surrounding ids to provide a valid string. Typically stripping extra spaces.
158pub struct DecodeStream {
159    /// The tokenizer used to decode token_ids
160    tokenizer: Arc<dyn traits::Tokenizer>,
161
162    skip_special_tokens: bool,
163    /// A temporary buffer of the necessary token_ids needed
164    /// to produce valid string chunks.
165    /// This typically contains 3 parts:
166    ///  - read
167    ///  - prefix
168    ///  - rest
169    ///
170    /// Read is the bit necessary to surround the prefix
171    /// so decoding the whole ids produces a valid prefix.
172    /// Prefix is the previously produced string, kept around to trim off of
173    /// the next valid chunk
174    ids: Vec<u32>,
175
176    /// The previously returned chunk that needs to be discarded from the
177    /// decoding of the current ids to produce the next chunk
178    prefix: String,
179
180    /// The index within the ids corresponding to the prefix so we can drain
181    /// correctlyk
182    prefix_index: usize,
183
184    /// We need to keep 2 prefixes.
185    /// Prefix is the second one that was already emitted to discard the part
186    /// of the text of all the ids
187    /// read is the prefix kept only for starting side effects of the prefix
188    read_index: usize,
189}
190
191impl DecodeStream {
192    pub fn new(tokenizer: Arc<dyn traits::Tokenizer>, skip_special_tokens: bool) -> Self {
193        Self {
194            tokenizer,
195            skip_special_tokens,
196            ids: Vec::new(),
197            prefix: "".to_string(),
198            prefix_index: 0,
199            read_index: 0,
200        }
201    }
202
203    /// Step appends a token_id to the internal state and tries to produce a text chunk.
204    ///
205    /// The method only fails if the internal state is corrupted.
206    ///
207    /// Returning `None` means the given id is not enough to produce a chunk.
208    /// This typically happens with `byte_fallback` options where some tokens do not
209    /// represent valid UTF-8, and only follow-up token_ids will help produce
210    /// a valid chunk.
211    pub fn step(&mut self, id: u32) -> Result<Option<String>> {
212        self.ids.push(id);
213        let string = self
214            .tokenizer
215            .decode(self.ids.as_slice(), self.skip_special_tokens)?;
216
217        if string.len() > self.prefix.len() && !string.ends_with('�') {
218            if !(string.starts_with(&self.prefix)) {
219                anyhow::bail!("Detokenizer failure: invalid prefix");
220            }
221            let new_text = &string[self.prefix.len()..].to_string();
222            let new_prefix_index = self.ids.len() - self.prefix_index;
223            self.prefix = self
224                .tokenizer
225                .decode(self.ids.as_slice(), self.skip_special_tokens)?;
226            self.read_index = self.prefix_index;
227            self.prefix_index = new_prefix_index;
228            Ok(Some(new_text.to_string()))
229        } else {
230            Ok(None)
231        }
232    }
233}
234
235/// Maintains state for an ongoing sequence of tokens and their decoded text
236pub struct Sequence {
237    /// Encodes text -> token_ids
238    tokenizer: Tokenizer,
239
240    /// The current sequence of token ids
241    token_ids: Vec<TokenIdType>,
242
243    /// The position in the current sequence the last decoded token completed
244    prefix_offset: usize,
245
246    /// Current position in the sequence
247    read_offset: usize,
248}
249
250impl std::fmt::Debug for Sequence {
251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        f.debug_struct("Sequence")
253            .field("tokenizer", &"Arc<dyn Tokenizer>")
254            .field(
255                "token_ids",
256                &format_args!("{}", {
257                    if self.token_ids.len() <= 20 {
258                        format!("{:?}", self.token_ids)
259                    } else {
260                        let first_ten = &self.token_ids[..10];
261                        let last_ten = &self.token_ids[self.token_ids.len() - 10..];
262                        format!("{:?} ... {:?}", first_ten, last_ten)
263                    }
264                }),
265            )
266            .field("prefix_offset", &self.prefix_offset)
267            .field("read_offset", &self.read_offset)
268            .field("token count", &self.token_ids.len())
269            .finish()
270    }
271}
272
273impl Sequence {
274    pub fn new(tokenizer: Tokenizer) -> Self {
275        Self {
276            tokenizer,
277            token_ids: Vec::new(),
278            prefix_offset: 0,
279            read_offset: 0,
280        }
281    }
282
283    pub fn is_empty(&self) -> bool {
284        self.token_ids.is_empty()
285    }
286
287    pub fn len(&self) -> usize {
288        self.token_ids.len()
289    }
290
291    pub fn clear(&mut self) {
292        self.token_ids.clear();
293        self.prefix_offset = 0;
294        self.read_offset = 0;
295    }
296
297    pub fn append_text(&mut self, input: &str) -> Result<()> {
298        // let tokenizer = self.tokenizer.read().map_err(|err| {
299        //     Error::msg(format!("Failed to acquire read lock on tokenizer: {}", err))
300        // })?;
301
302        let encoding = self.tokenizer.encode(input)?;
303        self.token_ids.extend(encoding.token_ids);
304        Ok(())
305    }
306
307    // Based on
308    // https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
309    // under Apache 2.0 license
310    pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<String> {
311        self.token_ids.push(token_id);
312        // log::trace!("pushed token_id: {}", token_id);
313
314        let prefix_text = self
315            .tokenizer
316            .decode(&self.token_ids[self.prefix_offset..self.read_offset], false)?;
317
318        let new_text = self
319            .tokenizer
320            .decode(&self.token_ids[self.prefix_offset..], false)?;
321
322        // if the end character of the previous returned sequence is a multi-byte character
323        // then we can not split the text on that byte offset, so we roll back to the byte offset
324        // of the start of that character
325        let mut prefix_text_len = prefix_text.len();
326        while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
327            prefix_text_len -= 1;
328        }
329        let prefix_text_len = prefix_text_len;
330
331        if new_text.len() > prefix_text.len() {
332            if new_text.ends_with("�") {
333                return Ok("".to_string());
334            } else {
335                // shift and update the state
336                let new_text = new_text[prefix_text_len..].to_string().replace("�", "");
337                self.prefix_offset = self.read_offset;
338                self.read_offset = self.token_ids.len();
339                return Ok(new_text);
340            }
341        }
342
343        Ok("".to_string())
344    }
345
346    pub fn tokenizer(&self) -> Tokenizer {
347        self.tokenizer.clone()
348    }
349
350    pub fn token_ids(&self) -> &[TokenIdType] {
351        &self.token_ids
352    }
353
354    pub fn text(&self) -> Result<String> {
355        // let tokenizer = self.tokenizer.read().map_err(|err| {
356        //     Error::msg(format!("Failed to acquire read lock on tokenizer: {}", err))
357        // })?;
358        self.tokenizer.decode(&self.token_ids, false)
359    }
360}
361
362/// The output conditions/values of a SequenceDecoder::add_token_id operation.
363/// Result of decoding a token, indicating whether text was produced or a stop condition was met
364pub enum SequenceDecoderOutput {
365    /// The text for the appended token_id
366    Text(String),
367
368    /// A sequence of token_ids has been partially matched a stop sequence, so the text is held
369    /// until either a match or a divergence
370    Held,
371
372    /// Indicates that a stop sequence has been matched and the decoder is stopped.
373    /// Subsequent calls to append_token_id will return an error
374    Stopped,
375
376    /// Indicates that a stop token_id has been matched and the decoder is stopped.
377    /// Subsequent calls to append_token_id will return an error
378    /// The text for the stop token_id is returned
379    StoppedWithText(String),
380}
381
382/// A Sequence for decoding a stream of token ids into text and detecting stop sequences.
383/// A stop sequence is either a matching token_id or a sequence of texts/strings which match.
384/// Matches happen first at the token-level, then at the sequence-level. Hidden takes precedence
385/// over visible. For example, if you put the same token_id in both `stop_token_ids_visible` and
386/// `stop_token_ids_hidden`, the token_id will be treated as hidden.
387#[derive(Debug)]
388pub struct StopSequenceDecoder {
389    // The current sequence of token ids
390    sequence: Sequence,
391
392    // Stop Tokens - the presence of any one of these should trigger a stop
393    // If found, the text for the matched token will be returned
394    stop_token_ids_visible: Vec<TokenIdType>,
395
396    // Stop Tokens - the presence of any one of these should trigger a stop
397    // If found, the text for the matched token will NOT be returned
398    stop_token_ids_hidden: Vec<TokenIdType>,
399
400    // Stop Words - the presence of any one of these should trigger a stop
401    // If found, the text for the matched token will be returned
402    #[allow(dead_code)]
403    stop_sequences_visible: Vec<String>,
404
405    // Stop Words - the presence of any one of these should trigger a stop
406    // If found, the text for the matched token will NOT be returned
407    stop_sequences_hidden: Vec<String>,
408
409    // If the decoder has observed and returned a stop SequenceDecoderOutput,
410    // futhur calls to append_token_id will return an error
411    stopped: bool,
412
413    // text jail - if a partial stop sequence is being observed, we hold/jail the text
414    // until either the stop sequence is matched or the sequence is reset by a divergence
415    state: String,
416}
417
418impl StopSequenceDecoder {
419    /// Builder object for configurating a StopSequenceDecoder
420    pub fn builder(tokenizer: Tokenizer) -> StopSequenceDecoderBuilder {
421        StopSequenceDecoderBuilder::new(tokenizer)
422    }
423
424    /// Add a token_id to the sequence and return the SequenceDecoderOutput
425    pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
426        if self.stopped {
427            return Err(Error::msg("Decoder is stopped"));
428        }
429
430        // update the sequence
431        let text = self.sequence.append_token_id(token_id)?;
432
433        // append the text to the state
434        self.state.push_str(text.as_str());
435
436        let mut stop: bool = false;
437        let mut visible: bool = false;
438
439        if self.stop_token_ids_visible.contains(&token_id) {
440            stop = true;
441            visible = true;
442        }
443
444        if self.stop_token_ids_hidden.contains(&token_id) {
445            stop = true;
446            visible = false;
447        }
448
449        if stop {
450            self.stopped = true;
451            let state = std::mem::take(&mut self.state);
452            if visible {
453                return Ok(SequenceDecoderOutput::StoppedWithText(state));
454            }
455            return Ok(SequenceDecoderOutput::Stopped);
456        }
457
458        // determine if state matches any of the stop sequences
459        for stop_sequence in self.stop_sequences_hidden.iter() {
460            if stop_sequence.starts_with(&self.state) {
461                if stop_sequence == &self.state {
462                    // on matched stop sequence, we do NOT return the jailed stop sequence
463                    self.stopped = true;
464                    return Ok(SequenceDecoderOutput::Stopped);
465                } else {
466                    return Ok(SequenceDecoderOutput::Held);
467                }
468            }
469        }
470
471        let state = std::mem::take(&mut self.state);
472        Ok(SequenceDecoderOutput::Text(state))
473    }
474
475    pub fn is_empty(&self) -> bool {
476        self.sequence.token_ids.is_empty()
477    }
478
479    pub fn len(&self) -> usize {
480        self.sequence.token_ids.len()
481    }
482
483    pub fn is_complete(&self) -> bool {
484        self.stopped
485    }
486
487    pub fn close(&mut self) {
488        self.stopped = true;
489    }
490}
491
492pub struct StopSequenceDecoderBuilder {
493    tokenizer: Tokenizer,
494    stop_token_ids_visible: Vec<TokenIdType>,
495    stop_token_ids_hidden: Vec<TokenIdType>,
496    stop_sequences_visible: Vec<String>,
497    stop_sequences_hidden: Vec<String>,
498}
499
500impl StopSequenceDecoderBuilder {
501    pub fn new(tokenizer: Tokenizer) -> Self {
502        Self {
503            tokenizer,
504            stop_token_ids_visible: Vec::new(),
505            stop_token_ids_hidden: Vec::new(),
506            stop_sequences_visible: Vec::new(),
507            stop_sequences_hidden: Vec::new(),
508        }
509    }
510
511    /// Adds a visible stop token id to the StopSequenceDecoder
512    pub fn add_stop_token_id_visible(mut self, token_id: TokenIdType) -> Self {
513        self.stop_token_ids_visible.push(token_id);
514        self
515    }
516
517    /// Adds a list of visible stop token ids to the StopSequenceDecoder
518    /// Each token_id is added as for an individual match
519    pub fn add_stop_token_ids_visible(mut self, token_ids: &[TokenIdType]) -> Self {
520        self.stop_token_ids_visible.extend(token_ids);
521        self
522    }
523
524    /// Adds a hidden stop token id to the StopSequenceDecoder
525    pub fn add_stop_token_id_hidden(mut self, token_id: TokenIdType) -> Self {
526        self.stop_token_ids_hidden.push(token_id);
527        self
528    }
529
530    /// Adds a list of hidden stop token ids to the StopSequenceDecoder
531    /// Each token_id is added as for an individual match
532    pub fn add_stop_token_ids_hidden(mut self, token_ids: &[TokenIdType]) -> Self {
533        self.stop_token_ids_hidden.extend(token_ids);
534        self
535    }
536
537    pub fn add_stop_sequence_visible(mut self, text: &str) -> Self {
538        self.stop_sequences_visible.push(text.to_string());
539        self
540    }
541
542    pub fn add_stop_sequences_visible(mut self, strings: &[&str]) -> Self {
543        self.stop_sequences_visible
544            .extend(strings.iter().map(|text| text.to_string()));
545        self
546    }
547
548    pub fn add_stop_sequence_hidden(mut self, text: &str) -> Self {
549        self.stop_sequences_hidden.push(text.to_string());
550        self
551    }
552
553    pub fn add_stop_sequences_hidden(mut self, strings: &[&str]) -> Self {
554        self.stop_sequences_hidden
555            .extend(strings.iter().map(|text| text.to_string()));
556        self
557    }
558
559    pub fn build(self) -> Result<StopSequenceDecoder> {
560        Ok(StopSequenceDecoder {
561            sequence: Sequence::new(self.tokenizer.clone()),
562            stop_token_ids_visible: self.stop_token_ids_visible,
563            stop_token_ids_hidden: self.stop_token_ids_hidden,
564            stop_sequences_visible: self.stop_sequences_visible,
565            stop_sequences_hidden: self.stop_sequences_hidden,
566            stopped: false,
567            state: String::new(),
568        })
569    }
570}