1use composable::Composable;
2use crate::{util::error::InputError, params::Parameters};
3
4use super::text::{Labels, TextInput};
5
6const LABEL_PREFIX: &str = "<<LABEL>>";
7const PROMPT_SEPARATOR: &str = "<<SEP>>";
8
9pub struct PromptInput {
11 pub prompts: Vec<String>,
12 pub labels: Labels,
13}
14
15pub struct InputToPrompt {
16 prompt_first: bool,
17}
18
19impl InputToPrompt {
24 pub fn new(prompt_first: bool) -> Self {
25 Self { prompt_first }
26 }
27
28 pub fn with_params(params: &Parameters) -> Self {
29 Self::new(params.prompt_first())
30 }
31
32 fn make_prompt(labels: &Vec<String>) -> String {
33 let mut result = String::new();
34 for label in labels {
35 result.push_str(LABEL_PREFIX);
36 result.push_str(&label.to_lowercase());
37 }
38 result.push_str(PROMPT_SEPARATOR);
39 result
40 }
41
42 fn get_labels(labels: &Labels, index: usize) -> Result<&Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
43 labels.get(index).ok_or_else(|| InputError::new("per-text labels must be aligned with texts").boxed())
44 }
45}
46
47impl Default for InputToPrompt {
48 fn default() -> Self {
49 Self { prompt_first: false }
50 }
51}
52
53impl Composable<TextInput, PromptInput> for InputToPrompt {
55 fn apply(&self, input: TextInput) -> composable::Result<PromptInput> {
56 let mut prompts = Vec::with_capacity(input.texts.len());
57 for (index, text) in input.texts.into_iter().enumerate() {
58 let labels = Self::get_labels(&input.labels, index)?;
59 let mut prompt = Self::make_prompt(labels);
60 if self.prompt_first {
61 prompt.push_str(&text);
62 prompts.push(prompt);
63 }
64 else {
65 let mut text = text;
66 text.push_str(&prompt);
67 prompts.push(text);
68 }
69 }
70 Ok(PromptInput {
71 prompts,
72 labels: input.labels
73 })
74 }
75}