1pub mod hf;
17
18#[cfg(feature = "sentencepiece")]
19pub mod sp;
20
21use std::hash::{DefaultHasher, Hash, Hasher};
26use std::sync::Arc;
27use std::{ops::Deref, path::Path};
28
29use crate::protocols::TokenIdType;
30pub use anyhow::{Error, Result};
31
32pub use hf::HuggingFaceTokenizer;
33
34#[cfg(feature = "sentencepiece")]
35pub use sp::SentencePieceTokenizer;
36
37#[derive(Debug)]
39pub enum TokenizerType {
40    HuggingFace(String),
41    #[cfg(feature = "sentencepiece")]
42    SentencePiece(String),
43}
44
45pub type Offsets = (usize, usize);
47
48#[derive(Debug, Hash)]
50pub struct Encoding {
51    pub token_ids: Vec<TokenIdType>,
52    pub tokens: Vec<String>,
53    pub spans: Vec<Offsets>,
54}
55
56pub mod traits {
57    use super::*;
58
59    pub trait Encoder: Send + Sync {
60        fn encode(&self, input: &str) -> Result<Encoding>;
61    }
62
63    pub trait Decoder: Send + Sync {
64        fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
65    }
66
67    pub trait Tokenizer: Encoder + Decoder {
68        }
71}
72
73impl Encoding {
74    pub fn get_hash(&self) -> u64 {
75        let mut hasher = DefaultHasher::new();
76        self.hash(&mut hasher);
77        hasher.finish()
78    }
79}
80
81#[derive(Clone)]
83pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
84
85impl Tokenizer {
86    pub fn from_file(file_path: &str) -> Result<Tokenizer> {
87        Ok(Tokenizer(create_tokenizer_from_file(file_path)?))
88    }
89
90    pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream {
92        DecodeStream::new(self.0.clone(), skip_special_tokens)
93    }
94}
95
96impl Deref for Tokenizer {
97    type Target = Arc<dyn traits::Tokenizer>;
98
99    fn deref(&self) -> &Self::Target {
100        &self.0
101    }
102}
103
104impl From<Arc<dyn traits::Tokenizer>> for Tokenizer {
105    fn from(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
106        Tokenizer(tokenizer)
107    }
108}
109
110impl<T> From<Arc<T>> for Tokenizer
111where
112    T: traits::Tokenizer + 'static, {
114    fn from(tokenizer: Arc<T>) -> Self {
115        Tokenizer(tokenizer)
116    }
117}
118
119pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
125    let path = Path::new(file_path);
126    let extension = path
127        .extension()
128        .and_then(std::ffi::OsStr::to_str)
129        .ok_or_else(|| Error::msg("Failed to read file extension".to_string()))?;
130
131    match extension {
132        "json" => {
133            let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
134            Ok(Arc::new(tokenizer))
135        }
136        "model" => {
137            #[cfg(feature = "sentencepiece")]
138            {
139                let tokenizer = SentencePieceTokenizer::from_file(file_path)?;
140                Ok(Arc::new(tokenizer))
141            }
142            #[cfg(not(feature = "sentencepiece"))]
143            {
144                Err(Error::msg(
145                    "SentencePiece tokenizer not supported".to_string(),
146                ))
147            }
148        }
149        _ => Err(Error::msg("Unsupported file type".to_string())),
150    }
151}
152
153pub struct DecodeStream {
159    tokenizer: Arc<dyn traits::Tokenizer>,
161
162    skip_special_tokens: bool,
163    ids: Vec<u32>,
175
176    prefix: String,
179
180    prefix_index: usize,
183
184    read_index: usize,
189}
190
191impl DecodeStream {
192    pub fn new(tokenizer: Arc<dyn traits::Tokenizer>, skip_special_tokens: bool) -> Self {
193        Self {
194            tokenizer,
195            skip_special_tokens,
196            ids: Vec::new(),
197            prefix: "".to_string(),
198            prefix_index: 0,
199            read_index: 0,
200        }
201    }
202
203    pub fn step(&mut self, id: u32) -> Result<Option<String>> {
212        self.ids.push(id);
213        let string = self
214            .tokenizer
215            .decode(self.ids.as_slice(), self.skip_special_tokens)?;
216
217        if string.len() > self.prefix.len() && !string.ends_with('�') {
218            if !(string.starts_with(&self.prefix)) {
219                anyhow::bail!("Detokenizer failure: invalid prefix");
220            }
221            let new_text = &string[self.prefix.len()..].to_string();
222            let new_prefix_index = self.ids.len() - self.prefix_index;
223            self.prefix = self
224                .tokenizer
225                .decode(self.ids.as_slice(), self.skip_special_tokens)?;
226            self.read_index = self.prefix_index;
227            self.prefix_index = new_prefix_index;
228            Ok(Some(new_text.to_string()))
229        } else {
230            Ok(None)
231        }
232    }
233}
234
235pub struct Sequence {
237    tokenizer: Tokenizer,
239
240    token_ids: Vec<TokenIdType>,
242
243    prefix_offset: usize,
245
246    read_offset: usize,
248}
249
250impl std::fmt::Debug for Sequence {
251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        f.debug_struct("Sequence")
253            .field("tokenizer", &"Arc<dyn Tokenizer>")
254            .field(
255                "token_ids",
256                &format_args!("{}", {
257                    if self.token_ids.len() <= 20 {
258                        format!("{:?}", self.token_ids)
259                    } else {
260                        let first_ten = &self.token_ids[..10];
261                        let last_ten = &self.token_ids[self.token_ids.len() - 10..];
262                        format!("{:?} ... {:?}", first_ten, last_ten)
263                    }
264                }),
265            )
266            .field("prefix_offset", &self.prefix_offset)
267            .field("read_offset", &self.read_offset)
268            .field("token count", &self.token_ids.len())
269            .finish()
270    }
271}
272
273impl Sequence {
274    pub fn new(tokenizer: Tokenizer) -> Self {
275        Self {
276            tokenizer,
277            token_ids: Vec::new(),
278            prefix_offset: 0,
279            read_offset: 0,
280        }
281    }
282
283    pub fn is_empty(&self) -> bool {
284        self.token_ids.is_empty()
285    }
286
287    pub fn len(&self) -> usize {
288        self.token_ids.len()
289    }
290
291    pub fn clear(&mut self) {
292        self.token_ids.clear();
293        self.prefix_offset = 0;
294        self.read_offset = 0;
295    }
296
297    pub fn append_text(&mut self, input: &str) -> Result<()> {
298        let encoding = self.tokenizer.encode(input)?;
303        self.token_ids.extend(encoding.token_ids);
304        Ok(())
305    }
306
307    pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<String> {
311        self.token_ids.push(token_id);
312        let prefix_text = self
315            .tokenizer
316            .decode(&self.token_ids[self.prefix_offset..self.read_offset], false)?;
317
318        let new_text = self
319            .tokenizer
320            .decode(&self.token_ids[self.prefix_offset..], false)?;
321
322        let mut prefix_text_len = prefix_text.len();
326        while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
327            prefix_text_len -= 1;
328        }
329        let prefix_text_len = prefix_text_len;
330
331        if new_text.len() > prefix_text.len() {
332            if new_text.ends_with("�") {
333                return Ok("".to_string());
334            } else {
335                let new_text = new_text[prefix_text_len..].to_string().replace("�", "");
337                self.prefix_offset = self.read_offset;
338                self.read_offset = self.token_ids.len();
339                return Ok(new_text);
340            }
341        }
342
343        Ok("".to_string())
344    }
345
346    pub fn tokenizer(&self) -> Tokenizer {
347        self.tokenizer.clone()
348    }
349
350    pub fn token_ids(&self) -> &[TokenIdType] {
351        &self.token_ids
352    }
353
354    pub fn text(&self) -> Result<String> {
355        self.tokenizer.decode(&self.token_ids, false)
359    }
360}
361
362pub enum SequenceDecoderOutput {
365    Text(String),
367
368    Held,
371
372    Stopped,
375
376    StoppedWithText(String),
380}
381
382#[derive(Debug)]
388pub struct StopSequenceDecoder {
389    sequence: Sequence,
391
392    stop_token_ids_visible: Vec<TokenIdType>,
395
396    stop_token_ids_hidden: Vec<TokenIdType>,
399
400    #[allow(dead_code)]
403    stop_sequences_visible: Vec<String>,
404
405    stop_sequences_hidden: Vec<String>,
408
409    stopped: bool,
412
413    state: String,
416}
417
418impl StopSequenceDecoder {
419    pub fn builder(tokenizer: Tokenizer) -> StopSequenceDecoderBuilder {
421        StopSequenceDecoderBuilder::new(tokenizer)
422    }
423
424    pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
426        if self.stopped {
427            return Err(Error::msg("Decoder is stopped"));
428        }
429
430        let text = self.sequence.append_token_id(token_id)?;
432
433        self.state.push_str(text.as_str());
435
436        let mut stop: bool = false;
437        let mut visible: bool = false;
438
439        if self.stop_token_ids_visible.contains(&token_id) {
440            stop = true;
441            visible = true;
442        }
443
444        if self.stop_token_ids_hidden.contains(&token_id) {
445            stop = true;
446            visible = false;
447        }
448
449        if stop {
450            self.stopped = true;
451            let state = std::mem::take(&mut self.state);
452            if visible {
453                return Ok(SequenceDecoderOutput::StoppedWithText(state));
454            }
455            return Ok(SequenceDecoderOutput::Stopped);
456        }
457
458        for stop_sequence in self.stop_sequences_hidden.iter() {
460            if stop_sequence.starts_with(&self.state) {
461                if stop_sequence == &self.state {
462                    self.stopped = true;
464                    return Ok(SequenceDecoderOutput::Stopped);
465                } else {
466                    return Ok(SequenceDecoderOutput::Held);
467                }
468            }
469        }
470
471        let state = std::mem::take(&mut self.state);
472        Ok(SequenceDecoderOutput::Text(state))
473    }
474
475    pub fn is_empty(&self) -> bool {
476        self.sequence.token_ids.is_empty()
477    }
478
479    pub fn len(&self) -> usize {
480        self.sequence.token_ids.len()
481    }
482
483    pub fn is_complete(&self) -> bool {
484        self.stopped
485    }
486
487    pub fn close(&mut self) {
488        self.stopped = true;
489    }
490}
491
492pub struct StopSequenceDecoderBuilder {
493    tokenizer: Tokenizer,
494    stop_token_ids_visible: Vec<TokenIdType>,
495    stop_token_ids_hidden: Vec<TokenIdType>,
496    stop_sequences_visible: Vec<String>,
497    stop_sequences_hidden: Vec<String>,
498}
499
500impl StopSequenceDecoderBuilder {
501    pub fn new(tokenizer: Tokenizer) -> Self {
502        Self {
503            tokenizer,
504            stop_token_ids_visible: Vec::new(),
505            stop_token_ids_hidden: Vec::new(),
506            stop_sequences_visible: Vec::new(),
507            stop_sequences_hidden: Vec::new(),
508        }
509    }
510
511    pub fn add_stop_token_id_visible(mut self, token_id: TokenIdType) -> Self {
513        self.stop_token_ids_visible.push(token_id);
514        self
515    }
516
517    pub fn add_stop_token_ids_visible(mut self, token_ids: &[TokenIdType]) -> Self {
520        self.stop_token_ids_visible.extend(token_ids);
521        self
522    }
523
524    pub fn add_stop_token_id_hidden(mut self, token_id: TokenIdType) -> Self {
526        self.stop_token_ids_hidden.push(token_id);
527        self
528    }
529
530    pub fn add_stop_token_ids_hidden(mut self, token_ids: &[TokenIdType]) -> Self {
533        self.stop_token_ids_hidden.extend(token_ids);
534        self
535    }
536
537    pub fn add_stop_sequence_visible(mut self, text: &str) -> Self {
538        self.stop_sequences_visible.push(text.to_string());
539        self
540    }
541
542    pub fn add_stop_sequences_visible(mut self, strings: &[&str]) -> Self {
543        self.stop_sequences_visible
544            .extend(strings.iter().map(|text| text.to_string()));
545        self
546    }
547
548    pub fn add_stop_sequence_hidden(mut self, text: &str) -> Self {
549        self.stop_sequences_hidden.push(text.to_string());
550        self
551    }
552
553    pub fn add_stop_sequences_hidden(mut self, strings: &[&str]) -> Self {
554        self.stop_sequences_hidden
555            .extend(strings.iter().map(|text| text.to_string()));
556        self
557    }
558
559    pub fn build(self) -> Result<StopSequenceDecoder> {
560        Ok(StopSequenceDecoder {
561            sequence: Sequence::new(self.tokenizer.clone()),
562            stop_token_ids_visible: self.stop_token_ids_visible,
563            stop_token_ids_hidden: self.stop_token_ids_hidden,
564            stop_sequences_visible: self.stop_sequences_visible,
565            stop_sequences_hidden: self.stop_sequences_hidden,
566            stopped: false,
567            state: String::new(),
568        })
569    }
570}