dynamo_llm/tokenizers/
hf.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use tokenizers::tokenizer::Tokenizer as HfTokenizer;
17
18use super::{
19    traits::{Decoder, Encoder, Tokenizer},
20    Encoding, Error, Result, TokenIdType,
21};
22
23pub struct HuggingFaceTokenizer {
24    tokenizer: HfTokenizer,
25}
26
27impl HuggingFaceTokenizer {
28    pub fn from_file(model_name: &str) -> Result<Self> {
29        let tokenizer = HfTokenizer::from_file(model_name)
30            .map_err(|err| Error::msg(format!("Error loading tokenizer: {}", err)))?;
31
32        Ok(HuggingFaceTokenizer { tokenizer })
33    }
34
35    pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
36        HuggingFaceTokenizer { tokenizer }
37    }
38}
39
40impl Encoder for HuggingFaceTokenizer {
41    fn encode(&self, input: &str) -> Result<Encoding> {
42        let encoding = self
43            .tokenizer
44            .encode(input, false)
45            .map_err(|err| Error::msg(format!("Error encoding input: {}", err)))?;
46
47        let token_ids = encoding.get_ids().to_vec();
48        let tokens = encoding.get_tokens().to_vec();
49        let spans = encoding.get_offsets().to_vec();
50
51        Ok(Encoding {
52            token_ids,
53            tokens,
54            spans,
55        })
56    }
57}
58
59impl Decoder for HuggingFaceTokenizer {
60    fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
61        let text = self
62            .tokenizer
63            .decode(token_ids, skip_special_tokens)
64            .map_err(|err| Error::msg(format!("Error decoding input: {}", err)))?;
65
66        Ok(text)
67    }
68}
69
70impl Tokenizer for HuggingFaceTokenizer {}
71
72impl From<HfTokenizer> for HuggingFaceTokenizer {
73    fn from(tokenizer: HfTokenizer) -> Self {
74        HuggingFaceTokenizer { tokenizer }
75    }
76}