alith_prompt/
local_prompt.rs

1use crate::PromptTokenizer;
2use minijinja::value::{Value, ValueKind, from_args};
3use minijinja::{Environment, Error, ErrorKind, context};
4use serde::Serialize;
5use std::collections::HashMap;
6use std::sync::Mutex;
7use std::sync::{Arc, MutexGuard};
8
9/// A prompt formatter for local LLMs that use chat templates.
10///
11/// `LocalPrompt` handles formatting messages according to a model's chat template,
12/// managing special tokens (BOS, EOS, UNK), and supporting generation prefixes.
13/// Unlike API prompts, local prompts need to handle the specific formatting requirements
14/// and token conventions of locally-run models.
15///
16/// The struct maintains both string and tokenized representations of the built prompt,
17/// along with thread-safe interior mutability for managing prompt state. It supports
18/// token counting and generation prefix management for model outputs.
19#[derive(Serialize)]
20pub struct LocalPrompt {
21    // Skip the tokenizer field
22    #[serde(skip)]
23    tokenizer: Arc<dyn PromptTokenizer>,
24    chat_template: String,
25    bos_token: Option<String>,
26    eos_token: String,
27    unk_token: Option<String>,
28    base_generation_prefix: Option<String>,
29    pub generation_prefix: Mutex<Option<String>>,
30    pub built_prompt_string: Mutex<Option<String>>,
31    pub built_prompt_as_tokens: Mutex<Option<Vec<u32>>>,
32    pub total_prompt_tokens: Mutex<Option<usize>>,
33}
34
35impl LocalPrompt {
36    pub(crate) fn new(
37        tokenizer: Arc<dyn PromptTokenizer>,
38        chat_template: &str,
39        bos_token: Option<&str>,
40        eos_token: &str,
41        unk_token: Option<&str>,
42        base_generation_prefix: Option<&str>,
43    ) -> Self {
44        Self {
45            tokenizer,
46            chat_template: chat_template.to_owned(),
47            bos_token: bos_token.map(|s| s.to_owned()),
48            eos_token: eos_token.to_owned(),
49            unk_token: unk_token.map(|s| s.to_owned()),
50            base_generation_prefix: base_generation_prefix.map(|s| s.to_owned()),
51            generation_prefix: None.into(),
52            built_prompt_string: None.into(),
53            built_prompt_as_tokens: None.into(),
54            total_prompt_tokens: None.into(),
55        }
56    }
57
58    // Setter methods
59    //
60
61    pub(crate) fn set_generation_prefix<T: AsRef<str>>(&self, generation_prefix: T) {
62        let mut self_generation_prefix = self.generation_prefix();
63        if self_generation_prefix.is_none()
64            || self_generation_prefix.as_deref() != Some(generation_prefix.as_ref())
65        {
66            *self_generation_prefix = Some(generation_prefix.as_ref().to_string());
67        }
68    }
69
70    pub(crate) fn clear_generation_prefix(&self) {
71        *self.generation_prefix() = None;
72    }
73
74    pub(crate) fn clear_built_prompt(&self) {
75        *self.built_prompt_string() = None;
76        *self.built_prompt_as_tokens() = None;
77        *self.total_prompt_tokens() = None;
78    }
79
80    // Getter methods
81    //
82
83    /// Retrieves the built prompt as a formatted string.
84    ///
85    /// Returns the complete prompt string with all messages formatted according to
86    /// the chat template, including any special tokens and generation prefix.
87    ///
88    /// # Returns
89    ///
90    /// Returns `Ok(String)` containing the formatted prompt string.
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if the prompt has not been built yet.
95    pub fn get_built_prompt(&self) -> Result<String, crate::Error> {
96        match &*self.built_prompt_string() {
97            Some(prompt) => Ok(prompt.clone()),
98            None => crate::bail!(
99                "LocalPrompt Error - built_prompt_string not available - prompt not built"
100            ),
101        }
102    }
103
104    /// Retrieves the built prompt as a vector of tokens.
105    ///
106    /// Returns the complete prompt converted to model tokens using the configured
107    /// tokenizer. This is useful for operations that need to work directly with
108    /// token IDs rather than text.
109    ///
110    /// # Returns
111    ///
112    /// Returns `Ok(Vec<u32>)` containing the token IDs for the prompt.
113    ///
114    /// # Errors
115    ///
116    /// Returns an error if the prompt has not been built yet.
117    pub fn get_built_prompt_as_tokens(&self) -> Result<Vec<u32>, crate::Error> {
118        match &*self.built_prompt_as_tokens() {
119            Some(prompt) => Ok(prompt.clone()),
120            None => crate::bail!(
121                "LocalPrompt Error - built_prompt_as_tokens not available - prompt not built"
122            ),
123        }
124    }
125
126    /// Gets the total number of tokens in the built prompt.
127    ///
128    /// Returns the exact token count of the built prompt, which is useful for
129    /// ensuring prompts stay within model context limits. This count reflects
130    /// all content, special tokens, and any generation prefix.
131    ///
132    /// # Returns
133    ///
134    /// Returns `Ok(usize)` containing the total token count.
135    ///
136    /// # Errors
137    ///
138    /// Returns an error if the prompt has not been built yet.
139    pub fn get_total_prompt_tokens(&self) -> Result<usize, crate::Error> {
140        match &*self.total_prompt_tokens() {
141            Some(prompt) => Ok(*prompt),
142            None => crate::bail!(
143                "LocalPrompt Error - total_prompt_tokens not available - prompt not built"
144            ),
145        }
146    }
147
148    // Builder methods
149    //
150
151    pub(crate) fn build_prompt(&self, built_prompt_messages: &[HashMap<String, String>]) {
152        let mut built_prompt_string = apply_chat_template(
153            built_prompt_messages,
154            &self.chat_template,
155            self.bos_token.as_deref(),
156            &self.eos_token,
157            self.unk_token.as_deref(),
158        );
159
160        {
161            if let Some(generation_prefix) = &*self.generation_prefix() {
162                if let Some(base_generation_prefix) = &self.base_generation_prefix {
163                    built_prompt_string.push_str(base_generation_prefix);
164                }
165                built_prompt_string.push_str(generation_prefix);
166            }
167        }
168
169        let built_prompt_as_tokens = self.tokenizer.tokenize(&built_prompt_string);
170        *self.total_prompt_tokens() = Some(built_prompt_as_tokens.len());
171        *self.built_prompt_as_tokens() = Some(built_prompt_as_tokens);
172        *self.built_prompt_string() = Some(built_prompt_string);
173    }
174
175    // Helper methods
176    //
177
178    fn generation_prefix(&self) -> MutexGuard<'_, Option<String>> {
179        self.generation_prefix.lock().unwrap_or_else(|e| {
180            panic!(
181                "LocalPrompt Error - generation_prefix not available: {:?}",
182                e
183            )
184        })
185    }
186
187    fn built_prompt_string(&self) -> MutexGuard<'_, Option<String>> {
188        self.built_prompt_string.lock().unwrap_or_else(|e| {
189            panic!(
190                "LocalPrompt Error - built_prompt_string not available: {:?}",
191                e
192            )
193        })
194    }
195
196    fn built_prompt_as_tokens(&self) -> MutexGuard<'_, Option<Vec<u32>>> {
197        self.built_prompt_as_tokens.lock().unwrap_or_else(|e| {
198            panic!(
199                "LocalPrompt Error - built_prompt_as_tokens not available: {:?}",
200                e
201            )
202        })
203    }
204
205    fn total_prompt_tokens(&self) -> MutexGuard<'_, Option<usize>> {
206        self.total_prompt_tokens.lock().unwrap_or_else(|e| {
207            panic!(
208                "LocalPrompt Error - total_prompt_tokens not available: {:?}",
209                e
210            )
211        })
212    }
213}
214
215impl Clone for LocalPrompt {
216    fn clone(&self) -> Self {
217        Self {
218            built_prompt_string: self.built_prompt_string().clone().into(),
219            built_prompt_as_tokens: self.built_prompt_as_tokens().clone().into(),
220            total_prompt_tokens: (*self.total_prompt_tokens()).into(),
221            generation_prefix: self.generation_prefix().clone().into(),
222            tokenizer: self.tokenizer.clone(),
223            chat_template: self.chat_template.clone(),
224            bos_token: self.bos_token.clone(),
225            eos_token: self.eos_token.clone(),
226            unk_token: self.unk_token.clone(),
227            base_generation_prefix: self.base_generation_prefix.clone(),
228        }
229    }
230}
231
232impl std::fmt::Display for LocalPrompt {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        writeln!(f)?;
235        writeln!(f, "LocalPrompt")?;
236
237        match *self.built_prompt_string() {
238            Some(ref prompt) => {
239                writeln!(f, "built_prompt_string:\n\n{}", prompt)?;
240                writeln!(f)?;
241            }
242            None => writeln!(f, "built_prompt_string: None")?,
243        };
244
245        match *self.total_prompt_tokens() {
246            Some(ref prompt) => {
247                writeln!(f, "total_prompt_tokens: {}", prompt)?;
248                writeln!(f)?;
249            }
250            None => writeln!(f, "total_prompt_tokens: None")?,
251        };
252
253        Ok(())
254    }
255}
256
257/// Applies a chat template to a message, given a message and a chat template.
258///
259/// # Arguments
260///
261/// * `message` - The message as a HashMap.
262/// * `chat_template` - The chat template as a String.
263///
264/// # Returns
265///
266/// The formatted message as a String.
267pub fn apply_chat_template(
268    messages: &[HashMap<String, String>],
269    chat_template: &str,
270    bos_token: Option<&str>,
271    eos_token: &str,
272    unk_token: Option<&str>,
273) -> String {
274    let mut env = Environment::new();
275    env.set_lstrip_blocks(true);
276    env.set_trim_blocks(true);
277    env.add_template("chat_template", chat_template)
278        .expect("Failed to add template");
279    env.add_function("raise_exception", raise_exception);
280
281    env.set_unknown_method_callback(|state, value, method, args| match (value.kind(), method) {
282        (ValueKind::String, "strip") => {
283            let _: () = from_args(args)?;
284            Ok(Value::from(value.as_str().unwrap_or("").trim()))
285        }
286        (ValueKind::Map, "items") => {
287            let _: () = from_args(args)?;
288            state.apply_filter("items", &[value.clone()])
289        }
290        _ => Err(Error::new(
291            ErrorKind::UnknownMethod,
292            format!("object has no method named {}", method),
293        )),
294    });
295
296    let tmpl = env
297        .get_template("chat_template")
298        .expect("Failed to get template");
299
300    let unk_token = unk_token.unwrap_or("");
301    let bos_token = bos_token.unwrap_or("");
302
303    tmpl.render(context! {
304        messages => messages,
305        add_generation_prompt => false,
306        bos_token => bos_token,
307        eos_token => eos_token,
308        unk_token => unk_token,
309    })
310    .expect("Failed to render template without system prompt")
311}
312
313/// This exists specifically for the minijinja template engine to raise an exception.
314fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
315    Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
316}