gliclass/
tokenizer.rs

1//! Wrapper around HuggingFace tokenizers
2use std::path::Path;
3use crate::util::result::Result;
4
5/// Wrapper around HuggingFace tokenizers
6pub struct Tokenizer {
7    tokenizer: tokenizers::Tokenizer,
8}
9
10impl Tokenizer {
11
12    pub fn new<P: AsRef<Path>>(tokenizer_path: P, max_length: Option<usize>) -> Result<Self> {
13        let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)?;
14        
15        if let Some(length) = max_length {
16            let mut truncation = tokenizers::TruncationParams::default();
17            truncation.max_length = length;
18            tokenizer.with_truncation(Some(truncation))?;
19        }
20
21        let mut padding = tokenizers::PaddingParams::default();
22        padding.strategy = tokenizers::PaddingStrategy::BatchLongest;    
23        
24        tokenizer.with_padding(Some(padding));
25        
26        Ok(Self { tokenizer })
27    }
28
29    pub fn tokenize<'s, E: Into<tokenizers::EncodeInput<'s>> + Send>(&self, input: Vec<E>) -> Result<(ndarray::Array2<i64>, ndarray::Array2<i64>)> {
30        let encodings = self.tokenizer.encode_batch(input, true)?;
31        let max_tokens = encodings.first().map(|x| x.len()).unwrap_or(0);
32        let mut input_ids = ndarray::Array2::zeros((0, max_tokens));
33        let mut attn_masks = ndarray::Array2::zeros((0, max_tokens));
34        for encoding in encodings {
35            input_ids.push_row(ndarray::ArrayView::from(&Self::to_i64(encoding.get_ids()).to_vec()))?;
36            attn_masks.push_row(ndarray::ArrayView::from(&Self::to_i64(encoding.get_attention_mask()).to_vec()))?;
37        }
38        Ok((input_ids, attn_masks))
39    }
40
41    fn to_i64(array: &[u32]) -> Vec<i64> {
42        array.iter().map(|x| *x as i64).collect()
43    }
44
45}