1use composable::Composable;
2use crate::util::result::Result;
3use crate::text::{token::Token, tokenizer::Tokenizer};
4use super::prompt::PromptInput;
5use ndarray::{Array, Array2, ArrayView};
6
7pub struct EncodedInput {
9 pub texts: Vec<String>,
10 pub tokens: Vec<Vec<Token>>,
11 pub entities: Vec<String>,
12 pub num_words: usize,
13 pub num_tokens: usize,
14 pub input_ids: Array2<i64>,
15 pub attention_masks: Array2<i64>,
16 pub word_masks: Array2<i64>,
17 pub text_lengths: Array2<i64>,
18}
19
20struct EncodedPrompt {
22 encoding: Vec<Vec<u32>>,
24 text_offset: usize,
26}
27
28impl EncodedInput {
29
30 pub fn from(input: PromptInput, tokenizer: &impl Tokenizer) -> Result<Self> {
34 let mut encodings: Vec<EncodedPrompt> = Vec::with_capacity(input.prompts.len());
36 let mut max_tokens: usize = 0;
38 for prompt in &input.prompts {
40 let mut prompt_tokens: Vec<Vec<u32>> = Vec::with_capacity(prompt.tokens().len());
42 let mut total_tokens: usize = 2;
44 let mut total_entity_tokens = 0;
46 for (pos, word) in prompt.tokens().iter().enumerate() {
48 let encoding = tokenizer.encode(word)?;
50 total_tokens += encoding.len();
52 if pos < prompt.entities_len() {
54 total_entity_tokens += encoding.len();
55 }
56 prompt_tokens.push(encoding);
57 }
58
59 let text_offset = total_entity_tokens + 1;
61
62 encodings.push(EncodedPrompt { encoding: prompt_tokens, text_offset });
64 max_tokens = std::cmp::max(max_tokens, total_tokens);
65 }
66
67 let mut input_ids = Array::zeros((0, max_tokens));
71 let mut attention_masks = Array::zeros((0, max_tokens));
72 let mut word_masks = Array::zeros((0, max_tokens));
73 for encoded_prompt in encodings {
74 let encoding = encoded_prompt.encoding;
75 let mut input_id = vec!(0i64; max_tokens);
76 let mut attn_mask = vec!(0i64; max_tokens);
77 let mut word_mask = vec!(0i64; max_tokens);
78
79 let mut idx: usize = 0;
80 let mut word_id: i64 = 0;
81
82 input_id[idx] = 1;
84 attn_mask[idx] = 1;
85 idx += 1;
86
87 for word in encoding {
89 for (token_idx, token) in word.iter().enumerate() {
90 input_id[idx] = *token as i64;
91 attn_mask[idx] = 1;
93 if idx >= encoded_prompt.text_offset && token_idx == 0 {
95 word_mask[idx] = word_id;
96 }
97 idx += 1;
99 }
100 if idx >= encoded_prompt.text_offset {
102 word_id += 1;
103 }
104 }
105
106 input_id[idx] = 2;
108 attn_mask[idx] = 1;
109
110 input_ids.push_row(ArrayView::from(&input_id))?;
112 attention_masks.push_row(ArrayView::from(&attn_mask))?;
113 word_masks.push_row(ArrayView::from(&word_mask))?;
114 }
115
116 let mut text_lengths = Array::zeros((0, 1));
118 for text_length in input.text_lengths {
119 text_lengths.push_row(ArrayView::from(&vec![text_length as i64]))?;
120 }
121
122 Ok(Self {
124 texts: input.texts,
125 tokens: input.tokens,
126 entities: input.entities,
127 num_words: input.num_words,
128 num_tokens: max_tokens,
129 input_ids,
130 attention_masks,
131 word_masks,
132 text_lengths,
133 })
134 }
135
136}
137
138
139
140pub struct PromptsToEncoded<'a, T> {
142 tokenizer: &'a T,
143}
144
145impl<'a, T> PromptsToEncoded<'a, T> {
146 pub fn new(tokenizer: &'a T) -> Self {
147 Self { tokenizer }
148 }
149}
150
151impl<T: Tokenizer> Composable<PromptInput, EncodedInput> for PromptsToEncoded<'_, T> {
152 fn apply(&self, input: PromptInput) -> Result<EncodedInput> {
153 EncodedInput::from(input, self.tokenizer)
154 }
155}
156
157
158#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test() -> Result<()> {
165 let splitter = crate::text::splitter::RegexSplitter::default();
166 let tokenizer = crate::text::tokenizer::HFTokenizer::from_file("models/gliner_small-v2.1/tokenizer.json")?;
167 let batch = [ "Short text", "This is a longer one, to test padding and gloubiboulga."];
168 let entities = [ "Person", "Place" ];
169 let input = super::super::text::TextInput::from_str(&batch, &entities)?;
170 let tokenized = super::super::tokenized::TokenizedInput::from(input, &splitter, None)?;
171 let prepared = PromptInput::from(tokenized);
172 let encoded = EncodedInput::from(prepared, &tokenizer)?;
173 if false {
175 println!("### {:?}", encoded.num_tokens);
176 println!("Tokens: {:?}", encoded.input_ids);
177 println!("Attn Masks: {:?}", encoded.attention_masks);
178 println!("Word masks: {:?}", encoded.word_masks);
179 }
180 const ENT_ID: i64 = 128002;
182 const SEP_ID: i64 = 128003;
183 assert_eq!(encoded.num_tokens, 22);
184 let ids1 = encoded.input_ids.row(0);
185 let ids2 = encoded.input_ids.row(1);
186 assert_eq!(ids1.len(), encoded.num_tokens);
187 assert_eq!(ids2.len(), encoded.num_tokens);
188 assert_eq!(ids1.iter().filter(|id| **id == 0).count(), 13);
189 assert_eq!(ids1.iter().filter(|id| **id == ENT_ID).count(), 2);
190 assert_eq!(ids1.iter().filter(|id| **id == SEP_ID).count(), 1);
191 assert_eq!(ids2.iter().filter(|id| **id == 0).count(), 0);
192 assert_eq!(ids2.iter().filter(|id| **id == ENT_ID).count(), 2);
193 assert_eq!(ids2.iter().filter(|id| **id == SEP_ID).count(), 1);
194 let attn1 = encoded.attention_masks.row(0);
196 let attn2 = encoded.attention_masks.row(1);
197 assert_eq!(attn1.iter().filter(|id| **id == 1).count(), 9);
198 assert_eq!(attn2.iter().filter(|id| **id == 1).count(), 22);
199 Ok(())
201 }
202
203 #[test]
204 fn test2() -> Result<()> {
205 let splitter = crate::text::splitter::RegexSplitter::default();
206 let tokenizer = crate::text::tokenizer::HFTokenizer::from_file(std::path::Path::new("models/gliner_small-v2.1/tokenizer.json"))?;
207 let batch = [ "My name is James Bond", "I like to drive my Aston Martin", "The villain in the movie is Auric Goldfinger"];
208 let entities = [ "movie character", "vehicle" ];
209 let input = super::super::text::TextInput::from_str(&batch, &entities)?;
210 let tokenized = super::super::tokenized::TokenizedInput::from(input, &splitter, None)?;
211 let prepared = PromptInput::from(tokenized);
212 let encoded = EncodedInput::from(prepared, &tokenizer)?;
213 if false {
215 println!("### {:?}", encoded.num_tokens);
216 println!("Tokens: {:?}", encoded.input_ids);
217 println!("Attn Masks: {:?}", encoded.attention_masks);
218 println!("Word masks: {:?}", encoded.word_masks);
219 println!("Text length: {:?}", encoded.text_lengths);
220 }
221 let ids1 = encoded.input_ids.row(0);
223 let attn1 = encoded.attention_masks.row(0);
224 let word1 = encoded.word_masks.row(0);
225 let len1 = encoded.text_lengths.row(0);
226 assert_eq!(ids1.to_vec(), vec![1, 128002, 1421, 1470, 128002, 1508, 128003, 573, 601, 269, 1749, 8728, 2, 0, 0, 0, 0]);
227 assert_eq!(attn1.to_vec(), vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]);
228 assert_eq!(word1.to_vec(), vec![0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0]);
229 assert_eq!(len1.to_vec(), vec![5]);
230 let ids2 = encoded.input_ids.row(1);
232 let attn2 = encoded.attention_masks.row(1);
233 let word2 = encoded.word_masks.row(1);
234 let len2 = encoded.text_lengths.row(1);
235 assert_eq!(ids2.to_vec(), vec![1, 128002, 1421, 1470, 128002, 1508, 128003, 273, 334, 264, 1168, 312, 20844, 2963, 2, 0, 0]);
236 assert_eq!(attn2.to_vec(), vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]);
237 assert_eq!(word2.to_vec(), vec![0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0]);
238 assert_eq!(len2.to_vec(), vec![7]);
239 let ids3 = encoded.input_ids.row(2);
241 let attn3 = encoded.attention_masks.row(2);
242 let word3 = encoded.word_masks.row(2);
243 let len3 = encoded.text_lengths.row(2);
244 assert_eq!(ids3.to_vec(), vec! [1, 128002, 1421, 1470, 128002, 1508, 128003, 279, 14701, 267, 262, 1421, 269, 336, 49530, 117349, 2]);
245 assert_eq!(attn3.to_vec(), vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
246 assert_eq!(word3.to_vec(), vec![0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 0, 8, 0]);
247 assert_eq!(len3.to_vec(), vec![8]);
248 Ok(())
249 }
250
251 #[test]
252 fn test_multiword_entity_label() -> Result<()> {
253 let splitter = crate::text::splitter::RegexSplitter::default();
254 let tokenizer = crate::text::tokenizer::HFTokenizer::from_file("models/gliner_small-v2.1/tokenizer.json")?;
255 let batch = [ "this is a test"];
256 let entities = [ "multi label" ];
257 let input = super::super::text::TextInput::from_str(&batch, &entities)?;
258 let tokenized = super::super::tokenized::TokenizedInput::from(input, &splitter, None)?;
259 let prepared = PromptInput::from(tokenized);
260 let encoded = EncodedInput::from(prepared, &tokenizer)?;
261 if false {
263 println!("### {:?}", encoded.num_tokens);
264 println!("Tokens: {:?}", encoded.input_ids);
265 println!("Attn Masks: {:?}", encoded.attention_masks);
266 println!("Word masks: {:?}", encoded.word_masks);
267 }
268 let ids = encoded.input_ids.row(0);
270 assert_eq!(ids.len(), 10);
271 let word_masks = encoded.word_masks.row(0);
272 assert_eq!(word_masks.to_vec(), vec![0, 0, 0, 0, 0, 1, 2, 3, 4, 0]);
273 Ok(())
275 }
276
277 #[test]
278 fn test_words_mask_multi_token_first_word() -> Result<()> {
279 let splitter = crate::text::splitter::RegexSplitter::default();
280 let tokenizer = crate::text::tokenizer::HFTokenizer::from_file("models/gliner_small-v2.1/tokenizer.json")?;
281 let batch = [ "1a John Doe"];
283 let entities = ["name"];
284 let input = super::super::text::TextInput::from_str(&batch, &entities)?;
285 let tokenized = super::super::tokenized::TokenizedInput::from(input, &splitter, None)?;
286 let prepared = PromptInput::from(tokenized);
287 let encoded = EncodedInput::from(prepared, &tokenizer)?;
288
289 assert_eq!(encoded.input_ids.row(0).len(), 9);
290 assert_eq!(encoded.word_masks.row(0).to_vec(), vec![0, 0, 0, 0, 1, 0, 2, 3, 0]);
291
292 Ok(())
293 }
294
295}