alith_prompt/
llm_prompt.rs

1use std::sync::Arc;
2use std::{
3    collections::HashMap,
4    sync::{Mutex, MutexGuard},
5};
6
7use crate::prompt_message::PromptMessages;
8use crate::{
9    ApiPrompt, LocalPrompt, PromptMessage, PromptMessageType, PromptTokenizer, TextConcatenator,
10    TextConcatenatorTrait,
11};
12
13/// A prompt management system that supports both API-based LLMs (like OpenAI) and local LLMs.
14///
15/// `LlmPrompt` provides a unified interface for building and managing prompts in different formats,
16/// with support for both API-style messaging (system/user/assistant) and local LLM chat templates.
17/// It handles token counting, message validation, and proper prompt formatting.
18/// ```
19pub struct LLMPrompt {
20    local_prompt: Option<LocalPrompt>,
21    api_prompt: Option<ApiPrompt>,
22    pub messages: PromptMessages,
23    pub concatenator: TextConcatenator,
24    pub built_prompt_messages: Mutex<Option<Vec<HashMap<String, String>>>>,
25}
26
27impl LLMPrompt {
28    /// Creates a new prompt instance configured for local LLMs using chat templates.
29    ///
30    /// # Arguments
31    ///
32    /// * `tokenizer` - A tokenizer implementation for counting tokens
33    /// * `chat_template` - The chat template string used to format messages
34    /// * `bos_token` - Optional beginning of sequence token
35    /// * `eos_token` - End of sequence token
36    /// * `unk_token` - Optional unknown token
37    /// * `base_generation_prefix` - Optional prefix to add before generation
38    ///
39    /// # Returns
40    ///
41    /// A new `LlmPrompt` instance configured for local LLM usage.
42    pub fn new_local_prompt(
43        tokenizer: std::sync::Arc<dyn PromptTokenizer>,
44        chat_template: &str,
45        bos_token: Option<&str>,
46        eos_token: &str,
47        unk_token: Option<&str>,
48        base_generation_prefix: Option<&str>,
49    ) -> Self {
50        Self {
51            local_prompt: Some(LocalPrompt::new(
52                tokenizer,
53                chat_template,
54                bos_token,
55                eos_token,
56                unk_token,
57                base_generation_prefix,
58            )),
59            ..Default::default()
60        }
61    }
62
63    /// Creates a new prompt instance configured for API-based LLMs like OpenAI.
64    ///
65    /// # Arguments
66    ///
67    /// * `tokenizer` - A tokenizer implementation for counting tokens
68    /// * `tokens_per_message` - Optional number of tokens to add per message (model-specific)
69    /// * `tokens_per_name` - Optional number of tokens to add for names (model-specific)
70    ///
71    /// # Returns
72    ///
73    /// A new `LlmPrompt` instance configured for API usage.
74    pub fn new_api_prompt(
75        tokenizer: std::sync::Arc<dyn PromptTokenizer>,
76        tokens_per_message: Option<u32>,
77        tokens_per_name: Option<i32>,
78    ) -> Self {
79        Self {
80            api_prompt: Some(ApiPrompt::new(
81                tokenizer,
82                tokens_per_message,
83                tokens_per_name,
84            )),
85            ..Default::default()
86        }
87    }
88
89    // Setter methods
90    //
91
92    /// Adds a system message to the prompt.
93    ///
94    /// System messages must be the first message in the sequence.
95    /// Returns an error if attempting to add a system message after other messages.
96    ///
97    /// # Returns
98    ///
99    /// A reference to the newly created message for setting content, or an error if validation fails.
100    pub fn add_system_message(&self) -> Result<Arc<PromptMessage>, crate::Error> {
101        {
102            let mut messages = self.messages();
103
104            if !messages.is_empty() {
105                crate::bail!("System message must be first message.");
106            };
107
108            let message = Arc::new(PromptMessage::new(
109                PromptMessageType::System,
110                &self.concatenator,
111            ));
112            messages.push(message);
113        }
114        self.clear_built_prompt();
115        Ok(self.last_message())
116    }
117
118    /// Adds a user message to the prompt.
119    ///
120    /// Cannot add a user message directly after another user message.
121    /// Returns an error if attempting to add consecutive user messages.
122    ///
123    /// # Returns
124    ///
125    /// A reference to the newly created message for setting content, or an error if validation fails.
126    pub fn add_user_message(&self) -> Result<Arc<PromptMessage>, crate::Error> {
127        {
128            let mut messages = self.messages();
129
130            if let Some(last) = messages.last() {
131                if last.message_type == PromptMessageType::User {
132                    crate::bail!("Cannot add user message when previous message is user message.");
133                }
134            }
135
136            let message = Arc::new(PromptMessage::new(
137                PromptMessageType::User,
138                &self.concatenator,
139            ));
140            messages.push(message);
141        }
142        self.clear_built_prompt();
143        Ok(self.last_message())
144    }
145
146    /// Adds an assistant message to the prompt.
147    ///
148    /// Cannot be the first message or follow another assistant message.
149    /// Returns an error if attempting to add as first message or after another assistant message.
150    ///
151    /// # Returns
152    ///
153    /// A reference to the newly created message for setting content, or an error if validation fails.
154    pub fn add_assistant_message(&self) -> Result<Arc<PromptMessage>, crate::Error> {
155        {
156            let mut messages = self.messages();
157
158            if messages.is_empty() {
159                crate::bail!("Cannot add assistant message as first message.");
160            } else if let Some(last) = messages.last() {
161                if last.message_type == PromptMessageType::Assistant {
162                    crate::bail!(
163                        "Cannot add assistant message when previous message is assistant message."
164                    );
165                }
166            };
167
168            let message = Arc::new(PromptMessage::new(
169                PromptMessageType::Assistant,
170                &self.concatenator,
171            ));
172            messages.push(message);
173        }
174        self.clear_built_prompt();
175        Ok(self.last_message())
176    }
177
178    /// Sets a prefix to be added before generation for local LLMs.
179    ///
180    /// This is typically used to prime the model's response.
181    /// Only applies to local LLM prompts, has no effect on API prompts.
182    ///
183    /// # Arguments
184    ///
185    /// * `generation_prefix` - The text to add before generation
186    pub fn set_generation_prefix<T: AsRef<str>>(&self, generation_prefix: T) {
187        self.clear_built_prompt();
188        if let Some(local_prompt) = &self.local_prompt {
189            local_prompt.set_generation_prefix(generation_prefix);
190        };
191    }
192
193    /// Clears any previously set generation prefix.
194    pub fn clear_generation_prefix(&self) {
195        self.clear_built_prompt();
196        if let Some(local_prompt) = &self.local_prompt {
197            local_prompt.clear_generation_prefix();
198        };
199    }
200
201    /// Resets the prompt, clearing all messages and built state.
202    pub fn reset_prompt(&self) {
203        self.messages().clear();
204        self.clear_built_prompt();
205    }
206
207    /// Clears any built prompt state, forcing a rebuild on next access.
208    pub fn clear_built_prompt(&self) {
209        if let Some(api_prompt) = &self.api_prompt {
210            api_prompt.clear_built_prompt();
211        };
212        if let Some(local_prompt) = &self.local_prompt {
213            local_prompt.clear_built_prompt();
214        };
215    }
216
217    // Getter methods
218    //
219
220    /// Gets and builds the local prompt if this is prompt has one. This method is required to unwrap the prompt and build it.
221    ///
222    /// # Returns
223    ///
224    /// A reference to the `LocalPrompt` if present, otherwise returns an error
225    pub fn local_prompt(&self) -> Result<&LocalPrompt, crate::Error> {
226        if let Some(local_prompt) = &self.local_prompt {
227            if local_prompt.get_built_prompt().is_err() {
228                self.precheck_build()?;
229                self.build_prompt()?;
230            }
231            Ok(local_prompt)
232        } else {
233            crate::bail!("LocalPrompt is None");
234        }
235    }
236
237    /// Gets and builds the API prompt if this is prompt has one. This method is required to unwrap the prompt and build it.
238    ///
239    /// # Returns
240    ///
241    /// A reference to the `ApiPrompt` if present, otherwise returns an error
242    pub fn api_prompt(&self) -> Result<&ApiPrompt, crate::Error> {
243        if let Some(api_prompt) = &self.api_prompt {
244            if api_prompt.get_built_prompt().is_err() {
245                self.precheck_build()?;
246                self.build_prompt()?;
247            }
248            Ok(api_prompt)
249        } else {
250            crate::bail!("ApiPrompt is None");
251        }
252    }
253
254    /// Retrieves the prompt messages in a standardized format compatible with API calls.
255    ///
256    /// This method returns messages in the same format as `ApiPrompt::get_built_prompt()`,
257    /// making it useful for consistent message handling across different LLM implementations.
258    /// The method handles lazy building of the prompt - if the messages haven't been built yet,
259    /// it will trigger the build process automatically.
260    ///
261    /// # Returns
262    ///
263    /// Returns `Ok(Vec<HashMap<String, String>>)` containing the formatted messages on success.
264    ///
265    /// # Errors
266    ///
267    /// Returns an error if:
268    /// - The current message sequence violates prompt rules (e.g., assistant message first)
269    /// - The build process fails
270    /// - The built messages are unexpectedly None after building
271    pub fn get_built_prompt_messages(&self) -> Result<Vec<HashMap<String, String>>, crate::Error> {
272        let built_prompt_messages = self.built_prompt_messages();
273
274        if let Some(built_prompt_messages) = &*built_prompt_messages {
275            return Ok(built_prompt_messages.clone());
276        };
277
278        self.precheck_build()?;
279        self.build_prompt()?;
280        if let Some(built_prompt_messages) = &*built_prompt_messages {
281            Ok(built_prompt_messages.clone())
282        } else {
283            crate::bail!("built_prompt_messages is None after building!");
284        }
285    }
286
287    // Builder methods
288    //
289
290    fn precheck_build(&self) -> crate::Result<()> {
291        if let Some(last) = self.messages().last() {
292            if last.message_type == PromptMessageType::Assistant {
293                crate::bail!(
294                    "Cannot build prompt when the current inference message is PromptMessageType::Assistant"
295                )
296            } else if last.message_type == PromptMessageType::System {
297                crate::bail!(
298                    "Cannot build prompt when the current inference message is PromptMessageType::System"
299                )
300            } else {
301                Ok(())
302            }
303        } else {
304            crate::bail!("Cannot build prompt when there are no messages.")
305        }
306    }
307
308    fn build_prompt(&self) -> crate::Result<()> {
309        let messages = self.messages();
310        let mut built_prompt_messages: Vec<HashMap<String, String>> = Vec::new();
311        let mut last_message_type = None;
312
313        for (i, message) in messages.iter().enumerate() {
314            let message_type = &message.message_type;
315            // Should these checks be moved elsewhere?
316            // Rule 1: System message can only be the first message
317            if *message_type == PromptMessageType::System && i != 0 {
318                panic!("System message can only be the first message.");
319            }
320            // Rule 2: First message must be either System or User
321            if i == 0
322                && *message_type != PromptMessageType::System
323                && *message_type != PromptMessageType::User
324            {
325                panic!("Conversation must start with either a System or User message.");
326            }
327            // Rule 3: Ensure alternating User/Assistant messages after the first message
328            if i > 0 {
329                match (last_message_type, message_type) {
330                    (Some(PromptMessageType::User), PromptMessageType::Assistant) => {}
331                    (Some(PromptMessageType::Assistant), PromptMessageType::User) => {}
332                    (Some(PromptMessageType::System), PromptMessageType::User) => {}
333                    _ => panic!(
334                        "Messages must alternate between User and Assistant after the first message (which can be System)."
335                    ),
336                }
337            }
338            last_message_type = Some(message_type.clone());
339
340            if let Some(built_message_string) = &*message.built_prompt_message() {
341                built_prompt_messages.push(HashMap::from([
342                    ("role".to_string(), message.message_type.as_str().to_owned()),
343                    ("content".to_string(), built_message_string.to_owned()),
344                ]));
345            } else {
346                crate::bail!("message.built_content is empty and skipped");
347            }
348        }
349
350        *self.built_prompt_messages.lock().unwrap_or_else(|e| {
351            panic!(
352                "LlmPrompt Error - built_prompt_messages not available: {:?}",
353                e
354            )
355        }) = Some(built_prompt_messages.clone());
356
357        if let Some(api_prompt) = &self.api_prompt {
358            api_prompt.build_prompt(&built_prompt_messages);
359        };
360        if let Some(local_prompt) = &self.local_prompt {
361            local_prompt.build_prompt(&built_prompt_messages);
362        };
363
364        Ok(())
365    }
366
367    // Helper methods
368    //
369
370    fn messages(&self) -> MutexGuard<'_, Vec<Arc<PromptMessage>>> {
371        self.messages.messages()
372    }
373
374    fn last_message(&self) -> Arc<PromptMessage> {
375        self.messages()
376            .last()
377            .expect("LlmPrompt Error - last message not available")
378            .clone()
379    }
380
381    fn built_prompt_messages(&self) -> MutexGuard<'_, Option<Vec<HashMap<String, String>>>> {
382        self.built_prompt_messages.lock().unwrap_or_else(|e| {
383            panic!(
384                "LlmPrompt Error - built_prompt_messages not available: {:?}",
385                e
386            )
387        })
388    }
389}
390
391impl Default for LLMPrompt {
392    fn default() -> Self {
393        Self {
394            local_prompt: None,
395            api_prompt: None,
396            messages: PromptMessages::default(),
397            concatenator: TextConcatenator::default(),
398            built_prompt_messages: Mutex::new(None),
399        }
400    }
401}
402
403impl Clone for LLMPrompt {
404    fn clone(&self) -> Self {
405        Self {
406            local_prompt: self.local_prompt.clone(),
407            api_prompt: self.api_prompt.clone(),
408            messages: self.messages.clone(),
409            concatenator: self.concatenator.clone(),
410            built_prompt_messages: self.built_prompt_messages().clone().into(),
411        }
412    }
413}
414
415impl std::fmt::Display for LLMPrompt {
416    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417        writeln!(f)?;
418        writeln!(f, "LlmPrompt")?;
419
420        // Builds prompt if not already built, but skips the precheck.
421        if self.get_built_prompt_messages().is_err() {
422            match self.build_prompt() {
423                Ok(_) => {}
424                Err(e) => {
425                    writeln!(f, "Error building prompt: {:?}", e)?;
426                }
427            }
428        }
429
430        if let Some(local_prompt) = &self.local_prompt {
431            write!(f, "{}", local_prompt)?;
432        }
433
434        if let Some(api_prompt) = &self.api_prompt {
435            write!(f, "{}", api_prompt)?;
436        }
437
438        Ok(())
439    }
440}
441
442impl TextConcatenatorTrait for LLMPrompt {
443    fn concatenator_mut(&mut self) -> &mut TextConcatenator {
444        &mut self.concatenator
445    }
446
447    fn clear_built(&self) {
448        self.clear_built_prompt();
449    }
450}