llm_tokenizer/
chat_template.rs

1//! Chat template support for tokenizers using Jinja2 templates
2//!
3//! This module provides functionality to apply chat templates to messages,
4//! similar to HuggingFace transformers' apply_chat_template method.
5
6use std::{collections::HashMap, fs};
7
8use anyhow::{anyhow, Result};
9use minijinja::{
10    context,
11    machinery::{
12        ast::{Expr, Stmt},
13        parse, WhitespaceConfig,
14    },
15    syntax::SyntaxConfig,
16    value::Kwargs,
17    Environment, Error as MinijinjaError, ErrorKind, Value,
18};
19use serde::Serialize;
20use serde_json::{self, ser::PrettyFormatter, Value as JsonValue};
21
22/// Chat template content format
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
24pub enum ChatTemplateContentFormat {
25    /// Content is a simple string
26    #[default]
27    String,
28    /// Content is a list of structured parts (OpenAI format)
29    OpenAI,
30}
31
32impl std::fmt::Display for ChatTemplateContentFormat {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            Self::String => write!(f, "string"),
36            Self::OpenAI => write!(f, "openai"),
37        }
38    }
39}
40
41/// Detect the content format expected by a Jinja2 chat template
42///
43/// This implements the same detection logic as SGLang's detect_jinja_template_content_format
44/// which uses AST parsing to look for content iteration patterns.
45///
46/// Returns:
47/// - ChatTemplateContentFormat::OpenAI if template expects structured content (list of parts)
48/// - ChatTemplateContentFormat::String if template expects simple string content
49pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat {
50    // Use AST-based detection (enabled by default)
51    if let Some(format) = detect_format_with_ast(template) {
52        return format;
53    }
54
55    // Default to string format if AST parsing fails
56    ChatTemplateContentFormat::String
57}
58
59/// Flags tracking which OpenAI-style patterns we've seen
60#[derive(Default, Debug, Clone, Copy)]
61struct Flags {
62    saw_iteration: bool,
63    saw_structure: bool,
64    saw_assignment: bool,
65    saw_macro: bool,
66}
67
68impl Flags {
69    fn any(self) -> bool {
70        self.saw_iteration || self.saw_structure || self.saw_assignment || self.saw_macro
71    }
72}
73
74/// Single-pass AST detector with scope tracking
75struct Detector<'a> {
76    ast: &'a Stmt<'a>,
77    /// Message loop vars currently in scope (e.g., `message`, `m`, `msg`)
78    scope: std::collections::VecDeque<String>,
79    scope_set: std::collections::HashSet<String>,
80    flags: Flags,
81}
82
83impl<'a> Detector<'a> {
84    fn new(ast: &'a Stmt<'a>) -> Self {
85        Self {
86            ast,
87            scope: std::collections::VecDeque::new(),
88            scope_set: std::collections::HashSet::new(),
89            flags: Flags::default(),
90        }
91    }
92
93    fn run(mut self) -> Flags {
94        self.walk_stmt(self.ast);
95        self.flags
96    }
97
98    fn push_scope(&mut self, var: String) {
99        self.scope.push_back(var.clone());
100        self.scope_set.insert(var);
101    }
102
103    fn pop_scope(&mut self) {
104        if let Some(v) = self.scope.pop_back() {
105            self.scope_set.remove(&v);
106        }
107    }
108
109    fn is_var_access(expr: &Expr, varname: &str) -> bool {
110        matches!(expr, Expr::Var(v) if v.id == varname)
111    }
112
113    fn is_const_str(expr: &Expr, value: &str) -> bool {
114        matches!(expr, Expr::Const(c) if c.value.as_str() == Some(value))
115    }
116
117    fn is_numeric_const(expr: &Expr) -> bool {
118        matches!(expr, Expr::Const(c) if c.value.is_number())
119    }
120
121    /// Check if expr is varname.content or varname["content"]
122    fn is_var_dot_content(expr: &Expr, varname: &str) -> bool {
123        match expr {
124            Expr::GetAttr(g) => Self::is_var_access(&g.expr, varname) && g.name == "content",
125            Expr::GetItem(g) => {
126                Self::is_var_access(&g.expr, varname)
127                    && Self::is_const_str(&g.subscript_expr, "content")
128            }
129            // Unwrap filters/tests that just wrap the same expr
130            Expr::Filter(f) => f
131                .expr
132                .as_ref()
133                .is_some_and(|e| Self::is_var_dot_content(e, varname)),
134            Expr::Test(t) => Self::is_var_dot_content(&t.expr, varname),
135            _ => false,
136        }
137    }
138
139    /// Check if expr accesses .content on any variable in our scope, or any descendant of it.
140    fn is_any_scope_var_content(&self, expr: &Expr) -> bool {
141        let mut current_expr = expr;
142        loop {
143            // Check if current level matches <scopeVar>.content
144            if self
145                .scope_set
146                .iter()
147                .any(|v| Self::is_var_dot_content(current_expr, v))
148            {
149                return true;
150            }
151            // Walk up the expression tree
152            match current_expr {
153                Expr::GetAttr(g) => current_expr = &g.expr,
154                Expr::GetItem(g) => current_expr = &g.expr,
155                _ => return false,
156            }
157        }
158    }
159
160    fn walk_stmt(&mut self, stmt: &Stmt) {
161        // Early exit if we've already detected an OpenAI pattern
162        if self.flags.any() {
163            return;
164        }
165
166        match stmt {
167            Stmt::Template(t) => {
168                for ch in &t.children {
169                    self.walk_stmt(ch);
170                }
171            }
172            // {% for message in messages %}
173            Stmt::ForLoop(fl) => {
174                // Detect "for X in messages" → push X into scope
175                if let Expr::Var(iter) = &fl.iter {
176                    if iter.id == "messages" {
177                        if let Expr::Var(target) = &fl.target {
178                            self.push_scope(target.id.to_string());
179                        }
180                    }
181                }
182
183                // Also detect "for ... in message.content" or "for ... in content"
184                // - Iterating directly over <scopeVar>.content => OpenAI style
185                if self.is_any_scope_var_content(&fl.iter) {
186                    self.flags.saw_iteration = true;
187                }
188                // - Iterating over a local var named "content"
189                if matches!(&fl.iter, Expr::Var(v) if v.id == "content") {
190                    self.flags.saw_iteration = true;
191                }
192
193                for b in &fl.body {
194                    self.walk_stmt(b);
195                }
196
197                // Pop scope if we pushed it
198                if let Expr::Var(iter) = &fl.iter {
199                    if iter.id == "messages" && matches!(&fl.target, Expr::Var(_)) {
200                        self.pop_scope();
201                    }
202                }
203            }
204            Stmt::IfCond(ic) => {
205                self.inspect_expr_for_structure(&ic.expr);
206                for b in &ic.true_body {
207                    self.walk_stmt(b);
208                }
209                for b in &ic.false_body {
210                    self.walk_stmt(b);
211                }
212            }
213            Stmt::EmitExpr(e) => {
214                self.inspect_expr_for_structure(&e.expr);
215            }
216            // {% set content = message.content %}
217            Stmt::Set(s) => {
218                if Self::is_var_access(&s.target, "content")
219                    && self.is_any_scope_var_content(&s.expr)
220                {
221                    self.flags.saw_assignment = true;
222                }
223            }
224            Stmt::Macro(m) => {
225                // Heuristic: macro that checks type (via `is` test) and also has any loop
226                let mut has_type_check = false;
227                let mut has_loop = false;
228                Self::scan_macro_body(&m.body, &mut has_type_check, &mut has_loop);
229                if has_type_check && has_loop {
230                    self.flags.saw_macro = true;
231                }
232            }
233            _ => {}
234        }
235    }
236
237    fn inspect_expr_for_structure(&mut self, expr: &Expr) {
238        if self.flags.saw_structure {
239            return;
240        }
241
242        match expr {
243            // content[0] or message.content[0]
244            Expr::GetItem(gi) => {
245                if (matches!(&gi.expr, Expr::Var(v) if v.id == "content")
246                    || self.is_any_scope_var_content(&gi.expr))
247                    && Self::is_numeric_const(&gi.subscript_expr)
248                {
249                    self.flags.saw_structure = true;
250                }
251            }
252            // content|length or message.content|length
253            Expr::Filter(f) => {
254                if f.name == "length" {
255                    if let Some(inner) = &f.expr {
256                        // Box derefs automatically, so `&**inner` is `&Expr`
257                        let inner_ref: &Expr = inner;
258                        let is_content_var = matches!(inner_ref, Expr::Var(v) if v.id == "content");
259                        if is_content_var || self.is_any_scope_var_content(inner_ref) {
260                            self.flags.saw_structure = true;
261                        }
262                    }
263                } else if let Some(inner) = &f.expr {
264                    let inner_ref: &Expr = inner;
265                    self.inspect_expr_for_structure(inner_ref);
266                }
267            }
268            // content is sequence/iterable OR message.content is sequence/iterable
269            Expr::Test(t) => {
270                if t.name == "sequence" || t.name == "iterable" || t.name == "string" {
271                    if matches!(&t.expr, Expr::Var(v) if v.id == "content")
272                        || self.is_any_scope_var_content(&t.expr)
273                    {
274                        self.flags.saw_structure = true;
275                    }
276                } else {
277                    self.inspect_expr_for_structure(&t.expr);
278                }
279            }
280            Expr::GetAttr(g) => {
281                // Keep walking; nested expressions can hide structure checks
282                self.inspect_expr_for_structure(&g.expr);
283            }
284            // Handle binary operations like: if (message.content is string) and other_cond
285            Expr::BinOp(op) => {
286                self.inspect_expr_for_structure(&op.left);
287                self.inspect_expr_for_structure(&op.right);
288            }
289            // Handle unary operations like: if not (message.content is string)
290            Expr::UnaryOp(op) => {
291                self.inspect_expr_for_structure(&op.expr);
292            }
293            _ => {}
294        }
295    }
296
297    fn scan_macro_body(body: &[Stmt], has_type_check: &mut bool, has_loop: &mut bool) {
298        for s in body {
299            if *has_type_check && *has_loop {
300                return;
301            }
302
303            match s {
304                Stmt::IfCond(ic) => {
305                    if matches!(&ic.expr, Expr::Test(_)) {
306                        *has_type_check = true;
307                    }
308                    Self::scan_macro_body(&ic.true_body, has_type_check, has_loop);
309                    Self::scan_macro_body(&ic.false_body, has_type_check, has_loop);
310                }
311                Stmt::ForLoop(fl) => {
312                    *has_loop = true;
313                    Self::scan_macro_body(&fl.body, has_type_check, has_loop);
314                }
315                Stmt::Template(t) => {
316                    Self::scan_macro_body(&t.children, has_type_check, has_loop);
317                }
318                _ => {}
319            }
320        }
321    }
322}
323
324/// AST-based detection using minijinja's unstable machinery
325/// Single-pass detector with scope tracking
326fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
327    let ast = match parse(
328        template,
329        "template",
330        SyntaxConfig {},
331        WhitespaceConfig::default(),
332    ) {
333        Ok(ast) => ast,
334        Err(_) => return Some(ChatTemplateContentFormat::String),
335    };
336
337    let flags = Detector::new(&ast).run();
338    Some(if flags.any() {
339        ChatTemplateContentFormat::OpenAI
340    } else {
341        ChatTemplateContentFormat::String
342    })
343}
344
345/// Parameters for chat template application
346#[derive(Default)]
347pub struct ChatTemplateParams<'a> {
348    pub add_generation_prompt: bool,
349    pub tools: Option<&'a [serde_json::Value]>,
350    pub documents: Option<&'a [serde_json::Value]>,
351    pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
352}
353
354/// Custom tojson filter compatible with HuggingFace transformers' implementation.
355///
356/// HuggingFace transformers registers a custom `tojson` filter that accepts additional
357/// keyword arguments beyond what standard Jinja2 provides:
358/// - `ensure_ascii` (bool): Whether to escape non-ASCII characters (ignored in Rust, always UTF-8)
359/// - `indent` (int): Number of spaces for indentation (pretty-printing)
360/// - `separators` (ignored): Custom separators for JSON output
361/// - `sort_keys` (bool): Whether to sort dictionary keys
362///
363/// This is necessary for compatibility with chat templates from HuggingFace Hub models.
364/// See: https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
365fn tojson_filter(value: Value, kwargs: Kwargs) -> std::result::Result<Value, MinijinjaError> {
366    let _ensure_ascii: Option<bool> = kwargs.get("ensure_ascii")?;
367    let indent: Option<i64> = kwargs.get("indent")?;
368    let _separators: Option<Value> = kwargs.get("separators")?;
369    let sort_keys: Option<bool> = kwargs.get("sort_keys")?;
370
371    // Ensure all kwargs are consumed to avoid "unknown keyword argument" errors
372    kwargs.assert_all_used()?;
373
374    let json_value: serde_json::Value = serde_json::to_value(&value).map_err(|e| {
375        MinijinjaError::new(
376            ErrorKind::InvalidOperation,
377            format!("Failed to convert to JSON value: {}", e),
378        )
379    })?;
380
381    // Helper to serialize with custom indentation
382    fn serialize_with_indent<T: Serialize>(
383        value: &T,
384        spaces: usize,
385    ) -> std::result::Result<String, MinijinjaError> {
386        let indent_str = vec![b' '; spaces];
387        let formatter = PrettyFormatter::with_indent(&indent_str);
388        let mut buf = Vec::new();
389        let mut serializer = serde_json::Serializer::with_formatter(&mut buf, formatter);
390        value.serialize(&mut serializer).map_err(|e| {
391            MinijinjaError::new(
392                ErrorKind::InvalidOperation,
393                format!("Failed to serialize JSON: {}", e),
394            )
395        })?;
396        String::from_utf8(buf).map_err(|e| {
397            MinijinjaError::new(
398                ErrorKind::InvalidOperation,
399                format!("Invalid UTF-8 in JSON output: {}", e),
400            )
401        })
402    }
403
404    // Serialize with options
405    let json_str: std::result::Result<String, MinijinjaError> = {
406        let sorted_json;
407        let value_to_serialize = if sort_keys.unwrap_or(false) {
408            sorted_json = sort_json_keys(&json_value);
409            &sorted_json
410        } else {
411            &json_value
412        };
413
414        if let Some(spaces) = indent {
415            if spaces < 0 {
416                return Err(MinijinjaError::new(
417                    ErrorKind::InvalidOperation,
418                    "indent cannot be negative",
419                ));
420            }
421            serialize_with_indent(value_to_serialize, spaces as usize)
422        } else {
423            serde_json::to_string(value_to_serialize).map_err(|e| {
424                MinijinjaError::new(
425                    ErrorKind::InvalidOperation,
426                    format!("Failed to serialize JSON: {}", e),
427                )
428            })
429        }
430    };
431
432    json_str.map(Value::from_safe_string)
433}
434
435/// Recursively sort all object keys in a JSON value
436fn sort_json_keys(value: &JsonValue) -> JsonValue {
437    match value {
438        JsonValue::Object(map) => {
439            let mut sorted: serde_json::Map<String, JsonValue> = serde_json::Map::new();
440            let mut keys: Vec<_> = map.keys().collect();
441            keys.sort();
442            for key in keys {
443                sorted.insert(key.clone(), sort_json_keys(&map[key]));
444            }
445            JsonValue::Object(sorted)
446        }
447        JsonValue::Array(arr) => JsonValue::Array(arr.iter().map(sort_json_keys).collect()),
448        _ => value.clone(),
449    }
450}
451
452/// Chat template processor using Jinja2 - simple wrapper like HuggingFace
453pub struct ChatTemplateProcessor {
454    template: String,
455}
456
457impl ChatTemplateProcessor {
458    /// Create a new chat template processor
459    pub fn new(template: String) -> Self {
460        ChatTemplateProcessor { template }
461    }
462
463    /// Apply the chat template to a list of messages
464    ///
465    /// This mimics the behavior of HuggingFace's apply_chat_template method
466    /// but returns the formatted string instead of token IDs.
467    /// Messages should be pre-processed into the format expected by the template.
468    pub fn apply_chat_template(
469        &self,
470        messages: &[serde_json::Value],
471        params: ChatTemplateParams,
472    ) -> Result<String> {
473        let mut env = Environment::new();
474
475        // Register the template
476        env.add_template("chat", &self.template)
477            .map_err(|e| anyhow!("Failed to add template: {}", e))?;
478
479        // Enable Python method compatibility (e.g., str.startswith, str.endswith)
480        env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
481
482        // Register custom tojson filter compatible with HuggingFace transformers
483        // This overrides minijinja's built-in tojson to support additional kwargs
484        // like ensure_ascii, separators, and sort_keys that HuggingFace templates use
485        env.add_filter("tojson", tojson_filter);
486
487        // Get the template
488        let tmpl = env
489            .get_template("chat")
490            .map_err(|e| anyhow!("Failed to get template: {}", e))?;
491
492        // Convert messages to minijinja::Value (messages already processed by router)
493        let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
494
495        let base_context = context! {
496            messages => &minijinja_messages,
497            add_generation_prompt => params.add_generation_prompt,
498            tools => params.tools,
499            documents => params.documents,
500        };
501
502        // Merge with template_kwargs if provided
503        let ctx = if let Some(kwargs) = params.template_kwargs {
504            context! {
505                ..base_context,
506                ..Value::from_serialize(kwargs)
507            }
508        } else {
509            base_context
510        };
511
512        // Render the template
513        let rendered = tmpl
514            .render(&ctx)
515            .map_err(|e| anyhow!("Failed to render template: {}", e))?;
516
517        Ok(rendered)
518    }
519}
520
521/// Load chat template from tokenizer config JSON
522pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
523    let content = fs::read_to_string(config_path)?;
524    let config: serde_json::Value = serde_json::from_str(&content)?;
525
526    // Look for chat_template in the config
527    if let Some(template) = config.get("chat_template") {
528        if let Some(template_str) = template.as_str() {
529            return Ok(Some(template_str.to_string()));
530        }
531    }
532
533    Ok(None)
534}