alith_prompt/
api_prompt.rs

1use crate::{PromptTokenizer, token_count::total_prompt_tokens_openai_format};
2use serde::Serialize;
3use std::{
4    collections::HashMap,
5    sync::{Arc, Mutex, MutexGuard},
6};
7
8/// A prompt formatter for API-based language models that follow OpenAI's message format.
9///
10/// `ApiPrompt` handles formatting messages into the standard role/content pairs used by
11/// API-based LLMs. It manages token counting specific to these
12/// models, including per-message and per-name token overhead.
13///
14/// The struct maintains thread-safe interior mutability for built messages and token counts,
15/// rebuilding them as needed when the prompt content changes.
16#[derive(Serialize)]
17pub struct ApiPrompt {
18    #[serde(skip)]
19    tokenizer: Arc<dyn PromptTokenizer>,
20    tokens_per_message: Option<u32>,
21    tokens_per_name: Option<i32>,
22    built_prompt_messages: Mutex<Option<Vec<HashMap<String, String>>>>,
23    total_prompt_tokens: Mutex<Option<u64>>,
24}
25
26impl ApiPrompt {
27    pub fn new(
28        tokenizer: Arc<dyn PromptTokenizer>,
29        tokens_per_message: Option<u32>,
30        tokens_per_name: Option<i32>,
31    ) -> Self {
32        Self {
33            tokenizer,
34            tokens_per_message,
35            tokens_per_name,
36            total_prompt_tokens: None.into(),
37            built_prompt_messages: None.into(),
38        }
39    }
40
41    // Setter methods
42    //
43
44    pub(crate) fn clear_built_prompt(&self) {
45        *self.built_prompt_messages() = None;
46        *self.total_prompt_tokens() = None;
47    }
48
49    // Getter methods
50    //
51
52    /// Retrieves the built prompt messages in OpenAI API format.
53    ///
54    /// Returns the messages as a vector of hashmaps, where each message contains
55    /// a "role" key (system/user/assistant) and a "content" key with the message text.
56    ///
57    /// # Returns
58    ///
59    /// Returns `Ok(Vec<HashMap<String, String>>)` containing the formatted messages.
60    ///
61    /// # Errors
62    ///
63    /// Returns an error if the prompt has not been built yet.
64    pub fn get_built_prompt(&self) -> Result<Vec<HashMap<String, String>>, crate::Error> {
65        match &*self.built_prompt_messages() {
66            Some(prompt) => Ok(prompt.clone()),
67            None => crate::bail!(
68                "ApiPrompt Error - built_prompt_messages not available - prompt not built"
69            ),
70        }
71    }
72
73    /// Gets the total number of tokens in the prompt, including any model-specific overhead.
74    ///
75    /// The total includes the base tokens from all messages plus any additional tokens
76    /// specified by `tokens_per_message` and `tokens_per_name`. This count is useful for
77    /// ensuring prompts stay within model context limits.
78    ///
79    /// # Returns
80    ///
81    /// Returns `Ok(u64)` containing the total token count.
82    ///
83    /// # Errors
84    ///
85    /// Returns an error if the prompt has not been built yet.
86    pub fn get_total_prompt_tokens(&self) -> Result<u64, crate::Error> {
87        match &*self.total_prompt_tokens() {
88            Some(prompt) => Ok(*prompt),
89            None => crate::bail!(
90                "ApiPrompt Error - total_prompt_tokens not available - prompt not built"
91            ),
92        }
93    }
94
95    // Builder methods
96    //
97
98    pub(crate) fn build_prompt(&self, built_prompt_messages: &[HashMap<String, String>]) {
99        *self.total_prompt_tokens() = Some(total_prompt_tokens_openai_format(
100            built_prompt_messages,
101            self.tokens_per_message,
102            self.tokens_per_name,
103            &self.tokenizer,
104        ));
105
106        *self.built_prompt_messages() = Some(built_prompt_messages.to_vec());
107    }
108
109    // Helper methods
110    //
111
112    fn built_prompt_messages(&self) -> MutexGuard<'_, Option<Vec<HashMap<String, String>>>> {
113        self.built_prompt_messages.lock().unwrap_or_else(|e| {
114            panic!(
115                "ApiPrompt Error - built_prompt_messages not available: {:?}",
116                e
117            )
118        })
119    }
120
121    fn total_prompt_tokens(&self) -> MutexGuard<'_, Option<u64>> {
122        self.total_prompt_tokens.lock().unwrap_or_else(|e| {
123            panic!(
124                "ApiPrompt Error - total_prompt_tokens not available: {:?}",
125                e
126            )
127        })
128    }
129}
130
131impl Clone for ApiPrompt {
132    fn clone(&self) -> Self {
133        Self {
134            tokenizer: self.tokenizer.clone(),
135            tokens_per_message: self.tokens_per_message,
136            tokens_per_name: self.tokens_per_name,
137            total_prompt_tokens: (*self.total_prompt_tokens()).into(),
138            built_prompt_messages: self.built_prompt_messages().clone().into(),
139        }
140    }
141}
142
143impl std::fmt::Display for ApiPrompt {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        writeln!(f)?;
146        writeln!(f, "ApiPrompt")?;
147
148        match *self.total_prompt_tokens() {
149            Some(ref prompt) => {
150                writeln!(f, "total_prompt_tokens:\n\n{}", prompt)?;
151                writeln!(f)?;
152            }
153            None => writeln!(f, "total_prompt_tokens: None")?,
154        };
155
156        Ok(())
157    }
158}