1use std::path::Path;
3use crate::util::result::Result;
4
5pub struct Tokenizer {
7 tokenizer: tokenizers::Tokenizer,
8}
9
10impl Tokenizer {
11
12 pub fn new<P: AsRef<Path>>(tokenizer_path: P, max_length: Option<usize>) -> Result<Self> {
13 let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)?;
14
15 if let Some(length) = max_length {
16 let mut truncation = tokenizers::TruncationParams::default();
17 truncation.max_length = length;
18 tokenizer.with_truncation(Some(truncation))?;
19 }
20
21 let mut padding = tokenizers::PaddingParams::default();
22 padding.strategy = tokenizers::PaddingStrategy::BatchLongest;
23
24 tokenizer.with_padding(Some(padding));
25
26 Ok(Self { tokenizer })
27 }
28
29 pub fn tokenize<'s, E: Into<tokenizers::EncodeInput<'s>> + Send>(&self, input: Vec<E>) -> Result<(ndarray::Array2<i64>, ndarray::Array2<i64>)> {
30 let encodings = self.tokenizer.encode_batch(input, true)?;
31 let max_tokens = encodings.first().map(|x| x.len()).unwrap_or(0);
32 let mut input_ids = ndarray::Array2::zeros((0, max_tokens));
33 let mut attn_masks = ndarray::Array2::zeros((0, max_tokens));
34 for encoding in encodings {
35 input_ids.push_row(ndarray::ArrayView::from(&Self::to_i64(encoding.get_ids()).to_vec()))?;
36 attn_masks.push_row(ndarray::ArrayView::from(&Self::to_i64(encoding.get_attention_mask()).to_vec()))?;
37 }
38 Ok((input_ids, attn_masks))
39 }
40
41 fn to_i64(array: &[u32]) -> Vec<i64> {
42 array.iter().map(|x| *x as i64).collect()
43 }
44
45}