clip_txt/
lib.rs

1use anyhow::anyhow;
2use anyhow::Result;
3use std::path::Path;
4pub use tokenizers::tokenizer::Tokenizer;
5pub struct Tokener {
6  tokenizer: Tokenizer,
7}
8
9pub type VecIdsMask = (Vec<Box<[u32]>>, Vec<Box<[u32]>>);
10pub type IdsMask = (Box<[u32]>, Box<[u32]>);
11
12impl Tokener {
13  pub fn from_file(path: impl AsRef<Path>, max_length: usize) -> Result<Self> {
14    let mut tokenizer = Tokenizer::from_file(path).map_err(|e| anyhow!(e))?;
15    let truncation = tokenizers::utils::truncation::TruncationParams {
16      max_length,
17      ..Default::default()
18    };
19    tokenizer.with_truncation(Some(truncation));
20    Ok(Tokener { tokenizer })
21  }
22
23  pub fn encode_batch(
24    &self,
25    txt_li: impl ExactSizeIterator + Iterator<Item = impl AsRef<str>>,
26  ) -> Result<VecIdsMask> {
27    let len = txt_li.len();
28    let mut id_li_li = Vec::with_capacity(len);
29    let mut mask_li = Vec::with_capacity(len);
30    if len > 0 {
31      let tokenizer = &self.tokenizer;
32      let li = tokenizer
33        .encode_batch(
34          txt_li.map(|x| x.as_ref().to_owned()).collect::<Vec<_>>(),
35          true,
36        )
37        .map_err(|e| anyhow!(e))?;
38      let max = li.iter().map(|item| item.get_ids().len()).max().unwrap();
39      for encoding in li {
40        let id_li = encoding.get_ids();
41        let mask = encoding.get_attention_mask();
42
43        let diff = max - id_li.len();
44        let (id_li, mask) = if diff == 0 {
45          (Box::from(id_li), Box::from(mask))
46        } else {
47          let mut id_li = id_li.to_vec();
48          let mut mask = mask.to_vec();
49          id_li.extend(std::iter::repeat(0).take(diff));
50          mask.extend(std::iter::repeat(0).take(diff));
51          (id_li.into(), mask.into())
52        };
53
54        id_li_li.push(id_li);
55        mask_li.push(mask);
56      }
57    }
58    Ok((id_li_li, mask_li))
59  }
60
61  pub fn encode(&self, txt: impl AsRef<str>) -> Result<IdsMask> {
62    let tokenizer = &self.tokenizer;
63    let encoding = tokenizer
64      .encode(txt.as_ref(), true)
65      .map_err(|e| anyhow!(e))?;
66    let id_li = Box::from(encoding.get_ids());
67    let mask = Box::from(encoding.get_attention_mask());
68    Ok((id_li, mask))
69  }
70}
71
72#[cfg(test)]
73mod tests {
74  use super::*;
75
76  #[test]
77  fn test() -> Result<()> {
78    let mut process_dir: std::path::PathBuf = std::env::current_dir()?;
79    process_dir.push("process");
80    let mut process_tokenizer_json = process_dir.clone();
81    process_tokenizer_json.push("tokenizer.json");
82    let tokener = Tokener::from_file(process_tokenizer_json, 77)?;
83    let li = [
84      "a photo of dog",
85      "a photo of chinese woman",
86      //"房贷利率(mortgage rate),是指在银行办理的用于购房的贷款,该贷款要按照银行规定的利率支付利息。中国房贷利率是由中国人民银行统一规定的,各个商业银行执行的时候可以在一定的区间内自行浮动;房贷利率不是一直不变的,而是经常变动的。",
87    ];
88    for word in li {
89      let vec = tokener.encode(word)?;
90      println!("\n❯ {}\n{:?}\n", word, vec);
91    }
92    // dbg!(tokener.encode_batch(li.into())?);
93    Ok(())
94  }
95}