1use std::{collections::HashMap, fs};
7
8use anyhow::{anyhow, Result};
9use minijinja::{
10 context,
11 machinery::{
12 ast::{Expr, Stmt},
13 parse, WhitespaceConfig,
14 },
15 syntax::SyntaxConfig,
16 value::Kwargs,
17 Environment, Error as MinijinjaError, ErrorKind, Value,
18};
19use serde::Serialize;
20use serde_json::{self, ser::PrettyFormatter, Value as JsonValue};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
24pub enum ChatTemplateContentFormat {
25 #[default]
27 String,
28 OpenAI,
30}
31
32impl std::fmt::Display for ChatTemplateContentFormat {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 Self::String => write!(f, "string"),
36 Self::OpenAI => write!(f, "openai"),
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum ThinkingKeyName {
45 EnableThinking,
47 Thinking,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum ThinkingToggle {
53 #[default]
57 None,
58 DefaultOn,
62 DefaultOff,
66}
67
68pub fn detect_thinking_toggle(template: &str) -> (ThinkingToggle, Option<ThinkingKeyName>) {
71 let has_enable_thinking = template.contains("enable_thinking");
72 let has_thinking_var = template.contains("if thinking ")
74 || template.contains("thinking is ")
75 || template.contains("thinking ==")
76 || template.contains("set thinking ");
77
78 if !has_enable_thinking && !has_thinking_var {
79 return (ThinkingToggle::None, None);
80 }
81
82 let key_name = if has_enable_thinking {
84 ThinkingKeyName::EnableThinking
85 } else {
86 ThinkingKeyName::Thinking
87 };
88
89 if template.contains("set thinking = false") || template.contains("set thinking=false") {
92 return (ThinkingToggle::DefaultOff, Some(key_name));
93 }
94 if template.contains("set enable_thinking = false")
95 || template.contains("set enable_thinking=false")
96 {
97 return (ThinkingToggle::DefaultOff, Some(key_name));
98 }
99
100 (ThinkingToggle::DefaultOn, Some(key_name))
102}
103
104pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat {
113 detect_format_with_ast(template)
115}
116
117#[derive(Default, Debug, Clone, Copy)]
119struct Flags {
120 saw_iteration: bool,
121 saw_structure: bool,
122 saw_assignment: bool,
123 saw_macro: bool,
124}
125
126impl Flags {
127 fn any(self) -> bool {
128 self.saw_iteration || self.saw_structure || self.saw_macro
133 }
134}
135
136struct Detector<'a> {
138 ast: &'a Stmt<'a>,
139 scope: std::collections::VecDeque<String>,
141 scope_set: std::collections::HashSet<String>,
142 flags: Flags,
143 think_in_prefill: bool,
145}
146
147impl<'a> Detector<'a> {
148 fn new(ast: &'a Stmt<'a>) -> Self {
149 Self {
150 ast,
151 scope: std::collections::VecDeque::new(),
152 scope_set: std::collections::HashSet::new(),
153 flags: Flags::default(),
154 think_in_prefill: false,
155 }
156 }
157
158 fn run(mut self) -> (Flags, bool) {
159 self.walk_stmt(self.ast);
160 (self.flags, self.think_in_prefill)
161 }
162
163 fn push_scope(&mut self, var: String) {
164 self.scope.push_back(var.clone());
165 self.scope_set.insert(var);
166 }
167
168 fn pop_scope(&mut self) {
169 if let Some(v) = self.scope.pop_back() {
170 self.scope_set.remove(&v);
171 }
172 }
173
174 fn is_var_access(expr: &Expr, varname: &str) -> bool {
175 matches!(expr, Expr::Var(v) if v.id == varname)
176 }
177
178 fn is_const_str(expr: &Expr, value: &str) -> bool {
179 matches!(expr, Expr::Const(c) if c.value.as_str() == Some(value))
180 }
181
182 fn is_numeric_const(expr: &Expr) -> bool {
183 matches!(expr, Expr::Const(c) if c.value.is_number())
184 }
185
186 fn is_var_dot_content(expr: &Expr, varname: &str) -> bool {
188 match expr {
189 Expr::GetAttr(g) => Self::is_var_access(&g.expr, varname) && g.name == "content",
190 Expr::GetItem(g) => {
191 Self::is_var_access(&g.expr, varname)
192 && Self::is_const_str(&g.subscript_expr, "content")
193 }
194 Expr::Filter(f) => f
196 .expr
197 .as_ref()
198 .is_some_and(|e| Self::is_var_dot_content(e, varname)),
199 Expr::Test(t) => Self::is_var_dot_content(&t.expr, varname),
200 _ => false,
201 }
202 }
203
204 fn is_any_scope_var_content(&self, expr: &Expr) -> bool {
206 let mut current_expr = expr;
207 loop {
208 if self
210 .scope_set
211 .iter()
212 .any(|v| Self::is_var_dot_content(current_expr, v))
213 {
214 return true;
215 }
216 match current_expr {
218 Expr::GetAttr(g) => current_expr = &g.expr,
219 Expr::GetItem(g) => current_expr = &g.expr,
220 _ => return false,
221 }
222 }
223 }
224
225 fn expr_references_var(expr: &Expr, name: &str) -> bool {
227 match expr {
228 Expr::Var(v) => v.id == name,
229 Expr::BinOp(b) => {
230 Self::expr_references_var(&b.left, name)
231 || Self::expr_references_var(&b.right, name)
232 }
233 Expr::UnaryOp(u) => Self::expr_references_var(&u.expr, name),
234 _ => false,
235 }
236 }
237
238 fn body_has_think_tag(stmts: &[Stmt]) -> bool {
240 for stmt in stmts {
241 match stmt {
242 Stmt::EmitRaw(raw) if raw.raw.contains("<think>") => return true,
243 Stmt::EmitExpr(e) => {
244 if let Expr::Const(c) = &e.expr {
245 if c.value.as_str().is_some_and(|s| s.contains("<think>")) {
246 return true;
247 }
248 }
249 }
250 Stmt::IfCond(ic) => {
251 if Self::body_has_think_tag(&ic.true_body)
252 || Self::body_has_think_tag(&ic.false_body)
253 {
254 return true;
255 }
256 }
257 _ => {}
258 }
259 }
260 false
261 }
262
263 fn walk_stmt(&mut self, stmt: &Stmt) {
264 match stmt {
265 Stmt::Template(t) => {
266 for ch in &t.children {
267 self.walk_stmt(ch);
268 }
269 }
270 Stmt::ForLoop(fl) => {
272 if let Expr::Var(iter) = &fl.iter {
274 if iter.id == "messages" {
275 if let Expr::Var(target) = &fl.target {
276 self.push_scope(target.id.to_string());
277 }
278 }
279 }
280
281 if self.is_any_scope_var_content(&fl.iter) {
284 self.flags.saw_iteration = true;
285 }
286 if matches!(&fl.iter, Expr::Var(v) if v.id == "content") {
288 self.flags.saw_iteration = true;
289 }
290
291 for b in &fl.body {
292 self.walk_stmt(b);
293 }
294
295 if let Expr::Var(iter) = &fl.iter {
297 if iter.id == "messages" && matches!(&fl.target, Expr::Var(_)) {
298 self.pop_scope();
299 }
300 }
301 }
302 Stmt::IfCond(ic) => {
303 self.inspect_expr_for_structure(&ic.expr);
304
305 if !self.think_in_prefill
307 && Self::expr_references_var(&ic.expr, "add_generation_prompt")
308 {
309 self.think_in_prefill = Self::body_has_think_tag(&ic.true_body);
310 }
311
312 for b in &ic.true_body {
313 self.walk_stmt(b);
314 }
315 for b in &ic.false_body {
316 self.walk_stmt(b);
317 }
318 }
319 Stmt::EmitExpr(e) => {
320 self.inspect_expr_for_structure(&e.expr);
321 }
322 Stmt::Set(s) => {
324 if Self::is_var_access(&s.target, "content")
325 && self.is_any_scope_var_content(&s.expr)
326 {
327 self.flags.saw_assignment = true;
328 }
329 }
330 Stmt::Macro(m) => {
331 let mut has_type_check = false;
333 let mut has_loop = false;
334 Self::scan_macro_body(&m.body, &mut has_type_check, &mut has_loop);
335 if has_type_check && has_loop {
336 self.flags.saw_macro = true;
337 }
338 }
339 _ => {}
340 }
341 }
342
343 fn inspect_expr_for_structure(&mut self, expr: &Expr) {
344 if self.flags.saw_structure {
345 return;
346 }
347
348 match expr {
349 Expr::GetItem(gi) => {
351 if (matches!(&gi.expr, Expr::Var(v) if v.id == "content")
352 || self.is_any_scope_var_content(&gi.expr))
353 && Self::is_numeric_const(&gi.subscript_expr)
354 {
355 self.flags.saw_structure = true;
356 }
357 }
358 Expr::Filter(f) => {
360 if f.name == "length" {
361 if let Some(inner) = &f.expr {
362 let inner_ref: &Expr = inner;
364 let is_content_var = matches!(inner_ref, Expr::Var(v) if v.id == "content");
365 if is_content_var || self.is_any_scope_var_content(inner_ref) {
366 self.flags.saw_structure = true;
367 }
368 }
369 } else if let Some(inner) = &f.expr {
370 let inner_ref: &Expr = inner;
371 self.inspect_expr_for_structure(inner_ref);
372 }
373 }
374 Expr::Test(t) => self.inspect_expr_for_structure(&t.expr),
378 Expr::GetAttr(g) => {
379 self.inspect_expr_for_structure(&g.expr);
381 }
382 Expr::BinOp(op) => {
384 self.inspect_expr_for_structure(&op.left);
385 self.inspect_expr_for_structure(&op.right);
386 }
387 Expr::UnaryOp(op) => {
389 self.inspect_expr_for_structure(&op.expr);
390 }
391 _ => {}
392 }
393 }
394
395 fn scan_macro_body(body: &[Stmt], has_type_check: &mut bool, has_loop: &mut bool) {
396 for s in body {
397 if *has_type_check && *has_loop {
398 return;
399 }
400
401 match s {
402 Stmt::IfCond(ic) => {
403 if matches!(&ic.expr, Expr::Test(_)) {
404 *has_type_check = true;
405 }
406 Self::scan_macro_body(&ic.true_body, has_type_check, has_loop);
407 Self::scan_macro_body(&ic.false_body, has_type_check, has_loop);
408 }
409 Stmt::ForLoop(fl) => {
410 *has_loop = true;
411 Self::scan_macro_body(&fl.body, has_type_check, has_loop);
412 }
413 Stmt::Template(t) => {
414 Self::scan_macro_body(&t.children, has_type_check, has_loop);
415 }
416 _ => {}
417 }
418 }
419 }
420}
421
422fn detect_format_with_ast(template: &str) -> ChatTemplateContentFormat {
425 detect_all_with_ast(template).0
426}
427
428fn detect_all(
430 template: &str,
431) -> (
432 ChatTemplateContentFormat,
433 bool,
434 ThinkingToggle,
435 Option<ThinkingKeyName>,
436) {
437 let (thinking_toggle, thinking_key_name) = detect_thinking_toggle(template);
438 let (content_format, think_in_prefill) = detect_all_with_ast(template);
439 (
440 content_format,
441 think_in_prefill,
442 thinking_toggle,
443 thinking_key_name,
444 )
445}
446
447fn detect_all_with_ast(template: &str) -> (ChatTemplateContentFormat, bool) {
449 let ast = match parse(
450 template,
451 "template",
452 SyntaxConfig {},
453 WhitespaceConfig::default(),
454 ) {
455 Ok(ast) => ast,
456 Err(_) => return (ChatTemplateContentFormat::String, false),
457 };
458
459 let (flags, think_in_prefill) = Detector::new(&ast).run();
460 let content_format = if flags.any() {
461 ChatTemplateContentFormat::OpenAI
462 } else {
463 ChatTemplateContentFormat::String
464 };
465 (content_format, think_in_prefill)
466}
467
468#[derive(Default)]
470pub struct ChatTemplateParams<'a> {
471 pub add_generation_prompt: bool,
472 pub tools: Option<&'a [serde_json::Value]>,
473 pub documents: Option<&'a [serde_json::Value]>,
474 pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
475 pub special_tokens: Option<&'a crate::traits::SpecialTokens>,
478}
479
480fn tojson_filter(value: Value, kwargs: Kwargs) -> std::result::Result<Value, MinijinjaError> {
492 let _ensure_ascii: Option<bool> = kwargs.get("ensure_ascii")?;
493 let indent: Option<i64> = kwargs.get("indent")?;
494 let _separators: Option<Value> = kwargs.get("separators")?;
495 let sort_keys: Option<bool> = kwargs.get("sort_keys")?;
496
497 kwargs.assert_all_used()?;
499
500 let json_value: serde_json::Value = serde_json::to_value(&value).map_err(|e| {
501 MinijinjaError::new(
502 ErrorKind::InvalidOperation,
503 format!("Failed to convert to JSON value: {e}"),
504 )
505 })?;
506
507 fn serialize_with_indent<T: Serialize>(
509 value: &T,
510 spaces: usize,
511 ) -> std::result::Result<String, MinijinjaError> {
512 let indent_str = vec![b' '; spaces];
513 let formatter = PrettyFormatter::with_indent(&indent_str);
514 let mut buf = Vec::new();
515 let mut serializer = serde_json::Serializer::with_formatter(&mut buf, formatter);
516 value.serialize(&mut serializer).map_err(|e| {
517 MinijinjaError::new(
518 ErrorKind::InvalidOperation,
519 format!("Failed to serialize JSON: {e}"),
520 )
521 })?;
522 String::from_utf8(buf).map_err(|e| {
523 MinijinjaError::new(
524 ErrorKind::InvalidOperation,
525 format!("Invalid UTF-8 in JSON output: {e}"),
526 )
527 })
528 }
529
530 let json_str: std::result::Result<String, MinijinjaError> = {
532 let sorted_json;
533 let value_to_serialize = if sort_keys.unwrap_or(false) {
534 sorted_json = sort_json_keys(&json_value);
535 &sorted_json
536 } else {
537 &json_value
538 };
539
540 if let Some(spaces) = indent {
541 if spaces < 0 {
542 return Err(MinijinjaError::new(
543 ErrorKind::InvalidOperation,
544 "indent cannot be negative",
545 ));
546 }
547 serialize_with_indent(value_to_serialize, spaces as usize)
548 } else {
549 serde_json::to_string(value_to_serialize).map_err(|e| {
550 MinijinjaError::new(
551 ErrorKind::InvalidOperation,
552 format!("Failed to serialize JSON: {e}"),
553 )
554 })
555 }
556 };
557
558 json_str.map(Value::from_safe_string)
559}
560
561fn sort_json_keys(value: &JsonValue) -> JsonValue {
563 match value {
564 JsonValue::Object(map) => {
565 let mut sorted: serde_json::Map<String, JsonValue> = serde_json::Map::new();
566 let mut keys: Vec<_> = map.keys().collect();
567 keys.sort();
568 for key in keys {
569 sorted.insert(key.clone(), sort_json_keys(&map[key]));
570 }
571 JsonValue::Object(sorted)
572 }
573 JsonValue::Array(arr) => JsonValue::Array(arr.iter().map(sort_json_keys).collect()),
574 _ => value.clone(),
575 }
576}
577
578fn build_environment(template: String) -> Result<Environment<'static>> {
583 let mut env = Environment::new();
584
585 env.set_trim_blocks(true);
589 env.set_lstrip_blocks(true);
590
591 env.add_template_owned("chat".to_owned(), template)
593 .map_err(|e| anyhow!("Failed to add template: {e}"))?;
594
595 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
597
598 env.add_filter("tojson", tojson_filter);
602
603 Ok(env)
604}
605
606fn special_token_value(token: Option<&str>) -> Value {
611 token.map_or(Value::UNDEFINED, Value::from)
612}
613
614fn render_chat_template(
615 env: &Environment<'_>,
616 messages: &[serde_json::Value],
617 params: ChatTemplateParams,
618) -> Result<String> {
619 let tmpl = env
620 .get_template("chat")
621 .map_err(|e| anyhow!("Failed to get template: {e}"))?;
622
623 let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
625
626 let tools_value = params.tools.map_or(Value::UNDEFINED, Value::from_serialize);
631 let documents_value = params
632 .documents
633 .map_or(Value::UNDEFINED, Value::from_serialize);
634
635 let bos_value =
639 special_token_value(params.special_tokens.and_then(|st| st.bos_token.as_deref()));
640 let eos_value =
641 special_token_value(params.special_tokens.and_then(|st| st.eos_token.as_deref()));
642 let unk_value =
643 special_token_value(params.special_tokens.and_then(|st| st.unk_token.as_deref()));
644 let pad_value =
645 special_token_value(params.special_tokens.and_then(|st| st.pad_token.as_deref()));
646
647 let base_context = context! {
648 messages => &minijinja_messages,
649 add_generation_prompt => params.add_generation_prompt,
650 tools => tools_value,
651 documents => documents_value,
652 bos_token => bos_value,
653 eos_token => eos_value,
654 unk_token => unk_value,
655 pad_token => pad_value,
656 };
657
658 let ctx = if let Some(kwargs) = params.template_kwargs {
660 context! {
661 ..base_context,
662 ..Value::from_serialize(kwargs)
663 }
664 } else {
665 base_context
666 };
667
668 let rendered = tmpl
670 .render(&ctx)
671 .map_err(|e| anyhow!("Failed to render template: {e}"))?;
672
673 Ok(rendered)
674}
675
676pub struct ChatTemplateProcessor {
678 env: Environment<'static>,
679}
680
681impl ChatTemplateProcessor {
682 pub fn new(template: String) -> Result<Self> {
688 let env = build_environment(template)?;
689 Ok(ChatTemplateProcessor { env })
690 }
691
692 pub fn apply_chat_template(
698 &self,
699 messages: &[serde_json::Value],
700 params: ChatTemplateParams,
701 ) -> Result<String> {
702 render_chat_template(&self.env, messages, params)
703 }
704}
705
706pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
708 let content = fs::read_to_string(config_path)?;
709 let config: serde_json::Value = serde_json::from_str(&content)?;
710
711 if let Some(template) = config.get("chat_template") {
713 if let Some(template_str) = template.as_str() {
714 return Ok(Some(template_str.to_string()));
715 }
716 }
717
718 Ok(None)
719}
720
721pub fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
724 let content = fs::read_to_string(template_path)
725 .map_err(|e| anyhow!("Failed to read chat template file: {e}"))?;
726
727 if template_path.ends_with(".json") {
728 let json_value: serde_json::Value = serde_json::from_str(&content)
729 .map_err(|e| anyhow!("Failed to parse chat_template.json: {e}"))?;
730
731 if let Some(template_str) = json_value.as_str() {
732 return Ok(Some(template_str.to_string()));
733 } else if let Some(obj) = json_value.as_object() {
734 if let Some(template_value) = obj.get("chat_template") {
735 if let Some(template_str) = template_value.as_str() {
736 return Ok(Some(template_str.to_string()));
737 }
738 }
739 }
740
741 return Err(anyhow!(
742 "chat_template.json does not contain a valid template",
743 ));
744 }
745
746 let template = content.trim().replace("\\n", "\n");
748 Ok(Some(template))
749}
750
751pub struct ChatTemplateState {
762 env: Option<Environment<'static>>,
764 content_format: ChatTemplateContentFormat,
765 thinking_toggle: ThinkingToggle,
767 thinking_key_name: Option<ThinkingKeyName>,
769 think_in_prefill: bool,
771}
772
773impl std::fmt::Debug for ChatTemplateState {
774 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
775 f.debug_struct("ChatTemplateState")
776 .field("has_template", &self.env.is_some())
777 .field("content_format", &self.content_format)
778 .field("thinking_toggle", &self.thinking_toggle)
779 .field("think_in_prefill", &self.think_in_prefill)
780 .finish()
781 }
782}
783
784impl ChatTemplateState {
785 pub fn new(template: Option<String>) -> Result<Self> {
786 let (content_format, think_in_prefill, thinking_toggle, thinking_key_name) =
787 template.as_ref().map(|t| detect_all(t)).unwrap_or_default();
788 let env = template.map(build_environment).transpose()?;
789 Ok(Self {
790 env,
791 content_format,
792 thinking_toggle,
793 thinking_key_name,
794 think_in_prefill,
795 })
796 }
797
798 pub fn empty() -> Self {
803 Self {
804 env: None,
805 content_format: ChatTemplateContentFormat::default(),
806 thinking_toggle: ThinkingToggle::None,
807 thinking_key_name: None,
808 think_in_prefill: false,
809 }
810 }
811
812 pub fn apply(
813 &self,
814 messages: &[serde_json::Value],
815 params: ChatTemplateParams,
816 ) -> Result<String> {
817 let env = self.env.as_ref().ok_or_else(|| {
818 anyhow!(
819 "Cannot use chat template functions because tokenizer.chat_template is not set \
820 and no template argument was passed! For information about writing templates and \
821 setting the tokenizer.chat_template attribute, please see the documentation at \
822 https://huggingface.co/docs/transformers/main/en/chat_templating",
823 )
824 })?;
825 render_chat_template(env, messages, params)
826 }
827
828 pub fn set(&mut self, template: String) -> Result<()> {
829 let (content_format, think_in_prefill, thinking_toggle, thinking_key_name) =
830 detect_all(&template);
831 let env = build_environment(template)?;
832 self.content_format = content_format;
833 self.thinking_toggle = thinking_toggle;
834 self.thinking_key_name = thinking_key_name;
835 self.think_in_prefill = think_in_prefill;
836 self.env = Some(env);
837 Ok(())
838 }
839
840 pub fn content_format(&self) -> ChatTemplateContentFormat {
841 self.content_format
842 }
843
844 pub fn thinking_toggle(&self) -> ThinkingToggle {
845 self.thinking_toggle
846 }
847
848 pub fn thinking_key_name(&self) -> Option<ThinkingKeyName> {
849 self.thinking_key_name
850 }
851
852 pub fn think_in_prefill(&self) -> bool {
853 self.think_in_prefill
854 }
855}
856
857#[cfg(test)]
858mod tests {
859 use super::*;
860
861 #[test]
862 fn test_chat_template_state_no_template() {
863 let state = ChatTemplateState::new(None).unwrap();
864 assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
865 let result = state.apply(&[], ChatTemplateParams::default());
866 assert!(result.is_err());
867 }
868
869 #[test]
870 fn test_chat_template_state_set() {
871 let mut state = ChatTemplateState::new(None).unwrap();
872 state.set("{{ messages }}".to_string()).unwrap();
873 assert_eq!(state.content_format(), ChatTemplateContentFormat::String);
874 }
875
876 #[test]
877 fn test_chat_template_state_invalid_template() {
878 let result = ChatTemplateState::new(Some("{% invalid".to_string()));
879 assert!(result.is_err());
880 let err = result.unwrap_err().to_string();
881 assert!(
882 err.contains("Failed to add template"),
883 "Error should explain parse failure, got: {err}"
884 );
885 }
886
887 #[test]
888 fn test_chat_template_processor_invalid_template() {
889 let result = ChatTemplateProcessor::new("{% invalid".to_string());
890 assert!(result.is_err());
891 }
892
893 #[test]
894 fn test_special_tokens_injected_into_context() {
895 let template = "{{ bos_token }}{% for message in messages %}{{ message.content }}{% endfor %}{{ eos_token }}";
896 let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
897
898 let messages = vec![serde_json::json!({"role": "user", "content": "hello"})];
899 let special_tokens = crate::traits::SpecialTokens {
900 bos_token: Some("<s>".to_string()),
901 eos_token: Some("</s>".to_string()),
902 ..Default::default()
903 };
904
905 let result = state
906 .apply(
907 &messages,
908 ChatTemplateParams {
909 special_tokens: Some(&special_tokens),
910 ..Default::default()
911 },
912 )
913 .unwrap();
914
915 assert_eq!(result, "<s>hello</s>");
916 }
917
918 #[test]
919 fn test_special_tokens_undefined_when_not_provided() {
920 let template = "{% if bos_token is defined %}{{ bos_token }}{% endif %}hello";
921 let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
922
923 let result = state.apply(&[], ChatTemplateParams::default()).unwrap();
924 assert_eq!(result, "hello");
925 }
926
927 #[test]
928 fn test_special_tokens_partial() {
929 let template =
930 "{{ bos_token }}hello{% if eos_token is defined %}{{ eos_token }}{% endif %}";
931 let state = ChatTemplateState::new(Some(template.to_string())).unwrap();
932
933 let special_tokens = crate::traits::SpecialTokens {
934 bos_token: Some("<s>".to_string()),
935 eos_token: None,
936 ..Default::default()
937 };
938
939 let result = state
940 .apply(
941 &[],
942 ChatTemplateParams {
943 special_tokens: Some(&special_tokens),
944 ..Default::default()
945 },
946 )
947 .unwrap();
948
949 assert_eq!(result, "<s>hello");
950 }
951}