alith_client/components/
instruct_prompt.rs

1use alith_prompt::{PromptMessage, PromptMessageType, TextConcatenator, TextConcatenatorTrait};
2use anyhow::{Result, anyhow};
3
4#[derive(Clone)]
5pub struct InstructPrompt {
6    pub instructions: Option<PromptMessage>,
7    pub supporting_material: Option<PromptMessage>,
8    pub concatenator: TextConcatenator,
9}
10
11impl Default for InstructPrompt {
12    fn default() -> Self {
13        Self::new()
14    }
15}
16
17impl InstructPrompt {
18    pub fn new() -> Self {
19        Self {
20            instructions: None,
21            supporting_material: None,
22            concatenator: TextConcatenator::default(),
23        }
24    }
25
26    pub fn reset_instruct_prompt(&mut self) {
27        self.instructions = None;
28        self.supporting_material = None;
29    }
30
31    pub fn build_instructions(&self) -> Option<String> {
32        if let Some(instructions) = &self.instructions {
33            instructions.get_built_prompt_message().ok()
34        } else {
35            None
36        }
37    }
38
39    pub fn build_supporting_material(&self) -> Option<String> {
40        if let Some(supporting_material) = &self.supporting_material {
41            supporting_material.get_built_prompt_message().ok()
42        } else {
43            None
44        }
45    }
46
47    pub fn build_instruct_prompt(&mut self, supporting_material_first: bool) -> Result<String> {
48        Ok(
49            match (self.build_instructions(), self.build_supporting_material()) {
50                (Some(instructions), Some(supporting_material)) => {
51                    if supporting_material_first {
52                        format!(
53                            "{}{}{}",
54                            supporting_material,
55                            self.concatenator.as_str(),
56                            instructions
57                        )
58                    } else {
59                        format!(
60                            "{}{}{}",
61                            instructions,
62                            self.concatenator.as_str(),
63                            supporting_material
64                        )
65                    }
66                }
67                (Some(instructions), None) => instructions,
68                (None, Some(supporting_material)) => supporting_material,
69
70                (None, None) => {
71                    return Err(anyhow!("No instructions or supporting material found"));
72                }
73            },
74        )
75    }
76}
77
78impl TextConcatenatorTrait for InstructPrompt {
79    fn concatenator_mut(&mut self) -> &mut TextConcatenator {
80        &mut self.concatenator
81    }
82
83    fn clear_built(&self) {}
84}
85
86pub trait InstructPromptTrait {
87    fn instruct_prompt_mut(&mut self) -> &mut InstructPrompt;
88
89    fn set_instructions<T: AsRef<str>>(&mut self, instructions: T) -> &mut Self {
90        self.instructions().set_content(instructions);
91        self
92    }
93
94    fn instructions(&mut self) -> &mut PromptMessage {
95        if self.instruct_prompt_mut().instructions.is_none() {
96            self.instruct_prompt_mut().instructions = Some(PromptMessage::new(
97                PromptMessageType::User,
98                &self.instruct_prompt_mut().concatenator,
99            ));
100        }
101        self.instruct_prompt_mut().instructions.as_mut().unwrap()
102    }
103
104    fn set_supporting_material<T: AsRef<str>>(&mut self, supporting_material: T) -> &mut Self {
105        self.supporting_material().set_content(supporting_material);
106        self
107    }
108
109    fn supporting_material(&mut self) -> &mut PromptMessage {
110        if self.instruct_prompt_mut().supporting_material.is_none() {
111            self.instruct_prompt_mut().supporting_material = Some(PromptMessage::new(
112                PromptMessageType::User,
113                &self.instruct_prompt_mut().concatenator,
114            ));
115        }
116
117        self.instruct_prompt_mut()
118            .supporting_material
119            .as_mut()
120            .unwrap()
121    }
122}