1use std::{collections::HashMap, fs};
7
8use anyhow::{anyhow, Result};
9use minijinja::{
10 context,
11 machinery::{
12 ast::{Expr, Stmt},
13 parse, WhitespaceConfig,
14 },
15 syntax::SyntaxConfig,
16 value::Kwargs,
17 Environment, Error as MinijinjaError, ErrorKind, Value,
18};
19use serde::Serialize;
20use serde_json::{self, ser::PrettyFormatter, Value as JsonValue};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
24pub enum ChatTemplateContentFormat {
25 #[default]
27 String,
28 OpenAI,
30}
31
32impl std::fmt::Display for ChatTemplateContentFormat {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 Self::String => write!(f, "string"),
36 Self::OpenAI => write!(f, "openai"),
37 }
38 }
39}
40
41pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat {
50 detect_format_with_ast(template)
52}
53
54#[derive(Default, Debug, Clone, Copy)]
56struct Flags {
57 saw_iteration: bool,
58 saw_structure: bool,
59 saw_assignment: bool,
60 saw_macro: bool,
61}
62
63impl Flags {
64 fn any(self) -> bool {
65 self.saw_iteration || self.saw_structure || self.saw_assignment || self.saw_macro
66 }
67}
68
69struct Detector<'a> {
71 ast: &'a Stmt<'a>,
72 scope: std::collections::VecDeque<String>,
74 scope_set: std::collections::HashSet<String>,
75 flags: Flags,
76}
77
78impl<'a> Detector<'a> {
79 fn new(ast: &'a Stmt<'a>) -> Self {
80 Self {
81 ast,
82 scope: std::collections::VecDeque::new(),
83 scope_set: std::collections::HashSet::new(),
84 flags: Flags::default(),
85 }
86 }
87
88 fn run(mut self) -> Flags {
89 self.walk_stmt(self.ast);
90 self.flags
91 }
92
93 fn push_scope(&mut self, var: String) {
94 self.scope.push_back(var.clone());
95 self.scope_set.insert(var);
96 }
97
98 fn pop_scope(&mut self) {
99 if let Some(v) = self.scope.pop_back() {
100 self.scope_set.remove(&v);
101 }
102 }
103
104 fn is_var_access(expr: &Expr, varname: &str) -> bool {
105 matches!(expr, Expr::Var(v) if v.id == varname)
106 }
107
108 fn is_const_str(expr: &Expr, value: &str) -> bool {
109 matches!(expr, Expr::Const(c) if c.value.as_str() == Some(value))
110 }
111
112 fn is_numeric_const(expr: &Expr) -> bool {
113 matches!(expr, Expr::Const(c) if c.value.is_number())
114 }
115
116 fn is_var_dot_content(expr: &Expr, varname: &str) -> bool {
118 match expr {
119 Expr::GetAttr(g) => Self::is_var_access(&g.expr, varname) && g.name == "content",
120 Expr::GetItem(g) => {
121 Self::is_var_access(&g.expr, varname)
122 && Self::is_const_str(&g.subscript_expr, "content")
123 }
124 Expr::Filter(f) => f
126 .expr
127 .as_ref()
128 .is_some_and(|e| Self::is_var_dot_content(e, varname)),
129 Expr::Test(t) => Self::is_var_dot_content(&t.expr, varname),
130 _ => false,
131 }
132 }
133
134 fn is_any_scope_var_content(&self, expr: &Expr) -> bool {
136 let mut current_expr = expr;
137 loop {
138 if self
140 .scope_set
141 .iter()
142 .any(|v| Self::is_var_dot_content(current_expr, v))
143 {
144 return true;
145 }
146 match current_expr {
148 Expr::GetAttr(g) => current_expr = &g.expr,
149 Expr::GetItem(g) => current_expr = &g.expr,
150 _ => return false,
151 }
152 }
153 }
154
155 fn walk_stmt(&mut self, stmt: &Stmt) {
156 if self.flags.any() {
158 return;
159 }
160
161 match stmt {
162 Stmt::Template(t) => {
163 for ch in &t.children {
164 self.walk_stmt(ch);
165 }
166 }
167 Stmt::ForLoop(fl) => {
169 if let Expr::Var(iter) = &fl.iter {
171 if iter.id == "messages" {
172 if let Expr::Var(target) = &fl.target {
173 self.push_scope(target.id.to_string());
174 }
175 }
176 }
177
178 if self.is_any_scope_var_content(&fl.iter) {
181 self.flags.saw_iteration = true;
182 }
183 if matches!(&fl.iter, Expr::Var(v) if v.id == "content") {
185 self.flags.saw_iteration = true;
186 }
187
188 for b in &fl.body {
189 self.walk_stmt(b);
190 }
191
192 if let Expr::Var(iter) = &fl.iter {
194 if iter.id == "messages" && matches!(&fl.target, Expr::Var(_)) {
195 self.pop_scope();
196 }
197 }
198 }
199 Stmt::IfCond(ic) => {
200 self.inspect_expr_for_structure(&ic.expr);
201 for b in &ic.true_body {
202 self.walk_stmt(b);
203 }
204 for b in &ic.false_body {
205 self.walk_stmt(b);
206 }
207 }
208 Stmt::EmitExpr(e) => {
209 self.inspect_expr_for_structure(&e.expr);
210 }
211 Stmt::Set(s) => {
213 if Self::is_var_access(&s.target, "content")
214 && self.is_any_scope_var_content(&s.expr)
215 {
216 self.flags.saw_assignment = true;
217 }
218 }
219 Stmt::Macro(m) => {
220 let mut has_type_check = false;
222 let mut has_loop = false;
223 Self::scan_macro_body(&m.body, &mut has_type_check, &mut has_loop);
224 if has_type_check && has_loop {
225 self.flags.saw_macro = true;
226 }
227 }
228 _ => {}
229 }
230 }
231
232 fn inspect_expr_for_structure(&mut self, expr: &Expr) {
233 if self.flags.saw_structure {
234 return;
235 }
236
237 match expr {
238 Expr::GetItem(gi) => {
240 if (matches!(&gi.expr, Expr::Var(v) if v.id == "content")
241 || self.is_any_scope_var_content(&gi.expr))
242 && Self::is_numeric_const(&gi.subscript_expr)
243 {
244 self.flags.saw_structure = true;
245 }
246 }
247 Expr::Filter(f) => {
249 if f.name == "length" {
250 if let Some(inner) = &f.expr {
251 let inner_ref: &Expr = inner;
253 let is_content_var = matches!(inner_ref, Expr::Var(v) if v.id == "content");
254 if is_content_var || self.is_any_scope_var_content(inner_ref) {
255 self.flags.saw_structure = true;
256 }
257 }
258 } else if let Some(inner) = &f.expr {
259 let inner_ref: &Expr = inner;
260 self.inspect_expr_for_structure(inner_ref);
261 }
262 }
263 Expr::Test(t) => self.inspect_expr_for_structure(&t.expr),
267 Expr::GetAttr(g) => {
268 self.inspect_expr_for_structure(&g.expr);
270 }
271 Expr::BinOp(op) => {
273 self.inspect_expr_for_structure(&op.left);
274 self.inspect_expr_for_structure(&op.right);
275 }
276 Expr::UnaryOp(op) => {
278 self.inspect_expr_for_structure(&op.expr);
279 }
280 _ => {}
281 }
282 }
283
284 fn scan_macro_body(body: &[Stmt], has_type_check: &mut bool, has_loop: &mut bool) {
285 for s in body {
286 if *has_type_check && *has_loop {
287 return;
288 }
289
290 match s {
291 Stmt::IfCond(ic) => {
292 if matches!(&ic.expr, Expr::Test(_)) {
293 *has_type_check = true;
294 }
295 Self::scan_macro_body(&ic.true_body, has_type_check, has_loop);
296 Self::scan_macro_body(&ic.false_body, has_type_check, has_loop);
297 }
298 Stmt::ForLoop(fl) => {
299 *has_loop = true;
300 Self::scan_macro_body(&fl.body, has_type_check, has_loop);
301 }
302 Stmt::Template(t) => {
303 Self::scan_macro_body(&t.children, has_type_check, has_loop);
304 }
305 _ => {}
306 }
307 }
308 }
309}
310
311fn detect_format_with_ast(template: &str) -> ChatTemplateContentFormat {
314 let ast = match parse(
315 template,
316 "template",
317 SyntaxConfig {},
318 WhitespaceConfig::default(),
319 ) {
320 Ok(ast) => ast,
321 Err(_) => return ChatTemplateContentFormat::String,
322 };
323
324 let flags = Detector::new(&ast).run();
325 if flags.any() {
326 ChatTemplateContentFormat::OpenAI
327 } else {
328 ChatTemplateContentFormat::String
329 }
330}
331
332#[derive(Default)]
334pub struct ChatTemplateParams<'a> {
335 pub add_generation_prompt: bool,
336 pub tools: Option<&'a [serde_json::Value]>,
337 pub documents: Option<&'a [serde_json::Value]>,
338 pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
339}
340
341fn tojson_filter(value: Value, kwargs: Kwargs) -> std::result::Result<Value, MinijinjaError> {
353 let _ensure_ascii: Option<bool> = kwargs.get("ensure_ascii")?;
354 let indent: Option<i64> = kwargs.get("indent")?;
355 let _separators: Option<Value> = kwargs.get("separators")?;
356 let sort_keys: Option<bool> = kwargs.get("sort_keys")?;
357
358 kwargs.assert_all_used()?;
360
361 let json_value: serde_json::Value = serde_json::to_value(&value).map_err(|e| {
362 MinijinjaError::new(
363 ErrorKind::InvalidOperation,
364 format!("Failed to convert to JSON value: {e}"),
365 )
366 })?;
367
368 fn serialize_with_indent<T: Serialize>(
370 value: &T,
371 spaces: usize,
372 ) -> std::result::Result<String, MinijinjaError> {
373 let indent_str = vec![b' '; spaces];
374 let formatter = PrettyFormatter::with_indent(&indent_str);
375 let mut buf = Vec::new();
376 let mut serializer = serde_json::Serializer::with_formatter(&mut buf, formatter);
377 value.serialize(&mut serializer).map_err(|e| {
378 MinijinjaError::new(
379 ErrorKind::InvalidOperation,
380 format!("Failed to serialize JSON: {e}"),
381 )
382 })?;
383 String::from_utf8(buf).map_err(|e| {
384 MinijinjaError::new(
385 ErrorKind::InvalidOperation,
386 format!("Invalid UTF-8 in JSON output: {e}"),
387 )
388 })
389 }
390
391 let json_str: std::result::Result<String, MinijinjaError> = {
393 let sorted_json;
394 let value_to_serialize = if sort_keys.unwrap_or(false) {
395 sorted_json = sort_json_keys(&json_value);
396 &sorted_json
397 } else {
398 &json_value
399 };
400
401 if let Some(spaces) = indent {
402 if spaces < 0 {
403 return Err(MinijinjaError::new(
404 ErrorKind::InvalidOperation,
405 "indent cannot be negative",
406 ));
407 }
408 serialize_with_indent(value_to_serialize, spaces as usize)
409 } else {
410 serde_json::to_string(value_to_serialize).map_err(|e| {
411 MinijinjaError::new(
412 ErrorKind::InvalidOperation,
413 format!("Failed to serialize JSON: {e}"),
414 )
415 })
416 }
417 };
418
419 json_str.map(Value::from_safe_string)
420}
421
422fn sort_json_keys(value: &JsonValue) -> JsonValue {
424 match value {
425 JsonValue::Object(map) => {
426 let mut sorted: serde_json::Map<String, JsonValue> = serde_json::Map::new();
427 let mut keys: Vec<_> = map.keys().collect();
428 keys.sort();
429 for key in keys {
430 sorted.insert(key.clone(), sort_json_keys(&map[key]));
431 }
432 JsonValue::Object(sorted)
433 }
434 JsonValue::Array(arr) => JsonValue::Array(arr.iter().map(sort_json_keys).collect()),
435 _ => value.clone(),
436 }
437}
438
439fn build_environment(template: String) -> Result<Environment<'static>> {
444 let mut env = Environment::new();
445
446 env.add_template_owned("chat".to_owned(), template)
448 .map_err(|e| anyhow!("Failed to add template: {e}"))?;
449
450 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
452
453 env.add_filter("tojson", tojson_filter);
457
458 Ok(env)
459}
460
461fn render_chat_template(
463 env: &Environment<'_>,
464 messages: &[serde_json::Value],
465 params: ChatTemplateParams,
466) -> Result<String> {
467 let tmpl = env
468 .get_template("chat")
469 .map_err(|e| anyhow!("Failed to get template: {e}"))?;
470
471 let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
473
474 let base_context = context! {
475 messages => &minijinja_messages,
476 add_generation_prompt => params.add_generation_prompt,
477 tools => params.tools,
478 documents => params.documents,
479 };
480
481 let ctx = if let Some(kwargs) = params.template_kwargs {
483 context! {
484 ..base_context,
485 ..Value::from_serialize(kwargs)
486 }
487 } else {
488 base_context
489 };
490
491 let rendered = tmpl
493 .render(&ctx)
494 .map_err(|e| anyhow!("Failed to render template: {e}"))?;
495
496 Ok(rendered)
497}
498
499pub struct ChatTemplateProcessor {
501 env: Environment<'static>,
502}
503
504impl ChatTemplateProcessor {
505 pub fn new(template: String) -> Result<Self> {
511 let env = build_environment(template)?;
512 Ok(ChatTemplateProcessor { env })
513 }
514
515 pub fn apply_chat_template(
521 &self,
522 messages: &[serde_json::Value],
523 params: ChatTemplateParams,
524 ) -> Result<String> {
525 render_chat_template(&self.env, messages, params)
526 }
527}
528
529pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
531 let content = fs::read_to_string(config_path)?;
532 let config: serde_json::Value = serde_json::from_str(&content)?;
533
534 if let Some(template) = config.get("chat_template") {
536 if let Some(template_str) = template.as_str() {
537 return Ok(Some(template_str.to_string()));
538 }
539 }
540
541 Ok(None)
542}
543
544pub fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
547 let content = fs::read_to_string(template_path)
548 .map_err(|e| anyhow!("Failed to read chat template file: {e}"))?;
549
550 if template_path.ends_with(".json") {
551 let json_value: serde_json::Value = serde_json::from_str(&content)
552 .map_err(|e| anyhow!("Failed to parse chat_template.json: {e}"))?;
553
554 if let Some(template_str) = json_value.as_str() {
555 return Ok(Some(template_str.to_string()));
556 } else if let Some(obj) = json_value.as_object() {
557 if let Some(template_value) = obj.get("chat_template") {
558 if let Some(template_str) = template_value.as_str() {
559 return Ok(Some(template_str.to_string()));
560 }
561 }
562 }
563
564 return Err(anyhow!(
565 "chat_template.json does not contain a valid template",
566 ));
567 }
568
569 let template = content.trim().replace("\\n", "\n");
571 Ok(Some(template))
572}
573
574pub struct ChatTemplateState {
585 env: Option<Environment<'static>>,
587 content_format: ChatTemplateContentFormat,
588}
589
590impl std::fmt::Debug for ChatTemplateState {
591 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592 f.debug_struct("ChatTemplateState")
593 .field("has_template", &self.env.is_some())
594 .field("content_format", &self.content_format)
595 .finish()
596 }
597}
598
599impl ChatTemplateState {
600 pub fn new(template: Option<String>) -> Result<Self> {
601 let content_format = template
602 .as_ref()
603 .map(|t| detect_chat_template_content_format(t))
604 .unwrap_or_default();
605 let env = template.map(build_environment).transpose()?;
606 Ok(Self {
607 env,
608 content_format,
609 })
610 }
611
612 pub fn empty() -> Self {
617 Self {
618 env: None,
619 content_format: ChatTemplateContentFormat::default(),
620 }
621 }
622
623 pub fn apply(
624 &self,
625 messages: &[serde_json::Value],
626 params: ChatTemplateParams,
627 ) -> Result<String> {
628 let env = self.env.as_ref().ok_or_else(|| {
629 anyhow!(
630 "Cannot use chat template functions because tokenizer.chat_template is not set \
631 and no template argument was passed! For information about writing templates and \
632 setting the tokenizer.chat_template attribute, please see the documentation at \
633 https://huggingface.co/docs/transformers/main/en/chat_templating",
634 )
635 })?;
636 render_chat_template(env, messages, params)
637 }
638
639 pub fn set(&mut self, template: String) -> Result<()> {
640 let content_format = detect_chat_template_content_format(&template);
641 let env = build_environment(template)?;
642 self.content_format = content_format;
643 self.env = Some(env);
644 Ok(())
645 }
646
647 pub fn content_format(&self) -> ChatTemplateContentFormat {
648 self.content_format
649 }
650}
651
652#[cfg(test)]
653mod tests {
654 use super::*;
655
656 #[test]
657 fn test_chat_template_state_no_template() {
658 let state = ChatTemplateState::new(None).unwrap();
659 assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
660 let result = state.apply(&[], ChatTemplateParams::default());
661 assert!(result.is_err());
662 }
663
664 #[test]
665 fn test_chat_template_state_set() {
666 let mut state = ChatTemplateState::new(None).unwrap();
667 state.set("{{ messages }}".to_string()).unwrap();
668 assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
669 }
670
671 #[test]
672 fn test_chat_template_state_invalid_template() {
673 let result = ChatTemplateState::new(Some("{% invalid".to_string()));
674 assert!(result.is_err());
675 let err = result.unwrap_err().to_string();
676 assert!(
677 err.contains("Failed to add template"),
678 "Error should explain parse failure, got: {err}"
679 );
680 }
681
682 #[test]
683 fn test_chat_template_processor_invalid_template() {
684 let result = ChatTemplateProcessor::new("{% invalid".to_string());
685 assert!(result.is_err());
686 }
687}