1pub mod hf;
5
6use std::hash::{DefaultHasher, Hash, Hasher};
11use std::sync::Arc;
12use std::{ops::Deref, path::Path};
13
14use crate::protocols::TokenIdType;
15pub use anyhow::{Error, Result};
16
17pub use hf::HuggingFaceTokenizer;
18
19#[derive(Debug)]
21pub enum TokenizerType {
22 HuggingFace(String),
23}
24
25pub type Offsets = (usize, usize);
27
28#[derive(Debug, Clone)]
30pub enum Encoding {
31 Hf(Box<tokenizers::tokenizer::Encoding>),
33 Sp(Vec<TokenIdType>),
35}
36
37impl Encoding {
38 pub fn token_ids(&self) -> &[u32] {
39 match self {
40 Encoding::Hf(inner) => inner.get_ids(),
41 Encoding::Sp(inner) => inner,
42 }
43 }
44}
45
46impl Hash for Encoding {
47 fn hash<H: Hasher>(&self, state: &mut H) {
48 self.token_ids().hash(state);
49 }
50}
51
52pub mod traits {
53 use super::*;
54
55 pub trait Encoder: Send + Sync {
56 fn encode(&self, input: &str) -> Result<Encoding>;
57 fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>;
58 }
59
60 pub trait Decoder: Send + Sync {
61 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
62 }
63
64 pub trait Tokenizer: Encoder + Decoder {
65 }
68}
69
70impl Encoding {
71 pub fn get_hash(&self) -> u64 {
72 let mut hasher = DefaultHasher::new();
73 self.hash(&mut hasher);
74 hasher.finish()
75 }
76}
77
78#[derive(Clone)]
80pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
81
82impl Tokenizer {
83 pub fn from_file(file_path: &str) -> Result<Tokenizer> {
84 Ok(Tokenizer(create_tokenizer_from_file(file_path)?))
85 }
86
87 pub fn decode_stream(
89 &self,
90 prompt_token_ids: &[TokenIdType],
91 skip_special_tokens: bool,
92 ) -> DecodeStream {
93 DecodeStream::new(self.0.clone(), prompt_token_ids, skip_special_tokens)
94 }
95}
96
97impl Deref for Tokenizer {
98 type Target = Arc<dyn traits::Tokenizer>;
99
100 fn deref(&self) -> &Self::Target {
101 &self.0
102 }
103}
104
105impl From<Arc<dyn traits::Tokenizer>> for Tokenizer {
106 fn from(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
107 Tokenizer(tokenizer)
108 }
109}
110
111impl<T> From<Arc<T>> for Tokenizer
112where
113 T: traits::Tokenizer + 'static, {
115 fn from(tokenizer: Arc<T>) -> Self {
116 Tokenizer(tokenizer)
117 }
118}
119
120pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
125 let path = Path::new(file_path);
126 let extension = path
127 .extension()
128 .and_then(std::ffi::OsStr::to_str)
129 .ok_or_else(|| Error::msg("Failed to read file extension".to_string()))?;
130
131 match extension {
132 "json" => {
133 let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
134 Ok(Arc::new(tokenizer))
135 }
136 _ => Err(Error::msg("Unsupported file type".to_string())),
137 }
138}
139
140const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
146
147pub struct DecodeStream {
153 tokenizer: Arc<dyn traits::Tokenizer>,
155
156 skip_special_tokens: bool,
157 all_token_ids: Vec<u32>,
169
170 prefix_offset: usize,
171
172 read_offset: usize,
173}
174
175impl DecodeStream {
176 pub fn new(
177 tokenizer: Arc<dyn traits::Tokenizer>,
178 prompt_token_ids: &[TokenIdType],
179 skip_special_tokens: bool,
180 ) -> Self {
181 let num_input_tokens = prompt_token_ids.len();
182 let prompt_token_ids = prompt_token_ids.to_vec();
183 Self {
184 tokenizer,
185 skip_special_tokens,
186 all_token_ids: prompt_token_ids,
187 prefix_offset: num_input_tokens
188 .saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET),
189 read_offset: num_input_tokens,
190 }
191 }
192
193 pub fn step(&mut self, id: u32) -> Result<Option<String>> {
203 self.all_token_ids.push(id);
204
205 let prefix_text = self.tokenizer.decode(
206 &self.all_token_ids[self.prefix_offset..self.read_offset],
207 self.skip_special_tokens,
208 )?;
209
210 let new_text = self.tokenizer.decode(
211 &self.all_token_ids[self.prefix_offset..],
212 self.skip_special_tokens,
213 )?;
214
215 if new_text.len() > prefix_text.len() && !new_text.ends_with("�") {
216 let new_text = new_text[prefix_text.len()..].to_string();
217
218 self.prefix_offset = self.read_offset;
219 self.read_offset = self.all_token_ids.len();
220
221 Ok(Some(new_text))
222 } else {
223 Ok(None)
224 }
225 }
226}
227
228pub struct Sequence {
230 tokenizer: Tokenizer,
232
233 token_ids: Vec<TokenIdType>,
235
236 prefix_offset: usize,
238
239 read_offset: usize,
241}
242
243impl std::fmt::Debug for Sequence {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 f.debug_struct("Sequence")
246 .field("tokenizer", &"Arc<dyn Tokenizer>")
247 .field(
248 "token_ids",
249 &format_args!("{}", {
250 let token_ids = self.token_ids();
251 if token_ids.len() <= 20 {
252 format!("{:?}", token_ids)
253 } else {
254 let first_ten = &token_ids[..10];
255 let last_ten = &token_ids[token_ids.len() - 10..];
256 format!("{:?} ... {:?}", first_ten, last_ten)
257 }
258 }),
259 )
260 .field("prefix_offset", &self.prefix_offset)
261 .field("read_offset", &self.read_offset)
262 .field("token count", &self.token_ids.len())
263 .finish()
264 }
265}
266
267impl Sequence {
268 pub fn new(tokenizer: Tokenizer) -> Self {
269 Self {
270 tokenizer,
271 token_ids: Vec::new(),
272 prefix_offset: 0,
273 read_offset: 0,
274 }
275 }
276
277 pub fn is_empty(&self) -> bool {
278 self.token_ids.is_empty()
279 }
280
281 pub fn len(&self) -> usize {
282 self.token_ids.len()
283 }
284
285 pub fn clear(&mut self) {
286 self.token_ids.clear();
287 self.prefix_offset = 0;
288 self.read_offset = 0;
289 }
290
291 pub fn append_text(&mut self, input: &str) -> Result<()> {
292 let encoding = self.tokenizer.encode(input)?;
297 self.token_ids.extend(encoding.token_ids());
298 Ok(())
299 }
300
301 pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<String> {
305 self.token_ids.push(token_id);
306 let prefix_text = self
309 .tokenizer
310 .decode(&self.token_ids[self.prefix_offset..self.read_offset], false)?;
311
312 let new_text = self
313 .tokenizer
314 .decode(&self.token_ids[self.prefix_offset..], false)?;
315
316 let mut prefix_text_len = prefix_text.len();
320 while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
321 prefix_text_len -= 1;
322 }
323 let prefix_text_len = prefix_text_len;
324
325 if new_text.len() > prefix_text.len() {
326 if new_text.ends_with("�") {
327 return Ok("".to_string());
328 } else {
329 let new_text = new_text[prefix_text_len..].to_string().replace("�", "");
331 self.prefix_offset = self.read_offset;
332 self.read_offset = self.token_ids.len();
333 return Ok(new_text);
334 }
335 }
336
337 Ok("".to_string())
338 }
339
340 pub fn tokenizer(&self) -> Tokenizer {
341 self.tokenizer.clone()
342 }
343
344 pub fn token_ids(&self) -> &[TokenIdType] {
345 &self.token_ids
346 }
347
348 pub fn text(&self) -> Result<String> {
349 self.tokenizer.decode(&self.token_ids, false)
353 }
354}
355
356pub enum SequenceDecoderOutput {
359 Text(String),
361
362 Held,
365
366 Stopped,
369
370 StoppedWithText(String),
374}
375
376#[derive(Debug)]
382pub struct StopSequenceDecoder {
383 sequence: Sequence,
385
386 stop_token_ids_visible: Vec<TokenIdType>,
389
390 stop_token_ids_hidden: Vec<TokenIdType>,
393
394 #[allow(dead_code)]
397 stop_sequences_visible: Vec<String>,
398
399 stop_sequences_hidden: Vec<String>,
402
403 stopped: bool,
406
407 state: String,
410}
411
412impl StopSequenceDecoder {
413 pub fn builder(tokenizer: Tokenizer) -> StopSequenceDecoderBuilder {
415 StopSequenceDecoderBuilder::new(tokenizer)
416 }
417
418 pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
420 if self.stopped {
421 return Err(Error::msg("Decoder is stopped"));
422 }
423
424 let text = self.sequence.append_token_id(token_id)?;
426
427 self.state.push_str(text.as_str());
429
430 let mut stop: bool = false;
431 let mut visible: bool = false;
432
433 if self.stop_token_ids_visible.contains(&token_id) {
434 stop = true;
435 visible = true;
436 }
437
438 if self.stop_token_ids_hidden.contains(&token_id) {
439 stop = true;
440 visible = false;
441 }
442
443 if stop {
444 self.stopped = true;
445 let state = std::mem::take(&mut self.state);
446 if visible {
447 return Ok(SequenceDecoderOutput::StoppedWithText(state));
448 }
449 return Ok(SequenceDecoderOutput::Stopped);
450 }
451
452 for stop_sequence in self.stop_sequences_hidden.iter() {
454 if stop_sequence.starts_with(&self.state) {
455 if stop_sequence == &self.state {
456 self.stopped = true;
458 return Ok(SequenceDecoderOutput::Stopped);
459 } else {
460 return Ok(SequenceDecoderOutput::Held);
461 }
462 }
463 }
464
465 let state = std::mem::take(&mut self.state);
466 Ok(SequenceDecoderOutput::Text(state))
467 }
468
469 pub fn is_empty(&self) -> bool {
470 self.sequence.token_ids.is_empty()
471 }
472
473 pub fn len(&self) -> usize {
474 self.sequence.token_ids.len()
475 }
476
477 pub fn is_complete(&self) -> bool {
478 self.stopped
479 }
480
481 pub fn close(&mut self) {
482 self.stopped = true;
483 }
484}
485
486pub struct StopSequenceDecoderBuilder {
487 tokenizer: Tokenizer,
488 stop_token_ids_visible: Vec<TokenIdType>,
489 stop_token_ids_hidden: Vec<TokenIdType>,
490 stop_sequences_visible: Vec<String>,
491 stop_sequences_hidden: Vec<String>,
492}
493
494impl StopSequenceDecoderBuilder {
495 pub fn new(tokenizer: Tokenizer) -> Self {
496 Self {
497 tokenizer,
498 stop_token_ids_visible: Vec::new(),
499 stop_token_ids_hidden: Vec::new(),
500 stop_sequences_visible: Vec::new(),
501 stop_sequences_hidden: Vec::new(),
502 }
503 }
504
505 pub fn add_stop_token_id_visible(mut self, token_id: TokenIdType) -> Self {
507 self.stop_token_ids_visible.push(token_id);
508 self
509 }
510
511 pub fn add_stop_token_ids_visible(mut self, token_ids: &[TokenIdType]) -> Self {
514 self.stop_token_ids_visible.extend(token_ids);
515 self
516 }
517
518 pub fn add_stop_token_id_hidden(mut self, token_id: TokenIdType) -> Self {
520 self.stop_token_ids_hidden.push(token_id);
521 self
522 }
523
524 pub fn add_stop_token_ids_hidden(mut self, token_ids: &[TokenIdType]) -> Self {
527 self.stop_token_ids_hidden.extend(token_ids);
528 self
529 }
530
531 pub fn add_stop_sequence_visible(mut self, text: &str) -> Self {
532 self.stop_sequences_visible.push(text.to_string());
533 self
534 }
535
536 pub fn add_stop_sequences_visible(mut self, strings: &[&str]) -> Self {
537 self.stop_sequences_visible
538 .extend(strings.iter().map(|text| text.to_string()));
539 self
540 }
541
542 pub fn add_stop_sequence_hidden(mut self, text: &str) -> Self {
543 self.stop_sequences_hidden.push(text.to_string());
544 self
545 }
546
547 pub fn add_stop_sequences_hidden(mut self, strings: &[&str]) -> Self {
548 self.stop_sequences_hidden
549 .extend(strings.iter().map(|text| text.to_string()));
550 self
551 }
552
553 pub fn build(self) -> Result<StopSequenceDecoder> {
554 Ok(StopSequenceDecoder {
555 sequence: Sequence::new(self.tokenizer.clone()),
556 stop_token_ids_visible: self.stop_token_ids_visible,
557 stop_token_ids_hidden: self.stop_token_ids_hidden,
558 stop_sequences_visible: self.stop_sequences_visible,
559 stop_sequences_hidden: self.stop_sequences_hidden,
560 stopped: false,
561 state: String::new(),
562 })
563 }
564}