gliclass/input/
prompt.rs

1use composable::Composable;
2use crate::{util::error::InputError, params::Parameters};
3
4use super::text::{Labels, TextInput};
5
6const LABEL_PREFIX: &str = "<<LABEL>>";
7const PROMPT_SEPARATOR: &str = "<<SEP>>";
8
9/// Prompts build from input texts and labels
10pub struct PromptInput {
11    pub prompts: Vec<String>,
12    pub labels: Labels,
13}
14
15pub struct InputToPrompt {
16    prompt_first: bool,
17}
18
19/// Transformation from text input to prompts.
20///
21/// Prompt format: `[sequence]<<LABEL>>label1<<LABEL>label2...<<SEP>>[sequence]`. 
22/// The actual text comes before or after, depending on the `prompt_first` parameter.
23impl InputToPrompt {
24    pub fn new(prompt_first: bool) -> Self {
25        Self { prompt_first }
26    }
27
28    pub fn with_params(params: &Parameters) -> Self {
29        Self::new(params.prompt_first())
30    }
31
32    fn make_prompt(labels: &Vec<String>) -> String {
33        let mut result = String::new();
34        for label in labels {
35            result.push_str(LABEL_PREFIX);
36            result.push_str(&label.to_lowercase());
37        }
38        result.push_str(PROMPT_SEPARATOR);
39        result
40    }
41
42    fn get_labels(labels: &Labels, index: usize) -> Result<&Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
43        labels.get(index).ok_or_else(|| InputError::new("per-text labels must be aligned with texts").boxed())
44    }
45}
46
47impl Default for InputToPrompt {
48    fn default() -> Self {
49        Self { prompt_first: false }
50    }
51}
52
53/// Transformation from input text to prompts
54impl Composable<TextInput, PromptInput> for InputToPrompt {
55    fn apply(&self, input: TextInput) -> composable::Result<PromptInput> {
56        let mut prompts = Vec::with_capacity(input.texts.len());
57        for (index, text) in input.texts.into_iter().enumerate() {
58            let labels = Self::get_labels(&input.labels, index)?;
59            let mut prompt = Self::make_prompt(labels);
60            if self.prompt_first {
61                prompt.push_str(&text);
62                prompts.push(prompt);
63            }
64            else {
65                let mut text = text;
66                text.push_str(&prompt);
67                prompts.push(text);                
68            }
69        }
70        Ok(PromptInput { 
71            prompts, 
72            labels: input.labels 
73        })
74    }
75}