gliner/model/input/
encoded.rs

1use composable::Composable;
2use crate::util::result::Result;
3use crate::text::{token::Token, tokenizer::Tokenizer};
4use super::prompt::PromptInput;
5use ndarray::{Array, Array2, ArrayView};
6
7/// Represents encoded prompts (after sub-word tokenization)
8pub struct EncodedInput {
9    pub texts: Vec<String>,
10    pub tokens: Vec<Vec<Token>>,
11    pub entities: Vec<String>,
12    pub num_words: usize,
13    pub num_tokens: usize,
14    pub input_ids: Array2<i64>,
15    pub attention_masks: Array2<i64>,
16    pub word_masks: Array2<i64>,
17    pub text_lengths: Array2<i64>,
18}
19
20/// Utility struct
21struct EncodedPrompt {
22    /// encodings of each word
23    encoding: Vec<Vec<u32>>,
24    /// offset of the first token of the actual text (beside entity labels)
25    text_offset: usize,
26}
27
28impl EncodedInput {
29
30    // Each word of each prompt is encoded *one by one*. So each word generates an encoding as 
31    // a Vec<u32> (sub-word tokenization). So for each prompt we get a Vec<Vec<u32>> (which is 
32    // stored in the 'encoding' field).
33    pub fn from(input: PromptInput, tokenizer: &impl Tokenizer) -> Result<Self> {
34        // prepare the result vector
35        let mut encodings: Vec<EncodedPrompt> = Vec::with_capacity(input.prompts.len());
36        // maximum number of sub-word tokens found in one prompt (will be the width of the input tensor)
37        let mut max_tokens: usize = 0;
38        // process each prompt
39        for prompt in &input.prompts {
40            // resulting sequence of encodings for each word of the current prompt
41            let mut prompt_tokens: Vec<Vec<u32>> = Vec::with_capacity(prompt.tokens().len());
42            // total number of sub-word tokens for the current prompt (adding 2 for initial and terminal tokens)
43            let mut total_tokens: usize = 2;
44            // number of sub-word tokens for the entities part only (before the actual text)
45            let mut total_entity_tokens = 0;
46            // encode each token of the current prompt
47            for (pos, word) in prompt.tokens().iter().enumerate() {
48                // actually encode the word
49                let encoding = tokenizer.encode(word)?;
50                // increment the number of sub-word tokens accordingly
51                total_tokens += encoding.len();
52                // increment the number of sub-word tokens in the entity part (will be used to start the word masks at the right place)
53                if pos < prompt.entities_len() {
54                    total_entity_tokens += encoding.len();
55                }
56                prompt_tokens.push(encoding);
57            }
58
59            // Adding 1 for the start token
60            let text_offset = total_entity_tokens + 1;
61
62            // update global result: push encoded prompt and update max_tokens
63            encodings.push(EncodedPrompt { encoding: prompt_tokens, text_offset });
64            max_tokens = std::cmp::max(max_tokens, total_tokens);
65        }
66
67        // Compute vectors for each prompt. The `encoding` structure (which is
68        // word by word) gets flattened, but the word-level information is
69        // still represented by the "word mask".
70        let mut input_ids = Array::zeros((0, max_tokens));
71        let mut attention_masks = Array::zeros((0, max_tokens));
72        let mut word_masks = Array::zeros((0, max_tokens));
73        for encoded_prompt in encodings {
74            let encoding = encoded_prompt.encoding;
75            let mut input_id = vec!(0i64; max_tokens);
76            let mut attn_mask = vec!(0i64; max_tokens);
77            let mut word_mask = vec!(0i64; max_tokens);
78
79            let mut idx: usize = 0;
80            let mut word_id: i64 = 0;
81
82            // add initial token
83            input_id[idx] = 1;
84            attn_mask[idx] = 1;
85            idx += 1;
86
87            // process each encoded (sub-word) token
88            for word in encoding {
89                for (token_idx, token) in word.iter().enumerate() {
90                    input_id[idx] = *token as i64;
91                    // attention mask
92                    attn_mask[idx] = 1;
93                    // word mask (only for non-label tokens and first token of the word)
94                    if idx >= encoded_prompt.text_offset && token_idx == 0 {
95                        word_mask[idx] = word_id;
96                    }
97                    // update position
98                    idx += 1;
99                }
100                // increment word mask (if we are over the label tokens)
101                if idx >= encoded_prompt.text_offset {
102                    word_id += 1;
103                }
104            }
105
106            // add terminal token
107            input_id[idx] = 2;
108            attn_mask[idx] = 1;
109
110            // update final results
111            input_ids.push_row(ArrayView::from(&input_id))?;
112            attention_masks.push_row(ArrayView::from(&attn_mask))?;
113            word_masks.push_row(ArrayView::from(&word_mask))?;
114        }
115
116        // text lengths (this data is fundamentally one-dimensional, but the model expects a two-dimensional one)
117        let mut text_lengths = Array::zeros((0, 1));
118        for text_length in input.text_lengths {
119            text_lengths.push_row(ArrayView::from(&vec![text_length as i64]))?;
120        }
121
122        // job's done
123        Ok(Self {
124            texts: input.texts,
125            tokens: input.tokens,
126            entities: input.entities,
127            num_words: input.num_words,
128            num_tokens: max_tokens,
129            input_ids,
130            attention_masks,
131            word_masks,
132            text_lengths,
133        })
134    }
135
136}
137
138
139
140/// Composable: Prompts => Encoded
141pub struct PromptsToEncoded<'a, T> {
142    tokenizer: &'a T,
143}
144
145impl<'a, T> PromptsToEncoded<'a, T> {
146    pub fn new(tokenizer: &'a T) -> Self {
147        Self { tokenizer }
148    }
149}
150
151impl<T: Tokenizer> Composable<PromptInput, EncodedInput> for PromptsToEncoded<'_, T> {
152    fn apply(&self, input: PromptInput) -> Result<EncodedInput> {
153        EncodedInput::from(input, self.tokenizer)
154    }
155}
156
157
158/// Unit tests
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test() -> Result<()> {
165        let splitter = crate::text::splitter::RegexSplitter::default();
166        let tokenizer = crate::text::tokenizer::HFTokenizer::from_file("models/gliner_small-v2.1/tokenizer.json")?;
167        let batch = [ "Short text", "This is a longer one, to test padding and gloubiboulga."];
168        let entities = [ "Person", "Place" ];
169        let input = super::super::text::TextInput::from_str(&batch, &entities)?;
170        let tokenized = super::super::tokenized::TokenizedInput::from(input, &splitter, None)?;
171        let prepared = PromptInput::from(tokenized);
172        let encoded = EncodedInput::from(prepared, &tokenizer)?;
173        // Some prints
174        if false {
175            println!("### {:?}", encoded.num_tokens);
176            println!("Tokens: {:?}", encoded.input_ids);
177            println!("Attn Masks: {:?}", encoded.attention_masks);
178            println!("Word masks: {:?}", encoded.word_masks);
179        }
180        // Assertions on input ids
181        const ENT_ID: i64 = 128002;
182        const SEP_ID: i64 = 128003;
183        assert_eq!(encoded.num_tokens, 22);
184        let ids1 = encoded.input_ids.row(0);
185        let ids2 = encoded.input_ids.row(1);
186        assert_eq!(ids1.len(), encoded.num_tokens);
187        assert_eq!(ids2.len(), encoded.num_tokens);
188        assert_eq!(ids1.iter().filter(|id| **id == 0).count(), 13);
189        assert_eq!(ids1.iter().filter(|id| **id == ENT_ID).count(), 2);
190        assert_eq!(ids1.iter().filter(|id| **id == SEP_ID).count(), 1);
191        assert_eq!(ids2.iter().filter(|id| **id == 0).count(), 0);
192        assert_eq!(ids2.iter().filter(|id| **id == ENT_ID).count(), 2);
193        assert_eq!(ids2.iter().filter(|id| **id == SEP_ID).count(), 1);
194        // Assertions on attention mask
195        let attn1 = encoded.attention_masks.row(0);
196        let attn2 = encoded.attention_masks.row(1);
197        assert_eq!(attn1.iter().filter(|id| **id == 1).count(), 9);
198        assert_eq!(attn2.iter().filter(|id| **id == 1).count(), 22);
199        // Everything rules
200        Ok(())
201    }
202
203    #[test]
204    fn test2() -> Result<()> {
205        let splitter = crate::text::splitter::RegexSplitter::default();
206        let tokenizer = crate::text::tokenizer::HFTokenizer::from_file(std::path::Path::new("models/gliner_small-v2.1/tokenizer.json"))?;
207        let batch = [ "My name is James Bond", "I like to drive my Aston Martin", "The villain in the movie is Auric Goldfinger"];
208        let entities = [ "movie character", "vehicle" ];
209        let input = super::super::text::TextInput::from_str(&batch, &entities)?;
210        let tokenized = super::super::tokenized::TokenizedInput::from(input, &splitter, None)?;
211        let prepared = PromptInput::from(tokenized);
212        let encoded = EncodedInput::from(prepared, &tokenizer)?;
213        // Some prints
214        if false {
215            println!("### {:?}", encoded.num_tokens);
216            println!("Tokens: {:?}", encoded.input_ids);
217            println!("Attn Masks: {:?}", encoded.attention_masks);
218            println!("Word masks: {:?}", encoded.word_masks);
219            println!("Text length: {:?}", encoded.text_lengths);
220        }
221        // Assertions on first sequence
222        let ids1 = encoded.input_ids.row(0);
223        let attn1 = encoded.attention_masks.row(0);
224        let word1 = encoded.word_masks.row(0);
225        let len1 = encoded.text_lengths.row(0);
226        assert_eq!(ids1.to_vec(), vec![1, 128002, 1421, 1470, 128002, 1508, 128003, 573, 601, 269, 1749, 8728, 2, 0, 0, 0, 0]);
227        assert_eq!(attn1.to_vec(), vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]);
228        assert_eq!(word1.to_vec(), vec![0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 0, 0, 0, 0, 0]);
229        assert_eq!(len1.to_vec(), vec![5]);
230        // Assertions on second sequence
231        let ids2 = encoded.input_ids.row(1);
232        let attn2 = encoded.attention_masks.row(1);
233        let word2 = encoded.word_masks.row(1);
234        let len2 = encoded.text_lengths.row(1);
235        assert_eq!(ids2.to_vec(), vec![1, 128002, 1421, 1470, 128002, 1508, 128003, 273, 334, 264, 1168, 312, 20844, 2963, 2, 0, 0]);
236        assert_eq!(attn2.to_vec(), vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]);
237        assert_eq!(word2.to_vec(), vec![0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0]);
238        assert_eq!(len2.to_vec(), vec![7]);
239        // Assertions on third sequence
240        let ids3 = encoded.input_ids.row(2);
241        let attn3 = encoded.attention_masks.row(2);
242        let word3 = encoded.word_masks.row(2);
243        let len3 = encoded.text_lengths.row(2);
244        assert_eq!(ids3.to_vec(), vec! [1, 128002, 1421, 1470, 128002, 1508, 128003, 279, 14701, 267, 262, 1421, 269, 336, 49530, 117349, 2]);
245        assert_eq!(attn3.to_vec(), vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
246        assert_eq!(word3.to_vec(), vec![0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 0, 8, 0]);
247        assert_eq!(len3.to_vec(), vec![8]);
248        Ok(())
249    }
250
251    #[test]
252    fn test_multiword_entity_label() -> Result<()> {
253        let splitter = crate::text::splitter::RegexSplitter::default();
254        let tokenizer = crate::text::tokenizer::HFTokenizer::from_file("models/gliner_small-v2.1/tokenizer.json")?;
255        let batch = [ "this is a test"];
256        let entities = [ "multi label" ];
257        let input = super::super::text::TextInput::from_str(&batch, &entities)?;
258        let tokenized = super::super::tokenized::TokenizedInput::from(input, &splitter, None)?;
259        let prepared = PromptInput::from(tokenized);
260        let encoded = EncodedInput::from(prepared, &tokenizer)?;
261        // Some prints
262        if false {
263            println!("### {:?}", encoded.num_tokens);
264            println!("Tokens: {:?}", encoded.input_ids);
265            println!("Attn Masks: {:?}", encoded.attention_masks);
266            println!("Word masks: {:?}", encoded.word_masks);
267        }
268        // Assertions
269        let ids = encoded.input_ids.row(0);
270        assert_eq!(ids.len(), 10);
271        let word_masks = encoded.word_masks.row(0);
272        assert_eq!(word_masks.to_vec(), vec![0, 0, 0, 0, 0, 1, 2, 3, 4, 0]);
273        // Everything rules
274        Ok(())
275    }
276
277    #[test]
278    fn test_words_mask_multi_token_first_word() -> Result<()> {
279        let splitter = crate::text::splitter::RegexSplitter::default();
280        let tokenizer = crate::text::tokenizer::HFTokenizer::from_file("models/gliner_small-v2.1/tokenizer.json")?;
281        // "1a" is encoded with 2 tokens, the rest are 1
282        let batch = [ "1a John Doe"];
283        let entities = ["name"];
284        let input = super::super::text::TextInput::from_str(&batch, &entities)?;
285        let tokenized = super::super::tokenized::TokenizedInput::from(input, &splitter, None)?;
286        let prepared = PromptInput::from(tokenized);
287        let encoded = EncodedInput::from(prepared, &tokenizer)?;
288
289        assert_eq!(encoded.input_ids.row(0).len(), 9);
290        assert_eq!(encoded.word_masks.row(0).to_vec(), vec![0, 0, 0, 0, 1, 0, 2, 3, 0]);
291
292        Ok(())
293    }
294
295}