alith_prompt/
llm_prompt.rs1use std::sync::Arc;
2use std::{
3 collections::HashMap,
4 sync::{Mutex, MutexGuard},
5};
6
7use crate::prompt_message::PromptMessages;
8use crate::{
9 ApiPrompt, LocalPrompt, PromptMessage, PromptMessageType, PromptTokenizer, TextConcatenator,
10 TextConcatenatorTrait,
11};
12
13pub struct LLMPrompt {
20 local_prompt: Option<LocalPrompt>,
21 api_prompt: Option<ApiPrompt>,
22 pub messages: PromptMessages,
23 pub concatenator: TextConcatenator,
24 pub built_prompt_messages: Mutex<Option<Vec<HashMap<String, String>>>>,
25}
26
27impl LLMPrompt {
28 pub fn new_local_prompt(
43 tokenizer: std::sync::Arc<dyn PromptTokenizer>,
44 chat_template: &str,
45 bos_token: Option<&str>,
46 eos_token: &str,
47 unk_token: Option<&str>,
48 base_generation_prefix: Option<&str>,
49 ) -> Self {
50 Self {
51 local_prompt: Some(LocalPrompt::new(
52 tokenizer,
53 chat_template,
54 bos_token,
55 eos_token,
56 unk_token,
57 base_generation_prefix,
58 )),
59 ..Default::default()
60 }
61 }
62
63 pub fn new_api_prompt(
75 tokenizer: std::sync::Arc<dyn PromptTokenizer>,
76 tokens_per_message: Option<u32>,
77 tokens_per_name: Option<i32>,
78 ) -> Self {
79 Self {
80 api_prompt: Some(ApiPrompt::new(
81 tokenizer,
82 tokens_per_message,
83 tokens_per_name,
84 )),
85 ..Default::default()
86 }
87 }
88
89 pub fn add_system_message(&self) -> Result<Arc<PromptMessage>, crate::Error> {
101 {
102 let mut messages = self.messages();
103
104 if !messages.is_empty() {
105 crate::bail!("System message must be first message.");
106 };
107
108 let message = Arc::new(PromptMessage::new(
109 PromptMessageType::System,
110 &self.concatenator,
111 ));
112 messages.push(message);
113 }
114 self.clear_built_prompt();
115 Ok(self.last_message())
116 }
117
118 pub fn add_user_message(&self) -> Result<Arc<PromptMessage>, crate::Error> {
127 {
128 let mut messages = self.messages();
129
130 if let Some(last) = messages.last() {
131 if last.message_type == PromptMessageType::User {
132 crate::bail!("Cannot add user message when previous message is user message.");
133 }
134 }
135
136 let message = Arc::new(PromptMessage::new(
137 PromptMessageType::User,
138 &self.concatenator,
139 ));
140 messages.push(message);
141 }
142 self.clear_built_prompt();
143 Ok(self.last_message())
144 }
145
146 pub fn add_assistant_message(&self) -> Result<Arc<PromptMessage>, crate::Error> {
155 {
156 let mut messages = self.messages();
157
158 if messages.is_empty() {
159 crate::bail!("Cannot add assistant message as first message.");
160 } else if let Some(last) = messages.last() {
161 if last.message_type == PromptMessageType::Assistant {
162 crate::bail!(
163 "Cannot add assistant message when previous message is assistant message."
164 );
165 }
166 };
167
168 let message = Arc::new(PromptMessage::new(
169 PromptMessageType::Assistant,
170 &self.concatenator,
171 ));
172 messages.push(message);
173 }
174 self.clear_built_prompt();
175 Ok(self.last_message())
176 }
177
178 pub fn set_generation_prefix<T: AsRef<str>>(&self, generation_prefix: T) {
187 self.clear_built_prompt();
188 if let Some(local_prompt) = &self.local_prompt {
189 local_prompt.set_generation_prefix(generation_prefix);
190 };
191 }
192
193 pub fn clear_generation_prefix(&self) {
195 self.clear_built_prompt();
196 if let Some(local_prompt) = &self.local_prompt {
197 local_prompt.clear_generation_prefix();
198 };
199 }
200
201 pub fn reset_prompt(&self) {
203 self.messages().clear();
204 self.clear_built_prompt();
205 }
206
207 pub fn clear_built_prompt(&self) {
209 if let Some(api_prompt) = &self.api_prompt {
210 api_prompt.clear_built_prompt();
211 };
212 if let Some(local_prompt) = &self.local_prompt {
213 local_prompt.clear_built_prompt();
214 };
215 }
216
217 pub fn local_prompt(&self) -> Result<&LocalPrompt, crate::Error> {
226 if let Some(local_prompt) = &self.local_prompt {
227 if local_prompt.get_built_prompt().is_err() {
228 self.precheck_build()?;
229 self.build_prompt()?;
230 }
231 Ok(local_prompt)
232 } else {
233 crate::bail!("LocalPrompt is None");
234 }
235 }
236
237 pub fn api_prompt(&self) -> Result<&ApiPrompt, crate::Error> {
243 if let Some(api_prompt) = &self.api_prompt {
244 if api_prompt.get_built_prompt().is_err() {
245 self.precheck_build()?;
246 self.build_prompt()?;
247 }
248 Ok(api_prompt)
249 } else {
250 crate::bail!("ApiPrompt is None");
251 }
252 }
253
254 pub fn get_built_prompt_messages(&self) -> Result<Vec<HashMap<String, String>>, crate::Error> {
272 let built_prompt_messages = self.built_prompt_messages();
273
274 if let Some(built_prompt_messages) = &*built_prompt_messages {
275 return Ok(built_prompt_messages.clone());
276 };
277
278 self.precheck_build()?;
279 self.build_prompt()?;
280 if let Some(built_prompt_messages) = &*built_prompt_messages {
281 Ok(built_prompt_messages.clone())
282 } else {
283 crate::bail!("built_prompt_messages is None after building!");
284 }
285 }
286
287 fn precheck_build(&self) -> crate::Result<()> {
291 if let Some(last) = self.messages().last() {
292 if last.message_type == PromptMessageType::Assistant {
293 crate::bail!(
294 "Cannot build prompt when the current inference message is PromptMessageType::Assistant"
295 )
296 } else if last.message_type == PromptMessageType::System {
297 crate::bail!(
298 "Cannot build prompt when the current inference message is PromptMessageType::System"
299 )
300 } else {
301 Ok(())
302 }
303 } else {
304 crate::bail!("Cannot build prompt when there are no messages.")
305 }
306 }
307
308 fn build_prompt(&self) -> crate::Result<()> {
309 let messages = self.messages();
310 let mut built_prompt_messages: Vec<HashMap<String, String>> = Vec::new();
311 let mut last_message_type = None;
312
313 for (i, message) in messages.iter().enumerate() {
314 let message_type = &message.message_type;
315 if *message_type == PromptMessageType::System && i != 0 {
318 panic!("System message can only be the first message.");
319 }
320 if i == 0
322 && *message_type != PromptMessageType::System
323 && *message_type != PromptMessageType::User
324 {
325 panic!("Conversation must start with either a System or User message.");
326 }
327 if i > 0 {
329 match (last_message_type, message_type) {
330 (Some(PromptMessageType::User), PromptMessageType::Assistant) => {}
331 (Some(PromptMessageType::Assistant), PromptMessageType::User) => {}
332 (Some(PromptMessageType::System), PromptMessageType::User) => {}
333 _ => panic!(
334 "Messages must alternate between User and Assistant after the first message (which can be System)."
335 ),
336 }
337 }
338 last_message_type = Some(message_type.clone());
339
340 if let Some(built_message_string) = &*message.built_prompt_message() {
341 built_prompt_messages.push(HashMap::from([
342 ("role".to_string(), message.message_type.as_str().to_owned()),
343 ("content".to_string(), built_message_string.to_owned()),
344 ]));
345 } else {
346 crate::bail!("message.built_content is empty and skipped");
347 }
348 }
349
350 *self.built_prompt_messages.lock().unwrap_or_else(|e| {
351 panic!(
352 "LlmPrompt Error - built_prompt_messages not available: {:?}",
353 e
354 )
355 }) = Some(built_prompt_messages.clone());
356
357 if let Some(api_prompt) = &self.api_prompt {
358 api_prompt.build_prompt(&built_prompt_messages);
359 };
360 if let Some(local_prompt) = &self.local_prompt {
361 local_prompt.build_prompt(&built_prompt_messages);
362 };
363
364 Ok(())
365 }
366
367 fn messages(&self) -> MutexGuard<'_, Vec<Arc<PromptMessage>>> {
371 self.messages.messages()
372 }
373
374 fn last_message(&self) -> Arc<PromptMessage> {
375 self.messages()
376 .last()
377 .expect("LlmPrompt Error - last message not available")
378 .clone()
379 }
380
381 fn built_prompt_messages(&self) -> MutexGuard<'_, Option<Vec<HashMap<String, String>>>> {
382 self.built_prompt_messages.lock().unwrap_or_else(|e| {
383 panic!(
384 "LlmPrompt Error - built_prompt_messages not available: {:?}",
385 e
386 )
387 })
388 }
389}
390
391impl Default for LLMPrompt {
392 fn default() -> Self {
393 Self {
394 local_prompt: None,
395 api_prompt: None,
396 messages: PromptMessages::default(),
397 concatenator: TextConcatenator::default(),
398 built_prompt_messages: Mutex::new(None),
399 }
400 }
401}
402
403impl Clone for LLMPrompt {
404 fn clone(&self) -> Self {
405 Self {
406 local_prompt: self.local_prompt.clone(),
407 api_prompt: self.api_prompt.clone(),
408 messages: self.messages.clone(),
409 concatenator: self.concatenator.clone(),
410 built_prompt_messages: self.built_prompt_messages().clone().into(),
411 }
412 }
413}
414
415impl std::fmt::Display for LLMPrompt {
416 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417 writeln!(f)?;
418 writeln!(f, "LlmPrompt")?;
419
420 if self.get_built_prompt_messages().is_err() {
422 match self.build_prompt() {
423 Ok(_) => {}
424 Err(e) => {
425 writeln!(f, "Error building prompt: {:?}", e)?;
426 }
427 }
428 }
429
430 if let Some(local_prompt) = &self.local_prompt {
431 write!(f, "{}", local_prompt)?;
432 }
433
434 if let Some(api_prompt) = &self.api_prompt {
435 write!(f, "{}", api_prompt)?;
436 }
437
438 Ok(())
439 }
440}
441
442impl TextConcatenatorTrait for LLMPrompt {
443 fn concatenator_mut(&mut self) -> &mut TextConcatenator {
444 &mut self.concatenator
445 }
446
447 fn clear_built(&self) {
448 self.clear_built_prompt();
449 }
450}