Skip to main content

llm_tokenizer/
chat_template.rs

1//! Chat template support for tokenizers using Jinja2 templates
2//!
3//! This module provides functionality to apply chat templates to messages,
4//! similar to HuggingFace transformers' apply_chat_template method.
5
6use std::{collections::HashMap, fs};
7
8use anyhow::{anyhow, Result};
9use minijinja::{
10    context,
11    machinery::{
12        ast::{Expr, Stmt},
13        parse, WhitespaceConfig,
14    },
15    syntax::SyntaxConfig,
16    value::Kwargs,
17    Environment, Error as MinijinjaError, ErrorKind, Value,
18};
19use serde::Serialize;
20use serde_json::{self, ser::PrettyFormatter, Value as JsonValue};
21
22/// Chat template content format
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
24pub enum ChatTemplateContentFormat {
25    /// Content is a simple string
26    #[default]
27    String,
28    /// Content is a list of structured parts (OpenAI format)
29    OpenAI,
30}
31
32impl std::fmt::Display for ChatTemplateContentFormat {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            Self::String => write!(f, "string"),
36            Self::OpenAI => write!(f, "openai"),
37        }
38    }
39}
40
41/// Detect the content format expected by a Jinja2 chat template
42///
43/// This implements the same detection logic as SGLang's detect_jinja_template_content_format
44/// which uses AST parsing to look for content iteration patterns.
45///
46/// Returns:
47/// - ChatTemplateContentFormat::OpenAI if template expects structured content (list of parts)
48/// - ChatTemplateContentFormat::String if template expects simple string content
49pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat {
50    // Use AST-based detection (enabled by default)
51    detect_format_with_ast(template)
52}
53
54/// Flags tracking which OpenAI-style patterns we've seen
55#[derive(Default, Debug, Clone, Copy)]
56struct Flags {
57    saw_iteration: bool,
58    saw_structure: bool,
59    saw_assignment: bool,
60    saw_macro: bool,
61}
62
63impl Flags {
64    fn any(self) -> bool {
65        // `saw_assignment` alone (e.g. `set content = message.content`) is NOT sufficient
66        // to classify as OpenAI format. Many string-format templates (Qwen3, etc.) use this
67        // pattern to extract content into a local variable, then check `content is string`.
68        // Without iteration or structural access, the template handles string content only.
69        self.saw_iteration || self.saw_structure || self.saw_macro
70    }
71}
72
73/// Single-pass AST detector with scope tracking
74struct Detector<'a> {
75    ast: &'a Stmt<'a>,
76    /// Message loop vars currently in scope (e.g., `message`, `m`, `msg`)
77    scope: std::collections::VecDeque<String>,
78    scope_set: std::collections::HashSet<String>,
79    flags: Flags,
80}
81
82impl<'a> Detector<'a> {
83    fn new(ast: &'a Stmt<'a>) -> Self {
84        Self {
85            ast,
86            scope: std::collections::VecDeque::new(),
87            scope_set: std::collections::HashSet::new(),
88            flags: Flags::default(),
89        }
90    }
91
92    fn run(mut self) -> Flags {
93        self.walk_stmt(self.ast);
94        self.flags
95    }
96
97    fn push_scope(&mut self, var: String) {
98        self.scope.push_back(var.clone());
99        self.scope_set.insert(var);
100    }
101
102    fn pop_scope(&mut self) {
103        if let Some(v) = self.scope.pop_back() {
104            self.scope_set.remove(&v);
105        }
106    }
107
108    fn is_var_access(expr: &Expr, varname: &str) -> bool {
109        matches!(expr, Expr::Var(v) if v.id == varname)
110    }
111
112    fn is_const_str(expr: &Expr, value: &str) -> bool {
113        matches!(expr, Expr::Const(c) if c.value.as_str() == Some(value))
114    }
115
116    fn is_numeric_const(expr: &Expr) -> bool {
117        matches!(expr, Expr::Const(c) if c.value.is_number())
118    }
119
120    /// Check if expr is varname.content or varname["content"]
121    fn is_var_dot_content(expr: &Expr, varname: &str) -> bool {
122        match expr {
123            Expr::GetAttr(g) => Self::is_var_access(&g.expr, varname) && g.name == "content",
124            Expr::GetItem(g) => {
125                Self::is_var_access(&g.expr, varname)
126                    && Self::is_const_str(&g.subscript_expr, "content")
127            }
128            // Unwrap filters/tests that just wrap the same expr
129            Expr::Filter(f) => f
130                .expr
131                .as_ref()
132                .is_some_and(|e| Self::is_var_dot_content(e, varname)),
133            Expr::Test(t) => Self::is_var_dot_content(&t.expr, varname),
134            _ => false,
135        }
136    }
137
138    /// Check if expr accesses .content on any variable in our scope, or any descendant of it.
139    fn is_any_scope_var_content(&self, expr: &Expr) -> bool {
140        let mut current_expr = expr;
141        loop {
142            // Check if current level matches <scopeVar>.content
143            if self
144                .scope_set
145                .iter()
146                .any(|v| Self::is_var_dot_content(current_expr, v))
147            {
148                return true;
149            }
150            // Walk up the expression tree
151            match current_expr {
152                Expr::GetAttr(g) => current_expr = &g.expr,
153                Expr::GetItem(g) => current_expr = &g.expr,
154                _ => return false,
155            }
156        }
157    }
158
159    fn walk_stmt(&mut self, stmt: &Stmt) {
160        // Early exit if we've already detected an OpenAI pattern
161        if self.flags.any() {
162            return;
163        }
164
165        match stmt {
166            Stmt::Template(t) => {
167                for ch in &t.children {
168                    self.walk_stmt(ch);
169                }
170            }
171            // {% for message in messages %}
172            Stmt::ForLoop(fl) => {
173                // Detect "for X in messages" → push X into scope
174                if let Expr::Var(iter) = &fl.iter {
175                    if iter.id == "messages" {
176                        if let Expr::Var(target) = &fl.target {
177                            self.push_scope(target.id.to_string());
178                        }
179                    }
180                }
181
182                // Also detect "for ... in message.content" or "for ... in content"
183                // - Iterating directly over <scopeVar>.content => OpenAI style
184                if self.is_any_scope_var_content(&fl.iter) {
185                    self.flags.saw_iteration = true;
186                }
187                // - Iterating over a local var named "content"
188                if matches!(&fl.iter, Expr::Var(v) if v.id == "content") {
189                    self.flags.saw_iteration = true;
190                }
191
192                for b in &fl.body {
193                    self.walk_stmt(b);
194                }
195
196                // Pop scope if we pushed it
197                if let Expr::Var(iter) = &fl.iter {
198                    if iter.id == "messages" && matches!(&fl.target, Expr::Var(_)) {
199                        self.pop_scope();
200                    }
201                }
202            }
203            Stmt::IfCond(ic) => {
204                self.inspect_expr_for_structure(&ic.expr);
205                for b in &ic.true_body {
206                    self.walk_stmt(b);
207                }
208                for b in &ic.false_body {
209                    self.walk_stmt(b);
210                }
211            }
212            Stmt::EmitExpr(e) => {
213                self.inspect_expr_for_structure(&e.expr);
214            }
215            // {% set content = message.content %}
216            Stmt::Set(s) => {
217                if Self::is_var_access(&s.target, "content")
218                    && self.is_any_scope_var_content(&s.expr)
219                {
220                    self.flags.saw_assignment = true;
221                }
222            }
223            Stmt::Macro(m) => {
224                // Heuristic: macro that checks type (via `is` test) and also has any loop
225                let mut has_type_check = false;
226                let mut has_loop = false;
227                Self::scan_macro_body(&m.body, &mut has_type_check, &mut has_loop);
228                if has_type_check && has_loop {
229                    self.flags.saw_macro = true;
230                }
231            }
232            _ => {}
233        }
234    }
235
236    fn inspect_expr_for_structure(&mut self, expr: &Expr) {
237        if self.flags.saw_structure {
238            return;
239        }
240
241        match expr {
242            // content[0] or message.content[0]
243            Expr::GetItem(gi) => {
244                if (matches!(&gi.expr, Expr::Var(v) if v.id == "content")
245                    || self.is_any_scope_var_content(&gi.expr))
246                    && Self::is_numeric_const(&gi.subscript_expr)
247                {
248                    self.flags.saw_structure = true;
249                }
250            }
251            // content|length or message.content|length
252            Expr::Filter(f) => {
253                if f.name == "length" {
254                    if let Some(inner) = &f.expr {
255                        // Box derefs automatically, so `&**inner` is `&Expr`
256                        let inner_ref: &Expr = inner;
257                        let is_content_var = matches!(inner_ref, Expr::Var(v) if v.id == "content");
258                        if is_content_var || self.is_any_scope_var_content(inner_ref) {
259                            self.flags.saw_structure = true;
260                        }
261                    }
262                } else if let Some(inner) = &f.expr {
263                    let inner_ref: &Expr = inner;
264                    self.inspect_expr_for_structure(inner_ref);
265                }
266            }
267            // Type tests like `content is iterable` or `message.content is string`
268            // These are used for branching (e.g., Llama 3.1 uses them for tool output formatting),
269            // not as indicators that the template expects structured content. Keep walking.
270            Expr::Test(t) => self.inspect_expr_for_structure(&t.expr),
271            Expr::GetAttr(g) => {
272                // Keep walking; nested expressions can hide structure checks
273                self.inspect_expr_for_structure(&g.expr);
274            }
275            // Handle binary operations like: if (message.content is string) and other_cond
276            Expr::BinOp(op) => {
277                self.inspect_expr_for_structure(&op.left);
278                self.inspect_expr_for_structure(&op.right);
279            }
280            // Handle unary operations like: if not (message.content is string)
281            Expr::UnaryOp(op) => {
282                self.inspect_expr_for_structure(&op.expr);
283            }
284            _ => {}
285        }
286    }
287
288    fn scan_macro_body(body: &[Stmt], has_type_check: &mut bool, has_loop: &mut bool) {
289        for s in body {
290            if *has_type_check && *has_loop {
291                return;
292            }
293
294            match s {
295                Stmt::IfCond(ic) => {
296                    if matches!(&ic.expr, Expr::Test(_)) {
297                        *has_type_check = true;
298                    }
299                    Self::scan_macro_body(&ic.true_body, has_type_check, has_loop);
300                    Self::scan_macro_body(&ic.false_body, has_type_check, has_loop);
301                }
302                Stmt::ForLoop(fl) => {
303                    *has_loop = true;
304                    Self::scan_macro_body(&fl.body, has_type_check, has_loop);
305                }
306                Stmt::Template(t) => {
307                    Self::scan_macro_body(&t.children, has_type_check, has_loop);
308                }
309                _ => {}
310            }
311        }
312    }
313}
314
315/// AST-based detection using minijinja's unstable machinery
316/// Single-pass detector with scope tracking
317fn detect_format_with_ast(template: &str) -> ChatTemplateContentFormat {
318    let ast = match parse(
319        template,
320        "template",
321        SyntaxConfig {},
322        WhitespaceConfig::default(),
323    ) {
324        Ok(ast) => ast,
325        Err(_) => return ChatTemplateContentFormat::String,
326    };
327
328    let flags = Detector::new(&ast).run();
329    if flags.any() {
330        ChatTemplateContentFormat::OpenAI
331    } else {
332        ChatTemplateContentFormat::String
333    }
334}
335
336/// Parameters for chat template application
337#[derive(Default)]
338pub struct ChatTemplateParams<'a> {
339    pub add_generation_prompt: bool,
340    pub tools: Option<&'a [serde_json::Value]>,
341    pub documents: Option<&'a [serde_json::Value]>,
342    pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
343    /// Special tokens to inject into the template context.
344    /// Many templates reference `{{ bos_token }}`, `{{ eos_token }}`, etc.
345    pub special_tokens: Option<&'a crate::traits::SpecialTokens>,
346}
347
348/// Custom tojson filter compatible with HuggingFace transformers' implementation.
349///
350/// HuggingFace transformers registers a custom `tojson` filter that accepts additional
351/// keyword arguments beyond what standard Jinja2 provides:
352/// - `ensure_ascii` (bool): Whether to escape non-ASCII characters (ignored in Rust, always UTF-8)
353/// - `indent` (int): Number of spaces for indentation (pretty-printing)
354/// - `separators` (ignored): Custom separators for JSON output
355/// - `sort_keys` (bool): Whether to sort dictionary keys
356///
357/// This is necessary for compatibility with chat templates from HuggingFace Hub models.
358/// See: https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
359fn tojson_filter(value: Value, kwargs: Kwargs) -> std::result::Result<Value, MinijinjaError> {
360    let _ensure_ascii: Option<bool> = kwargs.get("ensure_ascii")?;
361    let indent: Option<i64> = kwargs.get("indent")?;
362    let _separators: Option<Value> = kwargs.get("separators")?;
363    let sort_keys: Option<bool> = kwargs.get("sort_keys")?;
364
365    // Ensure all kwargs are consumed to avoid "unknown keyword argument" errors
366    kwargs.assert_all_used()?;
367
368    let json_value: serde_json::Value = serde_json::to_value(&value).map_err(|e| {
369        MinijinjaError::new(
370            ErrorKind::InvalidOperation,
371            format!("Failed to convert to JSON value: {e}"),
372        )
373    })?;
374
375    // Helper to serialize with custom indentation
376    fn serialize_with_indent<T: Serialize>(
377        value: &T,
378        spaces: usize,
379    ) -> std::result::Result<String, MinijinjaError> {
380        let indent_str = vec![b' '; spaces];
381        let formatter = PrettyFormatter::with_indent(&indent_str);
382        let mut buf = Vec::new();
383        let mut serializer = serde_json::Serializer::with_formatter(&mut buf, formatter);
384        value.serialize(&mut serializer).map_err(|e| {
385            MinijinjaError::new(
386                ErrorKind::InvalidOperation,
387                format!("Failed to serialize JSON: {e}"),
388            )
389        })?;
390        String::from_utf8(buf).map_err(|e| {
391            MinijinjaError::new(
392                ErrorKind::InvalidOperation,
393                format!("Invalid UTF-8 in JSON output: {e}"),
394            )
395        })
396    }
397
398    // Serialize with options
399    let json_str: std::result::Result<String, MinijinjaError> = {
400        let sorted_json;
401        let value_to_serialize = if sort_keys.unwrap_or(false) {
402            sorted_json = sort_json_keys(&json_value);
403            &sorted_json
404        } else {
405            &json_value
406        };
407
408        if let Some(spaces) = indent {
409            if spaces < 0 {
410                return Err(MinijinjaError::new(
411                    ErrorKind::InvalidOperation,
412                    "indent cannot be negative",
413                ));
414            }
415            serialize_with_indent(value_to_serialize, spaces as usize)
416        } else {
417            serde_json::to_string(value_to_serialize).map_err(|e| {
418                MinijinjaError::new(
419                    ErrorKind::InvalidOperation,
420                    format!("Failed to serialize JSON: {e}"),
421                )
422            })
423        }
424    };
425
426    json_str.map(Value::from_safe_string)
427}
428
429/// Recursively sort all object keys in a JSON value
430fn sort_json_keys(value: &JsonValue) -> JsonValue {
431    match value {
432        JsonValue::Object(map) => {
433            let mut sorted: serde_json::Map<String, JsonValue> = serde_json::Map::new();
434            let mut keys: Vec<_> = map.keys().collect();
435            keys.sort();
436            for key in keys {
437                sorted.insert(key.clone(), sort_json_keys(&map[key]));
438            }
439            JsonValue::Object(sorted)
440        }
441        JsonValue::Array(arr) => JsonValue::Array(arr.iter().map(sort_json_keys).collect()),
442        _ => value.clone(),
443    }
444}
445
446/// Build a pre-configured `Environment<'static>` with the given template string,
447/// Python-compat method callback, and custom `tojson` filter already registered.
448/// The template is stored under the name `"chat"` using owned storage so the
449/// environment carries no borrows.
450fn build_environment(template: String) -> Result<Environment<'static>> {
451    let mut env = Environment::new();
452
453    // Match HuggingFace's Jinja2 defaults: trim_blocks and lstrip_blocks are
454    // enabled in Python's transformers but default to false in minijinja.
455    // Without these, templates like GLM-5's produce incorrect whitespace.
456    env.set_trim_blocks(true);
457    env.set_lstrip_blocks(true);
458
459    // Register the template with owned storage (no lifetime dependency on caller)
460    env.add_template_owned("chat".to_owned(), template)
461        .map_err(|e| anyhow!("Failed to add template: {e}"))?;
462
463    // Enable Python method compatibility (e.g., str.startswith, str.endswith)
464    env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
465
466    // Register custom tojson filter compatible with HuggingFace transformers
467    // This overrides minijinja's built-in tojson to support additional kwargs
468    // like ensure_ascii, separators, and sort_keys that HuggingFace templates use
469    env.add_filter("tojson", tojson_filter);
470
471    Ok(env)
472}
473
474/// Render the `"chat"` template in the given environment against messages and params.
475/// Convert an optional token string to a minijinja Value.
476/// Present tokens become strings; absent tokens become UNDEFINED
477/// so templates can use `{% if bos_token is defined %}` guards.
478fn special_token_value(token: Option<&str>) -> Value {
479    token.map_or(Value::UNDEFINED, Value::from)
480}
481
482fn render_chat_template(
483    env: &Environment<'_>,
484    messages: &[serde_json::Value],
485    params: ChatTemplateParams,
486) -> Result<String> {
487    let tmpl = env
488        .get_template("chat")
489        .map_err(|e| anyhow!("Failed to get template: {e}"))?;
490
491    // Convert messages to minijinja::Value (messages already processed by router)
492    let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
493
494    // Use Value::UNDEFINED for missing optional params so they are truly "undefined"
495    // in the template context, matching HuggingFace Python behavior. Many chat templates
496    // use `{% if tools is defined %}` guards — passing null (none) instead of undefined
497    // would bypass those guards since `none` IS defined, causing `tools | length` to fail.
498    let tools_value = params.tools.map_or(Value::UNDEFINED, Value::from_serialize);
499    let documents_value = params
500        .documents
501        .map_or(Value::UNDEFINED, Value::from_serialize);
502
503    // Inject special tokens (bos_token, eos_token, etc.) into context.
504    // Use UNDEFINED for missing tokens so `{% if bos_token is defined %}` works correctly.
505    // This matches HuggingFace Python which passes self.special_tokens_map to the renderer.
506    let bos_value =
507        special_token_value(params.special_tokens.and_then(|st| st.bos_token.as_deref()));
508    let eos_value =
509        special_token_value(params.special_tokens.and_then(|st| st.eos_token.as_deref()));
510    let unk_value =
511        special_token_value(params.special_tokens.and_then(|st| st.unk_token.as_deref()));
512    let pad_value =
513        special_token_value(params.special_tokens.and_then(|st| st.pad_token.as_deref()));
514
515    let base_context = context! {
516        messages => &minijinja_messages,
517        add_generation_prompt => params.add_generation_prompt,
518        tools => tools_value,
519        documents => documents_value,
520        bos_token => bos_value,
521        eos_token => eos_value,
522        unk_token => unk_value,
523        pad_token => pad_value,
524    };
525
526    // Merge with template_kwargs if provided (caller kwargs override special tokens)
527    let ctx = if let Some(kwargs) = params.template_kwargs {
528        context! {
529            ..base_context,
530            ..Value::from_serialize(kwargs)
531        }
532    } else {
533        base_context
534    };
535
536    // Render the template
537    let rendered = tmpl
538        .render(&ctx)
539        .map_err(|e| anyhow!("Failed to render template: {e}"))?;
540
541    Ok(rendered)
542}
543
544/// Chat template processor using Jinja2 - simple wrapper like HuggingFace
545pub struct ChatTemplateProcessor {
546    env: Environment<'static>,
547}
548
549impl ChatTemplateProcessor {
550    /// Create a new chat template processor.
551    ///
552    /// Returns an error if the template fails to parse, so callers get an
553    /// actionable message immediately rather than a confusing "template not
554    /// found" error on the first render.
555    pub fn new(template: String) -> Result<Self> {
556        let env = build_environment(template)?;
557        Ok(ChatTemplateProcessor { env })
558    }
559
560    /// Apply the chat template to a list of messages
561    ///
562    /// This mimics the behavior of HuggingFace's apply_chat_template method
563    /// but returns the formatted string instead of token IDs.
564    /// Messages should be pre-processed into the format expected by the template.
565    pub fn apply_chat_template(
566        &self,
567        messages: &[serde_json::Value],
568        params: ChatTemplateParams,
569    ) -> Result<String> {
570        render_chat_template(&self.env, messages, params)
571    }
572}
573
574/// Load chat template from tokenizer config JSON
575pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
576    let content = fs::read_to_string(config_path)?;
577    let config: serde_json::Value = serde_json::from_str(&content)?;
578
579    // Look for chat_template in the config
580    if let Some(template) = config.get("chat_template") {
581        if let Some(template_str) = template.as_str() {
582            return Ok(Some(template_str.to_string()));
583        }
584    }
585
586    Ok(None)
587}
588
589/// Load chat template from a file (.jinja or .json containing Jinja).
590/// Shared between all tokenizer backends.
591pub fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
592    let content = fs::read_to_string(template_path)
593        .map_err(|e| anyhow!("Failed to read chat template file: {e}"))?;
594
595    if template_path.ends_with(".json") {
596        let json_value: serde_json::Value = serde_json::from_str(&content)
597            .map_err(|e| anyhow!("Failed to parse chat_template.json: {e}"))?;
598
599        if let Some(template_str) = json_value.as_str() {
600            return Ok(Some(template_str.to_string()));
601        } else if let Some(obj) = json_value.as_object() {
602            if let Some(template_value) = obj.get("chat_template") {
603                if let Some(template_str) = template_value.as_str() {
604                    return Ok(Some(template_str.to_string()));
605                }
606            }
607        }
608
609        return Err(anyhow!(
610            "chat_template.json does not contain a valid template",
611        ));
612    }
613
614    // Plain .jinja file
615    let template = content.trim().replace("\\n", "\n");
616    Ok(Some(template))
617}
618
619/// Chat template state that can be embedded in any tokenizer struct.
620/// Eliminates duplicated apply/set/format methods across tokenizer backends.
621///
622/// The compiled `minijinja::Environment` (with the template parsed, filters
623/// registered, and Python-compat callback installed) is cached so that
624/// `apply()` only performs rendering -- no parsing or environment setup.
625/// The cache is rebuilt whenever `set()` is called.
626///
627/// `Environment<'static>` is both `Send` and `Sync`, so embedding this in
628/// tokenizer structs shared across threads is safe.
629pub struct ChatTemplateState {
630    /// Cached, fully-configured environment. `None` when no template is set.
631    env: Option<Environment<'static>>,
632    content_format: ChatTemplateContentFormat,
633}
634
635impl std::fmt::Debug for ChatTemplateState {
636    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
637        f.debug_struct("ChatTemplateState")
638            .field("has_template", &self.env.is_some())
639            .field("content_format", &self.content_format)
640            .finish()
641    }
642}
643
644impl ChatTemplateState {
645    pub fn new(template: Option<String>) -> Result<Self> {
646        let content_format = template
647            .as_ref()
648            .map(|t| detect_chat_template_content_format(t))
649            .unwrap_or_default();
650        let env = template.map(build_environment).transpose()?;
651        Ok(Self {
652            env,
653            content_format,
654        })
655    }
656
657    /// Create a `ChatTemplateState` with no template set.
658    ///
659    /// Unlike `new(None)`, this is infallible since there is no template to
660    /// parse — useful in constructors that don't return `Result`.
661    pub fn empty() -> Self {
662        Self {
663            env: None,
664            content_format: ChatTemplateContentFormat::default(),
665        }
666    }
667
668    pub fn apply(
669        &self,
670        messages: &[serde_json::Value],
671        params: ChatTemplateParams,
672    ) -> Result<String> {
673        let env = self.env.as_ref().ok_or_else(|| {
674            anyhow!(
675                "Cannot use chat template functions because tokenizer.chat_template is not set \
676                 and no template argument was passed! For information about writing templates and \
677                 setting the tokenizer.chat_template attribute, please see the documentation at \
678                 https://huggingface.co/docs/transformers/main/en/chat_templating",
679            )
680        })?;
681        render_chat_template(env, messages, params)
682    }
683
684    pub fn set(&mut self, template: String) -> Result<()> {
685        let content_format = detect_chat_template_content_format(&template);
686        let env = build_environment(template)?;
687        self.content_format = content_format;
688        self.env = Some(env);
689        Ok(())
690    }
691
692    pub fn content_format(&self) -> ChatTemplateContentFormat {
693        self.content_format
694    }
695}
696
697#[cfg(test)]
698mod tests {
699    use super::*;
700
701    #[test]
702    fn test_chat_template_state_no_template() {
703        let state = ChatTemplateState::new(None).unwrap();
704        assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
705        let result = state.apply(&[], ChatTemplateParams::default());
706        assert!(result.is_err());
707    }
708
709    #[test]
710    fn test_chat_template_state_set() {
711        let mut state = ChatTemplateState::new(None).unwrap();
712        state.set("{{ messages }}".to_string()).unwrap();
713        assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
714    }
715
716    #[test]
717    fn test_chat_template_state_invalid_template() {
718        let result = ChatTemplateState::new(Some("{% invalid".to_string()));
719        assert!(result.is_err());
720        let err = result.unwrap_err().to_string();
721        assert!(
722            err.contains("Failed to add template"),
723            "Error should explain parse failure, got: {err}"
724        );
725    }
726
727    #[test]
728    fn test_chat_template_processor_invalid_template() {
729        let result = ChatTemplateProcessor::new("{% invalid".to_string());
730        assert!(result.is_err());
731    }
732
733    #[test]
734    fn test_special_tokens_injected_into_context() {
735        let template = "{{ bos_token }}{% for message in messages %}{{ message.content }}{% endfor %}{{ eos_token }}";
736        let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
737
738        let messages = vec![serde_json::json!({"role": "user", "content": "hello"})];
739        let special_tokens = crate::traits::SpecialTokens {
740            bos_token: Some("<s>".to_string()),
741            eos_token: Some("</s>".to_string()),
742            ..Default::default()
743        };
744
745        let result = state
746            .apply(
747                &messages,
748                ChatTemplateParams {
749                    special_tokens: Some(&special_tokens),
750                    ..Default::default()
751                },
752            )
753            .unwrap();
754
755        assert_eq!(result, "<s>hello</s>");
756    }
757
758    #[test]
759    fn test_special_tokens_undefined_when_not_provided() {
760        let template = "{% if bos_token is defined %}{{ bos_token }}{% endif %}hello";
761        let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
762
763        let result = state.apply(&[], ChatTemplateParams::default()).unwrap();
764        assert_eq!(result, "hello");
765    }
766
767    #[test]
768    fn test_special_tokens_partial() {
769        let template =
770            "{{ bos_token }}hello{% if eos_token is defined %}{{ eos_token }}{% endif %}";
771        let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
772
773        let special_tokens = crate::traits::SpecialTokens {
774            bos_token: Some("<s>".to_string()),
775            eos_token: None,
776            ..Default::default()
777        };
778
779        let result = state
780            .apply(
781                &[],
782                ChatTemplateParams {
783                    special_tokens: Some(&special_tokens),
784                    ..Default::default()
785                },
786            )
787            .unwrap();
788
789        assert_eq!(result, "<s>hello");
790    }
791}