1pub mod hf;
17
18#[cfg(feature = "sentencepiece")]
19pub mod sp;
20
21use std::hash::{DefaultHasher, Hash, Hasher};
26use std::sync::Arc;
27use std::{ops::Deref, path::Path};
28
29use crate::protocols::TokenIdType;
30pub use anyhow::{Error, Result};
31
32pub use hf::HuggingFaceTokenizer;
33
34#[cfg(feature = "sentencepiece")]
35pub use sp::SentencePieceTokenizer;
36
37#[derive(Debug)]
39pub enum TokenizerType {
40 HuggingFace(String),
41 #[cfg(feature = "sentencepiece")]
42 SentencePiece(String),
43}
44
45pub type Offsets = (usize, usize);
47
48#[derive(Debug, Hash)]
50pub struct Encoding {
51 pub token_ids: Vec<TokenIdType>,
52 pub tokens: Vec<String>,
53 pub spans: Vec<Offsets>,
54}
55
56pub mod traits {
57 use super::*;
58
59 pub trait Encoder: Send + Sync {
60 fn encode(&self, input: &str) -> Result<Encoding>;
61 }
62
63 pub trait Decoder: Send + Sync {
64 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
65 }
66
67 pub trait Tokenizer: Encoder + Decoder {
68 }
71}
72
73impl Encoding {
74 pub fn get_hash(&self) -> u64 {
75 let mut hasher = DefaultHasher::new();
76 self.hash(&mut hasher);
77 hasher.finish()
78 }
79}
80
81#[derive(Clone)]
83pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
84
85impl Tokenizer {
86 pub fn from_file(file_path: &str) -> Result<Tokenizer> {
87 Ok(Tokenizer(create_tokenizer_from_file(file_path)?))
88 }
89
90 pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream {
92 DecodeStream::new(self.0.clone(), skip_special_tokens)
93 }
94}
95
96impl Deref for Tokenizer {
97 type Target = Arc<dyn traits::Tokenizer>;
98
99 fn deref(&self) -> &Self::Target {
100 &self.0
101 }
102}
103
104impl From<Arc<dyn traits::Tokenizer>> for Tokenizer {
105 fn from(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
106 Tokenizer(tokenizer)
107 }
108}
109
110impl<T> From<Arc<T>> for Tokenizer
111where
112 T: traits::Tokenizer + 'static, {
114 fn from(tokenizer: Arc<T>) -> Self {
115 Tokenizer(tokenizer)
116 }
117}
118
119pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
125 let path = Path::new(file_path);
126 let extension = path
127 .extension()
128 .and_then(std::ffi::OsStr::to_str)
129 .ok_or_else(|| Error::msg("Failed to read file extension".to_string()))?;
130
131 match extension {
132 "json" => {
133 let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
134 Ok(Arc::new(tokenizer))
135 }
136 "model" => {
137 #[cfg(feature = "sentencepiece")]
138 {
139 let tokenizer = SentencePieceTokenizer::from_file(file_path)?;
140 Ok(Arc::new(tokenizer))
141 }
142 #[cfg(not(feature = "sentencepiece"))]
143 {
144 Err(Error::msg(
145 "SentencePiece tokenizer not supported".to_string(),
146 ))
147 }
148 }
149 _ => Err(Error::msg("Unsupported file type".to_string())),
150 }
151}
152
153pub struct DecodeStream {
159 tokenizer: Arc<dyn traits::Tokenizer>,
161
162 skip_special_tokens: bool,
163 ids: Vec<u32>,
175
176 prefix: String,
179
180 prefix_index: usize,
183
184 read_index: usize,
189}
190
191impl DecodeStream {
192 pub fn new(tokenizer: Arc<dyn traits::Tokenizer>, skip_special_tokens: bool) -> Self {
193 Self {
194 tokenizer,
195 skip_special_tokens,
196 ids: Vec::new(),
197 prefix: "".to_string(),
198 prefix_index: 0,
199 read_index: 0,
200 }
201 }
202
203 pub fn step(&mut self, id: u32) -> Result<Option<String>> {
212 self.ids.push(id);
213 let string = self
214 .tokenizer
215 .decode(self.ids.as_slice(), self.skip_special_tokens)?;
216
217 if string.len() > self.prefix.len() && !string.ends_with('�') {
218 if !(string.starts_with(&self.prefix)) {
219 anyhow::bail!("Detokenizer failure: invalid prefix");
220 }
221 let new_text = &string[self.prefix.len()..].to_string();
222 let new_prefix_index = self.ids.len() - self.prefix_index;
223 self.prefix = self
224 .tokenizer
225 .decode(self.ids.as_slice(), self.skip_special_tokens)?;
226 self.read_index = self.prefix_index;
227 self.prefix_index = new_prefix_index;
228 Ok(Some(new_text.to_string()))
229 } else {
230 Ok(None)
231 }
232 }
233}
234
235pub struct Sequence {
237 tokenizer: Tokenizer,
239
240 token_ids: Vec<TokenIdType>,
242
243 prefix_offset: usize,
245
246 read_offset: usize,
248}
249
250impl std::fmt::Debug for Sequence {
251 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252 f.debug_struct("Sequence")
253 .field("tokenizer", &"Arc<dyn Tokenizer>")
254 .field(
255 "token_ids",
256 &format_args!("{}", {
257 if self.token_ids.len() <= 20 {
258 format!("{:?}", self.token_ids)
259 } else {
260 let first_ten = &self.token_ids[..10];
261 let last_ten = &self.token_ids[self.token_ids.len() - 10..];
262 format!("{:?} ... {:?}", first_ten, last_ten)
263 }
264 }),
265 )
266 .field("prefix_offset", &self.prefix_offset)
267 .field("read_offset", &self.read_offset)
268 .field("token count", &self.token_ids.len())
269 .finish()
270 }
271}
272
273impl Sequence {
274 pub fn new(tokenizer: Tokenizer) -> Self {
275 Self {
276 tokenizer,
277 token_ids: Vec::new(),
278 prefix_offset: 0,
279 read_offset: 0,
280 }
281 }
282
283 pub fn is_empty(&self) -> bool {
284 self.token_ids.is_empty()
285 }
286
287 pub fn len(&self) -> usize {
288 self.token_ids.len()
289 }
290
291 pub fn clear(&mut self) {
292 self.token_ids.clear();
293 self.prefix_offset = 0;
294 self.read_offset = 0;
295 }
296
297 pub fn append_text(&mut self, input: &str) -> Result<()> {
298 let encoding = self.tokenizer.encode(input)?;
303 self.token_ids.extend(encoding.token_ids);
304 Ok(())
305 }
306
307 pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<String> {
311 self.token_ids.push(token_id);
312 let prefix_text = self
315 .tokenizer
316 .decode(&self.token_ids[self.prefix_offset..self.read_offset], false)?;
317
318 let new_text = self
319 .tokenizer
320 .decode(&self.token_ids[self.prefix_offset..], false)?;
321
322 let mut prefix_text_len = prefix_text.len();
326 while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
327 prefix_text_len -= 1;
328 }
329 let prefix_text_len = prefix_text_len;
330
331 if new_text.len() > prefix_text.len() {
332 if new_text.ends_with("�") {
333 return Ok("".to_string());
334 } else {
335 let new_text = new_text[prefix_text_len..].to_string().replace("�", "");
337 self.prefix_offset = self.read_offset;
338 self.read_offset = self.token_ids.len();
339 return Ok(new_text);
340 }
341 }
342
343 Ok("".to_string())
344 }
345
346 pub fn tokenizer(&self) -> Tokenizer {
347 self.tokenizer.clone()
348 }
349
350 pub fn token_ids(&self) -> &[TokenIdType] {
351 &self.token_ids
352 }
353
354 pub fn text(&self) -> Result<String> {
355 self.tokenizer.decode(&self.token_ids, false)
359 }
360}
361
362pub enum SequenceDecoderOutput {
365 Text(String),
367
368 Held,
371
372 Stopped,
375
376 StoppedWithText(String),
380}
381
382#[derive(Debug)]
388pub struct StopSequenceDecoder {
389 sequence: Sequence,
391
392 stop_token_ids_visible: Vec<TokenIdType>,
395
396 stop_token_ids_hidden: Vec<TokenIdType>,
399
400 #[allow(dead_code)]
403 stop_sequences_visible: Vec<String>,
404
405 stop_sequences_hidden: Vec<String>,
408
409 stopped: bool,
412
413 state: String,
416}
417
418impl StopSequenceDecoder {
419 pub fn builder(tokenizer: Tokenizer) -> StopSequenceDecoderBuilder {
421 StopSequenceDecoderBuilder::new(tokenizer)
422 }
423
424 pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
426 if self.stopped {
427 return Err(Error::msg("Decoder is stopped"));
428 }
429
430 let text = self.sequence.append_token_id(token_id)?;
432
433 self.state.push_str(text.as_str());
435
436 let mut stop: bool = false;
437 let mut visible: bool = false;
438
439 if self.stop_token_ids_visible.contains(&token_id) {
440 stop = true;
441 visible = true;
442 }
443
444 if self.stop_token_ids_hidden.contains(&token_id) {
445 stop = true;
446 visible = false;
447 }
448
449 if stop {
450 self.stopped = true;
451 let state = std::mem::take(&mut self.state);
452 if visible {
453 return Ok(SequenceDecoderOutput::StoppedWithText(state));
454 }
455 return Ok(SequenceDecoderOutput::Stopped);
456 }
457
458 for stop_sequence in self.stop_sequences_hidden.iter() {
460 if stop_sequence.starts_with(&self.state) {
461 if stop_sequence == &self.state {
462 self.stopped = true;
464 return Ok(SequenceDecoderOutput::Stopped);
465 } else {
466 return Ok(SequenceDecoderOutput::Held);
467 }
468 }
469 }
470
471 let state = std::mem::take(&mut self.state);
472 Ok(SequenceDecoderOutput::Text(state))
473 }
474
475 pub fn is_empty(&self) -> bool {
476 self.sequence.token_ids.is_empty()
477 }
478
479 pub fn len(&self) -> usize {
480 self.sequence.token_ids.len()
481 }
482
483 pub fn is_complete(&self) -> bool {
484 self.stopped
485 }
486
487 pub fn close(&mut self) {
488 self.stopped = true;
489 }
490}
491
492pub struct StopSequenceDecoderBuilder {
493 tokenizer: Tokenizer,
494 stop_token_ids_visible: Vec<TokenIdType>,
495 stop_token_ids_hidden: Vec<TokenIdType>,
496 stop_sequences_visible: Vec<String>,
497 stop_sequences_hidden: Vec<String>,
498}
499
500impl StopSequenceDecoderBuilder {
501 pub fn new(tokenizer: Tokenizer) -> Self {
502 Self {
503 tokenizer,
504 stop_token_ids_visible: Vec::new(),
505 stop_token_ids_hidden: Vec::new(),
506 stop_sequences_visible: Vec::new(),
507 stop_sequences_hidden: Vec::new(),
508 }
509 }
510
511 pub fn add_stop_token_id_visible(mut self, token_id: TokenIdType) -> Self {
513 self.stop_token_ids_visible.push(token_id);
514 self
515 }
516
517 pub fn add_stop_token_ids_visible(mut self, token_ids: &[TokenIdType]) -> Self {
520 self.stop_token_ids_visible.extend(token_ids);
521 self
522 }
523
524 pub fn add_stop_token_id_hidden(mut self, token_id: TokenIdType) -> Self {
526 self.stop_token_ids_hidden.push(token_id);
527 self
528 }
529
530 pub fn add_stop_token_ids_hidden(mut self, token_ids: &[TokenIdType]) -> Self {
533 self.stop_token_ids_hidden.extend(token_ids);
534 self
535 }
536
537 pub fn add_stop_sequence_visible(mut self, text: &str) -> Self {
538 self.stop_sequences_visible.push(text.to_string());
539 self
540 }
541
542 pub fn add_stop_sequences_visible(mut self, strings: &[&str]) -> Self {
543 self.stop_sequences_visible
544 .extend(strings.iter().map(|text| text.to_string()));
545 self
546 }
547
548 pub fn add_stop_sequence_hidden(mut self, text: &str) -> Self {
549 self.stop_sequences_hidden.push(text.to_string());
550 self
551 }
552
553 pub fn add_stop_sequences_hidden(mut self, strings: &[&str]) -> Self {
554 self.stop_sequences_hidden
555 .extend(strings.iter().map(|text| text.to_string()));
556 self
557 }
558
559 pub fn build(self) -> Result<StopSequenceDecoder> {
560 Ok(StopSequenceDecoder {
561 sequence: Sequence::new(self.tokenizer.clone()),
562 stop_token_ids_visible: self.stop_token_ids_visible,
563 stop_token_ids_hidden: self.stop_token_ids_hidden,
564 stop_sequences_visible: self.stop_sequences_visible,
565 stop_sequences_hidden: self.stop_sequences_hidden,
566 stopped: false,
567 state: String::new(),
568 })
569 }
570}