1use std::{collections::HashMap, fs};
7
8use anyhow::{anyhow, Result};
9use minijinja::{
10 context,
11 machinery::{
12 ast::{Expr, Stmt},
13 parse, WhitespaceConfig,
14 },
15 syntax::SyntaxConfig,
16 value::Kwargs,
17 Environment, Error as MinijinjaError, ErrorKind, Value,
18};
19use serde::Serialize;
20use serde_json::{self, ser::PrettyFormatter, Value as JsonValue};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
24pub enum ChatTemplateContentFormat {
25 #[default]
27 String,
28 OpenAI,
30}
31
32impl std::fmt::Display for ChatTemplateContentFormat {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 Self::String => write!(f, "string"),
36 Self::OpenAI => write!(f, "openai"),
37 }
38 }
39}
40
41pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat {
50 if let Some(format) = detect_format_with_ast(template) {
52 return format;
53 }
54
55 ChatTemplateContentFormat::String
57}
58
59#[derive(Default, Debug, Clone, Copy)]
61struct Flags {
62 saw_iteration: bool,
63 saw_structure: bool,
64 saw_assignment: bool,
65 saw_macro: bool,
66}
67
68impl Flags {
69 fn any(self) -> bool {
70 self.saw_iteration || self.saw_structure || self.saw_assignment || self.saw_macro
71 }
72}
73
74struct Detector<'a> {
76 ast: &'a Stmt<'a>,
77 scope: std::collections::VecDeque<String>,
79 scope_set: std::collections::HashSet<String>,
80 flags: Flags,
81}
82
83impl<'a> Detector<'a> {
84 fn new(ast: &'a Stmt<'a>) -> Self {
85 Self {
86 ast,
87 scope: std::collections::VecDeque::new(),
88 scope_set: std::collections::HashSet::new(),
89 flags: Flags::default(),
90 }
91 }
92
93 fn run(mut self) -> Flags {
94 self.walk_stmt(self.ast);
95 self.flags
96 }
97
98 fn push_scope(&mut self, var: String) {
99 self.scope.push_back(var.clone());
100 self.scope_set.insert(var);
101 }
102
103 fn pop_scope(&mut self) {
104 if let Some(v) = self.scope.pop_back() {
105 self.scope_set.remove(&v);
106 }
107 }
108
109 fn is_var_access(expr: &Expr, varname: &str) -> bool {
110 matches!(expr, Expr::Var(v) if v.id == varname)
111 }
112
113 fn is_const_str(expr: &Expr, value: &str) -> bool {
114 matches!(expr, Expr::Const(c) if c.value.as_str() == Some(value))
115 }
116
117 fn is_numeric_const(expr: &Expr) -> bool {
118 matches!(expr, Expr::Const(c) if c.value.is_number())
119 }
120
121 fn is_var_dot_content(expr: &Expr, varname: &str) -> bool {
123 match expr {
124 Expr::GetAttr(g) => Self::is_var_access(&g.expr, varname) && g.name == "content",
125 Expr::GetItem(g) => {
126 Self::is_var_access(&g.expr, varname)
127 && Self::is_const_str(&g.subscript_expr, "content")
128 }
129 Expr::Filter(f) => f
131 .expr
132 .as_ref()
133 .is_some_and(|e| Self::is_var_dot_content(e, varname)),
134 Expr::Test(t) => Self::is_var_dot_content(&t.expr, varname),
135 _ => false,
136 }
137 }
138
139 fn is_any_scope_var_content(&self, expr: &Expr) -> bool {
141 let mut current_expr = expr;
142 loop {
143 if self
145 .scope_set
146 .iter()
147 .any(|v| Self::is_var_dot_content(current_expr, v))
148 {
149 return true;
150 }
151 match current_expr {
153 Expr::GetAttr(g) => current_expr = &g.expr,
154 Expr::GetItem(g) => current_expr = &g.expr,
155 _ => return false,
156 }
157 }
158 }
159
160 fn walk_stmt(&mut self, stmt: &Stmt) {
161 if self.flags.any() {
163 return;
164 }
165
166 match stmt {
167 Stmt::Template(t) => {
168 for ch in &t.children {
169 self.walk_stmt(ch);
170 }
171 }
172 Stmt::ForLoop(fl) => {
174 if let Expr::Var(iter) = &fl.iter {
176 if iter.id == "messages" {
177 if let Expr::Var(target) = &fl.target {
178 self.push_scope(target.id.to_string());
179 }
180 }
181 }
182
183 if self.is_any_scope_var_content(&fl.iter) {
186 self.flags.saw_iteration = true;
187 }
188 if matches!(&fl.iter, Expr::Var(v) if v.id == "content") {
190 self.flags.saw_iteration = true;
191 }
192
193 for b in &fl.body {
194 self.walk_stmt(b);
195 }
196
197 if let Expr::Var(iter) = &fl.iter {
199 if iter.id == "messages" && matches!(&fl.target, Expr::Var(_)) {
200 self.pop_scope();
201 }
202 }
203 }
204 Stmt::IfCond(ic) => {
205 self.inspect_expr_for_structure(&ic.expr);
206 for b in &ic.true_body {
207 self.walk_stmt(b);
208 }
209 for b in &ic.false_body {
210 self.walk_stmt(b);
211 }
212 }
213 Stmt::EmitExpr(e) => {
214 self.inspect_expr_for_structure(&e.expr);
215 }
216 Stmt::Set(s) => {
218 if Self::is_var_access(&s.target, "content")
219 && self.is_any_scope_var_content(&s.expr)
220 {
221 self.flags.saw_assignment = true;
222 }
223 }
224 Stmt::Macro(m) => {
225 let mut has_type_check = false;
227 let mut has_loop = false;
228 Self::scan_macro_body(&m.body, &mut has_type_check, &mut has_loop);
229 if has_type_check && has_loop {
230 self.flags.saw_macro = true;
231 }
232 }
233 _ => {}
234 }
235 }
236
237 fn inspect_expr_for_structure(&mut self, expr: &Expr) {
238 if self.flags.saw_structure {
239 return;
240 }
241
242 match expr {
243 Expr::GetItem(gi) => {
245 if (matches!(&gi.expr, Expr::Var(v) if v.id == "content")
246 || self.is_any_scope_var_content(&gi.expr))
247 && Self::is_numeric_const(&gi.subscript_expr)
248 {
249 self.flags.saw_structure = true;
250 }
251 }
252 Expr::Filter(f) => {
254 if f.name == "length" {
255 if let Some(inner) = &f.expr {
256 let inner_ref: &Expr = inner;
258 let is_content_var = matches!(inner_ref, Expr::Var(v) if v.id == "content");
259 if is_content_var || self.is_any_scope_var_content(inner_ref) {
260 self.flags.saw_structure = true;
261 }
262 }
263 } else if let Some(inner) = &f.expr {
264 let inner_ref: &Expr = inner;
265 self.inspect_expr_for_structure(inner_ref);
266 }
267 }
268 Expr::Test(t) => {
270 if t.name == "sequence" || t.name == "iterable" || t.name == "string" {
271 if matches!(&t.expr, Expr::Var(v) if v.id == "content")
272 || self.is_any_scope_var_content(&t.expr)
273 {
274 self.flags.saw_structure = true;
275 }
276 } else {
277 self.inspect_expr_for_structure(&t.expr);
278 }
279 }
280 Expr::GetAttr(g) => {
281 self.inspect_expr_for_structure(&g.expr);
283 }
284 Expr::BinOp(op) => {
286 self.inspect_expr_for_structure(&op.left);
287 self.inspect_expr_for_structure(&op.right);
288 }
289 Expr::UnaryOp(op) => {
291 self.inspect_expr_for_structure(&op.expr);
292 }
293 _ => {}
294 }
295 }
296
297 fn scan_macro_body(body: &[Stmt], has_type_check: &mut bool, has_loop: &mut bool) {
298 for s in body {
299 if *has_type_check && *has_loop {
300 return;
301 }
302
303 match s {
304 Stmt::IfCond(ic) => {
305 if matches!(&ic.expr, Expr::Test(_)) {
306 *has_type_check = true;
307 }
308 Self::scan_macro_body(&ic.true_body, has_type_check, has_loop);
309 Self::scan_macro_body(&ic.false_body, has_type_check, has_loop);
310 }
311 Stmt::ForLoop(fl) => {
312 *has_loop = true;
313 Self::scan_macro_body(&fl.body, has_type_check, has_loop);
314 }
315 Stmt::Template(t) => {
316 Self::scan_macro_body(&t.children, has_type_check, has_loop);
317 }
318 _ => {}
319 }
320 }
321 }
322}
323
324fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
327 let ast = match parse(
328 template,
329 "template",
330 SyntaxConfig {},
331 WhitespaceConfig::default(),
332 ) {
333 Ok(ast) => ast,
334 Err(_) => return Some(ChatTemplateContentFormat::String),
335 };
336
337 let flags = Detector::new(&ast).run();
338 Some(if flags.any() {
339 ChatTemplateContentFormat::OpenAI
340 } else {
341 ChatTemplateContentFormat::String
342 })
343}
344
345#[derive(Default)]
347pub struct ChatTemplateParams<'a> {
348 pub add_generation_prompt: bool,
349 pub tools: Option<&'a [serde_json::Value]>,
350 pub documents: Option<&'a [serde_json::Value]>,
351 pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
352}
353
354fn tojson_filter(value: Value, kwargs: Kwargs) -> std::result::Result<Value, MinijinjaError> {
366 let _ensure_ascii: Option<bool> = kwargs.get("ensure_ascii")?;
367 let indent: Option<i64> = kwargs.get("indent")?;
368 let _separators: Option<Value> = kwargs.get("separators")?;
369 let sort_keys: Option<bool> = kwargs.get("sort_keys")?;
370
371 kwargs.assert_all_used()?;
373
374 let json_value: serde_json::Value = serde_json::to_value(&value).map_err(|e| {
375 MinijinjaError::new(
376 ErrorKind::InvalidOperation,
377 format!("Failed to convert to JSON value: {}", e),
378 )
379 })?;
380
381 fn serialize_with_indent<T: Serialize>(
383 value: &T,
384 spaces: usize,
385 ) -> std::result::Result<String, MinijinjaError> {
386 let indent_str = vec![b' '; spaces];
387 let formatter = PrettyFormatter::with_indent(&indent_str);
388 let mut buf = Vec::new();
389 let mut serializer = serde_json::Serializer::with_formatter(&mut buf, formatter);
390 value.serialize(&mut serializer).map_err(|e| {
391 MinijinjaError::new(
392 ErrorKind::InvalidOperation,
393 format!("Failed to serialize JSON: {}", e),
394 )
395 })?;
396 String::from_utf8(buf).map_err(|e| {
397 MinijinjaError::new(
398 ErrorKind::InvalidOperation,
399 format!("Invalid UTF-8 in JSON output: {}", e),
400 )
401 })
402 }
403
404 let json_str: std::result::Result<String, MinijinjaError> = {
406 let sorted_json;
407 let value_to_serialize = if sort_keys.unwrap_or(false) {
408 sorted_json = sort_json_keys(&json_value);
409 &sorted_json
410 } else {
411 &json_value
412 };
413
414 if let Some(spaces) = indent {
415 if spaces < 0 {
416 return Err(MinijinjaError::new(
417 ErrorKind::InvalidOperation,
418 "indent cannot be negative",
419 ));
420 }
421 serialize_with_indent(value_to_serialize, spaces as usize)
422 } else {
423 serde_json::to_string(value_to_serialize).map_err(|e| {
424 MinijinjaError::new(
425 ErrorKind::InvalidOperation,
426 format!("Failed to serialize JSON: {}", e),
427 )
428 })
429 }
430 };
431
432 json_str.map(Value::from_safe_string)
433}
434
435fn sort_json_keys(value: &JsonValue) -> JsonValue {
437 match value {
438 JsonValue::Object(map) => {
439 let mut sorted: serde_json::Map<String, JsonValue> = serde_json::Map::new();
440 let mut keys: Vec<_> = map.keys().collect();
441 keys.sort();
442 for key in keys {
443 sorted.insert(key.clone(), sort_json_keys(&map[key]));
444 }
445 JsonValue::Object(sorted)
446 }
447 JsonValue::Array(arr) => JsonValue::Array(arr.iter().map(sort_json_keys).collect()),
448 _ => value.clone(),
449 }
450}
451
452pub struct ChatTemplateProcessor {
454 template: String,
455}
456
457impl ChatTemplateProcessor {
458 pub fn new(template: String) -> Self {
460 ChatTemplateProcessor { template }
461 }
462
463 pub fn apply_chat_template(
469 &self,
470 messages: &[serde_json::Value],
471 params: ChatTemplateParams,
472 ) -> Result<String> {
473 let mut env = Environment::new();
474
475 env.add_template("chat", &self.template)
477 .map_err(|e| anyhow!("Failed to add template: {}", e))?;
478
479 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
481
482 env.add_filter("tojson", tojson_filter);
486
487 let tmpl = env
489 .get_template("chat")
490 .map_err(|e| anyhow!("Failed to get template: {}", e))?;
491
492 let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
494
495 let base_context = context! {
496 messages => &minijinja_messages,
497 add_generation_prompt => params.add_generation_prompt,
498 tools => params.tools,
499 documents => params.documents,
500 };
501
502 let ctx = if let Some(kwargs) = params.template_kwargs {
504 context! {
505 ..base_context,
506 ..Value::from_serialize(kwargs)
507 }
508 } else {
509 base_context
510 };
511
512 let rendered = tmpl
514 .render(&ctx)
515 .map_err(|e| anyhow!("Failed to render template: {}", e))?;
516
517 Ok(rendered)
518 }
519}
520
521pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
523 let content = fs::read_to_string(config_path)?;
524 let config: serde_json::Value = serde_json::from_str(&content)?;
525
526 if let Some(template) = config.get("chat_template") {
528 if let Some(template_str) = template.as_str() {
529 return Ok(Some(template_str.to_string()));
530 }
531 }
532
533 Ok(None)
534}