1use crate::PromptTokenizer;
2use minijinja::value::{Value, ValueKind, from_args};
3use minijinja::{Environment, Error, ErrorKind, context};
4use serde::Serialize;
5use std::collections::HashMap;
6use std::sync::Mutex;
7use std::sync::{Arc, MutexGuard};
8
9#[derive(Serialize)]
20pub struct LocalPrompt {
21 #[serde(skip)]
23 tokenizer: Arc<dyn PromptTokenizer>,
24 chat_template: String,
25 bos_token: Option<String>,
26 eos_token: String,
27 unk_token: Option<String>,
28 base_generation_prefix: Option<String>,
29 pub generation_prefix: Mutex<Option<String>>,
30 pub built_prompt_string: Mutex<Option<String>>,
31 pub built_prompt_as_tokens: Mutex<Option<Vec<u32>>>,
32 pub total_prompt_tokens: Mutex<Option<usize>>,
33}
34
35impl LocalPrompt {
36 pub(crate) fn new(
37 tokenizer: Arc<dyn PromptTokenizer>,
38 chat_template: &str,
39 bos_token: Option<&str>,
40 eos_token: &str,
41 unk_token: Option<&str>,
42 base_generation_prefix: Option<&str>,
43 ) -> Self {
44 Self {
45 tokenizer,
46 chat_template: chat_template.to_owned(),
47 bos_token: bos_token.map(|s| s.to_owned()),
48 eos_token: eos_token.to_owned(),
49 unk_token: unk_token.map(|s| s.to_owned()),
50 base_generation_prefix: base_generation_prefix.map(|s| s.to_owned()),
51 generation_prefix: None.into(),
52 built_prompt_string: None.into(),
53 built_prompt_as_tokens: None.into(),
54 total_prompt_tokens: None.into(),
55 }
56 }
57
58 pub(crate) fn set_generation_prefix<T: AsRef<str>>(&self, generation_prefix: T) {
62 let mut self_generation_prefix = self.generation_prefix();
63 if self_generation_prefix.is_none()
64 || self_generation_prefix.as_deref() != Some(generation_prefix.as_ref())
65 {
66 *self_generation_prefix = Some(generation_prefix.as_ref().to_string());
67 }
68 }
69
70 pub(crate) fn clear_generation_prefix(&self) {
71 *self.generation_prefix() = None;
72 }
73
74 pub(crate) fn clear_built_prompt(&self) {
75 *self.built_prompt_string() = None;
76 *self.built_prompt_as_tokens() = None;
77 *self.total_prompt_tokens() = None;
78 }
79
80 pub fn get_built_prompt(&self) -> Result<String, crate::Error> {
96 match &*self.built_prompt_string() {
97 Some(prompt) => Ok(prompt.clone()),
98 None => crate::bail!(
99 "LocalPrompt Error - built_prompt_string not available - prompt not built"
100 ),
101 }
102 }
103
104 pub fn get_built_prompt_as_tokens(&self) -> Result<Vec<u32>, crate::Error> {
118 match &*self.built_prompt_as_tokens() {
119 Some(prompt) => Ok(prompt.clone()),
120 None => crate::bail!(
121 "LocalPrompt Error - built_prompt_as_tokens not available - prompt not built"
122 ),
123 }
124 }
125
126 pub fn get_total_prompt_tokens(&self) -> Result<usize, crate::Error> {
140 match &*self.total_prompt_tokens() {
141 Some(prompt) => Ok(*prompt),
142 None => crate::bail!(
143 "LocalPrompt Error - total_prompt_tokens not available - prompt not built"
144 ),
145 }
146 }
147
148 pub(crate) fn build_prompt(&self, built_prompt_messages: &[HashMap<String, String>]) {
152 let mut built_prompt_string = apply_chat_template(
153 built_prompt_messages,
154 &self.chat_template,
155 self.bos_token.as_deref(),
156 &self.eos_token,
157 self.unk_token.as_deref(),
158 );
159
160 {
161 if let Some(generation_prefix) = &*self.generation_prefix() {
162 if let Some(base_generation_prefix) = &self.base_generation_prefix {
163 built_prompt_string.push_str(base_generation_prefix);
164 }
165 built_prompt_string.push_str(generation_prefix);
166 }
167 }
168
169 let built_prompt_as_tokens = self.tokenizer.tokenize(&built_prompt_string);
170 *self.total_prompt_tokens() = Some(built_prompt_as_tokens.len());
171 *self.built_prompt_as_tokens() = Some(built_prompt_as_tokens);
172 *self.built_prompt_string() = Some(built_prompt_string);
173 }
174
175 fn generation_prefix(&self) -> MutexGuard<'_, Option<String>> {
179 self.generation_prefix.lock().unwrap_or_else(|e| {
180 panic!(
181 "LocalPrompt Error - generation_prefix not available: {:?}",
182 e
183 )
184 })
185 }
186
187 fn built_prompt_string(&self) -> MutexGuard<'_, Option<String>> {
188 self.built_prompt_string.lock().unwrap_or_else(|e| {
189 panic!(
190 "LocalPrompt Error - built_prompt_string not available: {:?}",
191 e
192 )
193 })
194 }
195
196 fn built_prompt_as_tokens(&self) -> MutexGuard<'_, Option<Vec<u32>>> {
197 self.built_prompt_as_tokens.lock().unwrap_or_else(|e| {
198 panic!(
199 "LocalPrompt Error - built_prompt_as_tokens not available: {:?}",
200 e
201 )
202 })
203 }
204
205 fn total_prompt_tokens(&self) -> MutexGuard<'_, Option<usize>> {
206 self.total_prompt_tokens.lock().unwrap_or_else(|e| {
207 panic!(
208 "LocalPrompt Error - total_prompt_tokens not available: {:?}",
209 e
210 )
211 })
212 }
213}
214
215impl Clone for LocalPrompt {
216 fn clone(&self) -> Self {
217 Self {
218 built_prompt_string: self.built_prompt_string().clone().into(),
219 built_prompt_as_tokens: self.built_prompt_as_tokens().clone().into(),
220 total_prompt_tokens: (*self.total_prompt_tokens()).into(),
221 generation_prefix: self.generation_prefix().clone().into(),
222 tokenizer: self.tokenizer.clone(),
223 chat_template: self.chat_template.clone(),
224 bos_token: self.bos_token.clone(),
225 eos_token: self.eos_token.clone(),
226 unk_token: self.unk_token.clone(),
227 base_generation_prefix: self.base_generation_prefix.clone(),
228 }
229 }
230}
231
232impl std::fmt::Display for LocalPrompt {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 writeln!(f)?;
235 writeln!(f, "LocalPrompt")?;
236
237 match *self.built_prompt_string() {
238 Some(ref prompt) => {
239 writeln!(f, "built_prompt_string:\n\n{}", prompt)?;
240 writeln!(f)?;
241 }
242 None => writeln!(f, "built_prompt_string: None")?,
243 };
244
245 match *self.total_prompt_tokens() {
246 Some(ref prompt) => {
247 writeln!(f, "total_prompt_tokens: {}", prompt)?;
248 writeln!(f)?;
249 }
250 None => writeln!(f, "total_prompt_tokens: None")?,
251 };
252
253 Ok(())
254 }
255}
256
257pub fn apply_chat_template(
268 messages: &[HashMap<String, String>],
269 chat_template: &str,
270 bos_token: Option<&str>,
271 eos_token: &str,
272 unk_token: Option<&str>,
273) -> String {
274 let mut env = Environment::new();
275 env.set_lstrip_blocks(true);
276 env.set_trim_blocks(true);
277 env.add_template("chat_template", chat_template)
278 .expect("Failed to add template");
279 env.add_function("raise_exception", raise_exception);
280
281 env.set_unknown_method_callback(|state, value, method, args| match (value.kind(), method) {
282 (ValueKind::String, "strip") => {
283 let _: () = from_args(args)?;
284 Ok(Value::from(value.as_str().unwrap_or("").trim()))
285 }
286 (ValueKind::Map, "items") => {
287 let _: () = from_args(args)?;
288 state.apply_filter("items", &[value.clone()])
289 }
290 _ => Err(Error::new(
291 ErrorKind::UnknownMethod,
292 format!("object has no method named {}", method),
293 )),
294 });
295
296 let tmpl = env
297 .get_template("chat_template")
298 .expect("Failed to get template");
299
300 let unk_token = unk_token.unwrap_or("");
301 let bos_token = bos_token.unwrap_or("");
302
303 tmpl.render(context! {
304 messages => messages,
305 add_generation_prompt => false,
306 bos_token => bos_token,
307 eos_token => eos_token,
308 unk_token => unk_token,
309 })
310 .expect("Failed to render template without system prompt")
311}
312
313fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
315 Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
316}