Skip to main content

llm_tokenizer/
chat_template.rs

1//! Chat template support for tokenizers using Jinja2 templates
2//!
3//! This module provides functionality to apply chat templates to messages,
4//! similar to HuggingFace transformers' apply_chat_template method.
5
6use std::{collections::HashMap, fs, io};
7
8use anyhow::{anyhow, Result};
9use minijinja::{
10    context,
11    machinery::{
12        ast::{Expr, Stmt},
13        parse, WhitespaceConfig,
14    },
15    syntax::SyntaxConfig,
16    value::Kwargs,
17    Environment, Error as MinijinjaError, ErrorKind, Value,
18};
19use serde::Serialize;
20use serde_json::{self, ser::Formatter, Value as JsonValue};
21
22/// Chat template content format
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
24pub enum ChatTemplateContentFormat {
25    /// Content is a simple string
26    #[default]
27    String,
28    /// Content is a list of structured parts (OpenAI format)
29    OpenAI,
30}
31
32impl std::fmt::Display for ChatTemplateContentFormat {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            Self::String => write!(f, "string"),
36            Self::OpenAI => write!(f, "openai"),
37        }
38    }
39}
40
41/// Result of detecting the thinking/reasoning toggle in a chat template.
42/// The variable name the template uses for the thinking toggle.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum ThinkingKeyName {
45    /// Template uses `enable_thinking` (Qwen3, GLM, Nemotron)
46    EnableThinking,
47    /// Template uses `thinking` (DeepSeek V3.1, Kimi-K2.5)
48    Thinking,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum ThinkingToggle {
53    /// Template has no thinking toggle. The model either always reasons
54    /// (e.g. DeepSeek R1) or never does — controlled by the parser's
55    /// `always_in_reasoning` config.
56    #[default]
57    None,
58    /// Template supports a thinking toggle that defaults to ON.
59    /// If the user doesn't pass anything, thinking is enabled.
60    /// (Qwen3, Qwen3.5, Nemotron, GLM-4.6, GLM-5, Kimi-K2.5)
61    DefaultOn,
62    /// Template supports a thinking toggle that defaults to OFF.
63    /// Thinking only activates when the user explicitly passes `thinking=true`.
64    /// (DeepSeek V3.1)
65    DefaultOff,
66}
67
68/// Detect whether the chat template supports a thinking/reasoning toggle
69/// and what its default value is.
70pub fn detect_thinking_toggle(template: &str) -> (ThinkingToggle, Option<ThinkingKeyName>) {
71    let has_enable_thinking = template.contains("enable_thinking");
72    // Trailing space prevents matching "thinking_mode", "thinking_budget", etc.
73    let has_thinking_var = template.contains("if thinking ")
74        || template.contains("thinking is ")
75        || template.contains("thinking ==")
76        || template.contains("set thinking ");
77
78    if !has_enable_thinking && !has_thinking_var {
79        return (ThinkingToggle::None, None);
80    }
81
82    // At least one must be true — both false returned ThinkingToggle::None above.
83    let key_name = if has_enable_thinking {
84        ThinkingKeyName::EnableThinking
85    } else {
86        ThinkingKeyName::Thinking
87    };
88
89    // Check if the template explicitly defaults thinking to false/off.
90    // DeepSeek V3.1 pattern: {% if not thinking is defined %}{% set thinking = false %}
91    if template.contains("set thinking = false") || template.contains("set thinking=false") {
92        return (ThinkingToggle::DefaultOff, Some(key_name));
93    }
94    if template.contains("set enable_thinking = false")
95        || template.contains("set enable_thinking=false")
96    {
97        return (ThinkingToggle::DefaultOff, Some(key_name));
98    }
99
100    // All other models default to thinking ON
101    (ThinkingToggle::DefaultOn, Some(key_name))
102}
103
104/// Detect the content format expected by a Jinja2 chat template
105///
106/// This implements the same detection logic as SGLang's detect_jinja_template_content_format
107/// which uses AST parsing to look for content iteration patterns.
108///
109/// Returns:
110/// - ChatTemplateContentFormat::OpenAI if template expects structured content (list of parts)
111/// - ChatTemplateContentFormat::String if template expects simple string content
112pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat {
113    // Use AST-based detection (enabled by default)
114    detect_format_with_ast(template)
115}
116
117/// Flags tracking which OpenAI-style patterns we've seen
118#[derive(Default, Debug, Clone, Copy)]
119struct Flags {
120    saw_iteration: bool,
121    saw_structure: bool,
122    saw_assignment: bool,
123    saw_macro: bool,
124}
125
126impl Flags {
127    fn any(self) -> bool {
128        // `saw_assignment` alone (e.g. `set content = message.content`) is NOT sufficient
129        // to classify as OpenAI format. Many string-format templates (Qwen3, etc.) use this
130        // pattern to extract content into a local variable, then check `content is string`.
131        // Without iteration or structural access, the template handles string content only.
132        self.saw_iteration || self.saw_structure || self.saw_macro
133    }
134}
135
136/// Single-pass AST detector with scope tracking
137struct Detector<'a> {
138    ast: &'a Stmt<'a>,
139    /// Message loop vars currently in scope (e.g., `message`, `m`, `msg`)
140    scope: std::collections::VecDeque<String>,
141    scope_set: std::collections::HashSet<String>,
142    flags: Flags,
143    /// Whether `<think>` appears inside an `add_generation_prompt` if-block
144    think_in_prefill: bool,
145}
146
147impl<'a> Detector<'a> {
148    fn new(ast: &'a Stmt<'a>) -> Self {
149        Self {
150            ast,
151            scope: std::collections::VecDeque::new(),
152            scope_set: std::collections::HashSet::new(),
153            flags: Flags::default(),
154            think_in_prefill: false,
155        }
156    }
157
158    fn run(mut self) -> (Flags, bool) {
159        self.walk_stmt(self.ast);
160        (self.flags, self.think_in_prefill)
161    }
162
163    fn push_scope(&mut self, var: String) {
164        self.scope.push_back(var.clone());
165        self.scope_set.insert(var);
166    }
167
168    fn pop_scope(&mut self) {
169        if let Some(v) = self.scope.pop_back() {
170            self.scope_set.remove(&v);
171        }
172    }
173
174    fn is_var_access(expr: &Expr, varname: &str) -> bool {
175        matches!(expr, Expr::Var(v) if v.id == varname)
176    }
177
178    fn is_const_str(expr: &Expr, value: &str) -> bool {
179        matches!(expr, Expr::Const(c) if c.value.as_str() == Some(value))
180    }
181
182    fn is_numeric_const(expr: &Expr) -> bool {
183        matches!(expr, Expr::Const(c) if c.value.is_number())
184    }
185
186    /// Check if expr is varname.content or varname["content"]
187    fn is_var_dot_content(expr: &Expr, varname: &str) -> bool {
188        match expr {
189            Expr::GetAttr(g) => Self::is_var_access(&g.expr, varname) && g.name == "content",
190            Expr::GetItem(g) => {
191                Self::is_var_access(&g.expr, varname)
192                    && Self::is_const_str(&g.subscript_expr, "content")
193            }
194            // Unwrap filters/tests that just wrap the same expr
195            Expr::Filter(f) => f
196                .expr
197                .as_ref()
198                .is_some_and(|e| Self::is_var_dot_content(e, varname)),
199            Expr::Test(t) => Self::is_var_dot_content(&t.expr, varname),
200            _ => false,
201        }
202    }
203
204    /// Check if expr accesses .content on any variable in our scope, or any descendant of it.
205    fn is_any_scope_var_content(&self, expr: &Expr) -> bool {
206        let mut current_expr = expr;
207        loop {
208            // Check if current level matches <scopeVar>.content
209            if self
210                .scope_set
211                .iter()
212                .any(|v| Self::is_var_dot_content(current_expr, v))
213            {
214                return true;
215            }
216            // Walk up the expression tree
217            match current_expr {
218                Expr::GetAttr(g) => current_expr = &g.expr,
219                Expr::GetItem(g) => current_expr = &g.expr,
220                _ => return false,
221            }
222        }
223    }
224
225    /// Check if an expression references a variable by name (walks through BinOp/UnaryOp).
226    fn expr_references_var(expr: &Expr, name: &str) -> bool {
227        match expr {
228            Expr::Var(v) => v.id == name,
229            Expr::BinOp(b) => {
230                Self::expr_references_var(&b.left, name)
231                    || Self::expr_references_var(&b.right, name)
232            }
233            Expr::UnaryOp(u) => Self::expr_references_var(&u.expr, name),
234            _ => false,
235        }
236    }
237
238    /// Check if a list of statements contains `<think>` in EmitRaw or string constants.
239    fn body_has_think_tag(stmts: &[Stmt]) -> bool {
240        for stmt in stmts {
241            match stmt {
242                Stmt::EmitRaw(raw) if raw.raw.contains("<think>") => return true,
243                Stmt::EmitExpr(e) => {
244                    if let Expr::Const(c) = &e.expr {
245                        if c.value.as_str().is_some_and(|s| s.contains("<think>")) {
246                            return true;
247                        }
248                    }
249                }
250                Stmt::IfCond(ic)
251                    if Self::body_has_think_tag(&ic.true_body)
252                        || Self::body_has_think_tag(&ic.false_body) =>
253                {
254                    return true;
255                }
256                _ => {}
257            }
258        }
259        false
260    }
261
262    fn walk_stmt(&mut self, stmt: &Stmt) {
263        match stmt {
264            Stmt::Template(t) => {
265                for ch in &t.children {
266                    self.walk_stmt(ch);
267                }
268            }
269            // {% for message in messages %}
270            Stmt::ForLoop(fl) => {
271                // Detect "for X in messages" → push X into scope
272                if let Expr::Var(iter) = &fl.iter {
273                    if iter.id == "messages" {
274                        if let Expr::Var(target) = &fl.target {
275                            self.push_scope(target.id.to_string());
276                        }
277                    }
278                }
279
280                // Also detect "for ... in message.content" or "for ... in content"
281                // - Iterating directly over <scopeVar>.content => OpenAI style
282                if self.is_any_scope_var_content(&fl.iter) {
283                    self.flags.saw_iteration = true;
284                }
285                // - Iterating over a local var named "content"
286                if matches!(&fl.iter, Expr::Var(v) if v.id == "content") {
287                    self.flags.saw_iteration = true;
288                }
289
290                for b in &fl.body {
291                    self.walk_stmt(b);
292                }
293
294                // Pop scope if we pushed it
295                if let Expr::Var(iter) = &fl.iter {
296                    if iter.id == "messages" && matches!(&fl.target, Expr::Var(_)) {
297                        self.pop_scope();
298                    }
299                }
300            }
301            Stmt::IfCond(ic) => {
302                self.inspect_expr_for_structure(&ic.expr);
303
304                // Detect <think> inside {% if add_generation_prompt [and ...] %} body
305                if !self.think_in_prefill
306                    && Self::expr_references_var(&ic.expr, "add_generation_prompt")
307                {
308                    self.think_in_prefill = Self::body_has_think_tag(&ic.true_body);
309                }
310
311                for b in &ic.true_body {
312                    self.walk_stmt(b);
313                }
314                for b in &ic.false_body {
315                    self.walk_stmt(b);
316                }
317            }
318            Stmt::EmitExpr(e) => {
319                self.inspect_expr_for_structure(&e.expr);
320            }
321            // {% set content = message.content %}
322            Stmt::Set(s)
323                if Self::is_var_access(&s.target, "content")
324                    && self.is_any_scope_var_content(&s.expr) =>
325            {
326                self.flags.saw_assignment = true;
327            }
328            Stmt::Macro(m) => {
329                // Heuristic: macro that checks type (via `is` test) and also has any loop
330                let mut has_type_check = false;
331                let mut has_loop = false;
332                Self::scan_macro_body(&m.body, &mut has_type_check, &mut has_loop);
333                if has_type_check && has_loop {
334                    self.flags.saw_macro = true;
335                }
336            }
337            _ => {}
338        }
339    }
340
341    fn inspect_expr_for_structure(&mut self, expr: &Expr) {
342        if self.flags.saw_structure {
343            return;
344        }
345
346        match expr {
347            // content[0] or message.content[0]
348            Expr::GetItem(gi)
349                if (matches!(&gi.expr, Expr::Var(v) if v.id == "content")
350                    || self.is_any_scope_var_content(&gi.expr))
351                    && Self::is_numeric_const(&gi.subscript_expr) =>
352            {
353                self.flags.saw_structure = true;
354            }
355            // content|length or message.content|length
356            Expr::Filter(f) => {
357                if f.name == "length" {
358                    if let Some(inner) = &f.expr {
359                        // Box derefs automatically, so `&**inner` is `&Expr`
360                        let inner_ref: &Expr = inner;
361                        let is_content_var = matches!(inner_ref, Expr::Var(v) if v.id == "content");
362                        if is_content_var || self.is_any_scope_var_content(inner_ref) {
363                            self.flags.saw_structure = true;
364                        }
365                    }
366                } else if let Some(inner) = &f.expr {
367                    let inner_ref: &Expr = inner;
368                    self.inspect_expr_for_structure(inner_ref);
369                }
370            }
371            // Type tests like `content is iterable` or `message.content is string`
372            // These are used for branching (e.g., Llama 3.1 uses them for tool output formatting),
373            // not as indicators that the template expects structured content. Keep walking.
374            Expr::Test(t) => self.inspect_expr_for_structure(&t.expr),
375            Expr::GetAttr(g) => {
376                // Keep walking; nested expressions can hide structure checks
377                self.inspect_expr_for_structure(&g.expr);
378            }
379            // Handle binary operations like: if (message.content is string) and other_cond
380            Expr::BinOp(op) => {
381                self.inspect_expr_for_structure(&op.left);
382                self.inspect_expr_for_structure(&op.right);
383            }
384            // Handle unary operations like: if not (message.content is string)
385            Expr::UnaryOp(op) => {
386                self.inspect_expr_for_structure(&op.expr);
387            }
388            _ => {}
389        }
390    }
391
392    fn scan_macro_body(body: &[Stmt], has_type_check: &mut bool, has_loop: &mut bool) {
393        for s in body {
394            if *has_type_check && *has_loop {
395                return;
396            }
397
398            match s {
399                Stmt::IfCond(ic) => {
400                    if matches!(&ic.expr, Expr::Test(_)) {
401                        *has_type_check = true;
402                    }
403                    Self::scan_macro_body(&ic.true_body, has_type_check, has_loop);
404                    Self::scan_macro_body(&ic.false_body, has_type_check, has_loop);
405                }
406                Stmt::ForLoop(fl) => {
407                    *has_loop = true;
408                    Self::scan_macro_body(&fl.body, has_type_check, has_loop);
409                }
410                Stmt::Template(t) => {
411                    Self::scan_macro_body(&t.children, has_type_check, has_loop);
412                }
413                _ => {}
414            }
415        }
416    }
417}
418
419/// AST-based detection using minijinja's unstable machinery
420/// Single-pass detector with scope tracking
421fn detect_format_with_ast(template: &str) -> ChatTemplateContentFormat {
422    detect_all_with_ast(template).0
423}
424
425/// Single-pass detection of content format, think-in-prefill, and thinking toggle.
426fn detect_all(
427    template: &str,
428) -> (
429    ChatTemplateContentFormat,
430    bool,
431    ThinkingToggle,
432    Option<ThinkingKeyName>,
433) {
434    let (thinking_toggle, thinking_key_name) = detect_thinking_toggle(template);
435    let (content_format, think_in_prefill) = detect_all_with_ast(template);
436    (
437        content_format,
438        think_in_prefill,
439        thinking_toggle,
440        thinking_key_name,
441    )
442}
443
444/// AST detection of content format and think-in-prefill.
445fn detect_all_with_ast(template: &str) -> (ChatTemplateContentFormat, bool) {
446    let ast = match parse(
447        template,
448        "template",
449        SyntaxConfig {},
450        WhitespaceConfig::default(),
451    ) {
452        Ok(ast) => ast,
453        Err(_) => return (ChatTemplateContentFormat::String, false),
454    };
455
456    let (flags, think_in_prefill) = Detector::new(&ast).run();
457    let content_format = if flags.any() {
458        ChatTemplateContentFormat::OpenAI
459    } else {
460        ChatTemplateContentFormat::String
461    };
462    (content_format, think_in_prefill)
463}
464
465/// Parameters for chat template application
466#[derive(Default)]
467pub struct ChatTemplateParams<'a> {
468    pub add_generation_prompt: bool,
469    pub tools: Option<&'a [serde_json::Value]>,
470    pub documents: Option<&'a [serde_json::Value]>,
471    pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
472    /// Special tokens to inject into the template context.
473    /// Many templates reference `{{ bos_token }}`, `{{ eos_token }}`, etc.
474    pub special_tokens: Option<&'a crate::traits::SpecialTokens>,
475}
476
477/// JSON separator pair passed through HuggingFace's `tojson` filter.
478#[derive(Debug, Clone)]
479struct JsonSeparators {
480    item: Vec<u8>,
481    key: Vec<u8>,
482}
483
484impl JsonSeparators {
485    fn python_default(indent: Option<i64>) -> Self {
486        // Python's json.dumps defaults to `(', ', ': ')` for compact output
487        // and `(',', ': ')` when pretty indentation is enabled.
488        let item = if indent.is_some() { "," } else { ", " };
489        Self {
490            item: item.as_bytes().to_vec(),
491            key: b": ".to_vec(),
492        }
493    }
494}
495
496/// Formatter matching Python's `json.dumps` separator and ASCII escaping rules.
497#[derive(Debug, Clone)]
498struct PythonJsonFormatter {
499    current_indent: usize,
500    has_value: bool,
501    indent: Option<Vec<u8>>,
502    separators: JsonSeparators,
503    ensure_ascii: bool,
504}
505
506impl PythonJsonFormatter {
507    fn new(indent: Option<usize>, separators: JsonSeparators, ensure_ascii: bool) -> Self {
508        Self {
509            current_indent: 0,
510            has_value: false,
511            indent: indent.map(|spaces| vec![b' '; spaces]),
512            separators,
513            ensure_ascii,
514        }
515    }
516}
517
518fn write_indent<W>(writer: &mut W, count: usize, indent: &[u8]) -> io::Result<()>
519where
520    W: ?Sized + io::Write,
521{
522    for _ in 0..count {
523        writer.write_all(indent)?;
524    }
525    Ok(())
526}
527
528fn write_u_escape<W>(writer: &mut W, code: u16) -> io::Result<()>
529where
530    W: ?Sized + io::Write,
531{
532    const HEX: &[u8; 16] = b"0123456789abcdef";
533    writer.write_all(&[
534        b'\\',
535        b'u',
536        HEX[((code >> 12) & 0xF) as usize],
537        HEX[((code >> 8) & 0xF) as usize],
538        HEX[((code >> 4) & 0xF) as usize],
539        HEX[(code & 0xF) as usize],
540    ])
541}
542
543impl Formatter for PythonJsonFormatter {
544    fn write_string_fragment<W>(&mut self, writer: &mut W, fragment: &str) -> io::Result<()>
545    where
546        W: ?Sized + io::Write,
547    {
548        if !self.ensure_ascii {
549            return writer.write_all(fragment.as_bytes());
550        }
551
552        for ch in fragment.chars() {
553            if ch.is_ascii() {
554                let mut buf = [0; 4];
555                writer.write_all(ch.encode_utf8(&mut buf).as_bytes())?;
556                continue;
557            }
558
559            let code = ch as u32;
560            if code <= 0xFFFF {
561                write_u_escape(writer, code as u16)?;
562            } else {
563                let shifted = code - 0x1_0000;
564                let high = 0xD800 + ((shifted >> 10) as u16);
565                let low = 0xDC00 + ((shifted & 0x3FF) as u16);
566                write_u_escape(writer, high)?;
567                write_u_escape(writer, low)?;
568            }
569        }
570        Ok(())
571    }
572
573    fn begin_array<W>(&mut self, writer: &mut W) -> io::Result<()>
574    where
575        W: ?Sized + io::Write,
576    {
577        if self.indent.is_some() {
578            self.current_indent += 1;
579            self.has_value = false;
580        }
581        writer.write_all(b"[")
582    }
583
584    fn end_array<W>(&mut self, writer: &mut W) -> io::Result<()>
585    where
586        W: ?Sized + io::Write,
587    {
588        if let Some(indent) = self.indent.as_deref() {
589            self.current_indent -= 1;
590            if self.has_value {
591                writer.write_all(b"\n")?;
592                write_indent(writer, self.current_indent, indent)?;
593            }
594        }
595        writer.write_all(b"]")
596    }
597
598    fn begin_array_value<W>(&mut self, writer: &mut W, first: bool) -> io::Result<()>
599    where
600        W: ?Sized + io::Write,
601    {
602        if let Some(indent) = self.indent.as_deref() {
603            if first {
604                writer.write_all(b"\n")?;
605            } else {
606                writer.write_all(&self.separators.item)?;
607                writer.write_all(b"\n")?;
608            }
609            write_indent(writer, self.current_indent, indent)
610        } else if first {
611            Ok(())
612        } else {
613            writer.write_all(&self.separators.item)
614        }
615    }
616
617    fn end_array_value<W>(&mut self, _writer: &mut W) -> io::Result<()>
618    where
619        W: ?Sized + io::Write,
620    {
621        self.has_value = true;
622        Ok(())
623    }
624
625    fn begin_object<W>(&mut self, writer: &mut W) -> io::Result<()>
626    where
627        W: ?Sized + io::Write,
628    {
629        if self.indent.is_some() {
630            self.current_indent += 1;
631            self.has_value = false;
632        }
633        writer.write_all(b"{")
634    }
635
636    fn end_object<W>(&mut self, writer: &mut W) -> io::Result<()>
637    where
638        W: ?Sized + io::Write,
639    {
640        if let Some(indent) = self.indent.as_deref() {
641            self.current_indent -= 1;
642            if self.has_value {
643                writer.write_all(b"\n")?;
644                write_indent(writer, self.current_indent, indent)?;
645            }
646        }
647        writer.write_all(b"}")
648    }
649
650    fn begin_object_key<W>(&mut self, writer: &mut W, first: bool) -> io::Result<()>
651    where
652        W: ?Sized + io::Write,
653    {
654        if let Some(indent) = self.indent.as_deref() {
655            if first {
656                writer.write_all(b"\n")?;
657            } else {
658                writer.write_all(&self.separators.item)?;
659                writer.write_all(b"\n")?;
660            }
661            write_indent(writer, self.current_indent, indent)
662        } else if first {
663            Ok(())
664        } else {
665            writer.write_all(&self.separators.item)
666        }
667    }
668
669    fn begin_object_value<W>(&mut self, writer: &mut W) -> io::Result<()>
670    where
671        W: ?Sized + io::Write,
672    {
673        writer.write_all(&self.separators.key)
674    }
675
676    fn end_object_value<W>(&mut self, _writer: &mut W) -> io::Result<()>
677    where
678        W: ?Sized + io::Write,
679    {
680        self.has_value = true;
681        Ok(())
682    }
683}
684
685fn invalid_tojson_option(message: impl Into<String>) -> MinijinjaError {
686    MinijinjaError::new(ErrorKind::InvalidOperation, message.into())
687}
688
689fn parse_separators(
690    separators: Option<Value>,
691    indent: Option<i64>,
692) -> std::result::Result<JsonSeparators, MinijinjaError> {
693    let Some(separators) = separators else {
694        return Ok(JsonSeparators::python_default(indent));
695    };
696    if separators.is_none() || separators.is_undefined() {
697        return Ok(JsonSeparators::python_default(indent));
698    }
699
700    let parsed: serde_json::Value = serde_json::to_value(&separators).map_err(|e| {
701        invalid_tojson_option(format!("Failed to convert separators to JSON value: {e}"))
702    })?;
703    let JsonValue::Array(values) = parsed else {
704        return Err(invalid_tojson_option(
705            "separators must be a two-item sequence",
706        ));
707    };
708    if values.len() != 2 {
709        return Err(invalid_tojson_option(
710            "separators must be a two-item sequence",
711        ));
712    }
713
714    let item = values[0]
715        .as_str()
716        .ok_or_else(|| invalid_tojson_option("item separator must be a string"))?;
717    let key = values[1]
718        .as_str()
719        .ok_or_else(|| invalid_tojson_option("key separator must be a string"))?;
720
721    Ok(JsonSeparators {
722        item: item.as_bytes().to_vec(),
723        key: key.as_bytes().to_vec(),
724    })
725}
726
727fn serialize_with_python_json<T: Serialize>(
728    value: &T,
729    indent: Option<i64>,
730    separators: JsonSeparators,
731    ensure_ascii: bool,
732) -> std::result::Result<String, MinijinjaError> {
733    let indent = indent
734        .map(|spaces| {
735            if spaces < 0 {
736                Err(invalid_tojson_option("indent cannot be negative"))
737            } else {
738                Ok(spaces as usize)
739            }
740        })
741        .transpose()?;
742
743    let formatter = PythonJsonFormatter::new(indent, separators, ensure_ascii);
744    let mut buf = Vec::new();
745    let mut serializer = serde_json::Serializer::with_formatter(&mut buf, formatter);
746    value.serialize(&mut serializer).map_err(|e| {
747        MinijinjaError::new(
748            ErrorKind::InvalidOperation,
749            format!("Failed to serialize JSON: {e}"),
750        )
751    })?;
752    String::from_utf8(buf).map_err(|e| {
753        MinijinjaError::new(
754            ErrorKind::InvalidOperation,
755            format!("Invalid UTF-8 in JSON output: {e}"),
756        )
757    })
758}
759
760/// Custom tojson filter compatible with HuggingFace transformers' implementation.
761///
762/// HuggingFace transformers registers a custom `tojson` filter that accepts additional
763/// keyword arguments beyond what standard Jinja2 provides:
764/// - `ensure_ascii` (bool): Whether to escape non-ASCII characters
765/// - `indent` (int): Number of spaces for indentation (pretty-printing)
766/// - `separators`: Custom item/key separators for JSON output
767/// - `sort_keys` (bool): Whether to sort dictionary keys
768///
769/// This is necessary for compatibility with chat templates from HuggingFace Hub models.
770/// See: https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
771fn tojson_filter(value: Value, kwargs: Kwargs) -> std::result::Result<Value, MinijinjaError> {
772    let ensure_ascii: Option<bool> = kwargs.get("ensure_ascii")?;
773    let indent: Option<i64> = kwargs.get("indent")?;
774    let separators: Option<Value> = kwargs.get("separators")?;
775    let sort_keys: Option<bool> = kwargs.get("sort_keys")?;
776
777    // Ensure all kwargs are consumed to avoid "unknown keyword argument" errors
778    kwargs.assert_all_used()?;
779
780    let json_value: serde_json::Value = serde_json::to_value(&value).map_err(|e| {
781        MinijinjaError::new(
782            ErrorKind::InvalidOperation,
783            format!("Failed to convert to JSON value: {e}"),
784        )
785    })?;
786
787    // Serialize with options
788    let json_str: std::result::Result<String, MinijinjaError> = {
789        let sorted_json;
790        let value_to_serialize = if sort_keys.unwrap_or(false) {
791            sorted_json = sort_json_keys(&json_value);
792            &sorted_json
793        } else {
794            &json_value
795        };
796
797        let separators = parse_separators(separators, indent)?;
798        serialize_with_python_json(
799            value_to_serialize,
800            indent,
801            separators,
802            ensure_ascii.unwrap_or(false),
803        )
804    };
805
806    json_str.map(Value::from_safe_string)
807}
808
809/// Recursively sort all object keys in a JSON value
810fn sort_json_keys(value: &JsonValue) -> JsonValue {
811    match value {
812        JsonValue::Object(map) => {
813            let mut sorted: serde_json::Map<String, JsonValue> = serde_json::Map::new();
814            let mut keys: Vec<_> = map.keys().collect();
815            keys.sort();
816            for key in keys {
817                sorted.insert(key.clone(), sort_json_keys(&map[key]));
818            }
819            JsonValue::Object(sorted)
820        }
821        JsonValue::Array(arr) => JsonValue::Array(arr.iter().map(sort_json_keys).collect()),
822        _ => value.clone(),
823    }
824}
825
826/// Build a pre-configured `Environment<'static>` with the given template string,
827/// Python-compat method callback, and custom `tojson` filter already registered.
828/// The template is stored under the name `"chat"` using owned storage so the
829/// environment carries no borrows.
830fn build_environment(template: String) -> Result<Environment<'static>> {
831    let mut env = Environment::new();
832
833    // Match HuggingFace's Jinja2 defaults: trim_blocks and lstrip_blocks are
834    // enabled in Python's transformers but default to false in minijinja.
835    // Without these, templates like GLM-5's produce incorrect whitespace.
836    env.set_trim_blocks(true);
837    env.set_lstrip_blocks(true);
838
839    // Register the template with owned storage (no lifetime dependency on caller)
840    env.add_template_owned("chat".to_owned(), template)
841        .map_err(|e| anyhow!("Failed to add template: {e}"))?;
842
843    // Enable Python method compatibility (e.g., str.startswith, str.endswith)
844    env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
845
846    // Register custom tojson filter compatible with HuggingFace transformers
847    // This overrides minijinja's built-in tojson to support additional kwargs
848    // like ensure_ascii, separators, and sort_keys that HuggingFace templates use
849    env.add_filter("tojson", tojson_filter);
850
851    Ok(env)
852}
853
854/// Render the `"chat"` template in the given environment against messages and params.
855/// Convert an optional token string to a minijinja Value.
856/// Present tokens become strings; absent tokens become UNDEFINED
857/// so templates can use `{% if bos_token is defined %}` guards.
858fn special_token_value(token: Option<&str>) -> Value {
859    token.map_or(Value::UNDEFINED, Value::from)
860}
861
862fn render_chat_template(
863    env: &Environment<'_>,
864    messages: &[serde_json::Value],
865    params: ChatTemplateParams,
866) -> Result<String> {
867    let tmpl = env
868        .get_template("chat")
869        .map_err(|e| anyhow!("Failed to get template: {e}"))?;
870
871    // Convert messages to minijinja::Value (messages already processed by router)
872    let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
873
874    // Use Value::UNDEFINED for missing optional params so they are truly "undefined"
875    // in the template context, matching HuggingFace Python behavior. Many chat templates
876    // use `{% if tools is defined %}` guards — passing null (none) instead of undefined
877    // would bypass those guards since `none` IS defined, causing `tools | length` to fail.
878    let tools_value = params.tools.map_or(Value::UNDEFINED, Value::from_serialize);
879    let documents_value = params
880        .documents
881        .map_or(Value::UNDEFINED, Value::from_serialize);
882
883    // Inject special tokens (bos_token, eos_token, etc.) into context.
884    // Use UNDEFINED for missing tokens so `{% if bos_token is defined %}` works correctly.
885    // This matches HuggingFace Python which passes self.special_tokens_map to the renderer.
886    let bos_value =
887        special_token_value(params.special_tokens.and_then(|st| st.bos_token.as_deref()));
888    let eos_value =
889        special_token_value(params.special_tokens.and_then(|st| st.eos_token.as_deref()));
890    let unk_value =
891        special_token_value(params.special_tokens.and_then(|st| st.unk_token.as_deref()));
892    let pad_value =
893        special_token_value(params.special_tokens.and_then(|st| st.pad_token.as_deref()));
894
895    let base_context = context! {
896        messages => &minijinja_messages,
897        add_generation_prompt => params.add_generation_prompt,
898        tools => tools_value,
899        documents => documents_value,
900        bos_token => bos_value,
901        eos_token => eos_value,
902        unk_token => unk_value,
903        pad_token => pad_value,
904    };
905
906    // Merge with template_kwargs if provided (caller kwargs override special tokens)
907    let ctx = if let Some(kwargs) = params.template_kwargs {
908        context! {
909            ..base_context,
910            ..Value::from_serialize(kwargs)
911        }
912    } else {
913        base_context
914    };
915
916    // Render the template
917    let rendered = tmpl
918        .render(&ctx)
919        .map_err(|e| anyhow!("Failed to render template: {e}"))?;
920
921    Ok(rendered)
922}
923
924/// Chat template processor using Jinja2 - simple wrapper like HuggingFace
925pub struct ChatTemplateProcessor {
926    env: Environment<'static>,
927}
928
929impl ChatTemplateProcessor {
930    /// Create a new chat template processor.
931    ///
932    /// Returns an error if the template fails to parse, so callers get an
933    /// actionable message immediately rather than a confusing "template not
934    /// found" error on the first render.
935    pub fn new(template: String) -> Result<Self> {
936        let env = build_environment(template)?;
937        Ok(ChatTemplateProcessor { env })
938    }
939
940    /// Apply the chat template to a list of messages
941    ///
942    /// This mimics the behavior of HuggingFace's apply_chat_template method
943    /// but returns the formatted string instead of token IDs.
944    /// Messages should be pre-processed into the format expected by the template.
945    pub fn apply_chat_template(
946        &self,
947        messages: &[serde_json::Value],
948        params: ChatTemplateParams,
949    ) -> Result<String> {
950        render_chat_template(&self.env, messages, params)
951    }
952}
953
954/// Load chat template from tokenizer config JSON
955pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
956    let content = fs::read_to_string(config_path)?;
957    let config: serde_json::Value = serde_json::from_str(&content)?;
958
959    // Look for chat_template in the config
960    if let Some(template) = config.get("chat_template") {
961        if let Some(template_str) = template.as_str() {
962            return Ok(Some(template_str.to_string()));
963        }
964    }
965
966    Ok(None)
967}
968
969/// Load chat template from a file (.jinja or .json containing Jinja).
970/// Shared between all tokenizer backends.
971pub fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
972    let content = fs::read_to_string(template_path)
973        .map_err(|e| anyhow!("Failed to read chat template file: {e}"))?;
974
975    if template_path.ends_with(".json") {
976        let json_value: serde_json::Value = serde_json::from_str(&content)
977            .map_err(|e| anyhow!("Failed to parse chat_template.json: {e}"))?;
978
979        if let Some(template_str) = json_value.as_str() {
980            return Ok(Some(template_str.to_string()));
981        } else if let Some(obj) = json_value.as_object() {
982            if let Some(template_value) = obj.get("chat_template") {
983                if let Some(template_str) = template_value.as_str() {
984                    return Ok(Some(template_str.to_string()));
985                }
986            }
987        }
988
989        return Err(anyhow!(
990            "chat_template.json does not contain a valid template",
991        ));
992    }
993
994    // Plain .jinja file
995    let template = content.trim().replace("\\n", "\n");
996    Ok(Some(template))
997}
998
999/// Chat template state that can be embedded in any tokenizer struct.
1000/// Eliminates duplicated apply/set/format methods across tokenizer backends.
1001///
1002/// The compiled `minijinja::Environment` (with the template parsed, filters
1003/// registered, and Python-compat callback installed) is cached so that
1004/// `apply()` only performs rendering -- no parsing or environment setup.
1005/// The cache is rebuilt whenever `set()` is called.
1006///
1007/// `Environment<'static>` is both `Send` and `Sync`, so embedding this in
1008/// tokenizer structs shared across threads is safe.
1009pub struct ChatTemplateState {
1010    /// Cached, fully-configured environment. `None` when no template is set.
1011    env: Option<Environment<'static>>,
1012    content_format: ChatTemplateContentFormat,
1013    /// Thinking toggle support detected from the template.
1014    thinking_toggle: ThinkingToggle,
1015    /// The variable name used for the thinking toggle (if any).
1016    thinking_key_name: Option<ThinkingKeyName>,
1017    /// Whether the template injects `<think>` in the generation prompt.
1018    think_in_prefill: bool,
1019}
1020
1021impl std::fmt::Debug for ChatTemplateState {
1022    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1023        f.debug_struct("ChatTemplateState")
1024            .field("has_template", &self.env.is_some())
1025            .field("content_format", &self.content_format)
1026            .field("thinking_toggle", &self.thinking_toggle)
1027            .field("think_in_prefill", &self.think_in_prefill)
1028            .finish()
1029    }
1030}
1031
1032impl ChatTemplateState {
1033    pub fn new(template: Option<String>) -> Result<Self> {
1034        let (content_format, think_in_prefill, thinking_toggle, thinking_key_name) =
1035            template.as_ref().map(|t| detect_all(t)).unwrap_or_default();
1036        let env = template.map(build_environment).transpose()?;
1037        Ok(Self {
1038            env,
1039            content_format,
1040            thinking_toggle,
1041            thinking_key_name,
1042            think_in_prefill,
1043        })
1044    }
1045
1046    /// Create a `ChatTemplateState` with no template set.
1047    ///
1048    /// Unlike `new(None)`, this is infallible since there is no template to
1049    /// parse — useful in constructors that don't return `Result`.
1050    pub fn empty() -> Self {
1051        Self {
1052            env: None,
1053            content_format: ChatTemplateContentFormat::default(),
1054            thinking_toggle: ThinkingToggle::None,
1055            thinking_key_name: None,
1056            think_in_prefill: false,
1057        }
1058    }
1059
1060    pub fn apply(
1061        &self,
1062        messages: &[serde_json::Value],
1063        params: ChatTemplateParams,
1064    ) -> Result<String> {
1065        let env = self.env.as_ref().ok_or_else(|| {
1066            anyhow!(
1067                "Cannot use chat template functions because tokenizer.chat_template is not set \
1068                 and no template argument was passed! For information about writing templates and \
1069                 setting the tokenizer.chat_template attribute, please see the documentation at \
1070                 https://huggingface.co/docs/transformers/main/en/chat_templating",
1071            )
1072        })?;
1073        render_chat_template(env, messages, params)
1074    }
1075
1076    pub fn set(&mut self, template: String) -> Result<()> {
1077        let (content_format, think_in_prefill, thinking_toggle, thinking_key_name) =
1078            detect_all(&template);
1079        let env = build_environment(template)?;
1080        self.content_format = content_format;
1081        self.thinking_toggle = thinking_toggle;
1082        self.thinking_key_name = thinking_key_name;
1083        self.think_in_prefill = think_in_prefill;
1084        self.env = Some(env);
1085        Ok(())
1086    }
1087
1088    pub fn content_format(&self) -> ChatTemplateContentFormat {
1089        self.content_format
1090    }
1091
1092    pub fn thinking_toggle(&self) -> ThinkingToggle {
1093        self.thinking_toggle
1094    }
1095
1096    pub fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
1097        self.thinking_key_name
1098    }
1099
1100    pub fn think_in_prefill(&self) -> bool {
1101        self.think_in_prefill
1102    }
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107    use super::*;
1108
1109    #[test]
1110    fn test_chat_template_state_no_template() {
1111        let state = ChatTemplateState::new(None).unwrap();
1112        assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
1113        let result = state.apply(&[], ChatTemplateParams::default());
1114        assert!(result.is_err());
1115    }
1116
1117    #[test]
1118    fn test_chat_template_state_set() {
1119        let mut state = ChatTemplateState::new(None).unwrap();
1120        state.set("{{ messages }}".to_string()).unwrap();
1121        assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
1122    }
1123
1124    #[test]
1125    fn test_chat_template_state_invalid_template() {
1126        let result = ChatTemplateState::new(Some("{% invalid".to_string()));
1127        assert!(result.is_err());
1128        let err = result.unwrap_err().to_string();
1129        assert!(
1130            err.contains("Failed to add template"),
1131            "Error should explain parse failure, got: {err}"
1132        );
1133    }
1134
1135    #[test]
1136    fn test_chat_template_processor_invalid_template() {
1137        let result = ChatTemplateProcessor::new("{% invalid".to_string());
1138        assert!(result.is_err());
1139    }
1140
1141    #[test]
1142    fn test_special_tokens_injected_into_context() {
1143        let template = "{{ bos_token }}{% for message in messages %}{{ message.content }}{% endfor %}{{ eos_token }}";
1144        let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
1145
1146        let messages = vec![serde_json::json!({"role": "user", "content": "hello"})];
1147        let special_tokens = crate::traits::SpecialTokens {
1148            bos_token: Some("<s>".to_string()),
1149            eos_token: Some("</s>".to_string()),
1150            ..Default::default()
1151        };
1152
1153        let result = state
1154            .apply(
1155                &messages,
1156                ChatTemplateParams {
1157                    special_tokens: Some(&special_tokens),
1158                    ..Default::default()
1159                },
1160            )
1161            .unwrap();
1162
1163        assert_eq!(result, "<s>hello</s>");
1164    }
1165
1166    #[test]
1167    fn test_special_tokens_undefined_when_not_provided() {
1168        let template = "{% if bos_token is defined %}{{ bos_token }}{% endif %}hello";
1169        let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
1170
1171        let result = state.apply(&[], ChatTemplateParams::default()).unwrap();
1172        assert_eq!(result, "hello");
1173    }
1174
1175    #[test]
1176    fn test_special_tokens_partial() {
1177        let template =
1178            "{{ bos_token }}hello{% if eos_token is defined %}{{ eos_token }}{% endif %}";
1179        let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
1180
1181        let special_tokens = crate::traits::SpecialTokens {
1182            bos_token: Some("<s>".to_string()),
1183            eos_token: None,
1184            ..Default::default()
1185        };
1186
1187        let result = state
1188            .apply(
1189                &[],
1190                ChatTemplateParams {
1191                    special_tokens: Some(&special_tokens),
1192                    ..Default::default()
1193                },
1194            )
1195            .unwrap();
1196
1197        assert_eq!(result, "<s>hello");
1198    }
1199}