1use anyhow::anyhow;
2use anyhow::Result;
3use std::path::Path;
4pub use tokenizers::tokenizer::Tokenizer;
5pub struct Tokener {
6 tokenizer: Tokenizer,
7}
8
9pub type VecIdsMask = (Vec<Box<[u32]>>, Vec<Box<[u32]>>);
10pub type IdsMask = (Box<[u32]>, Box<[u32]>);
11
12impl Tokener {
13 pub fn from_file(path: impl AsRef<Path>, max_length: usize) -> Result<Self> {
14 let mut tokenizer = Tokenizer::from_file(path).map_err(|e| anyhow!(e))?;
15 let truncation = tokenizers::utils::truncation::TruncationParams {
16 max_length,
17 ..Default::default()
18 };
19 tokenizer.with_truncation(Some(truncation));
20 Ok(Tokener { tokenizer })
21 }
22
23 pub fn encode_batch(
24 &self,
25 txt_li: impl ExactSizeIterator + Iterator<Item = impl AsRef<str>>,
26 ) -> Result<VecIdsMask> {
27 let len = txt_li.len();
28 let mut id_li_li = Vec::with_capacity(len);
29 let mut mask_li = Vec::with_capacity(len);
30 if len > 0 {
31 let tokenizer = &self.tokenizer;
32 let li = tokenizer
33 .encode_batch(
34 txt_li.map(|x| x.as_ref().to_owned()).collect::<Vec<_>>(),
35 true,
36 )
37 .map_err(|e| anyhow!(e))?;
38 let max = li.iter().map(|item| item.get_ids().len()).max().unwrap();
39 for encoding in li {
40 let id_li = encoding.get_ids();
41 let mask = encoding.get_attention_mask();
42
43 let diff = max - id_li.len();
44 let (id_li, mask) = if diff == 0 {
45 (Box::from(id_li), Box::from(mask))
46 } else {
47 let mut id_li = id_li.to_vec();
48 let mut mask = mask.to_vec();
49 id_li.extend(std::iter::repeat(0).take(diff));
50 mask.extend(std::iter::repeat(0).take(diff));
51 (id_li.into(), mask.into())
52 };
53
54 id_li_li.push(id_li);
55 mask_li.push(mask);
56 }
57 }
58 Ok((id_li_li, mask_li))
59 }
60
61 pub fn encode(&self, txt: impl AsRef<str>) -> Result<IdsMask> {
62 let tokenizer = &self.tokenizer;
63 let encoding = tokenizer
64 .encode(txt.as_ref(), true)
65 .map_err(|e| anyhow!(e))?;
66 let id_li = Box::from(encoding.get_ids());
67 let mask = Box::from(encoding.get_attention_mask());
68 Ok((id_li, mask))
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[test]
77 fn test() -> Result<()> {
78 let mut process_dir: std::path::PathBuf = std::env::current_dir()?;
79 process_dir.push("process");
80 let mut process_tokenizer_json = process_dir.clone();
81 process_tokenizer_json.push("tokenizer.json");
82 let tokener = Tokener::from_file(process_tokenizer_json, 77)?;
83 let li = [
84 "a photo of dog",
85 "a photo of chinese woman",
86 ];
88 for word in li {
89 let vec = tokener.encode(word)?;
90 println!("\n❯ {}\n{:?}\n", word, vec);
91 }
92 Ok(())
94 }
95}