use std::path::Path;
use minijinja::{context, Environment, Value as JinjaValue};
use serde_json::{json, Value};
use crate::tasks::generate::{ContentBlock, GenerateRequest, Message, ThinkingMode, ToolCall};
use crate::InferenceError;
pub struct ChatTemplate {
env: Environment<'static>,
bos_token: String,
eos_token: String,
}
impl ChatTemplate {
pub fn load(model_dir: &Path) -> Result<Option<Self>, InferenceError> {
let Some(src) = Self::read_source(model_dir)? else {
return Ok(None);
};
let mut t = Self::from_source(src)?;
let tok_cfg = model_dir.join("tokenizer_config.json");
if let Ok(raw) = std::fs::read_to_string(&tok_cfg) {
if let Ok(v) = serde_json::from_str::<Value>(&raw) {
let s = |k: &str| {
v.get(k)
.and_then(|x| x.as_str().map(str::to_string).or_else(|| {
x.get("content").and_then(|c| c.as_str()).map(str::to_string)
}))
.unwrap_or_default()
};
t.bos_token = s("bos_token");
t.eos_token = s("eos_token");
}
}
Ok(Some(t))
}
pub fn from_source(src: String) -> Result<Self, InferenceError> {
let mut env = Environment::new();
env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
env.add_function(
"raise_exception",
|msg: String| -> Result<JinjaValue, minijinja::Error> {
Err(minijinja::Error::new(
minijinja::ErrorKind::InvalidOperation,
msg,
))
},
);
env.add_template_owned("chat", src)
.map_err(|e| InferenceError::InferenceFailed(format!("chat_template parse: {e}")))?;
Ok(Self {
env,
bos_token: String::new(),
eos_token: String::new(),
})
}
fn read_source(model_dir: &Path) -> Result<Option<String>, InferenceError> {
let jinja = model_dir.join("chat_template.jinja");
if jinja.exists() {
return std::fs::read_to_string(&jinja).map(Some).map_err(|e| {
InferenceError::InferenceFailed(format!("read {}: {e}", jinja.display()))
});
}
let tok_cfg = model_dir.join("tokenizer_config.json");
if tok_cfg.exists() {
let raw = std::fs::read_to_string(&tok_cfg).map_err(|e| {
InferenceError::InferenceFailed(format!("read {}: {e}", tok_cfg.display()))
})?;
let v: Value = serde_json::from_str(&raw).map_err(|e| {
InferenceError::InferenceFailed(format!("parse {}: {e}", tok_cfg.display()))
})?;
if let Some(t) = v.get("chat_template").and_then(|t| t.as_str()) {
return Ok(Some(t.to_string()));
}
}
Ok(None)
}
pub fn render(
&self,
messages: &[Message],
tools: Option<&[Value]>,
thinking: ThinkingMode,
) -> Result<String, InferenceError> {
let msgs: Vec<Value> = messages.iter().filter_map(message_to_json).collect();
let tools: Option<Vec<Value>> = tools.map(|ts| ts.iter().map(normalize_tool).collect());
let enable_thinking = match thinking {
ThinkingMode::On => JinjaValue::from(true),
ThinkingMode::Off => JinjaValue::from(false),
ThinkingMode::Auto => JinjaValue::UNDEFINED,
};
let tmpl = self
.env
.get_template("chat")
.map_err(|e| InferenceError::InferenceFailed(format!("chat_template missing: {e}")))?;
let ctx = context! {
messages => JinjaValue::from_serialize(&msgs),
tools => JinjaValue::from_serialize(&tools),
add_generation_prompt => true,
enable_thinking => enable_thinking,
bos_token => self.bos_token.clone(),
eos_token => self.eos_token.clone(),
};
tmpl.render(ctx)
.map_err(|e| InferenceError::InferenceFailed(format!("chat_template render: {e}")))
}
pub fn render_request(&self, req: &GenerateRequest) -> Result<String, InferenceError> {
let messages = request_messages(req);
self.render(&messages, req.tools.as_deref(), req.params.thinking)
}
}
fn request_messages(req: &GenerateRequest) -> Vec<Message> {
if let Some(msgs) = &req.messages {
if !msgs.is_empty() {
return msgs.clone();
}
}
let mut v = Vec::new();
if let Some(ctx) = &req.context {
if !ctx.is_empty() {
v.push(Message::System {
content: ctx.clone(),
});
}
}
v.push(Message::User {
content: req.prompt.clone(),
});
v
}
fn normalize_tool(t: &Value) -> Value {
let wrapped = t.get("type").and_then(|v| v.as_str()) == Some("function")
&& t.get("function").is_some();
if wrapped {
t.clone()
} else {
json!({ "type": "function", "function": t })
}
}
fn message_to_json(m: &Message) -> Option<Value> {
match m {
Message::System { content } => Some(json!({ "role": "system", "content": content })),
Message::User { content } => Some(json!({ "role": "user", "content": content })),
Message::Assistant {
content,
tool_calls,
} => {
let mut o = json!({ "role": "assistant", "content": content });
if !tool_calls.is_empty() {
o["tool_calls"] = Value::Array(tool_calls.iter().map(toolcall_to_json).collect());
}
Some(o)
}
Message::ToolResult {
tool_use_id,
content,
} => Some(json!({
"role": "tool",
"tool_call_id": tool_use_id,
"content": content,
})),
Message::UserMultimodal { content } => {
let text: String = content
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
Some(json!({ "role": "user", "content": text }))
}
Message::ProviderOutputItems { .. } => None,
}
}
fn toolcall_to_json(tc: &ToolCall) -> Value {
let mut o = json!({
"type": "function",
"function": {
"name": tc.name,
"arguments": tc.arguments,
}
});
if let Some(id) = &tc.id {
o["id"] = json!(id);
}
o
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tasks::generate::Message;
use crate::tasks::generate::ToolCall;
use std::collections::HashMap;
fn inline_env(src: &str) -> ChatTemplate {
ChatTemplate::from_source(src.to_string()).unwrap()
}
#[test]
fn renders_messages_tools_and_generation_prompt() {
let src = "\
{%- for m in messages %}<|turn>{{ m.role }}\n{{ m.content }}<turn|>\n{%- endfor %}\
{%- if tools %}TOOLS:{% for t in tools %}{{ t.function.name }} {% endfor %}{% endif %}\
{%- if add_generation_prompt %}<|turn>model\n{% endif %}\
{%- if enable_thinking is defined and enable_thinking %}THINK{% endif %}";
let tmpl = inline_env(src);
let msgs = vec![
Message::System {
content: "be brief".into(),
},
Message::User {
content: "hi".into(),
},
];
let tools = serde_json::json!([{ "name": "get_weather" }]);
let tools_arr = tools.as_array().unwrap().as_slice();
let out = tmpl
.render(&msgs, Some(tools_arr), ThinkingMode::Off)
.unwrap();
assert!(out.contains("<|turn>system\nbe brief<turn|>"));
assert!(out.contains("<|turn>user\nhi<turn|>"));
assert!(out.contains("TOOLS:get_weather"));
assert!(out.trim_end().ends_with("<|turn>model"));
assert!(!out.contains("THINK"));
let auto = tmpl.render(&msgs, None, ThinkingMode::Auto).unwrap();
assert!(!auto.contains("THINK"));
let on = tmpl.render(&msgs, None, ThinkingMode::On).unwrap();
assert!(on.contains("THINK"));
}
#[test]
fn renders_real_gemma4_template() {
let src = include_str!("../../tests/fixtures/gemma4_chat_template.jinja");
let tmpl = ChatTemplate::from_source(src.to_string()).unwrap();
let tool = serde_json::json!([{
"type": "function",
"function": {
"name": "get_weather",
"description": "Look up the weather for a city",
"parameters": {
"type": "object",
"properties": {
"city": { "type": "string", "description": "City name" }
},
"required": ["city"]
}
}
}]);
let tool_arr = tool.as_array().unwrap().as_slice();
let msgs = vec![
Message::System {
content: "You are helpful.".into(),
},
Message::User {
content: "Weather in NYC?".into(),
},
];
let out = tmpl.render(&msgs, Some(tool_arr), ThinkingMode::Off).unwrap();
assert!(out.contains("<|turn>system"), "missing system turn:\n{out}");
assert!(
out.contains("get_weather"),
"tool declaration not rendered:\n{out}"
);
assert!(out.contains("Weather in NYC?"), "missing user content:\n{out}");
assert!(
out.contains("<|turn>model"),
"missing generation prompt:\n{out}"
);
let mut args = HashMap::new();
args.insert("city".to_string(), serde_json::json!("NYC"));
let msgs2 = vec![
Message::User {
content: "Weather in NYC?".into(),
},
Message::Assistant {
content: String::new(),
tool_calls: vec![ToolCall {
id: Some("call_0".into()),
name: "get_weather".into(),
arguments: args,
}],
},
Message::ToolResult {
tool_use_id: "call_0".into(),
content: "72F sunny".into(),
},
];
let out2 = tmpl.render(&msgs2, Some(tool_arr), ThinkingMode::Off).unwrap();
assert!(
out2.contains("<|tool_call>call:get_weather"),
"tool-call grammar not rendered:\n{out2}"
);
assert!(out2.contains("72F sunny"), "tool result not rendered:\n{out2}");
}
#[test]
fn flat_tools_normalize_for_hf_template() {
let src = include_str!("../../tests/fixtures/gemma4_chat_template.jinja");
let tmpl = ChatTemplate::from_source(src.to_string()).unwrap();
let flat = serde_json::json!([{
"name": "write_file",
"description": "Write text to a file",
"parameters": {
"type": "object",
"properties": { "path": { "type": "string", "description": "Target path" } },
"required": ["path"]
}
}]);
assert!(
flat[0].get("function").is_none(),
"fixture must be flat to exercise normalization"
);
let flat_arr = flat.as_array().unwrap().as_slice();
let msgs = vec![Message::User {
content: "Create output.txt".into(),
}];
let out = tmpl
.render(&msgs, Some(flat_arr), ThinkingMode::Off)
.expect("flat tools must render after normalization");
assert!(
out.contains("write_file"),
"flat tool declaration not rendered:\n{out}"
);
let nested = serde_json::json!({ "type": "function", "function": { "name": "x" } });
assert_eq!(normalize_tool(&nested), nested);
}
}