gliner/model/input/
prompt.rs

1use composable::Composable;
2use crate::util::result::Result;
3use crate::text::{prompt::Prompt, token::Token};
4use super::tokenized::TokenizedInput;
5
6
7/// Prepared prompts, appending entity and text tokens. 
8/// 
9/// Output form: 
10/// ```text
11/// [<<ENT>>, type1, <<ENT>>, type2, ..., <<ENT>>, typeK, <<SEP>>, token1, token2, ..., tokenN]
12/// ```
13pub struct PromptInput {
14    /// Texts (moved from input)
15    pub texts: Vec<String>,
16    /// Tokens (moved from input)
17    pub tokens: Vec<Vec<Token>>,
18    /// Entities (moved from input)
19    pub entities: Vec<String>,
20    /// Number of tokens of the text part for each prompt
21    pub text_lengths: Vec<usize>,
22    /// Maximum number of words in a prompt excluding entities (number of tokens in the largest sequence in the batch)
23    pub num_words: usize,
24    /// The actual prompts
25    pub prompts: Vec<Prompt>,    
26}
27
28
29impl PromptInput {
30
31    pub fn from(input: TokenizedInput) -> Self {
32        // prepare the entities part of the prompt (will be copied into each actual prompt)
33        let entities_prompt = Self::entities_prompt(&input.entities);        
34        // the text lengths for each sequence (number of actual tokens beside the entities part)
35        let mut text_lengths = Vec::<usize>::new();
36        // the maximum number of words in a prompt excluding entities (number of tokens in the largest sequence in the batch)
37        let mut num_words = 0;
38        // the actual prompts that will be created for each token sequence
39        let mut prompts = Vec::new();    
40        
41        // iterate over each sequence of tokens
42        for tokens in &input.tokens {
43            // prepare the sequence of tokens for this prompt
44            let mut prompt = Vec::with_capacity(entities_prompt.len() + tokens.len());
45            // copy the entities part
46            prompt.extend(entities_prompt.clone());
47            // append each text token of the current sequence
48            prompt.extend(tokens.iter().map(|token| token.text().to_string()));
49            // update output data
50            prompts.push(Prompt::new(prompt, tokens.len(), entities_prompt.len()));        
51            text_lengths.push(tokens.len());
52            num_words = std::cmp::max(num_words, tokens.len());
53    
54        }
55    
56        // job's done
57        Self {
58            texts: input.texts,
59            tokens: input.tokens,
60            entities: input.entities,
61            text_lengths,
62            num_words,
63            prompts,            
64        }
65    
66    }
67
68
69    /// Create the entities part of the prompt.
70    fn entities_prompt(entities: &Vec<String>) -> Vec<String> {
71        const ENTITY_TOKEN: &str = "<<ENT>>";
72        const SEP_TOKEN: &str = "<<SEP>>";
73
74        let mut result = Vec::with_capacity(entities.len() * 2 + 1);
75        for entity in entities {
76            result.push(ENTITY_TOKEN.to_string());
77            result.push(entity.clone());
78        }
79
80        result.push(SEP_TOKEN.to_string());
81        result
82    }
83
84}
85
86
87/// Composable: Tokenized => Prompt
88#[derive(Default)]
89pub struct TokenizedToPrompt { 
90}
91
92
93impl Composable<TokenizedInput, PromptInput> for TokenizedToPrompt {
94    fn apply(&self, input: TokenizedInput) -> Result<PromptInput> {
95        Ok(PromptInput::from(input))
96    }
97}
98
99/// Unit tests
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn test() -> Result<()> {
106        // Silent some clippy warnings for unit tests
107        #![allow(clippy::get_first)]
108        #![allow(clippy::unwrap_used)]
109        // Processing
110        let splitter = crate::text::splitter::RegexSplitter::default();
111        let batch = [ "This is a text !", "This is a longer one."];
112        let entities = [ "Person", "Place" ];
113        let input = super::super::text::TextInput::from_str(&batch, &entities)?;
114        let tokenized = super::super::tokenized::TokenizedInput::from(input, &splitter, None)?;
115        let prepared = PromptInput::from(tokenized);        
116        // Assertions
117        assert_eq!(prepared.prompts.len(), 2);
118        let prompt1 = prepared.prompts.get(0).unwrap();
119        let prompt2 = prepared.prompts.get(1).unwrap();
120        assert_eq!(prompt1.tokens().len(), 10);
121        assert_eq!(prompt2.tokens().len(), 11);        
122        assert_eq!(prompt1.text_len(), 5);
123        assert_eq!(prompt2.text_len(), 6);
124        assert_eq!(prompt1.entities_len(), prompt2.entities_len());
125        assert_eq!(prompt1.tokens().get(4).unwrap(), "<<SEP>>"); 
126        assert_eq!(prompt2.tokens().get(5).unwrap(), "This"); 
127        assert_eq!(prompt2.tokens().get(1).unwrap(), entities[0]);
128        assert_eq!(prompt2.tokens().get(3).unwrap(), entities[1]);
129        assert_eq!(prepared.num_words, prompt2.text_len()); // second prompt has the most tokens     
130        assert_eq!(prepared.text_lengths.len(), 2);
131        assert_eq!(*prepared.text_lengths.get(0).unwrap(), prompt1.text_len());
132        assert_eq!(*prepared.text_lengths.get(1).unwrap(), prompt2.text_len());
133        // Everything rules
134        Ok(())
135    }
136}