Skip to main content

mistralrs_core/pipeline/
chat_template.rs

1use std::collections::HashMap;
2
3use anyhow::Result;
4use either::Either;
5use indexmap::IndexMap;
6use itertools::Itertools;
7use minijinja::{context, value::Kwargs, Environment, Error, ErrorKind, Value};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use tokenizers::Tokenizer;
11use tracing::info;
12
13use crate::{MessageContent, Tool};
14
15const SUPPORTED_ALTERNATE_EOS: &[&str] = &[
16    "<|im_end|>",      // Handle ChatML case
17    "<end_of_turn>",   // Handle Gemma2 chat case
18    "<|end_of_text|>", // Hermes
19];
20
21#[allow(dead_code)]
22#[derive(Debug, Deserialize)]
23pub struct AddedTokensDecoder {
24    __type: Option<String>,
25    pub content: String,
26    lstrip: bool,
27    normalized: bool,
28    rstrip: bool,
29    single_word: bool,
30    special: Option<bool>,
31}
32
33fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
34    Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
35}
36
37#[derive(Debug, Deserialize)]
38pub struct BeginEndUnkPadTok(
39    #[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
40);
41
42#[derive(Debug, Deserialize)]
43pub struct ChatTemplateValue(
44    #[serde(with = "either::serde_untagged")] pub Either<String, Vec<HashMap<String, String>>>,
45);
46
47#[allow(dead_code)]
48#[derive(Debug, Deserialize, Default)]
49/// Template for chat models including bos/eos/unk as well as the chat template.
50pub struct ChatTemplate {
51    add_bos_token: Option<bool>,
52    add_eos_token: Option<bool>,
53    added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
54    additional_special_tokens: Option<Vec<String>>,
55    pub bos_token: Option<BeginEndUnkPadTok>,
56
57    /// Jinja format [chat templating] for chat completion.
58    ///
59    /// [chat templating]: https://huggingface.co/docs/transformers/chat_templating
60    pub chat_template: Option<ChatTemplateValue>,
61    clean_up_tokenization_spaces: Option<bool>,
62    device_map: Option<String>,
63    pub eos_token: Option<BeginEndUnkPadTok>,
64    legacy: Option<bool>,
65    model_max_length: Option<f64>,
66    pub pad_token: Option<BeginEndUnkPadTok>,
67    sp_model_kwargs: Option<HashMap<String, String>>,
68    spaces_between_special_tokens: Option<bool>,
69    tokenizer_class: Option<String>,
70    truncation_size: Option<String>,
71    pub unk_token: Option<BeginEndUnkPadTok>,
72    use_default_system_prompt: Option<bool>,
73}
74
75impl ChatTemplate {
76    pub fn has_chat_template(&self) -> bool {
77        self.chat_template.is_some()
78    }
79
80    /// Check if this chat template uses OpenAI Harmony format.
81    pub fn is_harmony_format(&self) -> bool {
82        if let Some(ref template_value) = self.chat_template {
83            let template_str = match &template_value.0 {
84                Either::Left(s) => s.as_str(),
85                Either::Right(vec) => {
86                    // For multi-template format, check if any template contains Harmony markers
87                    return vec
88                        .iter()
89                        .any(|t| t.values().any(|v| crate::harmony::is_harmony_template(v)));
90                }
91            };
92            crate::harmony::is_harmony_template(template_str)
93        } else {
94            false
95        }
96    }
97
98    /// Check if this chat template uses `<think>...</think>` tags for reasoning.
99    ///
100    /// This is mutually exclusive with Harmony format - if the template uses
101    /// Harmony format, this returns false even if think tags are present.
102    pub fn uses_think_tags(&self) -> bool {
103        // Don't enable if Harmony format is detected (mutual exclusivity)
104        if self.is_harmony_format() {
105            return false;
106        }
107
108        if let Some(ref template_value) = self.chat_template {
109            let template_str = match &template_value.0 {
110                Either::Left(s) => s.as_str(),
111                Either::Right(vec) => {
112                    // For multi-template format, check if any template contains think tags
113                    return vec.iter().any(|t| {
114                        t.values()
115                            .any(|v| crate::think_tags::is_think_tag_template(v))
116                    });
117                }
118            };
119            crate::think_tags::is_think_tag_template(template_str)
120        } else {
121            false
122        }
123    }
124
125    pub fn eos_tok(&self) -> Option<String> {
126        match self.eos_token.as_ref()?.0 {
127            Either::Left(ref lit) => Some(lit.clone()),
128            Either::Right(ref added) => Some(added.content.clone()),
129        }
130    }
131
132    pub fn bos_tok(&self) -> Option<String> {
133        match self.bos_token.as_ref()?.0 {
134            Either::Left(ref lit) => Some(lit.clone()),
135            Either::Right(ref added) => Some(added.content.clone()),
136        }
137    }
138
139    pub fn unk_tok(&self) -> Option<String> {
140        match self.unk_token.as_ref()?.0 {
141            Either::Left(ref lit) => Some(lit.clone()),
142            Either::Right(ref added) => Some(added.content.clone()),
143        }
144    }
145}
146
147pub fn calculate_eos_tokens(
148    chat_template: &ChatTemplate,
149    gen_conf: Option<GenerationConfig>,
150    tokenizer: &Tokenizer,
151) -> Vec<u32> {
152    let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
153    let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
154
155    for alternate in SUPPORTED_ALTERNATE_EOS {
156        if tokenizer.get_vocab(true).contains_key(*alternate) {
157            eos_tok_ids.push(alternate.to_string())
158        }
159    }
160
161    if let Some(gen_conf) = gen_conf {
162        if let Some(eos_field) = gen_conf.eos_token_id {
163            let ids = match eos_field {
164                Either::Left(id) => vec![id],
165                Either::Right(ids) => ids,
166            };
167            for id in ids {
168                let s = tokenizer
169                    .decode(&[id], false)
170                    .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
171                if !eos_tok_ids.contains(&s) {
172                    eos_tok_ids.push(s);
173                }
174            }
175        }
176
177        if let Some(bos_field) = gen_conf.bos_token_id {
178            let ids = match bos_field {
179                Either::Left(id) => vec![id],
180                Either::Right(ids) => ids,
181            };
182            for id in ids {
183                let s = tokenizer
184                    .decode(&[id], false)
185                    .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
186                if !bos_tok_ids.contains(&s) {
187                    bos_tok_ids.push(s);
188                }
189            }
190        }
191    }
192
193    eos_tok_ids = eos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
194    bos_tok_ids = bos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
195
196    let bos_render = bos_tok_ids
197        .iter()
198        .map(|val| format!("{val:?}"))
199        .collect::<Vec<String>>()
200        .join(", ");
201    let eos_render = eos_tok_ids
202        .iter()
203        .map(|val| format!("{val:?}"))
204        .collect::<Vec<String>>()
205        .join(", ");
206
207    info!(
208        "bos_toks = {bos_render}, eos_toks = {eos_render}, unk_tok = {}",
209        chat_template.unk_tok().unwrap_or("`None`".to_string()),
210    );
211
212    let mut eos_toks = Vec::new();
213    for eos_tok in eos_tok_ids {
214        eos_toks.push(
215            tokenizer
216                .get_vocab(true)
217                .get(&eos_tok)
218                .copied()
219                .unwrap_or_else(|| panic!("Unable to extract `{eos_tok}` EOS token.")),
220        )
221    }
222    eos_toks
223}
224
225#[allow(dead_code)]
226#[derive(Debug, Deserialize)]
227pub struct GenerationConfig {
228    #[serde(default)]
229    #[serde(with = "either::serde_untagged_optional")]
230    bos_token_id: Option<Either<u32, Vec<u32>>>,
231    #[serde(default)]
232    #[serde(with = "either::serde_untagged_optional")]
233    eos_token_id: Option<Either<u32, Vec<u32>>>,
234}
235
236fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {
237    if let Ok(indent) = kwargs.get("indent") {
238        let mut buf = Vec::new();
239        let repeat = b" ".repeat(indent);
240        let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat);
241        let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
242        value.serialize(&mut ser).unwrap();
243        String::from_utf8(buf).map_err(|err| {
244            Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
245        })
246    } else {
247        serde_json::to_string(&value).map_err(|err| {
248            Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
249        })
250    }
251    .map_err(|err| {
252        Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
253    })
254    .map(|s| {
255        // When this filter is used the return value is safe for both HTML and JSON
256        let mut rv = String::with_capacity(s.len());
257        for c in s.chars() {
258            match c {
259                '<' => rv.push_str("\\u003c"),
260                '>' => rv.push_str("\\u003e"),
261                '&' => rv.push_str("\\u0026"),
262                '\'' => rv.push_str("\\u0027"),
263                _ => rv.push(c),
264            }
265        }
266        Value::from_safe_string(rv)
267    })
268}
269
270fn strftime_now(fmt: String) -> Result<String, minijinja::Error> {
271    let date = chrono::Utc::now();
272    let date_string = date.format(&fmt).to_string();
273    Ok(date_string)
274}
275
276use crate::request::ReasoningEffort;
277
278#[allow(clippy::too_many_arguments)]
279pub fn apply_chat_template_to(
280    messages: Vec<IndexMap<String, MessageContent>>,
281    add_generation_prompt: bool,
282    enable_thinking: Option<bool>,
283    reasoning_effort: Option<ReasoningEffort>,
284    template: &ChatTemplateValue,
285    bos_tok: Option<String>,
286    eos_tok: Option<String>,
287    unk_tok: Option<String>,
288    tools: Vec<Tool>,
289) -> Result<String> {
290    let mut env = Environment::new();
291
292    // enable python methods such as .strip()
293    env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
294
295    // https://github.com/huggingface/transformers/blob/76a33a10923ccc1074917f6b6a1e719e626b7dc9/src/transformers/tokenization_utils_base.py#L1842
296    env.set_lstrip_blocks(true);
297    env.set_trim_blocks(true);
298
299    #[derive(Serialize, Deserialize)]
300    struct UntaggedContent(#[serde(with = "either::serde_untagged")] MessageContent);
301    let mut new_messages = Vec::new();
302    for message in messages {
303        let mut new_message = IndexMap::new();
304        for (k, v) in message {
305            new_message.insert(k, UntaggedContent(v));
306        }
307        new_messages.push(new_message);
308    }
309
310    let template = match &template.0 {
311        Either::Left(x) => x.clone(),
312        Either::Right(map) => {
313            let mut template = None;
314            let has_tool_use = map.iter().any(|t| {
315                t.get("name").is_some_and(|name| name == "tool_use") || t.contains_key("tool_use")
316            });
317            let must_use_tool_template = !tools.is_empty();
318
319            if must_use_tool_template && !has_tool_use {
320                anyhow::bail!(
321                    "Tools were provided but this chat template does not handle tool usage"
322                );
323            }
324
325            for t in map {
326                let name = t.get("name");
327                if let Some(name) = name {
328                    template = Some(t["template"].clone());
329                    #[allow(clippy::if_same_then_else)]
330                    if name == "tool_use" && !tools.is_empty() {
331                        break;
332                    } else if name == "default" && !must_use_tool_template {
333                        break;
334                    }
335                } else if t.contains_key("tool_use") && !tools.is_empty() {
336                    template = Some(t["tool_use"].clone());
337                    break;
338                } else if t.contains_key("default") && !must_use_tool_template {
339                    template = Some(t["default"].clone());
340                    break;
341                }
342            }
343
344            let Some(template) = template else {
345                anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools.");
346            };
347            template
348        }
349    };
350    let mut template = template.replace("[::-1]", "|reverse");
351    // Convert Python‑style descending ranges `range(..., -1, -1)` to a forward
352    // range followed by Jinja’s `|reverse` filter so it works even when
353    // negative‑step ranges aren’t supported.
354    let re = Regex::new(r"range\((?P<expr>[^,]+),\s*-1,\s*-1\)").unwrap();
355    template = re
356        .replace_all(&template, |caps: &regex::Captures| {
357            format!("range({})|reverse", &caps["expr"])
358        })
359        .into_owned();
360
361    if template.contains("{{ meta }}") {
362        // Fix for GLM4 models
363        template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", "");
364        template = template.replace("{{ meta }}", "");
365    }
366    if template.contains("{% generation %}") && template.contains("{% endgeneration %}") {
367        // Strip for smollm3 models
368        template = template.replace("{% generation %}", "");
369        template = template.replace("{% endgeneration %}", "");
370    }
371
372    env.add_template("chat_template", &template)?;
373    env.add_function("raise_exception", raise_exception);
374    env.add_filter("tojson", tojson);
375    env.add_function("strftime_now", strftime_now);
376    let tmpl = env.get_template("chat_template").unwrap();
377
378    let date = chrono::Utc::now();
379    let date_string = date.format("%d, %B, %Y").to_string();
380
381    // Convert reasoning effort to string for template
382    let reasoning_effort_str = reasoning_effort.map(|r| r.as_str()).unwrap_or("medium");
383
384    // Detect builtin tools from the tools list
385    // Known builtin tools for GPT-OSS/Harmony format: "browser", "python"
386    // Known builtin tools for Llama 3.x: "wolfram_alpha", "web_search", "brave_search", "python", "code_interpreter"
387    let builtin_tool_names = [
388        "browser",
389        "python",
390        "code_interpreter",
391        "web_search",
392        "brave_search",
393        "wolfram_alpha",
394    ];
395    let builtin_tools: Vec<&str> = tools
396        .iter()
397        .filter_map(|t| {
398            let name = t.function.name.as_str();
399            if builtin_tool_names.contains(&name) {
400                Some(name)
401            } else {
402                None
403            }
404        })
405        .collect();
406
407    if tools.is_empty() {
408        Ok(tmpl.render(context! {
409            messages => new_messages,
410            add_generation_prompt => add_generation_prompt,
411            bos_token => bos_tok,
412            eos_token => eos_tok,
413            unk_token => unk_tok,
414            date_string => date_string,
415            enable_thinking => enable_thinking.unwrap_or(true),
416            reasoning_effort => reasoning_effort_str,
417        })?)
418    } else {
419        Ok(tmpl.render(context! {
420            messages => new_messages,
421            add_generation_prompt => add_generation_prompt,
422            bos_token => bos_tok,
423            eos_token => eos_tok,
424            unk_token => unk_tok,
425            xml_tools => tools.clone(), // SmolLM3
426            tools => tools,
427            builtin_tools => builtin_tools,
428            date_string => date_string,
429            enable_thinking => enable_thinking.unwrap_or(true),
430            reasoning_effort => reasoning_effort_str,
431        })?)
432    }
433}