alith_prompt/
prompt_message.rs1use super::TextConcatenator;
2use serde::{Deserialize, Serialize};
3use std::sync::{Arc, Mutex, MutexGuard};
4
5#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
11pub enum PromptMessageType {
12 System,
15 User,
17 Assistant,
20 Function,
22}
23
24impl PromptMessageType {
25 pub fn as_str(&self) -> &str {
26 match self {
27 PromptMessageType::System => "system",
28 PromptMessageType::User => "user",
29 PromptMessageType::Assistant => "assistant",
30 PromptMessageType::Function => "function",
31 }
32 }
33}
34
35#[derive(Default, Debug)]
41pub struct PromptMessages(Mutex<Vec<Arc<PromptMessage>>>);
42
43impl PromptMessages {
44 pub(crate) fn messages(&self) -> MutexGuard<'_, Vec<Arc<PromptMessage>>> {
45 self.0
46 .lock()
47 .unwrap_or_else(|e| panic!("PromptMessages Error - messages not available: {:?}", e))
48 }
49}
50
51impl Clone for PromptMessages {
52 fn clone(&self) -> Self {
53 let cloned_messages: Vec<Arc<PromptMessage>> = self
54 .messages()
55 .iter()
56 .map(|message| Arc::new((**message).clone()))
57 .collect();
58 Self(cloned_messages.into())
59 }
60}
61
62impl std::fmt::Display for PromptMessages {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 let messages = self.messages();
65 for message in messages.iter() {
66 writeln!(f, "{}", message)?;
67 }
68 Ok(())
69 }
70}
71
72#[derive(Serialize, Deserialize, Debug)]
78pub struct PromptMessage {
79 pub content: Mutex<Vec<String>>,
80 pub built_prompt_message: Mutex<Option<String>>,
81 pub message_type: PromptMessageType,
82 pub concatenator: TextConcatenator,
83}
84
85impl PromptMessage {
86 pub fn new(message_type: PromptMessageType, concatenator: &TextConcatenator) -> Self {
87 Self {
88 content: Vec::new().into(),
89 built_prompt_message: None.into(),
90 message_type,
91 concatenator: concatenator.clone(),
92 }
93 }
94
95 pub fn set_content<T: AsRef<str>>(&self, content: T) -> &Self {
111 if content.as_ref().is_empty() {
112 return self;
113 }
114
115 let mut content_guard = self.content();
116 let should_update = content_guard
117 .first()
118 .is_none_or(|first| first != content.as_ref());
119
120 if should_update {
121 *content_guard = vec![content.as_ref().to_owned()];
122 self.build(content_guard);
123 }
124
125 self
126 }
127
128 pub fn prepend_content<T: AsRef<str>>(&self, content: T) -> &Self {
142 if content.as_ref().is_empty() {
143 return self;
144 }
145
146 let mut content_guard = self.content();
147 let should_update = content_guard
148 .first()
149 .is_none_or(|first| first != content.as_ref());
150
151 if should_update {
152 content_guard.insert(0, content.as_ref().to_owned());
153 self.build(content_guard);
154 }
155
156 self
157 }
158
159 pub fn append_content<T: AsRef<str>>(&self, content: T) -> &Self {
172 if content.as_ref().is_empty() {
173 return self;
174 }
175
176 let mut content_guard = self.content();
177 let should_update = content_guard
178 .last()
179 .is_none_or(|last| last != content.as_ref());
180
181 if should_update {
182 content_guard.push(content.as_ref().to_owned());
183 self.build(content_guard);
184 }
185
186 self
187 }
188
189 pub fn get_built_prompt_message(&self) -> Result<String, crate::Error> {
205 match &*self.built_prompt_message() {
206 Some(prompt) => Ok(prompt.clone()),
207 None => crate::bail!(
208 " PromptMessage Error - built_prompt_string not available - message not built"
209 ),
210 }
211 }
212
213 fn build(&self, content_guard: MutexGuard<'_, Vec<String>>) {
217 let mut built_prompt_message = String::new();
218
219 for c in content_guard.iter() {
220 if !built_prompt_message.is_empty() {
221 built_prompt_message.push_str(self.concatenator.as_str());
222 }
223 built_prompt_message.push_str(c.as_str());
224 }
225
226 *self.built_prompt_message() = Some(built_prompt_message);
227 }
228
229 fn content(&self) -> MutexGuard<'_, Vec<String>> {
233 self.content
234 .lock()
235 .unwrap_or_else(|e| panic!("PromptMessage Error - content not available: {:?}", e))
236 }
237
238 pub(crate) fn built_prompt_message(&self) -> MutexGuard<'_, Option<String>> {
239 self.built_prompt_message.lock().unwrap_or_else(|e| {
240 panic!(
241 "PromptMessage Error - built_prompt_message not available: {:?}",
242 e
243 )
244 })
245 }
246}
247
248impl Clone for PromptMessage {
249 fn clone(&self) -> Self {
250 Self {
251 content: self.content().clone().into(),
252 built_prompt_message: self.built_prompt_message().clone().into(),
253 message_type: self.message_type.clone(),
254 concatenator: self.concatenator.clone(),
255 }
256 }
257}
258
259impl std::fmt::Display for PromptMessage {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 let message_type = match self.message_type {
262 PromptMessageType::System => "System",
263 PromptMessageType::User => "User",
264 PromptMessageType::Assistant => "Assistant",
265 PromptMessageType::Function => "Function",
266 };
267 let message = match &*self.built_prompt_message() {
268 Some(built_message_string) => {
269 if built_message_string.len() > 300 {
270 format!(
271 "{}...",
272 built_message_string.chars().take(300).collect::<String>()
273 )
274 } else {
275 built_message_string.clone()
276 }
277 }
278 None => "debug message: empty or unbuilt".to_owned(),
279 };
280
281 writeln!(f, "\x1b[1m{message_type}\x1b[0m:\n{:?}", message)
282 }
283}