1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
//! # kalosm-sample
//! This is a sampling library for Kalosm.
//!
//! It handles choosing a token from a probability distribution. Samplers can be used to constrain the generation of text for example you can use a sampler to prevent the model from generating the same word twice in a row. Or you could only allow the model to generate a list of single digit numbers.

#![warn(missing_docs)]

use std::borrow::Cow;
use std::ops::Deref;
use std::sync::Arc;
use tokenizers::Decoder;
use tokenizers::DecoderWrapper;
use tokenizers::Model;
use tokenizers::ModelWrapper;
use tokenizers::Normalizer;
use tokenizers::NormalizerWrapper;
use tokenizers::PostProcessor;
use tokenizers::PostProcessorWrapper;
use tokenizers::PreTokenizer;
use tokenizers::PreTokenizerWrapper;
use tokenizers::TokenizerImpl;

mod structured_parser;
pub use structured_parser::*;
#[cfg(feature = "llamacpp")]
mod llm;

/// A type erased wrapper for a tokenizer.
pub struct DynTokenizer {
    tokenizer: Arc<dyn Tokenizer + Send + Sync>,
}

impl<M, N, PT, PP, D> From<tokenizers::tokenizer::TokenizerImpl<M, N, PT, PP, D>> for DynTokenizer
where
    M: Model + Send + Sync + 'static,
    N: Normalizer + Send + Sync + 'static,
    PT: PreTokenizer + Send + Sync + 'static,
    PP: PostProcessor + Send + Sync + 'static,
    D: Decoder + Send + Sync + 'static,
{
    fn from(tokenizer: tokenizers::tokenizer::TokenizerImpl<M, N, PT, PP, D>) -> Self {
        Self::new(tokenizer)
    }
}

impl From<tokenizers::Tokenizer> for DynTokenizer {
    fn from(tokenizer: tokenizers::Tokenizer) -> Self {
        Self::new(tokenizer)
    }
}

impl From<Arc<dyn Tokenizer + Send + Sync>> for DynTokenizer {
    fn from(tokenizer: Arc<dyn Tokenizer + Send + Sync>) -> Self {
        Self {
            tokenizer: tokenizer.clone(),
        }
    }
}

impl DynTokenizer {
    /// Create a new `DynTokenizer` from a `Tokenizer`.
    pub fn new<T: Tokenizer + Send + Sync + 'static>(tokenizer: T) -> Self {
        Self {
            tokenizer: Arc::new(tokenizer),
        }
    }
}

impl Tokenizer for DynTokenizer {
    fn encode(&self, text: &str, special_tokens: bool) -> anyhow::Result<Vec<u32>> {
        self.tokenizer.encode(text, special_tokens)
    }

    fn decode(&self, ids: &[u32]) -> anyhow::Result<Cow<'_, str>> {
        self.tokenizer.decode(ids)
    }

    fn get_all_tokens(&self) -> anyhow::Result<Cow<'_, [u32]>> {
        self.tokenizer.get_all_tokens()
    }
}

/// A tokenizer is a type that can decode a list of token ids into a string.
pub trait Tokenizer {
    /// Encode a string into a list of token ids.
    fn encode(&self, text: &str, add_special_tokens: bool) -> anyhow::Result<Vec<u32>>;

    /// Encode a list of strings into a list of token ids.
    fn encode_batch(
        &self,
        text: &[&str],
        add_special_tokens: bool,
    ) -> anyhow::Result<Vec<Vec<u32>>> {
        text.iter()
            .map(|text| self.encode(text, add_special_tokens))
            .collect()
    }

    /// Decode a list of token ids into a string.
    fn decode(&self, ids: &[u32]) -> anyhow::Result<Cow<'_, str>>;

    /// Decode a list of a list of token ids into a string.
    fn decode_batch(&self, ids: &[&[u32]]) -> anyhow::Result<Vec<Cow<'_, str>>> {
        ids.iter().map(|id| self.decode(id)).collect()
    }

    /// Get all possible tokens.
    fn get_all_tokens(&self) -> anyhow::Result<Cow<'_, [u32]>>;
}

impl<M, N, PT, PP, D> Tokenizer for tokenizers::tokenizer::TokenizerImpl<M, N, PT, PP, D>
where
    M: Model,
    N: Normalizer,
    PT: PreTokenizer,
    PP: PostProcessor,
    D: Decoder,
{
    fn encode(&self, text: &str, special_tokens: bool) -> anyhow::Result<Vec<u32>> {
        Ok(self
            .encode(text, special_tokens)
            .map_err(|e| anyhow::anyhow!(e))?
            .get_ids()
            .to_vec())
    }

    fn decode(&self, ids: &[u32]) -> anyhow::Result<Cow<'_, str>> {
        self.decode(ids, false)
            .map(|s| s.into())
            .map_err(|e| anyhow::anyhow!(e))
    }

    fn get_all_tokens(&self) -> anyhow::Result<Cow<'_, [u32]>> {
        Ok(self
            .get_vocab(true)
            .into_values()
            .collect::<Vec<_>>()
            .into())
    }
}

/// A tokenizer that uses the HuggingFace tokenizer with a cache for single tokens.
pub struct FasterHuggingFaceTokenizer {
    inner: tokenizers::Tokenizer,
    all_tokens: Vec<u32>,
}

impl FasterHuggingFaceTokenizer {
    /// Create a new `FasterHuggingFaceTokenizer` from a `tokenizers::Tokenizer`.
    pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
        Self {
            all_tokens: tokenizer.get_vocab(true).into_values().collect(),
            inner: tokenizer,
        }
    }

    /// Get the inner tokenizer.
    pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
        &self.inner
    }

    /// Get the inner tokenizer mutably.
    pub fn tokenizer_mut(&mut self) -> &mut tokenizers::Tokenizer {
        &mut self.inner
    }

    /// Consume the `FasterHuggingFaceTokenizer` and return the inner tokenizer.
    pub fn into_tokenizer(self) -> tokenizers::Tokenizer {
        self.inner
    }
}

impl Tokenizer for FasterHuggingFaceTokenizer {
    fn encode(&self, text: &str, special_tokens: bool) -> anyhow::Result<Vec<u32>> {
        self.inner.encode(text, special_tokens)
    }

    fn decode(&self, ids: &[u32]) -> anyhow::Result<Cow<'_, str>> {
        self.inner.decode(ids)
    }

    fn decode_batch(&self, ids: &[&[u32]]) -> anyhow::Result<Vec<Cow<'_, str>>> {
        self.inner.decode_batch(ids)
    }

    fn get_all_tokens(&self) -> anyhow::Result<Cow<'_, [u32]>> {
        Ok((&self.all_tokens).into())
    }
}

impl Tokenizer for tokenizers::Tokenizer {
    fn encode(&self, text: &str, special_tokens: bool) -> anyhow::Result<Vec<u32>> {
        let deref = self.deref();
        Ok(deref
            .encode(text, special_tokens)
            .map_err(|e| anyhow::anyhow!(e))?
            .get_ids()
            .to_vec())
    }

    fn decode(&self, ids: &[u32]) -> anyhow::Result<Cow<'_, str>> {
        let as_impl: &TokenizerImpl<
            ModelWrapper,
            NormalizerWrapper,
            PreTokenizerWrapper,
            PostProcessorWrapper,
            DecoderWrapper,
        > = self;
        Ok(as_impl
            .decode(ids, false)
            .map_err(|e| anyhow::anyhow!(e))?
            .into())
    }

    fn get_all_tokens(&self) -> anyhow::Result<Cow<'_, [u32]>> {
        let as_impl: &TokenizerImpl<
            ModelWrapper,
            NormalizerWrapper,
            PreTokenizerWrapper,
            PostProcessorWrapper,
            DecoderWrapper,
        > = self;
        Ok(as_impl
            .get_vocab(true)
            .into_values()
            .collect::<Vec<_>>()
            .into())
    }
}