1use std::{collections::HashMap, fs, io};
7
8use anyhow::{anyhow, Result};
9use minijinja::{
10 context,
11 machinery::{
12 ast::{Expr, Stmt},
13 parse, WhitespaceConfig,
14 },
15 syntax::SyntaxConfig,
16 value::Kwargs,
17 Environment, Error as MinijinjaError, ErrorKind, Value,
18};
19use serde::Serialize;
20use serde_json::{self, ser::Formatter, Value as JsonValue};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
24pub enum ChatTemplateContentFormat {
25 #[default]
27 String,
28 OpenAI,
30}
31
32impl std::fmt::Display for ChatTemplateContentFormat {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 Self::String => write!(f, "string"),
36 Self::OpenAI => write!(f, "openai"),
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum ThinkingKeyName {
45 EnableThinking,
47 Thinking,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum ThinkingToggle {
53 #[default]
57 None,
58 DefaultOn,
62 DefaultOff,
66}
67
68pub fn detect_thinking_toggle(template: &str) -> (ThinkingToggle, Option<ThinkingKeyName>) {
71 let has_enable_thinking = template.contains("enable_thinking");
72 let has_thinking_var = template.contains("if thinking ")
74 || template.contains("thinking is ")
75 || template.contains("thinking ==")
76 || template.contains("set thinking ");
77
78 if !has_enable_thinking && !has_thinking_var {
79 return (ThinkingToggle::None, None);
80 }
81
82 let key_name = if has_enable_thinking {
84 ThinkingKeyName::EnableThinking
85 } else {
86 ThinkingKeyName::Thinking
87 };
88
89 if template.contains("set thinking = false") || template.contains("set thinking=false") {
92 return (ThinkingToggle::DefaultOff, Some(key_name));
93 }
94 if template.contains("set enable_thinking = false")
95 || template.contains("set enable_thinking=false")
96 {
97 return (ThinkingToggle::DefaultOff, Some(key_name));
98 }
99
100 (ThinkingToggle::DefaultOn, Some(key_name))
102}
103
104pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat {
113 detect_format_with_ast(template)
115}
116
117#[derive(Default, Debug, Clone, Copy)]
119struct Flags {
120 saw_iteration: bool,
121 saw_structure: bool,
122 saw_assignment: bool,
123 saw_macro: bool,
124}
125
126impl Flags {
127 fn any(self) -> bool {
128 self.saw_iteration || self.saw_structure || self.saw_macro
133 }
134}
135
136struct Detector<'a> {
138 ast: &'a Stmt<'a>,
139 scope: std::collections::VecDeque<String>,
141 scope_set: std::collections::HashSet<String>,
142 flags: Flags,
143 think_in_prefill: bool,
145}
146
147impl<'a> Detector<'a> {
148 fn new(ast: &'a Stmt<'a>) -> Self {
149 Self {
150 ast,
151 scope: std::collections::VecDeque::new(),
152 scope_set: std::collections::HashSet::new(),
153 flags: Flags::default(),
154 think_in_prefill: false,
155 }
156 }
157
158 fn run(mut self) -> (Flags, bool) {
159 self.walk_stmt(self.ast);
160 (self.flags, self.think_in_prefill)
161 }
162
163 fn push_scope(&mut self, var: String) {
164 self.scope.push_back(var.clone());
165 self.scope_set.insert(var);
166 }
167
168 fn pop_scope(&mut self) {
169 if let Some(v) = self.scope.pop_back() {
170 self.scope_set.remove(&v);
171 }
172 }
173
174 fn is_var_access(expr: &Expr, varname: &str) -> bool {
175 matches!(expr, Expr::Var(v) if v.id == varname)
176 }
177
178 fn is_const_str(expr: &Expr, value: &str) -> bool {
179 matches!(expr, Expr::Const(c) if c.value.as_str() == Some(value))
180 }
181
182 fn is_numeric_const(expr: &Expr) -> bool {
183 matches!(expr, Expr::Const(c) if c.value.is_number())
184 }
185
186 fn is_var_dot_content(expr: &Expr, varname: &str) -> bool {
188 match expr {
189 Expr::GetAttr(g) => Self::is_var_access(&g.expr, varname) && g.name == "content",
190 Expr::GetItem(g) => {
191 Self::is_var_access(&g.expr, varname)
192 && Self::is_const_str(&g.subscript_expr, "content")
193 }
194 Expr::Filter(f) => f
196 .expr
197 .as_ref()
198 .is_some_and(|e| Self::is_var_dot_content(e, varname)),
199 Expr::Test(t) => Self::is_var_dot_content(&t.expr, varname),
200 _ => false,
201 }
202 }
203
204 fn is_any_scope_var_content(&self, expr: &Expr) -> bool {
206 let mut current_expr = expr;
207 loop {
208 if self
210 .scope_set
211 .iter()
212 .any(|v| Self::is_var_dot_content(current_expr, v))
213 {
214 return true;
215 }
216 match current_expr {
218 Expr::GetAttr(g) => current_expr = &g.expr,
219 Expr::GetItem(g) => current_expr = &g.expr,
220 _ => return false,
221 }
222 }
223 }
224
225 fn expr_references_var(expr: &Expr, name: &str) -> bool {
227 match expr {
228 Expr::Var(v) => v.id == name,
229 Expr::BinOp(b) => {
230 Self::expr_references_var(&b.left, name)
231 || Self::expr_references_var(&b.right, name)
232 }
233 Expr::UnaryOp(u) => Self::expr_references_var(&u.expr, name),
234 _ => false,
235 }
236 }
237
238 fn body_has_think_tag(stmts: &[Stmt]) -> bool {
240 for stmt in stmts {
241 match stmt {
242 Stmt::EmitRaw(raw) if raw.raw.contains("<think>") => return true,
243 Stmt::EmitExpr(e) => {
244 if let Expr::Const(c) = &e.expr {
245 if c.value.as_str().is_some_and(|s| s.contains("<think>")) {
246 return true;
247 }
248 }
249 }
250 Stmt::IfCond(ic)
251 if Self::body_has_think_tag(&ic.true_body)
252 || Self::body_has_think_tag(&ic.false_body) =>
253 {
254 return true;
255 }
256 _ => {}
257 }
258 }
259 false
260 }
261
262 fn walk_stmt(&mut self, stmt: &Stmt) {
263 match stmt {
264 Stmt::Template(t) => {
265 for ch in &t.children {
266 self.walk_stmt(ch);
267 }
268 }
269 Stmt::ForLoop(fl) => {
271 if let Expr::Var(iter) = &fl.iter {
273 if iter.id == "messages" {
274 if let Expr::Var(target) = &fl.target {
275 self.push_scope(target.id.to_string());
276 }
277 }
278 }
279
280 if self.is_any_scope_var_content(&fl.iter) {
283 self.flags.saw_iteration = true;
284 }
285 if matches!(&fl.iter, Expr::Var(v) if v.id == "content") {
287 self.flags.saw_iteration = true;
288 }
289
290 for b in &fl.body {
291 self.walk_stmt(b);
292 }
293
294 if let Expr::Var(iter) = &fl.iter {
296 if iter.id == "messages" && matches!(&fl.target, Expr::Var(_)) {
297 self.pop_scope();
298 }
299 }
300 }
301 Stmt::IfCond(ic) => {
302 self.inspect_expr_for_structure(&ic.expr);
303
304 if !self.think_in_prefill
306 && Self::expr_references_var(&ic.expr, "add_generation_prompt")
307 {
308 self.think_in_prefill = Self::body_has_think_tag(&ic.true_body);
309 }
310
311 for b in &ic.true_body {
312 self.walk_stmt(b);
313 }
314 for b in &ic.false_body {
315 self.walk_stmt(b);
316 }
317 }
318 Stmt::EmitExpr(e) => {
319 self.inspect_expr_for_structure(&e.expr);
320 }
321 Stmt::Set(s)
323 if Self::is_var_access(&s.target, "content")
324 && self.is_any_scope_var_content(&s.expr) =>
325 {
326 self.flags.saw_assignment = true;
327 }
328 Stmt::Macro(m) => {
329 let mut has_type_check = false;
331 let mut has_loop = false;
332 Self::scan_macro_body(&m.body, &mut has_type_check, &mut has_loop);
333 if has_type_check && has_loop {
334 self.flags.saw_macro = true;
335 }
336 }
337 _ => {}
338 }
339 }
340
341 fn inspect_expr_for_structure(&mut self, expr: &Expr) {
342 if self.flags.saw_structure {
343 return;
344 }
345
346 match expr {
347 Expr::GetItem(gi)
349 if (matches!(&gi.expr, Expr::Var(v) if v.id == "content")
350 || self.is_any_scope_var_content(&gi.expr))
351 && Self::is_numeric_const(&gi.subscript_expr) =>
352 {
353 self.flags.saw_structure = true;
354 }
355 Expr::Filter(f) => {
357 if f.name == "length" {
358 if let Some(inner) = &f.expr {
359 let inner_ref: &Expr = inner;
361 let is_content_var = matches!(inner_ref, Expr::Var(v) if v.id == "content");
362 if is_content_var || self.is_any_scope_var_content(inner_ref) {
363 self.flags.saw_structure = true;
364 }
365 }
366 } else if let Some(inner) = &f.expr {
367 let inner_ref: &Expr = inner;
368 self.inspect_expr_for_structure(inner_ref);
369 }
370 }
371 Expr::Test(t) => self.inspect_expr_for_structure(&t.expr),
375 Expr::GetAttr(g) => {
376 self.inspect_expr_for_structure(&g.expr);
378 }
379 Expr::BinOp(op) => {
381 self.inspect_expr_for_structure(&op.left);
382 self.inspect_expr_for_structure(&op.right);
383 }
384 Expr::UnaryOp(op) => {
386 self.inspect_expr_for_structure(&op.expr);
387 }
388 _ => {}
389 }
390 }
391
392 fn scan_macro_body(body: &[Stmt], has_type_check: &mut bool, has_loop: &mut bool) {
393 for s in body {
394 if *has_type_check && *has_loop {
395 return;
396 }
397
398 match s {
399 Stmt::IfCond(ic) => {
400 if matches!(&ic.expr, Expr::Test(_)) {
401 *has_type_check = true;
402 }
403 Self::scan_macro_body(&ic.true_body, has_type_check, has_loop);
404 Self::scan_macro_body(&ic.false_body, has_type_check, has_loop);
405 }
406 Stmt::ForLoop(fl) => {
407 *has_loop = true;
408 Self::scan_macro_body(&fl.body, has_type_check, has_loop);
409 }
410 Stmt::Template(t) => {
411 Self::scan_macro_body(&t.children, has_type_check, has_loop);
412 }
413 _ => {}
414 }
415 }
416 }
417}
418
419fn detect_format_with_ast(template: &str) -> ChatTemplateContentFormat {
422 detect_all_with_ast(template).0
423}
424
425fn detect_all(
427 template: &str,
428) -> (
429 ChatTemplateContentFormat,
430 bool,
431 ThinkingToggle,
432 Option<ThinkingKeyName>,
433) {
434 let (thinking_toggle, thinking_key_name) = detect_thinking_toggle(template);
435 let (content_format, think_in_prefill) = detect_all_with_ast(template);
436 (
437 content_format,
438 think_in_prefill,
439 thinking_toggle,
440 thinking_key_name,
441 )
442}
443
444fn detect_all_with_ast(template: &str) -> (ChatTemplateContentFormat, bool) {
446 let ast = match parse(
447 template,
448 "template",
449 SyntaxConfig {},
450 WhitespaceConfig::default(),
451 ) {
452 Ok(ast) => ast,
453 Err(_) => return (ChatTemplateContentFormat::String, false),
454 };
455
456 let (flags, think_in_prefill) = Detector::new(&ast).run();
457 let content_format = if flags.any() {
458 ChatTemplateContentFormat::OpenAI
459 } else {
460 ChatTemplateContentFormat::String
461 };
462 (content_format, think_in_prefill)
463}
464
465#[derive(Default)]
467pub struct ChatTemplateParams<'a> {
468 pub add_generation_prompt: bool,
469 pub tools: Option<&'a [serde_json::Value]>,
470 pub documents: Option<&'a [serde_json::Value]>,
471 pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
472 pub special_tokens: Option<&'a crate::traits::SpecialTokens>,
475}
476
477#[derive(Debug, Clone)]
479struct JsonSeparators {
480 item: Vec<u8>,
481 key: Vec<u8>,
482}
483
484impl JsonSeparators {
485 fn python_default(indent: Option<i64>) -> Self {
486 let item = if indent.is_some() { "," } else { ", " };
489 Self {
490 item: item.as_bytes().to_vec(),
491 key: b": ".to_vec(),
492 }
493 }
494}
495
496#[derive(Debug, Clone)]
498struct PythonJsonFormatter {
499 current_indent: usize,
500 has_value: bool,
501 indent: Option<Vec<u8>>,
502 separators: JsonSeparators,
503 ensure_ascii: bool,
504}
505
506impl PythonJsonFormatter {
507 fn new(indent: Option<usize>, separators: JsonSeparators, ensure_ascii: bool) -> Self {
508 Self {
509 current_indent: 0,
510 has_value: false,
511 indent: indent.map(|spaces| vec![b' '; spaces]),
512 separators,
513 ensure_ascii,
514 }
515 }
516}
517
518fn write_indent<W>(writer: &mut W, count: usize, indent: &[u8]) -> io::Result<()>
519where
520 W: ?Sized + io::Write,
521{
522 for _ in 0..count {
523 writer.write_all(indent)?;
524 }
525 Ok(())
526}
527
528fn write_u_escape<W>(writer: &mut W, code: u16) -> io::Result<()>
529where
530 W: ?Sized + io::Write,
531{
532 const HEX: &[u8; 16] = b"0123456789abcdef";
533 writer.write_all(&[
534 b'\\',
535 b'u',
536 HEX[((code >> 12) & 0xF) as usize],
537 HEX[((code >> 8) & 0xF) as usize],
538 HEX[((code >> 4) & 0xF) as usize],
539 HEX[(code & 0xF) as usize],
540 ])
541}
542
543impl Formatter for PythonJsonFormatter {
544 fn write_string_fragment<W>(&mut self, writer: &mut W, fragment: &str) -> io::Result<()>
545 where
546 W: ?Sized + io::Write,
547 {
548 if !self.ensure_ascii {
549 return writer.write_all(fragment.as_bytes());
550 }
551
552 for ch in fragment.chars() {
553 if ch.is_ascii() {
554 let mut buf = [0; 4];
555 writer.write_all(ch.encode_utf8(&mut buf).as_bytes())?;
556 continue;
557 }
558
559 let code = ch as u32;
560 if code <= 0xFFFF {
561 write_u_escape(writer, code as u16)?;
562 } else {
563 let shifted = code - 0x1_0000;
564 let high = 0xD800 + ((shifted >> 10) as u16);
565 let low = 0xDC00 + ((shifted & 0x3FF) as u16);
566 write_u_escape(writer, high)?;
567 write_u_escape(writer, low)?;
568 }
569 }
570 Ok(())
571 }
572
573 fn begin_array<W>(&mut self, writer: &mut W) -> io::Result<()>
574 where
575 W: ?Sized + io::Write,
576 {
577 if self.indent.is_some() {
578 self.current_indent += 1;
579 self.has_value = false;
580 }
581 writer.write_all(b"[")
582 }
583
584 fn end_array<W>(&mut self, writer: &mut W) -> io::Result<()>
585 where
586 W: ?Sized + io::Write,
587 {
588 if let Some(indent) = self.indent.as_deref() {
589 self.current_indent -= 1;
590 if self.has_value {
591 writer.write_all(b"\n")?;
592 write_indent(writer, self.current_indent, indent)?;
593 }
594 }
595 writer.write_all(b"]")
596 }
597
598 fn begin_array_value<W>(&mut self, writer: &mut W, first: bool) -> io::Result<()>
599 where
600 W: ?Sized + io::Write,
601 {
602 if let Some(indent) = self.indent.as_deref() {
603 if first {
604 writer.write_all(b"\n")?;
605 } else {
606 writer.write_all(&self.separators.item)?;
607 writer.write_all(b"\n")?;
608 }
609 write_indent(writer, self.current_indent, indent)
610 } else if first {
611 Ok(())
612 } else {
613 writer.write_all(&self.separators.item)
614 }
615 }
616
617 fn end_array_value<W>(&mut self, _writer: &mut W) -> io::Result<()>
618 where
619 W: ?Sized + io::Write,
620 {
621 self.has_value = true;
622 Ok(())
623 }
624
625 fn begin_object<W>(&mut self, writer: &mut W) -> io::Result<()>
626 where
627 W: ?Sized + io::Write,
628 {
629 if self.indent.is_some() {
630 self.current_indent += 1;
631 self.has_value = false;
632 }
633 writer.write_all(b"{")
634 }
635
636 fn end_object<W>(&mut self, writer: &mut W) -> io::Result<()>
637 where
638 W: ?Sized + io::Write,
639 {
640 if let Some(indent) = self.indent.as_deref() {
641 self.current_indent -= 1;
642 if self.has_value {
643 writer.write_all(b"\n")?;
644 write_indent(writer, self.current_indent, indent)?;
645 }
646 }
647 writer.write_all(b"}")
648 }
649
650 fn begin_object_key<W>(&mut self, writer: &mut W, first: bool) -> io::Result<()>
651 where
652 W: ?Sized + io::Write,
653 {
654 if let Some(indent) = self.indent.as_deref() {
655 if first {
656 writer.write_all(b"\n")?;
657 } else {
658 writer.write_all(&self.separators.item)?;
659 writer.write_all(b"\n")?;
660 }
661 write_indent(writer, self.current_indent, indent)
662 } else if first {
663 Ok(())
664 } else {
665 writer.write_all(&self.separators.item)
666 }
667 }
668
669 fn begin_object_value<W>(&mut self, writer: &mut W) -> io::Result<()>
670 where
671 W: ?Sized + io::Write,
672 {
673 writer.write_all(&self.separators.key)
674 }
675
676 fn end_object_value<W>(&mut self, _writer: &mut W) -> io::Result<()>
677 where
678 W: ?Sized + io::Write,
679 {
680 self.has_value = true;
681 Ok(())
682 }
683}
684
685fn invalid_tojson_option(message: impl Into<String>) -> MinijinjaError {
686 MinijinjaError::new(ErrorKind::InvalidOperation, message.into())
687}
688
689fn parse_separators(
690 separators: Option<Value>,
691 indent: Option<i64>,
692) -> std::result::Result<JsonSeparators, MinijinjaError> {
693 let Some(separators) = separators else {
694 return Ok(JsonSeparators::python_default(indent));
695 };
696 if separators.is_none() || separators.is_undefined() {
697 return Ok(JsonSeparators::python_default(indent));
698 }
699
700 let parsed: serde_json::Value = serde_json::to_value(&separators).map_err(|e| {
701 invalid_tojson_option(format!("Failed to convert separators to JSON value: {e}"))
702 })?;
703 let JsonValue::Array(values) = parsed else {
704 return Err(invalid_tojson_option(
705 "separators must be a two-item sequence",
706 ));
707 };
708 if values.len() != 2 {
709 return Err(invalid_tojson_option(
710 "separators must be a two-item sequence",
711 ));
712 }
713
714 let item = values[0]
715 .as_str()
716 .ok_or_else(|| invalid_tojson_option("item separator must be a string"))?;
717 let key = values[1]
718 .as_str()
719 .ok_or_else(|| invalid_tojson_option("key separator must be a string"))?;
720
721 Ok(JsonSeparators {
722 item: item.as_bytes().to_vec(),
723 key: key.as_bytes().to_vec(),
724 })
725}
726
727fn serialize_with_python_json<T: Serialize>(
728 value: &T,
729 indent: Option<i64>,
730 separators: JsonSeparators,
731 ensure_ascii: bool,
732) -> std::result::Result<String, MinijinjaError> {
733 let indent = indent
734 .map(|spaces| {
735 if spaces < 0 {
736 Err(invalid_tojson_option("indent cannot be negative"))
737 } else {
738 Ok(spaces as usize)
739 }
740 })
741 .transpose()?;
742
743 let formatter = PythonJsonFormatter::new(indent, separators, ensure_ascii);
744 let mut buf = Vec::new();
745 let mut serializer = serde_json::Serializer::with_formatter(&mut buf, formatter);
746 value.serialize(&mut serializer).map_err(|e| {
747 MinijinjaError::new(
748 ErrorKind::InvalidOperation,
749 format!("Failed to serialize JSON: {e}"),
750 )
751 })?;
752 String::from_utf8(buf).map_err(|e| {
753 MinijinjaError::new(
754 ErrorKind::InvalidOperation,
755 format!("Invalid UTF-8 in JSON output: {e}"),
756 )
757 })
758}
759
760fn tojson_filter(value: Value, kwargs: Kwargs) -> std::result::Result<Value, MinijinjaError> {
772 let ensure_ascii: Option<bool> = kwargs.get("ensure_ascii")?;
773 let indent: Option<i64> = kwargs.get("indent")?;
774 let separators: Option<Value> = kwargs.get("separators")?;
775 let sort_keys: Option<bool> = kwargs.get("sort_keys")?;
776
777 kwargs.assert_all_used()?;
779
780 let json_value: serde_json::Value = serde_json::to_value(&value).map_err(|e| {
781 MinijinjaError::new(
782 ErrorKind::InvalidOperation,
783 format!("Failed to convert to JSON value: {e}"),
784 )
785 })?;
786
787 let json_str: std::result::Result<String, MinijinjaError> = {
789 let sorted_json;
790 let value_to_serialize = if sort_keys.unwrap_or(false) {
791 sorted_json = sort_json_keys(&json_value);
792 &sorted_json
793 } else {
794 &json_value
795 };
796
797 let separators = parse_separators(separators, indent)?;
798 serialize_with_python_json(
799 value_to_serialize,
800 indent,
801 separators,
802 ensure_ascii.unwrap_or(false),
803 )
804 };
805
806 json_str.map(Value::from_safe_string)
807}
808
809fn sort_json_keys(value: &JsonValue) -> JsonValue {
811 match value {
812 JsonValue::Object(map) => {
813 let mut sorted: serde_json::Map<String, JsonValue> = serde_json::Map::new();
814 let mut keys: Vec<_> = map.keys().collect();
815 keys.sort();
816 for key in keys {
817 sorted.insert(key.clone(), sort_json_keys(&map[key]));
818 }
819 JsonValue::Object(sorted)
820 }
821 JsonValue::Array(arr) => JsonValue::Array(arr.iter().map(sort_json_keys).collect()),
822 _ => value.clone(),
823 }
824}
825
826fn build_environment(template: String) -> Result<Environment<'static>> {
831 let mut env = Environment::new();
832
833 env.set_trim_blocks(true);
837 env.set_lstrip_blocks(true);
838
839 env.add_template_owned("chat".to_owned(), template)
841 .map_err(|e| anyhow!("Failed to add template: {e}"))?;
842
843 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
845
846 env.add_filter("tojson", tojson_filter);
850
851 Ok(env)
852}
853
854fn special_token_value(token: Option<&str>) -> Value {
859 token.map_or(Value::UNDEFINED, Value::from)
860}
861
862fn render_chat_template(
863 env: &Environment<'_>,
864 messages: &[serde_json::Value],
865 params: ChatTemplateParams,
866) -> Result<String> {
867 let tmpl = env
868 .get_template("chat")
869 .map_err(|e| anyhow!("Failed to get template: {e}"))?;
870
871 let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
873
874 let tools_value = params.tools.map_or(Value::UNDEFINED, Value::from_serialize);
879 let documents_value = params
880 .documents
881 .map_or(Value::UNDEFINED, Value::from_serialize);
882
883 let bos_value =
887 special_token_value(params.special_tokens.and_then(|st| st.bos_token.as_deref()));
888 let eos_value =
889 special_token_value(params.special_tokens.and_then(|st| st.eos_token.as_deref()));
890 let unk_value =
891 special_token_value(params.special_tokens.and_then(|st| st.unk_token.as_deref()));
892 let pad_value =
893 special_token_value(params.special_tokens.and_then(|st| st.pad_token.as_deref()));
894
895 let base_context = context! {
896 messages => &minijinja_messages,
897 add_generation_prompt => params.add_generation_prompt,
898 tools => tools_value,
899 documents => documents_value,
900 bos_token => bos_value,
901 eos_token => eos_value,
902 unk_token => unk_value,
903 pad_token => pad_value,
904 };
905
906 let ctx = if let Some(kwargs) = params.template_kwargs {
908 context! {
909 ..base_context,
910 ..Value::from_serialize(kwargs)
911 }
912 } else {
913 base_context
914 };
915
916 let rendered = tmpl
918 .render(&ctx)
919 .map_err(|e| anyhow!("Failed to render template: {e}"))?;
920
921 Ok(rendered)
922}
923
924pub struct ChatTemplateProcessor {
926 env: Environment<'static>,
927}
928
929impl ChatTemplateProcessor {
930 pub fn new(template: String) -> Result<Self> {
936 let env = build_environment(template)?;
937 Ok(ChatTemplateProcessor { env })
938 }
939
940 pub fn apply_chat_template(
946 &self,
947 messages: &[serde_json::Value],
948 params: ChatTemplateParams,
949 ) -> Result<String> {
950 render_chat_template(&self.env, messages, params)
951 }
952}
953
954pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
956 let content = fs::read_to_string(config_path)?;
957 let config: serde_json::Value = serde_json::from_str(&content)?;
958
959 if let Some(template) = config.get("chat_template") {
961 if let Some(template_str) = template.as_str() {
962 return Ok(Some(template_str.to_string()));
963 }
964 }
965
966 Ok(None)
967}
968
969pub fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
972 let content = fs::read_to_string(template_path)
973 .map_err(|e| anyhow!("Failed to read chat template file: {e}"))?;
974
975 if template_path.ends_with(".json") {
976 let json_value: serde_json::Value = serde_json::from_str(&content)
977 .map_err(|e| anyhow!("Failed to parse chat_template.json: {e}"))?;
978
979 if let Some(template_str) = json_value.as_str() {
980 return Ok(Some(template_str.to_string()));
981 } else if let Some(obj) = json_value.as_object() {
982 if let Some(template_value) = obj.get("chat_template") {
983 if let Some(template_str) = template_value.as_str() {
984 return Ok(Some(template_str.to_string()));
985 }
986 }
987 }
988
989 return Err(anyhow!(
990 "chat_template.json does not contain a valid template",
991 ));
992 }
993
994 let template = content.trim().replace("\\n", "\n");
996 Ok(Some(template))
997}
998
999pub struct ChatTemplateState {
1010 env: Option<Environment<'static>>,
1012 content_format: ChatTemplateContentFormat,
1013 thinking_toggle: ThinkingToggle,
1015 thinking_key_name: Option<ThinkingKeyName>,
1017 think_in_prefill: bool,
1019}
1020
1021impl std::fmt::Debug for ChatTemplateState {
1022 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1023 f.debug_struct("ChatTemplateState")
1024 .field("has_template", &self.env.is_some())
1025 .field("content_format", &self.content_format)
1026 .field("thinking_toggle", &self.thinking_toggle)
1027 .field("think_in_prefill", &self.think_in_prefill)
1028 .finish()
1029 }
1030}
1031
1032impl ChatTemplateState {
1033 pub fn new(template: Option<String>) -> Result<Self> {
1034 let (content_format, think_in_prefill, thinking_toggle, thinking_key_name) =
1035 template.as_ref().map(|t| detect_all(t)).unwrap_or_default();
1036 let env = template.map(build_environment).transpose()?;
1037 Ok(Self {
1038 env,
1039 content_format,
1040 thinking_toggle,
1041 thinking_key_name,
1042 think_in_prefill,
1043 })
1044 }
1045
1046 pub fn empty() -> Self {
1051 Self {
1052 env: None,
1053 content_format: ChatTemplateContentFormat::default(),
1054 thinking_toggle: ThinkingToggle::None,
1055 thinking_key_name: None,
1056 think_in_prefill: false,
1057 }
1058 }
1059
1060 pub fn apply(
1061 &self,
1062 messages: &[serde_json::Value],
1063 params: ChatTemplateParams,
1064 ) -> Result<String> {
1065 let env = self.env.as_ref().ok_or_else(|| {
1066 anyhow!(
1067 "Cannot use chat template functions because tokenizer.chat_template is not set \
1068 and no template argument was passed! For information about writing templates and \
1069 setting the tokenizer.chat_template attribute, please see the documentation at \
1070 https://huggingface.co/docs/transformers/main/en/chat_templating",
1071 )
1072 })?;
1073 render_chat_template(env, messages, params)
1074 }
1075
1076 pub fn set(&mut self, template: String) -> Result<()> {
1077 let (content_format, think_in_prefill, thinking_toggle, thinking_key_name) =
1078 detect_all(&template);
1079 let env = build_environment(template)?;
1080 self.content_format = content_format;
1081 self.thinking_toggle = thinking_toggle;
1082 self.thinking_key_name = thinking_key_name;
1083 self.think_in_prefill = think_in_prefill;
1084 self.env = Some(env);
1085 Ok(())
1086 }
1087
1088 pub fn content_format(&self) -> ChatTemplateContentFormat {
1089 self.content_format
1090 }
1091
1092 pub fn thinking_toggle(&self) -> ThinkingToggle {
1093 self.thinking_toggle
1094 }
1095
1096 pub fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
1097 self.thinking_key_name
1098 }
1099
1100 pub fn think_in_prefill(&self) -> bool {
1101 self.think_in_prefill
1102 }
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107 use super::*;
1108
1109 #[test]
1110 fn test_chat_template_state_no_template() {
1111 let state = ChatTemplateState::new(None).unwrap();
1112 assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
1113 let result = state.apply(&[], ChatTemplateParams::default());
1114 assert!(result.is_err());
1115 }
1116
1117 #[test]
1118 fn test_chat_template_state_set() {
1119 let mut state = ChatTemplateState::new(None).unwrap();
1120 state.set("{{ messages }}".to_string()).unwrap();
1121 assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
1122 }
1123
1124 #[test]
1125 fn test_chat_template_state_invalid_template() {
1126 let result = ChatTemplateState::new(Some("{% invalid".to_string()));
1127 assert!(result.is_err());
1128 let err = result.unwrap_err().to_string();
1129 assert!(
1130 err.contains("Failed to add template"),
1131 "Error should explain parse failure, got: {err}"
1132 );
1133 }
1134
1135 #[test]
1136 fn test_chat_template_processor_invalid_template() {
1137 let result = ChatTemplateProcessor::new("{% invalid".to_string());
1138 assert!(result.is_err());
1139 }
1140
1141 #[test]
1142 fn test_special_tokens_injected_into_context() {
1143 let template = "{{ bos_token }}{% for message in messages %}{{ message.content }}{% endfor %}{{ eos_token }}";
1144 let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
1145
1146 let messages = vec![serde_json::json!({"role": "user", "content": "hello"})];
1147 let special_tokens = crate::traits::SpecialTokens {
1148 bos_token: Some("<s>".to_string()),
1149 eos_token: Some("</s>".to_string()),
1150 ..Default::default()
1151 };
1152
1153 let result = state
1154 .apply(
1155 &messages,
1156 ChatTemplateParams {
1157 special_tokens: Some(&special_tokens),
1158 ..Default::default()
1159 },
1160 )
1161 .unwrap();
1162
1163 assert_eq!(result, "<s>hello</s>");
1164 }
1165
1166 #[test]
1167 fn test_special_tokens_undefined_when_not_provided() {
1168 let template = "{% if bos_token is defined %}{{ bos_token }}{% endif %}hello";
1169 let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
1170
1171 let result = state.apply(&[], ChatTemplateParams::default()).unwrap();
1172 assert_eq!(result, "hello");
1173 }
1174
1175 #[test]
1176 fn test_special_tokens_partial() {
1177 let template =
1178 "{{ bos_token }}hello{% if eos_token is defined %}{{ eos_token }}{% endif %}";
1179 let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
1180
1181 let special_tokens = crate::traits::SpecialTokens {
1182 bos_token: Some("<s>".to_string()),
1183 eos_token: None,
1184 ..Default::default()
1185 };
1186
1187 let result = state
1188 .apply(
1189 &[],
1190 ChatTemplateParams {
1191 special_tokens: Some(&special_tokens),
1192 ..Default::default()
1193 },
1194 )
1195 .unwrap();
1196
1197 assert_eq!(result, "<s>hello");
1198 }
1199}