Skip to main content

llm_tokenizer/
chat_template.rs

1//! Chat template support for tokenizers using Jinja2 templates
2//!
3//! This module provides functionality to apply chat templates to messages,
4//! similar to HuggingFace transformers' apply_chat_template method.
5
6use std::{collections::HashMap, fs};
7
8use anyhow::{anyhow, Result};
9use minijinja::{
10    context,
11    machinery::{
12        ast::{Expr, Stmt},
13        parse, WhitespaceConfig,
14    },
15    syntax::SyntaxConfig,
16    value::Kwargs,
17    Environment, Error as MinijinjaError, ErrorKind, Value,
18};
19use serde::Serialize;
20use serde_json::{self, ser::PrettyFormatter, Value as JsonValue};
21
22/// Chat template content format
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
24pub enum ChatTemplateContentFormat {
25    /// Content is a simple string
26    #[default]
27    String,
28    /// Content is a list of structured parts (OpenAI format)
29    OpenAI,
30}
31
32impl std::fmt::Display for ChatTemplateContentFormat {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            Self::String => write!(f, "string"),
36            Self::OpenAI => write!(f, "openai"),
37        }
38    }
39}
40
41/// Result of detecting the thinking/reasoning toggle in a chat template.
42/// The variable name the template uses for the thinking toggle.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum ThinkingKeyName {
45    /// Template uses `enable_thinking` (Qwen3, GLM, Nemotron)
46    EnableThinking,
47    /// Template uses `thinking` (DeepSeek V3.1, Kimi-K2.5)
48    Thinking,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum ThinkingToggle {
53    /// Template has no thinking toggle. The model either always reasons
54    /// (e.g. DeepSeek R1) or never does — controlled by the parser's
55    /// `always_in_reasoning` config.
56    #[default]
57    None,
58    /// Template supports a thinking toggle that defaults to ON.
59    /// If the user doesn't pass anything, thinking is enabled.
60    /// (Qwen3, Qwen3.5, Nemotron, GLM-4.6, GLM-5, Kimi-K2.5)
61    DefaultOn,
62    /// Template supports a thinking toggle that defaults to OFF.
63    /// Thinking only activates when the user explicitly passes `thinking=true`.
64    /// (DeepSeek V3.1)
65    DefaultOff,
66}
67
68/// Detect whether the chat template supports a thinking/reasoning toggle
69/// and what its default value is.
70pub fn detect_thinking_toggle(template: &str) -> (ThinkingToggle, Option<ThinkingKeyName>) {
71    let has_enable_thinking = template.contains("enable_thinking");
72    // Trailing space prevents matching "thinking_mode", "thinking_budget", etc.
73    let has_thinking_var = template.contains("if thinking ")
74        || template.contains("thinking is ")
75        || template.contains("thinking ==")
76        || template.contains("set thinking ");
77
78    if !has_enable_thinking && !has_thinking_var {
79        return (ThinkingToggle::None, None);
80    }
81
82    // At least one must be true — both false returned ThinkingToggle::None above.
83    let key_name = if has_enable_thinking {
84        ThinkingKeyName::EnableThinking
85    } else {
86        ThinkingKeyName::Thinking
87    };
88
89    // Check if the template explicitly defaults thinking to false/off.
90    // DeepSeek V3.1 pattern: {% if not thinking is defined %}{% set thinking = false %}
91    if template.contains("set thinking = false") || template.contains("set thinking=false") {
92        return (ThinkingToggle::DefaultOff, Some(key_name));
93    }
94    if template.contains("set enable_thinking = false")
95        || template.contains("set enable_thinking=false")
96    {
97        return (ThinkingToggle::DefaultOff, Some(key_name));
98    }
99
100    // All other models default to thinking ON
101    (ThinkingToggle::DefaultOn, Some(key_name))
102}
103
104/// Detect the content format expected by a Jinja2 chat template
105///
106/// This implements the same detection logic as SGLang's detect_jinja_template_content_format
107/// which uses AST parsing to look for content iteration patterns.
108///
109/// Returns:
110/// - ChatTemplateContentFormat::OpenAI if template expects structured content (list of parts)
111/// - ChatTemplateContentFormat::String if template expects simple string content
112pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat {
113    // Use AST-based detection (enabled by default)
114    detect_format_with_ast(template)
115}
116
117/// Flags tracking which OpenAI-style patterns we've seen
118#[derive(Default, Debug, Clone, Copy)]
119struct Flags {
120    saw_iteration: bool,
121    saw_structure: bool,
122    saw_assignment: bool,
123    saw_macro: bool,
124}
125
126impl Flags {
127    fn any(self) -> bool {
128        // `saw_assignment` alone (e.g. `set content = message.content`) is NOT sufficient
129        // to classify as OpenAI format. Many string-format templates (Qwen3, etc.) use this
130        // pattern to extract content into a local variable, then check `content is string`.
131        // Without iteration or structural access, the template handles string content only.
132        self.saw_iteration || self.saw_structure || self.saw_macro
133    }
134}
135
136/// Single-pass AST detector with scope tracking
137struct Detector<'a> {
138    ast: &'a Stmt<'a>,
139    /// Message loop vars currently in scope (e.g., `message`, `m`, `msg`)
140    scope: std::collections::VecDeque<String>,
141    scope_set: std::collections::HashSet<String>,
142    flags: Flags,
143    /// Whether `<think>` appears inside an `add_generation_prompt` if-block
144    think_in_prefill: bool,
145}
146
147impl<'a> Detector<'a> {
148    fn new(ast: &'a Stmt<'a>) -> Self {
149        Self {
150            ast,
151            scope: std::collections::VecDeque::new(),
152            scope_set: std::collections::HashSet::new(),
153            flags: Flags::default(),
154            think_in_prefill: false,
155        }
156    }
157
158    fn run(mut self) -> (Flags, bool) {
159        self.walk_stmt(self.ast);
160        (self.flags, self.think_in_prefill)
161    }
162
163    fn push_scope(&mut self, var: String) {
164        self.scope.push_back(var.clone());
165        self.scope_set.insert(var);
166    }
167
168    fn pop_scope(&mut self) {
169        if let Some(v) = self.scope.pop_back() {
170            self.scope_set.remove(&v);
171        }
172    }
173
174    fn is_var_access(expr: &Expr, varname: &str) -> bool {
175        matches!(expr, Expr::Var(v) if v.id == varname)
176    }
177
178    fn is_const_str(expr: &Expr, value: &str) -> bool {
179        matches!(expr, Expr::Const(c) if c.value.as_str() == Some(value))
180    }
181
182    fn is_numeric_const(expr: &Expr) -> bool {
183        matches!(expr, Expr::Const(c) if c.value.is_number())
184    }
185
186    /// Check if expr is varname.content or varname["content"]
187    fn is_var_dot_content(expr: &Expr, varname: &str) -> bool {
188        match expr {
189            Expr::GetAttr(g) => Self::is_var_access(&g.expr, varname) && g.name == "content",
190            Expr::GetItem(g) => {
191                Self::is_var_access(&g.expr, varname)
192                    && Self::is_const_str(&g.subscript_expr, "content")
193            }
194            // Unwrap filters/tests that just wrap the same expr
195            Expr::Filter(f) => f
196                .expr
197                .as_ref()
198                .is_some_and(|e| Self::is_var_dot_content(e, varname)),
199            Expr::Test(t) => Self::is_var_dot_content(&t.expr, varname),
200            _ => false,
201        }
202    }
203
204    /// Check if expr accesses .content on any variable in our scope, or any descendant of it.
205    fn is_any_scope_var_content(&self, expr: &Expr) -> bool {
206        let mut current_expr = expr;
207        loop {
208            // Check if current level matches <scopeVar>.content
209            if self
210                .scope_set
211                .iter()
212                .any(|v| Self::is_var_dot_content(current_expr, v))
213            {
214                return true;
215            }
216            // Walk up the expression tree
217            match current_expr {
218                Expr::GetAttr(g) => current_expr = &g.expr,
219                Expr::GetItem(g) => current_expr = &g.expr,
220                _ => return false,
221            }
222        }
223    }
224
225    /// Check if an expression references a variable by name (walks through BinOp/UnaryOp).
226    fn expr_references_var(expr: &Expr, name: &str) -> bool {
227        match expr {
228            Expr::Var(v) => v.id == name,
229            Expr::BinOp(b) => {
230                Self::expr_references_var(&b.left, name)
231                    || Self::expr_references_var(&b.right, name)
232            }
233            Expr::UnaryOp(u) => Self::expr_references_var(&u.expr, name),
234            _ => false,
235        }
236    }
237
238    /// Check if a list of statements contains `<think>` in EmitRaw or string constants.
239    fn body_has_think_tag(stmts: &[Stmt]) -> bool {
240        for stmt in stmts {
241            match stmt {
242                Stmt::EmitRaw(raw) if raw.raw.contains("<think>") => return true,
243                Stmt::EmitExpr(e) => {
244                    if let Expr::Const(c) = &e.expr {
245                        if c.value.as_str().is_some_and(|s| s.contains("<think>")) {
246                            return true;
247                        }
248                    }
249                }
250                Stmt::IfCond(ic) => {
251                    if Self::body_has_think_tag(&ic.true_body)
252                        || Self::body_has_think_tag(&ic.false_body)
253                    {
254                        return true;
255                    }
256                }
257                _ => {}
258            }
259        }
260        false
261    }
262
263    fn walk_stmt(&mut self, stmt: &Stmt) {
264        match stmt {
265            Stmt::Template(t) => {
266                for ch in &t.children {
267                    self.walk_stmt(ch);
268                }
269            }
270            // {% for message in messages %}
271            Stmt::ForLoop(fl) => {
272                // Detect "for X in messages" → push X into scope
273                if let Expr::Var(iter) = &fl.iter {
274                    if iter.id == "messages" {
275                        if let Expr::Var(target) = &fl.target {
276                            self.push_scope(target.id.to_string());
277                        }
278                    }
279                }
280
281                // Also detect "for ... in message.content" or "for ... in content"
282                // - Iterating directly over <scopeVar>.content => OpenAI style
283                if self.is_any_scope_var_content(&fl.iter) {
284                    self.flags.saw_iteration = true;
285                }
286                // - Iterating over a local var named "content"
287                if matches!(&fl.iter, Expr::Var(v) if v.id == "content") {
288                    self.flags.saw_iteration = true;
289                }
290
291                for b in &fl.body {
292                    self.walk_stmt(b);
293                }
294
295                // Pop scope if we pushed it
296                if let Expr::Var(iter) = &fl.iter {
297                    if iter.id == "messages" && matches!(&fl.target, Expr::Var(_)) {
298                        self.pop_scope();
299                    }
300                }
301            }
302            Stmt::IfCond(ic) => {
303                self.inspect_expr_for_structure(&ic.expr);
304
305                // Detect <think> inside {% if add_generation_prompt [and ...] %} body
306                if !self.think_in_prefill
307                    && Self::expr_references_var(&ic.expr, "add_generation_prompt")
308                {
309                    self.think_in_prefill = Self::body_has_think_tag(&ic.true_body);
310                }
311
312                for b in &ic.true_body {
313                    self.walk_stmt(b);
314                }
315                for b in &ic.false_body {
316                    self.walk_stmt(b);
317                }
318            }
319            Stmt::EmitExpr(e) => {
320                self.inspect_expr_for_structure(&e.expr);
321            }
322            // {% set content = message.content %}
323            Stmt::Set(s) => {
324                if Self::is_var_access(&s.target, "content")
325                    && self.is_any_scope_var_content(&s.expr)
326                {
327                    self.flags.saw_assignment = true;
328                }
329            }
330            Stmt::Macro(m) => {
331                // Heuristic: macro that checks type (via `is` test) and also has any loop
332                let mut has_type_check = false;
333                let mut has_loop = false;
334                Self::scan_macro_body(&m.body, &mut has_type_check, &mut has_loop);
335                if has_type_check && has_loop {
336                    self.flags.saw_macro = true;
337                }
338            }
339            _ => {}
340        }
341    }
342
343    fn inspect_expr_for_structure(&mut self, expr: &Expr) {
344        if self.flags.saw_structure {
345            return;
346        }
347
348        match expr {
349            // content[0] or message.content[0]
350            Expr::GetItem(gi) => {
351                if (matches!(&gi.expr, Expr::Var(v) if v.id == "content")
352                    || self.is_any_scope_var_content(&gi.expr))
353                    && Self::is_numeric_const(&gi.subscript_expr)
354                {
355                    self.flags.saw_structure = true;
356                }
357            }
358            // content|length or message.content|length
359            Expr::Filter(f) => {
360                if f.name == "length" {
361                    if let Some(inner) = &f.expr {
362                        // Box derefs automatically, so `&**inner` is `&Expr`
363                        let inner_ref: &Expr = inner;
364                        let is_content_var = matches!(inner_ref, Expr::Var(v) if v.id == "content");
365                        if is_content_var || self.is_any_scope_var_content(inner_ref) {
366                            self.flags.saw_structure = true;
367                        }
368                    }
369                } else if let Some(inner) = &f.expr {
370                    let inner_ref: &Expr = inner;
371                    self.inspect_expr_for_structure(inner_ref);
372                }
373            }
374            // Type tests like `content is iterable` or `message.content is string`
375            // These are used for branching (e.g., Llama 3.1 uses them for tool output formatting),
376            // not as indicators that the template expects structured content. Keep walking.
377            Expr::Test(t) => self.inspect_expr_for_structure(&t.expr),
378            Expr::GetAttr(g) => {
379                // Keep walking; nested expressions can hide structure checks
380                self.inspect_expr_for_structure(&g.expr);
381            }
382            // Handle binary operations like: if (message.content is string) and other_cond
383            Expr::BinOp(op) => {
384                self.inspect_expr_for_structure(&op.left);
385                self.inspect_expr_for_structure(&op.right);
386            }
387            // Handle unary operations like: if not (message.content is string)
388            Expr::UnaryOp(op) => {
389                self.inspect_expr_for_structure(&op.expr);
390            }
391            _ => {}
392        }
393    }
394
395    fn scan_macro_body(body: &[Stmt], has_type_check: &mut bool, has_loop: &mut bool) {
396        for s in body {
397            if *has_type_check && *has_loop {
398                return;
399            }
400
401            match s {
402                Stmt::IfCond(ic) => {
403                    if matches!(&ic.expr, Expr::Test(_)) {
404                        *has_type_check = true;
405                    }
406                    Self::scan_macro_body(&ic.true_body, has_type_check, has_loop);
407                    Self::scan_macro_body(&ic.false_body, has_type_check, has_loop);
408                }
409                Stmt::ForLoop(fl) => {
410                    *has_loop = true;
411                    Self::scan_macro_body(&fl.body, has_type_check, has_loop);
412                }
413                Stmt::Template(t) => {
414                    Self::scan_macro_body(&t.children, has_type_check, has_loop);
415                }
416                _ => {}
417            }
418        }
419    }
420}
421
422/// AST-based detection using minijinja's unstable machinery
423/// Single-pass detector with scope tracking
424fn detect_format_with_ast(template: &str) -> ChatTemplateContentFormat {
425    detect_all_with_ast(template).0
426}
427
428/// Single-pass detection of content format, think-in-prefill, and thinking toggle.
429fn detect_all(
430    template: &str,
431) -> (
432    ChatTemplateContentFormat,
433    bool,
434    ThinkingToggle,
435    Option<ThinkingKeyName>,
436) {
437    let (thinking_toggle, thinking_key_name) = detect_thinking_toggle(template);
438    let (content_format, think_in_prefill) = detect_all_with_ast(template);
439    (
440        content_format,
441        think_in_prefill,
442        thinking_toggle,
443        thinking_key_name,
444    )
445}
446
447/// AST detection of content format and think-in-prefill.
448fn detect_all_with_ast(template: &str) -> (ChatTemplateContentFormat, bool) {
449    let ast = match parse(
450        template,
451        "template",
452        SyntaxConfig {},
453        WhitespaceConfig::default(),
454    ) {
455        Ok(ast) => ast,
456        Err(_) => return (ChatTemplateContentFormat::String, false),
457    };
458
459    let (flags, think_in_prefill) = Detector::new(&ast).run();
460    let content_format = if flags.any() {
461        ChatTemplateContentFormat::OpenAI
462    } else {
463        ChatTemplateContentFormat::String
464    };
465    (content_format, think_in_prefill)
466}
467
468/// Parameters for chat template application
469#[derive(Default)]
470pub struct ChatTemplateParams<'a> {
471    pub add_generation_prompt: bool,
472    pub tools: Option<&'a [serde_json::Value]>,
473    pub documents: Option<&'a [serde_json::Value]>,
474    pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
475    /// Special tokens to inject into the template context.
476    /// Many templates reference `{{ bos_token }}`, `{{ eos_token }}`, etc.
477    pub special_tokens: Option<&'a crate::traits::SpecialTokens>,
478}
479
480/// Custom tojson filter compatible with HuggingFace transformers' implementation.
481///
482/// HuggingFace transformers registers a custom `tojson` filter that accepts additional
483/// keyword arguments beyond what standard Jinja2 provides:
484/// - `ensure_ascii` (bool): Whether to escape non-ASCII characters (ignored in Rust, always UTF-8)
485/// - `indent` (int): Number of spaces for indentation (pretty-printing)
486/// - `separators` (ignored): Custom separators for JSON output
487/// - `sort_keys` (bool): Whether to sort dictionary keys
488///
489/// This is necessary for compatibility with chat templates from HuggingFace Hub models.
490/// See: https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
491fn tojson_filter(value: Value, kwargs: Kwargs) -> std::result::Result<Value, MinijinjaError> {
492    let _ensure_ascii: Option<bool> = kwargs.get("ensure_ascii")?;
493    let indent: Option<i64> = kwargs.get("indent")?;
494    let _separators: Option<Value> = kwargs.get("separators")?;
495    let sort_keys: Option<bool> = kwargs.get("sort_keys")?;
496
497    // Ensure all kwargs are consumed to avoid "unknown keyword argument" errors
498    kwargs.assert_all_used()?;
499
500    let json_value: serde_json::Value = serde_json::to_value(&value).map_err(|e| {
501        MinijinjaError::new(
502            ErrorKind::InvalidOperation,
503            format!("Failed to convert to JSON value: {e}"),
504        )
505    })?;
506
507    // Helper to serialize with custom indentation
508    fn serialize_with_indent<T: Serialize>(
509        value: &T,
510        spaces: usize,
511    ) -> std::result::Result<String, MinijinjaError> {
512        let indent_str = vec![b' '; spaces];
513        let formatter = PrettyFormatter::with_indent(&indent_str);
514        let mut buf = Vec::new();
515        let mut serializer = serde_json::Serializer::with_formatter(&mut buf, formatter);
516        value.serialize(&mut serializer).map_err(|e| {
517            MinijinjaError::new(
518                ErrorKind::InvalidOperation,
519                format!("Failed to serialize JSON: {e}"),
520            )
521        })?;
522        String::from_utf8(buf).map_err(|e| {
523            MinijinjaError::new(
524                ErrorKind::InvalidOperation,
525                format!("Invalid UTF-8 in JSON output: {e}"),
526            )
527        })
528    }
529
530    // Serialize with options
531    let json_str: std::result::Result<String, MinijinjaError> = {
532        let sorted_json;
533        let value_to_serialize = if sort_keys.unwrap_or(false) {
534            sorted_json = sort_json_keys(&json_value);
535            &sorted_json
536        } else {
537            &json_value
538        };
539
540        if let Some(spaces) = indent {
541            if spaces < 0 {
542                return Err(MinijinjaError::new(
543                    ErrorKind::InvalidOperation,
544                    "indent cannot be negative",
545                ));
546            }
547            serialize_with_indent(value_to_serialize, spaces as usize)
548        } else {
549            serde_json::to_string(value_to_serialize).map_err(|e| {
550                MinijinjaError::new(
551                    ErrorKind::InvalidOperation,
552                    format!("Failed to serialize JSON: {e}"),
553                )
554            })
555        }
556    };
557
558    json_str.map(Value::from_safe_string)
559}
560
561/// Recursively sort all object keys in a JSON value
562fn sort_json_keys(value: &JsonValue) -> JsonValue {
563    match value {
564        JsonValue::Object(map) => {
565            let mut sorted: serde_json::Map<String, JsonValue> = serde_json::Map::new();
566            let mut keys: Vec<_> = map.keys().collect();
567            keys.sort();
568            for key in keys {
569                sorted.insert(key.clone(), sort_json_keys(&map[key]));
570            }
571            JsonValue::Object(sorted)
572        }
573        JsonValue::Array(arr) => JsonValue::Array(arr.iter().map(sort_json_keys).collect()),
574        _ => value.clone(),
575    }
576}
577
578/// Build a pre-configured `Environment<'static>` with the given template string,
579/// Python-compat method callback, and custom `tojson` filter already registered.
580/// The template is stored under the name `"chat"` using owned storage so the
581/// environment carries no borrows.
582fn build_environment(template: String) -> Result<Environment<'static>> {
583    let mut env = Environment::new();
584
585    // Match HuggingFace's Jinja2 defaults: trim_blocks and lstrip_blocks are
586    // enabled in Python's transformers but default to false in minijinja.
587    // Without these, templates like GLM-5's produce incorrect whitespace.
588    env.set_trim_blocks(true);
589    env.set_lstrip_blocks(true);
590
591    // Register the template with owned storage (no lifetime dependency on caller)
592    env.add_template_owned("chat".to_owned(), template)
593        .map_err(|e| anyhow!("Failed to add template: {e}"))?;
594
595    // Enable Python method compatibility (e.g., str.startswith, str.endswith)
596    env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
597
598    // Register custom tojson filter compatible with HuggingFace transformers
599    // This overrides minijinja's built-in tojson to support additional kwargs
600    // like ensure_ascii, separators, and sort_keys that HuggingFace templates use
601    env.add_filter("tojson", tojson_filter);
602
603    Ok(env)
604}
605
606/// Render the `"chat"` template in the given environment against messages and params.
607/// Convert an optional token string to a minijinja Value.
608/// Present tokens become strings; absent tokens become UNDEFINED
609/// so templates can use `{% if bos_token is defined %}` guards.
610fn special_token_value(token: Option<&str>) -> Value {
611    token.map_or(Value::UNDEFINED, Value::from)
612}
613
614fn render_chat_template(
615    env: &Environment<'_>,
616    messages: &[serde_json::Value],
617    params: ChatTemplateParams,
618) -> Result<String> {
619    let tmpl = env
620        .get_template("chat")
621        .map_err(|e| anyhow!("Failed to get template: {e}"))?;
622
623    // Convert messages to minijinja::Value (messages already processed by router)
624    let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
625
626    // Use Value::UNDEFINED for missing optional params so they are truly "undefined"
627    // in the template context, matching HuggingFace Python behavior. Many chat templates
628    // use `{% if tools is defined %}` guards — passing null (none) instead of undefined
629    // would bypass those guards since `none` IS defined, causing `tools | length` to fail.
630    let tools_value = params.tools.map_or(Value::UNDEFINED, Value::from_serialize);
631    let documents_value = params
632        .documents
633        .map_or(Value::UNDEFINED, Value::from_serialize);
634
635    // Inject special tokens (bos_token, eos_token, etc.) into context.
636    // Use UNDEFINED for missing tokens so `{% if bos_token is defined %}` works correctly.
637    // This matches HuggingFace Python which passes self.special_tokens_map to the renderer.
638    let bos_value =
639        special_token_value(params.special_tokens.and_then(|st| st.bos_token.as_deref()));
640    let eos_value =
641        special_token_value(params.special_tokens.and_then(|st| st.eos_token.as_deref()));
642    let unk_value =
643        special_token_value(params.special_tokens.and_then(|st| st.unk_token.as_deref()));
644    let pad_value =
645        special_token_value(params.special_tokens.and_then(|st| st.pad_token.as_deref()));
646
647    let base_context = context! {
648        messages => &minijinja_messages,
649        add_generation_prompt => params.add_generation_prompt,
650        tools => tools_value,
651        documents => documents_value,
652        bos_token => bos_value,
653        eos_token => eos_value,
654        unk_token => unk_value,
655        pad_token => pad_value,
656    };
657
658    // Merge with template_kwargs if provided (caller kwargs override special tokens)
659    let ctx = if let Some(kwargs) = params.template_kwargs {
660        context! {
661            ..base_context,
662            ..Value::from_serialize(kwargs)
663        }
664    } else {
665        base_context
666    };
667
668    // Render the template
669    let rendered = tmpl
670        .render(&ctx)
671        .map_err(|e| anyhow!("Failed to render template: {e}"))?;
672
673    Ok(rendered)
674}
675
676/// Chat template processor using Jinja2 - simple wrapper like HuggingFace
677pub struct ChatTemplateProcessor {
678    env: Environment<'static>,
679}
680
681impl ChatTemplateProcessor {
682    /// Create a new chat template processor.
683    ///
684    /// Returns an error if the template fails to parse, so callers get an
685    /// actionable message immediately rather than a confusing "template not
686    /// found" error on the first render.
687    pub fn new(template: String) -> Result<Self> {
688        let env = build_environment(template)?;
689        Ok(ChatTemplateProcessor { env })
690    }
691
692    /// Apply the chat template to a list of messages
693    ///
694    /// This mimics the behavior of HuggingFace's apply_chat_template method
695    /// but returns the formatted string instead of token IDs.
696    /// Messages should be pre-processed into the format expected by the template.
697    pub fn apply_chat_template(
698        &self,
699        messages: &[serde_json::Value],
700        params: ChatTemplateParams,
701    ) -> Result<String> {
702        render_chat_template(&self.env, messages, params)
703    }
704}
705
706/// Load chat template from tokenizer config JSON
707pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
708    let content = fs::read_to_string(config_path)?;
709    let config: serde_json::Value = serde_json::from_str(&content)?;
710
711    // Look for chat_template in the config
712    if let Some(template) = config.get("chat_template") {
713        if let Some(template_str) = template.as_str() {
714            return Ok(Some(template_str.to_string()));
715        }
716    }
717
718    Ok(None)
719}
720
721/// Load chat template from a file (.jinja or .json containing Jinja).
722/// Shared between all tokenizer backends.
723pub fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
724    let content = fs::read_to_string(template_path)
725        .map_err(|e| anyhow!("Failed to read chat template file: {e}"))?;
726
727    if template_path.ends_with(".json") {
728        let json_value: serde_json::Value = serde_json::from_str(&content)
729            .map_err(|e| anyhow!("Failed to parse chat_template.json: {e}"))?;
730
731        if let Some(template_str) = json_value.as_str() {
732            return Ok(Some(template_str.to_string()));
733        } else if let Some(obj) = json_value.as_object() {
734            if let Some(template_value) = obj.get("chat_template") {
735                if let Some(template_str) = template_value.as_str() {
736                    return Ok(Some(template_str.to_string()));
737                }
738            }
739        }
740
741        return Err(anyhow!(
742            "chat_template.json does not contain a valid template",
743        ));
744    }
745
746    // Plain .jinja file
747    let template = content.trim().replace("\\n", "\n");
748    Ok(Some(template))
749}
750
751/// Chat template state that can be embedded in any tokenizer struct.
752/// Eliminates duplicated apply/set/format methods across tokenizer backends.
753///
754/// The compiled `minijinja::Environment` (with the template parsed, filters
755/// registered, and Python-compat callback installed) is cached so that
756/// `apply()` only performs rendering -- no parsing or environment setup.
757/// The cache is rebuilt whenever `set()` is called.
758///
759/// `Environment<'static>` is both `Send` and `Sync`, so embedding this in
760/// tokenizer structs shared across threads is safe.
761pub struct ChatTemplateState {
762    /// Cached, fully-configured environment. `None` when no template is set.
763    env: Option<Environment<'static>>,
764    content_format: ChatTemplateContentFormat,
765    /// Thinking toggle support detected from the template.
766    thinking_toggle: ThinkingToggle,
767    /// The variable name used for the thinking toggle (if any).
768    thinking_key_name: Option<ThinkingKeyName>,
769    /// Whether the template injects `<think>` in the generation prompt.
770    think_in_prefill: bool,
771}
772
773impl std::fmt::Debug for ChatTemplateState {
774    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
775        f.debug_struct("ChatTemplateState")
776            .field("has_template", &self.env.is_some())
777            .field("content_format", &self.content_format)
778            .field("thinking_toggle", &self.thinking_toggle)
779            .field("think_in_prefill", &self.think_in_prefill)
780            .finish()
781    }
782}
783
784impl ChatTemplateState {
785    pub fn new(template: Option<String>) -> Result<Self> {
786        let (content_format, think_in_prefill, thinking_toggle, thinking_key_name) =
787            template.as_ref().map(|t| detect_all(t)).unwrap_or_default();
788        let env = template.map(build_environment).transpose()?;
789        Ok(Self {
790            env,
791            content_format,
792            thinking_toggle,
793            thinking_key_name,
794            think_in_prefill,
795        })
796    }
797
798    /// Create a `ChatTemplateState` with no template set.
799    ///
800    /// Unlike `new(None)`, this is infallible since there is no template to
801    /// parse — useful in constructors that don't return `Result`.
802    pub fn empty() -> Self {
803        Self {
804            env: None,
805            content_format: ChatTemplateContentFormat::default(),
806            thinking_toggle: ThinkingToggle::None,
807            thinking_key_name: None,
808            think_in_prefill: false,
809        }
810    }
811
812    pub fn apply(
813        &self,
814        messages: &[serde_json::Value],
815        params: ChatTemplateParams,
816    ) -> Result<String> {
817        let env = self.env.as_ref().ok_or_else(|| {
818            anyhow!(
819                "Cannot use chat template functions because tokenizer.chat_template is not set \
820                 and no template argument was passed! For information about writing templates and \
821                 setting the tokenizer.chat_template attribute, please see the documentation at \
822                 https://huggingface.co/docs/transformers/main/en/chat_templating",
823            )
824        })?;
825        render_chat_template(env, messages, params)
826    }
827
828    pub fn set(&mut self, template: String) -> Result<()> {
829        let (content_format, think_in_prefill, thinking_toggle, thinking_key_name) =
830            detect_all(&template);
831        let env = build_environment(template)?;
832        self.content_format = content_format;
833        self.thinking_toggle = thinking_toggle;
834        self.thinking_key_name = thinking_key_name;
835        self.think_in_prefill = think_in_prefill;
836        self.env = Some(env);
837        Ok(())
838    }
839
840    pub fn content_format(&self) -> ChatTemplateContentFormat {
841        self.content_format
842    }
843
844    pub fn thinking_toggle(&self) -> ThinkingToggle {
845        self.thinking_toggle
846    }
847
848    pub fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
849        self.thinking_key_name
850    }
851
852    pub fn think_in_prefill(&self) -> bool {
853        self.think_in_prefill
854    }
855}
856
857#[cfg(test)]
858mod tests {
859    use super::*;
860
861    #[test]
862    fn test_chat_template_state_no_template() {
863        let state = ChatTemplateState::new(None).unwrap();
864        assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
865        let result = state.apply(&[], ChatTemplateParams::default());
866        assert!(result.is_err());
867    }
868
869    #[test]
870    fn test_chat_template_state_set() {
871        let mut state = ChatTemplateState::new(None).unwrap();
872        state.set("{{ messages }}".to_string()).unwrap();
873        assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
874    }
875
876    #[test]
877    fn test_chat_template_state_invalid_template() {
878        let result = ChatTemplateState::new(Some("{% invalid".to_string()));
879        assert!(result.is_err());
880        let err = result.unwrap_err().to_string();
881        assert!(
882            err.contains("Failed to add template"),
883            "Error should explain parse failure, got: {err}"
884        );
885    }
886
887    #[test]
888    fn test_chat_template_processor_invalid_template() {
889        let result = ChatTemplateProcessor::new("{% invalid".to_string());
890        assert!(result.is_err());
891    }
892
893    #[test]
894    fn test_special_tokens_injected_into_context() {
895        let template = "{{ bos_token }}{% for message in messages %}{{ message.content }}{% endfor %}{{ eos_token }}";
896        let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
897
898        let messages = vec![serde_json::json!({"role": "user", "content": "hello"})];
899        let special_tokens = crate::traits::SpecialTokens {
900            bos_token: Some("<s>".to_string()),
901            eos_token: Some("</s>".to_string()),
902            ..Default::default()
903        };
904
905        let result = state
906            .apply(
907                &messages,
908                ChatTemplateParams {
909                    special_tokens: Some(&special_tokens),
910                    ..Default::default()
911                },
912            )
913            .unwrap();
914
915        assert_eq!(result, "<s>hello</s>");
916    }
917
918    #[test]
919    fn test_special_tokens_undefined_when_not_provided() {
920        let template = "{% if bos_token is defined %}{{ bos_token }}{% endif %}hello";
921        let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
922
923        let result = state.apply(&[], ChatTemplateParams::default()).unwrap();
924        assert_eq!(result, "hello");
925    }
926
927    #[test]
928    fn test_special_tokens_partial() {
929        let template =
930            "{{ bos_token }}hello{% if eos_token is defined %}{{ eos_token }}{% endif %}";
931        let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
932
933        let special_tokens = crate::traits::SpecialTokens {
934            bos_token: Some("<s>".to_string()),
935            eos_token: None,
936            ..Default::default()
937        };
938
939        let result = state
940            .apply(
941                &[],
942                ChatTemplateParams {
943                    special_tokens: Some(&special_tokens),
944                    ..Default::default()
945                },
946            )
947            .unwrap();
948
949        assert_eq!(result, "<s>hello");
950    }
951}