Skip to main content

llm_tokenizer/
chat_template.rs

1//! Chat template support for tokenizers using Jinja2 templates
2//!
3//! This module provides functionality to apply chat templates to messages,
4//! similar to HuggingFace transformers' apply_chat_template method.
5
6use std::{collections::HashMap, fs};
7
8use anyhow::{anyhow, Result};
9use minijinja::{
10    context,
11    machinery::{
12        ast::{Expr, Stmt},
13        parse, WhitespaceConfig,
14    },
15    syntax::SyntaxConfig,
16    value::Kwargs,
17    Environment, Error as MinijinjaError, ErrorKind, Value,
18};
19use serde::Serialize;
20use serde_json::{self, ser::PrettyFormatter, Value as JsonValue};
21
22/// Chat template content format
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
24pub enum ChatTemplateContentFormat {
25    /// Content is a simple string
26    #[default]
27    String,
28    /// Content is a list of structured parts (OpenAI format)
29    OpenAI,
30}
31
32impl std::fmt::Display for ChatTemplateContentFormat {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            Self::String => write!(f, "string"),
36            Self::OpenAI => write!(f, "openai"),
37        }
38    }
39}
40
41/// Detect the content format expected by a Jinja2 chat template
42///
43/// This implements the same detection logic as SGLang's detect_jinja_template_content_format
44/// which uses AST parsing to look for content iteration patterns.
45///
46/// Returns:
47/// - ChatTemplateContentFormat::OpenAI if template expects structured content (list of parts)
48/// - ChatTemplateContentFormat::String if template expects simple string content
49pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat {
50    // Use AST-based detection (enabled by default)
51    detect_format_with_ast(template)
52}
53
54/// Flags tracking which OpenAI-style patterns we've seen
55#[derive(Default, Debug, Clone, Copy)]
56struct Flags {
57    saw_iteration: bool,
58    saw_structure: bool,
59    saw_assignment: bool,
60    saw_macro: bool,
61}
62
63impl Flags {
64    fn any(self) -> bool {
65        self.saw_iteration || self.saw_structure || self.saw_assignment || self.saw_macro
66    }
67}
68
69/// Single-pass AST detector with scope tracking
70struct Detector<'a> {
71    ast: &'a Stmt<'a>,
72    /// Message loop vars currently in scope (e.g., `message`, `m`, `msg`)
73    scope: std::collections::VecDeque<String>,
74    scope_set: std::collections::HashSet<String>,
75    flags: Flags,
76}
77
78impl<'a> Detector<'a> {
79    fn new(ast: &'a Stmt<'a>) -> Self {
80        Self {
81            ast,
82            scope: std::collections::VecDeque::new(),
83            scope_set: std::collections::HashSet::new(),
84            flags: Flags::default(),
85        }
86    }
87
88    fn run(mut self) -> Flags {
89        self.walk_stmt(self.ast);
90        self.flags
91    }
92
93    fn push_scope(&mut self, var: String) {
94        self.scope.push_back(var.clone());
95        self.scope_set.insert(var);
96    }
97
98    fn pop_scope(&mut self) {
99        if let Some(v) = self.scope.pop_back() {
100            self.scope_set.remove(&v);
101        }
102    }
103
104    fn is_var_access(expr: &Expr, varname: &str) -> bool {
105        matches!(expr, Expr::Var(v) if v.id == varname)
106    }
107
108    fn is_const_str(expr: &Expr, value: &str) -> bool {
109        matches!(expr, Expr::Const(c) if c.value.as_str() == Some(value))
110    }
111
112    fn is_numeric_const(expr: &Expr) -> bool {
113        matches!(expr, Expr::Const(c) if c.value.is_number())
114    }
115
116    /// Check if expr is varname.content or varname["content"]
117    fn is_var_dot_content(expr: &Expr, varname: &str) -> bool {
118        match expr {
119            Expr::GetAttr(g) => Self::is_var_access(&g.expr, varname) && g.name == "content",
120            Expr::GetItem(g) => {
121                Self::is_var_access(&g.expr, varname)
122                    && Self::is_const_str(&g.subscript_expr, "content")
123            }
124            // Unwrap filters/tests that just wrap the same expr
125            Expr::Filter(f) => f
126                .expr
127                .as_ref()
128                .is_some_and(|e| Self::is_var_dot_content(e, varname)),
129            Expr::Test(t) => Self::is_var_dot_content(&t.expr, varname),
130            _ => false,
131        }
132    }
133
134    /// Check if expr accesses .content on any variable in our scope, or any descendant of it.
135    fn is_any_scope_var_content(&self, expr: &Expr) -> bool {
136        let mut current_expr = expr;
137        loop {
138            // Check if current level matches <scopeVar>.content
139            if self
140                .scope_set
141                .iter()
142                .any(|v| Self::is_var_dot_content(current_expr, v))
143            {
144                return true;
145            }
146            // Walk up the expression tree
147            match current_expr {
148                Expr::GetAttr(g) => current_expr = &g.expr,
149                Expr::GetItem(g) => current_expr = &g.expr,
150                _ => return false,
151            }
152        }
153    }
154
155    fn walk_stmt(&mut self, stmt: &Stmt) {
156        // Early exit if we've already detected an OpenAI pattern
157        if self.flags.any() {
158            return;
159        }
160
161        match stmt {
162            Stmt::Template(t) => {
163                for ch in &t.children {
164                    self.walk_stmt(ch);
165                }
166            }
167            // {% for message in messages %}
168            Stmt::ForLoop(fl) => {
169                // Detect "for X in messages" → push X into scope
170                if let Expr::Var(iter) = &fl.iter {
171                    if iter.id == "messages" {
172                        if let Expr::Var(target) = &fl.target {
173                            self.push_scope(target.id.to_string());
174                        }
175                    }
176                }
177
178                // Also detect "for ... in message.content" or "for ... in content"
179                // - Iterating directly over <scopeVar>.content => OpenAI style
180                if self.is_any_scope_var_content(&fl.iter) {
181                    self.flags.saw_iteration = true;
182                }
183                // - Iterating over a local var named "content"
184                if matches!(&fl.iter, Expr::Var(v) if v.id == "content") {
185                    self.flags.saw_iteration = true;
186                }
187
188                for b in &fl.body {
189                    self.walk_stmt(b);
190                }
191
192                // Pop scope if we pushed it
193                if let Expr::Var(iter) = &fl.iter {
194                    if iter.id == "messages" && matches!(&fl.target, Expr::Var(_)) {
195                        self.pop_scope();
196                    }
197                }
198            }
199            Stmt::IfCond(ic) => {
200                self.inspect_expr_for_structure(&ic.expr);
201                for b in &ic.true_body {
202                    self.walk_stmt(b);
203                }
204                for b in &ic.false_body {
205                    self.walk_stmt(b);
206                }
207            }
208            Stmt::EmitExpr(e) => {
209                self.inspect_expr_for_structure(&e.expr);
210            }
211            // {% set content = message.content %}
212            Stmt::Set(s) => {
213                if Self::is_var_access(&s.target, "content")
214                    && self.is_any_scope_var_content(&s.expr)
215                {
216                    self.flags.saw_assignment = true;
217                }
218            }
219            Stmt::Macro(m) => {
220                // Heuristic: macro that checks type (via `is` test) and also has any loop
221                let mut has_type_check = false;
222                let mut has_loop = false;
223                Self::scan_macro_body(&m.body, &mut has_type_check, &mut has_loop);
224                if has_type_check && has_loop {
225                    self.flags.saw_macro = true;
226                }
227            }
228            _ => {}
229        }
230    }
231
232    fn inspect_expr_for_structure(&mut self, expr: &Expr) {
233        if self.flags.saw_structure {
234            return;
235        }
236
237        match expr {
238            // content[0] or message.content[0]
239            Expr::GetItem(gi) => {
240                if (matches!(&gi.expr, Expr::Var(v) if v.id == "content")
241                    || self.is_any_scope_var_content(&gi.expr))
242                    && Self::is_numeric_const(&gi.subscript_expr)
243                {
244                    self.flags.saw_structure = true;
245                }
246            }
247            // content|length or message.content|length
248            Expr::Filter(f) => {
249                if f.name == "length" {
250                    if let Some(inner) = &f.expr {
251                        // Box derefs automatically, so `&**inner` is `&Expr`
252                        let inner_ref: &Expr = inner;
253                        let is_content_var = matches!(inner_ref, Expr::Var(v) if v.id == "content");
254                        if is_content_var || self.is_any_scope_var_content(inner_ref) {
255                            self.flags.saw_structure = true;
256                        }
257                    }
258                } else if let Some(inner) = &f.expr {
259                    let inner_ref: &Expr = inner;
260                    self.inspect_expr_for_structure(inner_ref);
261                }
262            }
263            // Type tests like `content is iterable` or `message.content is string`
264            // These are used for branching (e.g., Llama 3.1 uses them for tool output formatting),
265            // not as indicators that the template expects structured content. Keep walking.
266            Expr::Test(t) => self.inspect_expr_for_structure(&t.expr),
267            Expr::GetAttr(g) => {
268                // Keep walking; nested expressions can hide structure checks
269                self.inspect_expr_for_structure(&g.expr);
270            }
271            // Handle binary operations like: if (message.content is string) and other_cond
272            Expr::BinOp(op) => {
273                self.inspect_expr_for_structure(&op.left);
274                self.inspect_expr_for_structure(&op.right);
275            }
276            // Handle unary operations like: if not (message.content is string)
277            Expr::UnaryOp(op) => {
278                self.inspect_expr_for_structure(&op.expr);
279            }
280            _ => {}
281        }
282    }
283
284    fn scan_macro_body(body: &[Stmt], has_type_check: &mut bool, has_loop: &mut bool) {
285        for s in body {
286            if *has_type_check && *has_loop {
287                return;
288            }
289
290            match s {
291                Stmt::IfCond(ic) => {
292                    if matches!(&ic.expr, Expr::Test(_)) {
293                        *has_type_check = true;
294                    }
295                    Self::scan_macro_body(&ic.true_body, has_type_check, has_loop);
296                    Self::scan_macro_body(&ic.false_body, has_type_check, has_loop);
297                }
298                Stmt::ForLoop(fl) => {
299                    *has_loop = true;
300                    Self::scan_macro_body(&fl.body, has_type_check, has_loop);
301                }
302                Stmt::Template(t) => {
303                    Self::scan_macro_body(&t.children, has_type_check, has_loop);
304                }
305                _ => {}
306            }
307        }
308    }
309}
310
311/// AST-based detection using minijinja's unstable machinery
312/// Single-pass detector with scope tracking
313fn detect_format_with_ast(template: &str) -> ChatTemplateContentFormat {
314    let ast = match parse(
315        template,
316        "template",
317        SyntaxConfig {},
318        WhitespaceConfig::default(),
319    ) {
320        Ok(ast) => ast,
321        Err(_) => return ChatTemplateContentFormat::String,
322    };
323
324    let flags = Detector::new(&ast).run();
325    if flags.any() {
326        ChatTemplateContentFormat::OpenAI
327    } else {
328        ChatTemplateContentFormat::String
329    }
330}
331
332/// Parameters for chat template application
333#[derive(Default)]
334pub struct ChatTemplateParams<'a> {
335    pub add_generation_prompt: bool,
336    pub tools: Option<&'a [serde_json::Value]>,
337    pub documents: Option<&'a [serde_json::Value]>,
338    pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
339}
340
341/// Custom tojson filter compatible with HuggingFace transformers' implementation.
342///
343/// HuggingFace transformers registers a custom `tojson` filter that accepts additional
344/// keyword arguments beyond what standard Jinja2 provides:
345/// - `ensure_ascii` (bool): Whether to escape non-ASCII characters (ignored in Rust, always UTF-8)
346/// - `indent` (int): Number of spaces for indentation (pretty-printing)
347/// - `separators` (ignored): Custom separators for JSON output
348/// - `sort_keys` (bool): Whether to sort dictionary keys
349///
350/// This is necessary for compatibility with chat templates from HuggingFace Hub models.
351/// See: https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
352fn tojson_filter(value: Value, kwargs: Kwargs) -> std::result::Result<Value, MinijinjaError> {
353    let _ensure_ascii: Option<bool> = kwargs.get("ensure_ascii")?;
354    let indent: Option<i64> = kwargs.get("indent")?;
355    let _separators: Option<Value> = kwargs.get("separators")?;
356    let sort_keys: Option<bool> = kwargs.get("sort_keys")?;
357
358    // Ensure all kwargs are consumed to avoid "unknown keyword argument" errors
359    kwargs.assert_all_used()?;
360
361    let json_value: serde_json::Value = serde_json::to_value(&value).map_err(|e| {
362        MinijinjaError::new(
363            ErrorKind::InvalidOperation,
364            format!("Failed to convert to JSON value: {e}"),
365        )
366    })?;
367
368    // Helper to serialize with custom indentation
369    fn serialize_with_indent<T: Serialize>(
370        value: &T,
371        spaces: usize,
372    ) -> std::result::Result<String, MinijinjaError> {
373        let indent_str = vec![b' '; spaces];
374        let formatter = PrettyFormatter::with_indent(&indent_str);
375        let mut buf = Vec::new();
376        let mut serializer = serde_json::Serializer::with_formatter(&mut buf, formatter);
377        value.serialize(&mut serializer).map_err(|e| {
378            MinijinjaError::new(
379                ErrorKind::InvalidOperation,
380                format!("Failed to serialize JSON: {e}"),
381            )
382        })?;
383        String::from_utf8(buf).map_err(|e| {
384            MinijinjaError::new(
385                ErrorKind::InvalidOperation,
386                format!("Invalid UTF-8 in JSON output: {e}"),
387            )
388        })
389    }
390
391    // Serialize with options
392    let json_str: std::result::Result<String, MinijinjaError> = {
393        let sorted_json;
394        let value_to_serialize = if sort_keys.unwrap_or(false) {
395            sorted_json = sort_json_keys(&json_value);
396            &sorted_json
397        } else {
398            &json_value
399        };
400
401        if let Some(spaces) = indent {
402            if spaces < 0 {
403                return Err(MinijinjaError::new(
404                    ErrorKind::InvalidOperation,
405                    "indent cannot be negative",
406                ));
407            }
408            serialize_with_indent(value_to_serialize, spaces as usize)
409        } else {
410            serde_json::to_string(value_to_serialize).map_err(|e| {
411                MinijinjaError::new(
412                    ErrorKind::InvalidOperation,
413                    format!("Failed to serialize JSON: {e}"),
414                )
415            })
416        }
417    };
418
419    json_str.map(Value::from_safe_string)
420}
421
422/// Recursively sort all object keys in a JSON value
423fn sort_json_keys(value: &JsonValue) -> JsonValue {
424    match value {
425        JsonValue::Object(map) => {
426            let mut sorted: serde_json::Map<String, JsonValue> = serde_json::Map::new();
427            let mut keys: Vec<_> = map.keys().collect();
428            keys.sort();
429            for key in keys {
430                sorted.insert(key.clone(), sort_json_keys(&map[key]));
431            }
432            JsonValue::Object(sorted)
433        }
434        JsonValue::Array(arr) => JsonValue::Array(arr.iter().map(sort_json_keys).collect()),
435        _ => value.clone(),
436    }
437}
438
439/// Build a pre-configured `Environment<'static>` with the given template string,
440/// Python-compat method callback, and custom `tojson` filter already registered.
441/// The template is stored under the name `"chat"` using owned storage so the
442/// environment carries no borrows.
443fn build_environment(template: String) -> Result<Environment<'static>> {
444    let mut env = Environment::new();
445
446    // Register the template with owned storage (no lifetime dependency on caller)
447    env.add_template_owned("chat".to_owned(), template)
448        .map_err(|e| anyhow!("Failed to add template: {e}"))?;
449
450    // Enable Python method compatibility (e.g., str.startswith, str.endswith)
451    env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
452
453    // Register custom tojson filter compatible with HuggingFace transformers
454    // This overrides minijinja's built-in tojson to support additional kwargs
455    // like ensure_ascii, separators, and sort_keys that HuggingFace templates use
456    env.add_filter("tojson", tojson_filter);
457
458    Ok(env)
459}
460
461/// Render the `"chat"` template in the given environment against messages and params.
462fn render_chat_template(
463    env: &Environment<'_>,
464    messages: &[serde_json::Value],
465    params: ChatTemplateParams,
466) -> Result<String> {
467    let tmpl = env
468        .get_template("chat")
469        .map_err(|e| anyhow!("Failed to get template: {e}"))?;
470
471    // Convert messages to minijinja::Value (messages already processed by router)
472    let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
473
474    let base_context = context! {
475        messages => &minijinja_messages,
476        add_generation_prompt => params.add_generation_prompt,
477        tools => params.tools,
478        documents => params.documents,
479    };
480
481    // Merge with template_kwargs if provided
482    let ctx = if let Some(kwargs) = params.template_kwargs {
483        context! {
484            ..base_context,
485            ..Value::from_serialize(kwargs)
486        }
487    } else {
488        base_context
489    };
490
491    // Render the template
492    let rendered = tmpl
493        .render(&ctx)
494        .map_err(|e| anyhow!("Failed to render template: {e}"))?;
495
496    Ok(rendered)
497}
498
499/// Chat template processor using Jinja2 - simple wrapper like HuggingFace
500pub struct ChatTemplateProcessor {
501    env: Environment<'static>,
502}
503
504impl ChatTemplateProcessor {
505    /// Create a new chat template processor.
506    ///
507    /// Returns an error if the template fails to parse, so callers get an
508    /// actionable message immediately rather than a confusing "template not
509    /// found" error on the first render.
510    pub fn new(template: String) -> Result<Self> {
511        let env = build_environment(template)?;
512        Ok(ChatTemplateProcessor { env })
513    }
514
515    /// Apply the chat template to a list of messages
516    ///
517    /// This mimics the behavior of HuggingFace's apply_chat_template method
518    /// but returns the formatted string instead of token IDs.
519    /// Messages should be pre-processed into the format expected by the template.
520    pub fn apply_chat_template(
521        &self,
522        messages: &[serde_json::Value],
523        params: ChatTemplateParams,
524    ) -> Result<String> {
525        render_chat_template(&self.env, messages, params)
526    }
527}
528
529/// Load chat template from tokenizer config JSON
530pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
531    let content = fs::read_to_string(config_path)?;
532    let config: serde_json::Value = serde_json::from_str(&content)?;
533
534    // Look for chat_template in the config
535    if let Some(template) = config.get("chat_template") {
536        if let Some(template_str) = template.as_str() {
537            return Ok(Some(template_str.to_string()));
538        }
539    }
540
541    Ok(None)
542}
543
544/// Load chat template from a file (.jinja or .json containing Jinja).
545/// Shared between all tokenizer backends.
546pub fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
547    let content = fs::read_to_string(template_path)
548        .map_err(|e| anyhow!("Failed to read chat template file: {e}"))?;
549
550    if template_path.ends_with(".json") {
551        let json_value: serde_json::Value = serde_json::from_str(&content)
552            .map_err(|e| anyhow!("Failed to parse chat_template.json: {e}"))?;
553
554        if let Some(template_str) = json_value.as_str() {
555            return Ok(Some(template_str.to_string()));
556        } else if let Some(obj) = json_value.as_object() {
557            if let Some(template_value) = obj.get("chat_template") {
558                if let Some(template_str) = template_value.as_str() {
559                    return Ok(Some(template_str.to_string()));
560                }
561            }
562        }
563
564        return Err(anyhow!(
565            "chat_template.json does not contain a valid template",
566        ));
567    }
568
569    // Plain .jinja file
570    let template = content.trim().replace("\\n", "\n");
571    Ok(Some(template))
572}
573
574/// Chat template state that can be embedded in any tokenizer struct.
575/// Eliminates duplicated apply/set/format methods across tokenizer backends.
576///
577/// The compiled `minijinja::Environment` (with the template parsed, filters
578/// registered, and Python-compat callback installed) is cached so that
579/// `apply()` only performs rendering -- no parsing or environment setup.
580/// The cache is rebuilt whenever `set()` is called.
581///
582/// `Environment<'static>` is both `Send` and `Sync`, so embedding this in
583/// tokenizer structs shared across threads is safe.
584pub struct ChatTemplateState {
585    /// Cached, fully-configured environment. `None` when no template is set.
586    env: Option<Environment<'static>>,
587    content_format: ChatTemplateContentFormat,
588}
589
590impl std::fmt::Debug for ChatTemplateState {
591    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592        f.debug_struct("ChatTemplateState")
593            .field("has_template", &self.env.is_some())
594            .field("content_format", &self.content_format)
595            .finish()
596    }
597}
598
599impl ChatTemplateState {
600    pub fn new(template: Option<String>) -> Result<Self> {
601        let content_format = template
602            .as_ref()
603            .map(|t| detect_chat_template_content_format(t))
604            .unwrap_or_default();
605        let env = template.map(build_environment).transpose()?;
606        Ok(Self {
607            env,
608            content_format,
609        })
610    }
611
612    /// Create a `ChatTemplateState` with no template set.
613    ///
614    /// Unlike `new(None)`, this is infallible since there is no template to
615    /// parse — useful in constructors that don't return `Result`.
616    pub fn empty() -> Self {
617        Self {
618            env: None,
619            content_format: ChatTemplateContentFormat::default(),
620        }
621    }
622
623    pub fn apply(
624        &self,
625        messages: &[serde_json::Value],
626        params: ChatTemplateParams,
627    ) -> Result<String> {
628        let env = self.env.as_ref().ok_or_else(|| {
629            anyhow!(
630                "Cannot use chat template functions because tokenizer.chat_template is not set \
631                 and no template argument was passed! For information about writing templates and \
632                 setting the tokenizer.chat_template attribute, please see the documentation at \
633                 https://huggingface.co/docs/transformers/main/en/chat_templating",
634            )
635        })?;
636        render_chat_template(env, messages, params)
637    }
638
639    pub fn set(&mut self, template: String) -> Result<()> {
640        let content_format = detect_chat_template_content_format(&template);
641        let env = build_environment(template)?;
642        self.content_format = content_format;
643        self.env = Some(env);
644        Ok(())
645    }
646
647    pub fn content_format(&self) -> ChatTemplateContentFormat {
648        self.content_format
649    }
650}
651
652#[cfg(test)]
653mod tests {
654    use super::*;
655
656    #[test]
657    fn test_chat_template_state_no_template() {
658        let state = ChatTemplateState::new(None).unwrap();
659        assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
660        let result = state.apply(&[], ChatTemplateParams::default());
661        assert!(result.is_err());
662    }
663
664    #[test]
665    fn test_chat_template_state_set() {
666        let mut state = ChatTemplateState::new(None).unwrap();
667        state.set("{{ messages }}".to_string()).unwrap();
668        assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
669    }
670
671    #[test]
672    fn test_chat_template_state_invalid_template() {
673        let result = ChatTemplateState::new(Some("{% invalid".to_string()));
674        assert!(result.is_err());
675        let err = result.unwrap_err().to_string();
676        assert!(
677            err.contains("Failed to add template"),
678            "Error should explain parse failure, got: {err}"
679        );
680    }
681
682    #[test]
683    fn test_chat_template_processor_invalid_template() {
684        let result = ChatTemplateProcessor::new("{% invalid".to_string());
685        assert!(result.is_err());
686    }
687}