1use std::{collections::HashMap, fs};
7
8use anyhow::{anyhow, Result};
9use minijinja::{
10 context,
11 machinery::{
12 ast::{Expr, Stmt},
13 parse, WhitespaceConfig,
14 },
15 syntax::SyntaxConfig,
16 value::Kwargs,
17 Environment, Error as MinijinjaError, ErrorKind, Value,
18};
19use serde::Serialize;
20use serde_json::{self, ser::PrettyFormatter, Value as JsonValue};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
24pub enum ChatTemplateContentFormat {
25 #[default]
27 String,
28 OpenAI,
30}
31
32impl std::fmt::Display for ChatTemplateContentFormat {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 Self::String => write!(f, "string"),
36 Self::OpenAI => write!(f, "openai"),
37 }
38 }
39}
40
41pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat {
50 detect_format_with_ast(template)
52}
53
54#[derive(Default, Debug, Clone, Copy)]
56struct Flags {
57 saw_iteration: bool,
58 saw_structure: bool,
59 saw_assignment: bool,
60 saw_macro: bool,
61}
62
63impl Flags {
64 fn any(self) -> bool {
65 self.saw_iteration || self.saw_structure || self.saw_macro
70 }
71}
72
73struct Detector<'a> {
75 ast: &'a Stmt<'a>,
76 scope: std::collections::VecDeque<String>,
78 scope_set: std::collections::HashSet<String>,
79 flags: Flags,
80}
81
82impl<'a> Detector<'a> {
83 fn new(ast: &'a Stmt<'a>) -> Self {
84 Self {
85 ast,
86 scope: std::collections::VecDeque::new(),
87 scope_set: std::collections::HashSet::new(),
88 flags: Flags::default(),
89 }
90 }
91
92 fn run(mut self) -> Flags {
93 self.walk_stmt(self.ast);
94 self.flags
95 }
96
97 fn push_scope(&mut self, var: String) {
98 self.scope.push_back(var.clone());
99 self.scope_set.insert(var);
100 }
101
102 fn pop_scope(&mut self) {
103 if let Some(v) = self.scope.pop_back() {
104 self.scope_set.remove(&v);
105 }
106 }
107
108 fn is_var_access(expr: &Expr, varname: &str) -> bool {
109 matches!(expr, Expr::Var(v) if v.id == varname)
110 }
111
112 fn is_const_str(expr: &Expr, value: &str) -> bool {
113 matches!(expr, Expr::Const(c) if c.value.as_str() == Some(value))
114 }
115
116 fn is_numeric_const(expr: &Expr) -> bool {
117 matches!(expr, Expr::Const(c) if c.value.is_number())
118 }
119
120 fn is_var_dot_content(expr: &Expr, varname: &str) -> bool {
122 match expr {
123 Expr::GetAttr(g) => Self::is_var_access(&g.expr, varname) && g.name == "content",
124 Expr::GetItem(g) => {
125 Self::is_var_access(&g.expr, varname)
126 && Self::is_const_str(&g.subscript_expr, "content")
127 }
128 Expr::Filter(f) => f
130 .expr
131 .as_ref()
132 .is_some_and(|e| Self::is_var_dot_content(e, varname)),
133 Expr::Test(t) => Self::is_var_dot_content(&t.expr, varname),
134 _ => false,
135 }
136 }
137
138 fn is_any_scope_var_content(&self, expr: &Expr) -> bool {
140 let mut current_expr = expr;
141 loop {
142 if self
144 .scope_set
145 .iter()
146 .any(|v| Self::is_var_dot_content(current_expr, v))
147 {
148 return true;
149 }
150 match current_expr {
152 Expr::GetAttr(g) => current_expr = &g.expr,
153 Expr::GetItem(g) => current_expr = &g.expr,
154 _ => return false,
155 }
156 }
157 }
158
159 fn walk_stmt(&mut self, stmt: &Stmt) {
160 if self.flags.any() {
162 return;
163 }
164
165 match stmt {
166 Stmt::Template(t) => {
167 for ch in &t.children {
168 self.walk_stmt(ch);
169 }
170 }
171 Stmt::ForLoop(fl) => {
173 if let Expr::Var(iter) = &fl.iter {
175 if iter.id == "messages" {
176 if let Expr::Var(target) = &fl.target {
177 self.push_scope(target.id.to_string());
178 }
179 }
180 }
181
182 if self.is_any_scope_var_content(&fl.iter) {
185 self.flags.saw_iteration = true;
186 }
187 if matches!(&fl.iter, Expr::Var(v) if v.id == "content") {
189 self.flags.saw_iteration = true;
190 }
191
192 for b in &fl.body {
193 self.walk_stmt(b);
194 }
195
196 if let Expr::Var(iter) = &fl.iter {
198 if iter.id == "messages" && matches!(&fl.target, Expr::Var(_)) {
199 self.pop_scope();
200 }
201 }
202 }
203 Stmt::IfCond(ic) => {
204 self.inspect_expr_for_structure(&ic.expr);
205 for b in &ic.true_body {
206 self.walk_stmt(b);
207 }
208 for b in &ic.false_body {
209 self.walk_stmt(b);
210 }
211 }
212 Stmt::EmitExpr(e) => {
213 self.inspect_expr_for_structure(&e.expr);
214 }
215 Stmt::Set(s) => {
217 if Self::is_var_access(&s.target, "content")
218 && self.is_any_scope_var_content(&s.expr)
219 {
220 self.flags.saw_assignment = true;
221 }
222 }
223 Stmt::Macro(m) => {
224 let mut has_type_check = false;
226 let mut has_loop = false;
227 Self::scan_macro_body(&m.body, &mut has_type_check, &mut has_loop);
228 if has_type_check && has_loop {
229 self.flags.saw_macro = true;
230 }
231 }
232 _ => {}
233 }
234 }
235
236 fn inspect_expr_for_structure(&mut self, expr: &Expr) {
237 if self.flags.saw_structure {
238 return;
239 }
240
241 match expr {
242 Expr::GetItem(gi) => {
244 if (matches!(&gi.expr, Expr::Var(v) if v.id == "content")
245 || self.is_any_scope_var_content(&gi.expr))
246 && Self::is_numeric_const(&gi.subscript_expr)
247 {
248 self.flags.saw_structure = true;
249 }
250 }
251 Expr::Filter(f) => {
253 if f.name == "length" {
254 if let Some(inner) = &f.expr {
255 let inner_ref: &Expr = inner;
257 let is_content_var = matches!(inner_ref, Expr::Var(v) if v.id == "content");
258 if is_content_var || self.is_any_scope_var_content(inner_ref) {
259 self.flags.saw_structure = true;
260 }
261 }
262 } else if let Some(inner) = &f.expr {
263 let inner_ref: &Expr = inner;
264 self.inspect_expr_for_structure(inner_ref);
265 }
266 }
267 Expr::Test(t) => self.inspect_expr_for_structure(&t.expr),
271 Expr::GetAttr(g) => {
272 self.inspect_expr_for_structure(&g.expr);
274 }
275 Expr::BinOp(op) => {
277 self.inspect_expr_for_structure(&op.left);
278 self.inspect_expr_for_structure(&op.right);
279 }
280 Expr::UnaryOp(op) => {
282 self.inspect_expr_for_structure(&op.expr);
283 }
284 _ => {}
285 }
286 }
287
288 fn scan_macro_body(body: &[Stmt], has_type_check: &mut bool, has_loop: &mut bool) {
289 for s in body {
290 if *has_type_check && *has_loop {
291 return;
292 }
293
294 match s {
295 Stmt::IfCond(ic) => {
296 if matches!(&ic.expr, Expr::Test(_)) {
297 *has_type_check = true;
298 }
299 Self::scan_macro_body(&ic.true_body, has_type_check, has_loop);
300 Self::scan_macro_body(&ic.false_body, has_type_check, has_loop);
301 }
302 Stmt::ForLoop(fl) => {
303 *has_loop = true;
304 Self::scan_macro_body(&fl.body, has_type_check, has_loop);
305 }
306 Stmt::Template(t) => {
307 Self::scan_macro_body(&t.children, has_type_check, has_loop);
308 }
309 _ => {}
310 }
311 }
312 }
313}
314
315fn detect_format_with_ast(template: &str) -> ChatTemplateContentFormat {
318 let ast = match parse(
319 template,
320 "template",
321 SyntaxConfig {},
322 WhitespaceConfig::default(),
323 ) {
324 Ok(ast) => ast,
325 Err(_) => return ChatTemplateContentFormat::String,
326 };
327
328 let flags = Detector::new(&ast).run();
329 if flags.any() {
330 ChatTemplateContentFormat::OpenAI
331 } else {
332 ChatTemplateContentFormat::String
333 }
334}
335
336#[derive(Default)]
338pub struct ChatTemplateParams<'a> {
339 pub add_generation_prompt: bool,
340 pub tools: Option<&'a [serde_json::Value]>,
341 pub documents: Option<&'a [serde_json::Value]>,
342 pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
343 pub special_tokens: Option<&'a crate::traits::SpecialTokens>,
346}
347
348fn tojson_filter(value: Value, kwargs: Kwargs) -> std::result::Result<Value, MinijinjaError> {
360 let _ensure_ascii: Option<bool> = kwargs.get("ensure_ascii")?;
361 let indent: Option<i64> = kwargs.get("indent")?;
362 let _separators: Option<Value> = kwargs.get("separators")?;
363 let sort_keys: Option<bool> = kwargs.get("sort_keys")?;
364
365 kwargs.assert_all_used()?;
367
368 let json_value: serde_json::Value = serde_json::to_value(&value).map_err(|e| {
369 MinijinjaError::new(
370 ErrorKind::InvalidOperation,
371 format!("Failed to convert to JSON value: {e}"),
372 )
373 })?;
374
375 fn serialize_with_indent<T: Serialize>(
377 value: &T,
378 spaces: usize,
379 ) -> std::result::Result<String, MinijinjaError> {
380 let indent_str = vec![b' '; spaces];
381 let formatter = PrettyFormatter::with_indent(&indent_str);
382 let mut buf = Vec::new();
383 let mut serializer = serde_json::Serializer::with_formatter(&mut buf, formatter);
384 value.serialize(&mut serializer).map_err(|e| {
385 MinijinjaError::new(
386 ErrorKind::InvalidOperation,
387 format!("Failed to serialize JSON: {e}"),
388 )
389 })?;
390 String::from_utf8(buf).map_err(|e| {
391 MinijinjaError::new(
392 ErrorKind::InvalidOperation,
393 format!("Invalid UTF-8 in JSON output: {e}"),
394 )
395 })
396 }
397
398 let json_str: std::result::Result<String, MinijinjaError> = {
400 let sorted_json;
401 let value_to_serialize = if sort_keys.unwrap_or(false) {
402 sorted_json = sort_json_keys(&json_value);
403 &sorted_json
404 } else {
405 &json_value
406 };
407
408 if let Some(spaces) = indent {
409 if spaces < 0 {
410 return Err(MinijinjaError::new(
411 ErrorKind::InvalidOperation,
412 "indent cannot be negative",
413 ));
414 }
415 serialize_with_indent(value_to_serialize, spaces as usize)
416 } else {
417 serde_json::to_string(value_to_serialize).map_err(|e| {
418 MinijinjaError::new(
419 ErrorKind::InvalidOperation,
420 format!("Failed to serialize JSON: {e}"),
421 )
422 })
423 }
424 };
425
426 json_str.map(Value::from_safe_string)
427}
428
429fn sort_json_keys(value: &JsonValue) -> JsonValue {
431 match value {
432 JsonValue::Object(map) => {
433 let mut sorted: serde_json::Map<String, JsonValue> = serde_json::Map::new();
434 let mut keys: Vec<_> = map.keys().collect();
435 keys.sort();
436 for key in keys {
437 sorted.insert(key.clone(), sort_json_keys(&map[key]));
438 }
439 JsonValue::Object(sorted)
440 }
441 JsonValue::Array(arr) => JsonValue::Array(arr.iter().map(sort_json_keys).collect()),
442 _ => value.clone(),
443 }
444}
445
446fn build_environment(template: String) -> Result<Environment<'static>> {
451 let mut env = Environment::new();
452
453 env.set_trim_blocks(true);
457 env.set_lstrip_blocks(true);
458
459 env.add_template_owned("chat".to_owned(), template)
461 .map_err(|e| anyhow!("Failed to add template: {e}"))?;
462
463 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
465
466 env.add_filter("tojson", tojson_filter);
470
471 Ok(env)
472}
473
474fn special_token_value(token: Option<&str>) -> Value {
479 token.map_or(Value::UNDEFINED, Value::from)
480}
481
482fn render_chat_template(
483 env: &Environment<'_>,
484 messages: &[serde_json::Value],
485 params: ChatTemplateParams,
486) -> Result<String> {
487 let tmpl = env
488 .get_template("chat")
489 .map_err(|e| anyhow!("Failed to get template: {e}"))?;
490
491 let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
493
494 let tools_value = params.tools.map_or(Value::UNDEFINED, Value::from_serialize);
499 let documents_value = params
500 .documents
501 .map_or(Value::UNDEFINED, Value::from_serialize);
502
503 let bos_value =
507 special_token_value(params.special_tokens.and_then(|st| st.bos_token.as_deref()));
508 let eos_value =
509 special_token_value(params.special_tokens.and_then(|st| st.eos_token.as_deref()));
510 let unk_value =
511 special_token_value(params.special_tokens.and_then(|st| st.unk_token.as_deref()));
512 let pad_value =
513 special_token_value(params.special_tokens.and_then(|st| st.pad_token.as_deref()));
514
515 let base_context = context! {
516 messages => &minijinja_messages,
517 add_generation_prompt => params.add_generation_prompt,
518 tools => tools_value,
519 documents => documents_value,
520 bos_token => bos_value,
521 eos_token => eos_value,
522 unk_token => unk_value,
523 pad_token => pad_value,
524 };
525
526 let ctx = if let Some(kwargs) = params.template_kwargs {
528 context! {
529 ..base_context,
530 ..Value::from_serialize(kwargs)
531 }
532 } else {
533 base_context
534 };
535
536 let rendered = tmpl
538 .render(&ctx)
539 .map_err(|e| anyhow!("Failed to render template: {e}"))?;
540
541 Ok(rendered)
542}
543
544pub struct ChatTemplateProcessor {
546 env: Environment<'static>,
547}
548
549impl ChatTemplateProcessor {
550 pub fn new(template: String) -> Result<Self> {
556 let env = build_environment(template)?;
557 Ok(ChatTemplateProcessor { env })
558 }
559
560 pub fn apply_chat_template(
566 &self,
567 messages: &[serde_json::Value],
568 params: ChatTemplateParams,
569 ) -> Result<String> {
570 render_chat_template(&self.env, messages, params)
571 }
572}
573
574pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
576 let content = fs::read_to_string(config_path)?;
577 let config: serde_json::Value = serde_json::from_str(&content)?;
578
579 if let Some(template) = config.get("chat_template") {
581 if let Some(template_str) = template.as_str() {
582 return Ok(Some(template_str.to_string()));
583 }
584 }
585
586 Ok(None)
587}
588
589pub fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
592 let content = fs::read_to_string(template_path)
593 .map_err(|e| anyhow!("Failed to read chat template file: {e}"))?;
594
595 if template_path.ends_with(".json") {
596 let json_value: serde_json::Value = serde_json::from_str(&content)
597 .map_err(|e| anyhow!("Failed to parse chat_template.json: {e}"))?;
598
599 if let Some(template_str) = json_value.as_str() {
600 return Ok(Some(template_str.to_string()));
601 } else if let Some(obj) = json_value.as_object() {
602 if let Some(template_value) = obj.get("chat_template") {
603 if let Some(template_str) = template_value.as_str() {
604 return Ok(Some(template_str.to_string()));
605 }
606 }
607 }
608
609 return Err(anyhow!(
610 "chat_template.json does not contain a valid template",
611 ));
612 }
613
614 let template = content.trim().replace("\\n", "\n");
616 Ok(Some(template))
617}
618
619pub struct ChatTemplateState {
630 env: Option<Environment<'static>>,
632 content_format: ChatTemplateContentFormat,
633}
634
635impl std::fmt::Debug for ChatTemplateState {
636 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
637 f.debug_struct("ChatTemplateState")
638 .field("has_template", &self.env.is_some())
639 .field("content_format", &self.content_format)
640 .finish()
641 }
642}
643
644impl ChatTemplateState {
645 pub fn new(template: Option<String>) -> Result<Self> {
646 let content_format = template
647 .as_ref()
648 .map(|t| detect_chat_template_content_format(t))
649 .unwrap_or_default();
650 let env = template.map(build_environment).transpose()?;
651 Ok(Self {
652 env,
653 content_format,
654 })
655 }
656
657 pub fn empty() -> Self {
662 Self {
663 env: None,
664 content_format: ChatTemplateContentFormat::default(),
665 }
666 }
667
668 pub fn apply(
669 &self,
670 messages: &[serde_json::Value],
671 params: ChatTemplateParams,
672 ) -> Result<String> {
673 let env = self.env.as_ref().ok_or_else(|| {
674 anyhow!(
675 "Cannot use chat template functions because tokenizer.chat_template is not set \
676 and no template argument was passed! For information about writing templates and \
677 setting the tokenizer.chat_template attribute, please see the documentation at \
678 https://huggingface.co/docs/transformers/main/en/chat_templating",
679 )
680 })?;
681 render_chat_template(env, messages, params)
682 }
683
684 pub fn set(&mut self, template: String) -> Result<()> {
685 let content_format = detect_chat_template_content_format(&template);
686 let env = build_environment(template)?;
687 self.content_format = content_format;
688 self.env = Some(env);
689 Ok(())
690 }
691
692 pub fn content_format(&self) -> ChatTemplateContentFormat {
693 self.content_format
694 }
695}
696
697#[cfg(test)]
698mod tests {
699 use super::*;
700
701 #[test]
702 fn test_chat_template_state_no_template() {
703 let state = ChatTemplateState::new(None).unwrap();
704 assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
705 let result = state.apply(&[], ChatTemplateParams::default());
706 assert!(result.is_err());
707 }
708
709 #[test]
710 fn test_chat_template_state_set() {
711 let mut state = ChatTemplateState::new(None).unwrap();
712 state.set("{{ messages }}".to_string()).unwrap();
713 assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
714 }
715
716 #[test]
717 fn test_chat_template_state_invalid_template() {
718 let result = ChatTemplateState::new(Some("{% invalid".to_string()));
719 assert!(result.is_err());
720 let err = result.unwrap_err().to_string();
721 assert!(
722 err.contains("Failed to add template"),
723 "Error should explain parse failure, got: {err}"
724 );
725 }
726
727 #[test]
728 fn test_chat_template_processor_invalid_template() {
729 let result = ChatTemplateProcessor::new("{% invalid".to_string());
730 assert!(result.is_err());
731 }
732
733 #[test]
734 fn test_special_tokens_injected_into_context() {
735 let template = "{{ bos_token }}{% for message in messages %}{{ message.content }}{% endfor %}{{ eos_token }}";
736 let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
737
738 let messages = vec![serde_json::json!({"role": "user", "content": "hello"})];
739 let special_tokens = crate::traits::SpecialTokens {
740 bos_token: Some("<s>".to_string()),
741 eos_token: Some("</s>".to_string()),
742 ..Default::default()
743 };
744
745 let result = state
746 .apply(
747 &messages,
748 ChatTemplateParams {
749 special_tokens: Some(&special_tokens),
750 ..Default::default()
751 },
752 )
753 .unwrap();
754
755 assert_eq!(result, "<s>hello</s>");
756 }
757
758 #[test]
759 fn test_special_tokens_undefined_when_not_provided() {
760 let template = "{% if bos_token is defined %}{{ bos_token }}{% endif %}hello";
761 let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
762
763 let result = state.apply(&[], ChatTemplateParams::default()).unwrap();
764 assert_eq!(result, "hello");
765 }
766
767 #[test]
768 fn test_special_tokens_partial() {
769 let template =
770 "{{ bos_token }}hello{% if eos_token is defined %}{{ eos_token }}{% endif %}";
771 let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
772
773 let special_tokens = crate::traits::SpecialTokens {
774 bos_token: Some("<s>".to_string()),
775 eos_token: None,
776 ..Default::default()
777 };
778
779 let result = state
780 .apply(
781 &[],
782 ChatTemplateParams {
783 special_tokens: Some(&special_tokens),
784 ..Default::default()
785 },
786 )
787 .unwrap();
788
789 assert_eq!(result, "<s>hello");
790 }
791}