gliner/model/input/
prompt.rs1use composable::Composable;
2use crate::util::result::Result;
3use crate::text::{prompt::Prompt, token::Token};
4use super::tokenized::TokenizedInput;
5
6
7pub struct PromptInput {
14 pub texts: Vec<String>,
16 pub tokens: Vec<Vec<Token>>,
18 pub entities: Vec<String>,
20 pub text_lengths: Vec<usize>,
22 pub num_words: usize,
24 pub prompts: Vec<Prompt>,
26}
27
28
29impl PromptInput {
30
31 pub fn from(input: TokenizedInput) -> Self {
32 let entities_prompt = Self::entities_prompt(&input.entities);
34 let mut text_lengths = Vec::<usize>::new();
36 let mut num_words = 0;
38 let mut prompts = Vec::new();
40
41 for tokens in &input.tokens {
43 let mut prompt = Vec::with_capacity(entities_prompt.len() + tokens.len());
45 prompt.extend(entities_prompt.clone());
47 prompt.extend(tokens.iter().map(|token| token.text().to_string()));
49 prompts.push(Prompt::new(prompt, tokens.len(), entities_prompt.len()));
51 text_lengths.push(tokens.len());
52 num_words = std::cmp::max(num_words, tokens.len());
53
54 }
55
56 Self {
58 texts: input.texts,
59 tokens: input.tokens,
60 entities: input.entities,
61 text_lengths,
62 num_words,
63 prompts,
64 }
65
66 }
67
68
69 fn entities_prompt(entities: &Vec<String>) -> Vec<String> {
71 const ENTITY_TOKEN: &str = "<<ENT>>";
72 const SEP_TOKEN: &str = "<<SEP>>";
73
74 let mut result = Vec::with_capacity(entities.len() * 2 + 1);
75 for entity in entities {
76 result.push(ENTITY_TOKEN.to_string());
77 result.push(entity.clone());
78 }
79
80 result.push(SEP_TOKEN.to_string());
81 result
82 }
83
84}
85
86
87#[derive(Default)]
89pub struct TokenizedToPrompt {
90}
91
92
93impl Composable<TokenizedInput, PromptInput> for TokenizedToPrompt {
94 fn apply(&self, input: TokenizedInput) -> Result<PromptInput> {
95 Ok(PromptInput::from(input))
96 }
97}
98
99#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn test() -> Result<()> {
106 #![allow(clippy::get_first)]
108 #![allow(clippy::unwrap_used)]
109 let splitter = crate::text::splitter::RegexSplitter::default();
111 let batch = [ "This is a text !", "This is a longer one."];
112 let entities = [ "Person", "Place" ];
113 let input = super::super::text::TextInput::from_str(&batch, &entities)?;
114 let tokenized = super::super::tokenized::TokenizedInput::from(input, &splitter, None)?;
115 let prepared = PromptInput::from(tokenized);
116 assert_eq!(prepared.prompts.len(), 2);
118 let prompt1 = prepared.prompts.get(0).unwrap();
119 let prompt2 = prepared.prompts.get(1).unwrap();
120 assert_eq!(prompt1.tokens().len(), 10);
121 assert_eq!(prompt2.tokens().len(), 11);
122 assert_eq!(prompt1.text_len(), 5);
123 assert_eq!(prompt2.text_len(), 6);
124 assert_eq!(prompt1.entities_len(), prompt2.entities_len());
125 assert_eq!(prompt1.tokens().get(4).unwrap(), "<<SEP>>");
126 assert_eq!(prompt2.tokens().get(5).unwrap(), "This");
127 assert_eq!(prompt2.tokens().get(1).unwrap(), entities[0]);
128 assert_eq!(prompt2.tokens().get(3).unwrap(), entities[1]);
129 assert_eq!(prepared.num_words, prompt2.text_len()); assert_eq!(prepared.text_lengths.len(), 2);
131 assert_eq!(*prepared.text_lengths.get(0).unwrap(), prompt1.text_len());
132 assert_eq!(*prepared.text_lengths.get(1).unwrap(), prompt2.text_len());
133 Ok(())
135 }
136}