1use std::collections::HashMap;
2
3use anyhow::Result;
4use either::Either;
5use indexmap::IndexMap;
6use itertools::Itertools;
7use minijinja::{context, value::Kwargs, Environment, Error, ErrorKind, Value};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use tokenizers::Tokenizer;
11use tracing::info;
12
13use crate::{MessageContent, Tool};
14
15const SUPPORTED_ALTERNATE_EOS: &[&str] = &[
16 "<|im_end|>", "<end_of_turn>", "<|end_of_text|>", ];
20
21#[allow(dead_code)]
22#[derive(Debug, Deserialize)]
23pub struct AddedTokensDecoder {
24 __type: Option<String>,
25 pub content: String,
26 lstrip: bool,
27 normalized: bool,
28 rstrip: bool,
29 single_word: bool,
30 special: Option<bool>,
31}
32
33fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
34 Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
35}
36
37#[derive(Debug, Deserialize)]
38pub struct BeginEndUnkPadTok(
39 #[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
40);
41
42#[derive(Debug, Deserialize)]
43pub struct ChatTemplateValue(
44 #[serde(with = "either::serde_untagged")] pub Either<String, Vec<HashMap<String, String>>>,
45);
46
47#[allow(dead_code)]
48#[derive(Debug, Deserialize, Default)]
49pub struct ChatTemplate {
51 add_bos_token: Option<bool>,
52 add_eos_token: Option<bool>,
53 added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
54 additional_special_tokens: Option<Vec<String>>,
55 pub bos_token: Option<BeginEndUnkPadTok>,
56
57 pub chat_template: Option<ChatTemplateValue>,
61 clean_up_tokenization_spaces: Option<bool>,
62 device_map: Option<String>,
63 pub eos_token: Option<BeginEndUnkPadTok>,
64 legacy: Option<bool>,
65 model_max_length: Option<f64>,
66 pub pad_token: Option<BeginEndUnkPadTok>,
67 sp_model_kwargs: Option<HashMap<String, String>>,
68 spaces_between_special_tokens: Option<bool>,
69 tokenizer_class: Option<String>,
70 truncation_size: Option<String>,
71 pub unk_token: Option<BeginEndUnkPadTok>,
72 use_default_system_prompt: Option<bool>,
73}
74
75impl ChatTemplate {
76 pub fn has_chat_template(&self) -> bool {
77 self.chat_template.is_some()
78 }
79
80 pub fn is_harmony_format(&self) -> bool {
82 if let Some(ref template_value) = self.chat_template {
83 let template_str = match &template_value.0 {
84 Either::Left(s) => s.as_str(),
85 Either::Right(vec) => {
86 return vec
88 .iter()
89 .any(|t| t.values().any(|v| crate::harmony::is_harmony_template(v)));
90 }
91 };
92 crate::harmony::is_harmony_template(template_str)
93 } else {
94 false
95 }
96 }
97
98 pub fn uses_think_tags(&self) -> bool {
103 if self.is_harmony_format() {
105 return false;
106 }
107
108 if let Some(ref template_value) = self.chat_template {
109 let template_str = match &template_value.0 {
110 Either::Left(s) => s.as_str(),
111 Either::Right(vec) => {
112 return vec.iter().any(|t| {
114 t.values()
115 .any(|v| crate::think_tags::is_think_tag_template(v))
116 });
117 }
118 };
119 crate::think_tags::is_think_tag_template(template_str)
120 } else {
121 false
122 }
123 }
124
125 pub fn eos_tok(&self) -> Option<String> {
126 match self.eos_token.as_ref()?.0 {
127 Either::Left(ref lit) => Some(lit.clone()),
128 Either::Right(ref added) => Some(added.content.clone()),
129 }
130 }
131
132 pub fn bos_tok(&self) -> Option<String> {
133 match self.bos_token.as_ref()?.0 {
134 Either::Left(ref lit) => Some(lit.clone()),
135 Either::Right(ref added) => Some(added.content.clone()),
136 }
137 }
138
139 pub fn unk_tok(&self) -> Option<String> {
140 match self.unk_token.as_ref()?.0 {
141 Either::Left(ref lit) => Some(lit.clone()),
142 Either::Right(ref added) => Some(added.content.clone()),
143 }
144 }
145}
146
147pub fn calculate_eos_tokens(
148 chat_template: &ChatTemplate,
149 gen_conf: Option<GenerationConfig>,
150 tokenizer: &Tokenizer,
151) -> Vec<u32> {
152 let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
153 let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
154
155 for alternate in SUPPORTED_ALTERNATE_EOS {
156 if tokenizer.get_vocab(true).contains_key(*alternate) {
157 eos_tok_ids.push(alternate.to_string())
158 }
159 }
160
161 if let Some(gen_conf) = gen_conf {
162 if let Some(eos_field) = gen_conf.eos_token_id {
163 let ids = match eos_field {
164 Either::Left(id) => vec![id],
165 Either::Right(ids) => ids,
166 };
167 for id in ids {
168 let s = tokenizer
169 .decode(&[id], false)
170 .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
171 if !eos_tok_ids.contains(&s) {
172 eos_tok_ids.push(s);
173 }
174 }
175 }
176
177 if let Some(bos_field) = gen_conf.bos_token_id {
178 let ids = match bos_field {
179 Either::Left(id) => vec![id],
180 Either::Right(ids) => ids,
181 };
182 for id in ids {
183 let s = tokenizer
184 .decode(&[id], false)
185 .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
186 if !bos_tok_ids.contains(&s) {
187 bos_tok_ids.push(s);
188 }
189 }
190 }
191 }
192
193 eos_tok_ids = eos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
194 bos_tok_ids = bos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
195
196 let bos_render = bos_tok_ids
197 .iter()
198 .map(|val| format!("{val:?}"))
199 .collect::<Vec<String>>()
200 .join(", ");
201 let eos_render = eos_tok_ids
202 .iter()
203 .map(|val| format!("{val:?}"))
204 .collect::<Vec<String>>()
205 .join(", ");
206
207 info!(
208 "bos_toks = {bos_render}, eos_toks = {eos_render}, unk_tok = {}",
209 chat_template.unk_tok().unwrap_or("`None`".to_string()),
210 );
211
212 let mut eos_toks = Vec::new();
213 for eos_tok in eos_tok_ids {
214 eos_toks.push(
215 tokenizer
216 .get_vocab(true)
217 .get(&eos_tok)
218 .copied()
219 .unwrap_or_else(|| panic!("Unable to extract `{eos_tok}` EOS token.")),
220 )
221 }
222 eos_toks
223}
224
225#[allow(dead_code)]
226#[derive(Debug, Deserialize)]
227pub struct GenerationConfig {
228 #[serde(default)]
229 #[serde(with = "either::serde_untagged_optional")]
230 bos_token_id: Option<Either<u32, Vec<u32>>>,
231 #[serde(default)]
232 #[serde(with = "either::serde_untagged_optional")]
233 eos_token_id: Option<Either<u32, Vec<u32>>>,
234}
235
236fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {
237 if let Ok(indent) = kwargs.get("indent") {
238 let mut buf = Vec::new();
239 let repeat = b" ".repeat(indent);
240 let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat);
241 let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
242 value.serialize(&mut ser).unwrap();
243 String::from_utf8(buf).map_err(|err| {
244 Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
245 })
246 } else {
247 serde_json::to_string(&value).map_err(|err| {
248 Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
249 })
250 }
251 .map_err(|err| {
252 Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
253 })
254 .map(|s| {
255 let mut rv = String::with_capacity(s.len());
257 for c in s.chars() {
258 match c {
259 '<' => rv.push_str("\\u003c"),
260 '>' => rv.push_str("\\u003e"),
261 '&' => rv.push_str("\\u0026"),
262 '\'' => rv.push_str("\\u0027"),
263 _ => rv.push(c),
264 }
265 }
266 Value::from_safe_string(rv)
267 })
268}
269
270fn strftime_now(fmt: String) -> Result<String, minijinja::Error> {
271 let date = chrono::Utc::now();
272 let date_string = date.format(&fmt).to_string();
273 Ok(date_string)
274}
275
276use crate::request::ReasoningEffort;
277
278#[allow(clippy::too_many_arguments)]
279pub fn apply_chat_template_to(
280 messages: Vec<IndexMap<String, MessageContent>>,
281 add_generation_prompt: bool,
282 enable_thinking: Option<bool>,
283 reasoning_effort: Option<ReasoningEffort>,
284 template: &ChatTemplateValue,
285 bos_tok: Option<String>,
286 eos_tok: Option<String>,
287 unk_tok: Option<String>,
288 tools: Vec<Tool>,
289) -> Result<String> {
290 let mut env = Environment::new();
291
292 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
294
295 env.set_lstrip_blocks(true);
297 env.set_trim_blocks(true);
298
299 #[derive(Serialize, Deserialize)]
300 struct UntaggedContent(#[serde(with = "either::serde_untagged")] MessageContent);
301 let mut new_messages = Vec::new();
302 for message in messages {
303 let mut new_message = IndexMap::new();
304 for (k, v) in message {
305 new_message.insert(k, UntaggedContent(v));
306 }
307 new_messages.push(new_message);
308 }
309
310 let template = match &template.0 {
311 Either::Left(x) => x.clone(),
312 Either::Right(map) => {
313 let mut template = None;
314 let has_tool_use = map.iter().any(|t| {
315 t.get("name").is_some_and(|name| name == "tool_use") || t.contains_key("tool_use")
316 });
317 let must_use_tool_template = !tools.is_empty();
318
319 if must_use_tool_template && !has_tool_use {
320 anyhow::bail!(
321 "Tools were provided but this chat template does not handle tool usage"
322 );
323 }
324
325 for t in map {
326 let name = t.get("name");
327 if let Some(name) = name {
328 template = Some(t["template"].clone());
329 #[allow(clippy::if_same_then_else)]
330 if name == "tool_use" && !tools.is_empty() {
331 break;
332 } else if name == "default" && !must_use_tool_template {
333 break;
334 }
335 } else if t.contains_key("tool_use") && !tools.is_empty() {
336 template = Some(t["tool_use"].clone());
337 break;
338 } else if t.contains_key("default") && !must_use_tool_template {
339 template = Some(t["default"].clone());
340 break;
341 }
342 }
343
344 let Some(template) = template else {
345 anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools.");
346 };
347 template
348 }
349 };
350 let mut template = template.replace("[::-1]", "|reverse");
351 let re = Regex::new(r"range\((?P<expr>[^,]+),\s*-1,\s*-1\)").unwrap();
355 template = re
356 .replace_all(&template, |caps: ®ex::Captures| {
357 format!("range({})|reverse", &caps["expr"])
358 })
359 .into_owned();
360
361 if template.contains("{{ meta }}") {
362 template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", "");
364 template = template.replace("{{ meta }}", "");
365 }
366 if template.contains("{% generation %}") && template.contains("{% endgeneration %}") {
367 template = template.replace("{% generation %}", "");
369 template = template.replace("{% endgeneration %}", "");
370 }
371
372 env.add_template("chat_template", &template)?;
373 env.add_function("raise_exception", raise_exception);
374 env.add_filter("tojson", tojson);
375 env.add_function("strftime_now", strftime_now);
376 let tmpl = env.get_template("chat_template").unwrap();
377
378 let date = chrono::Utc::now();
379 let date_string = date.format("%d, %B, %Y").to_string();
380
381 let reasoning_effort_str = reasoning_effort.map(|r| r.as_str()).unwrap_or("medium");
383
384 let builtin_tool_names = [
388 "browser",
389 "python",
390 "code_interpreter",
391 "web_search",
392 "brave_search",
393 "wolfram_alpha",
394 ];
395 let builtin_tools: Vec<&str> = tools
396 .iter()
397 .filter_map(|t| {
398 let name = t.function.name.as_str();
399 if builtin_tool_names.contains(&name) {
400 Some(name)
401 } else {
402 None
403 }
404 })
405 .collect();
406
407 if tools.is_empty() {
408 Ok(tmpl.render(context! {
409 messages => new_messages,
410 add_generation_prompt => add_generation_prompt,
411 bos_token => bos_tok,
412 eos_token => eos_tok,
413 unk_token => unk_tok,
414 date_string => date_string,
415 enable_thinking => enable_thinking.unwrap_or(true),
416 reasoning_effort => reasoning_effort_str,
417 })?)
418 } else {
419 Ok(tmpl.render(context! {
420 messages => new_messages,
421 add_generation_prompt => add_generation_prompt,
422 bos_token => bos_tok,
423 eos_token => eos_tok,
424 unk_token => unk_tok,
425 xml_tools => tools.clone(), tools => tools,
427 builtin_tools => builtin_tools,
428 date_string => date_string,
429 enable_thinking => enable_thinking.unwrap_or(true),
430 reasoning_effort => reasoning_effort_str,
431 })?)
432 }
433}