use serde_json::{json, Value};
use super::message::ChatMessage;
use super::tool_call::ToolDefinition;
use crate::error::{LlamaError, Result};
#[derive(Debug, Clone, Default)]
pub struct OpenAIChatTemplateParams<'a> {
pub messages: &'a [ChatMessage],
pub tools: &'a [ToolDefinition],
pub tool_choice: Option<String>,
pub json_schema: Option<Value>,
pub reasoning_format: Option<String>,
pub chat_template_kwargs: Option<Value>,
pub add_generation_prompt: bool,
pub use_jinja: bool,
pub parallel_tool_calls: bool,
pub add_bos: bool,
pub add_eos: bool,
}
#[derive(Debug, Clone, Default)]
pub struct ChatTemplateResult {
pub prompt: String,
pub grammar: Option<String>,
pub additional_stops: Vec<String>,
}
pub fn apply_chat_template_oaicompat(
params: OpenAIChatTemplateParams<'_>,
) -> Result<ChatTemplateResult> {
let messages_json: Vec<Value> = params
.messages
.iter()
.map(|m| {
json!({
"role": m.role.as_str(),
"content": m.content,
"name": m.name,
"tool_call_id": m.tool_call_id,
"tool_calls": m.tool_calls.iter().map(|c| json!({
"id": c.id,
"type": "function",
"function": {
"name": c.name,
"arguments": c.arguments.to_string(),
}
})).collect::<Vec<_>>(),
})
})
.collect();
let tools_json: Vec<Value> = params
.tools
.iter()
.map(|t| t.to_openai_function())
.collect();
let tpl = jinja_template_with_tools();
let mut env = serde_json::Map::new();
env.insert("messages".into(), Value::Array(messages_json));
env.insert("tools".into(), Value::Array(tools_json));
env.insert(
"add_generation_prompt".into(),
json!(params.add_generation_prompt),
);
if let Some(tc) = ¶ms.tool_choice {
env.insert("tool_choice".into(), json!(tc));
}
if let Some(s) = ¶ms.json_schema {
env.insert("json_schema".into(), s.clone());
}
let _ = Value::Object(env.clone());
let mut prompt = String::new();
if let Some(sys) = params
.messages
.iter()
.find(|m| m.role == super::message::Role::System)
{
prompt.push_str(&format!("[SYSTEM]\n{}\n", sys.content));
}
for m in params.messages {
if m.role == super::message::Role::System {
continue;
}
prompt.push_str(&format!(
"[{}]\n{}\n",
m.role.as_str().to_uppercase(),
m.content
));
}
if !params.tools.is_empty() {
prompt.push_str("\n[TOOLS]\n");
for t in params.tools {
prompt.push_str(&format!("- {}: {}\n", t.name, t.description));
}
}
if params.add_generation_prompt {
prompt.push_str("\n[ASSISTANT]\n");
}
let _ = tpl;
Ok(ChatTemplateResult {
prompt,
..Default::default()
})
}
fn jinja_template_with_tools() -> &'static str {
r#"
{% for m in messages %}
{% if m.role == "system" %}<|im_start|>system
{{ m.content }}<|im_end|>
{% elif m.role == "user" %}<|im_start|>user
{{ m.content }}<|im_end|>
{% elif m.role == "assistant" %}<|im_start|>assistant
{{ m.content }}<|im_end|>
{% elif m.role == "tool" %}<|im_start|>tool
{{ m.content }}<|im_end|>
{% endif %}
{% endfor %}{% if tools %}<|im_start|>system
{% for t in tools %}{{ t.function.name }}: {{ t.function.description }}
{% endfor %}<|im_end|>
{% endif %}{% if add_generation_prompt %}<|im_start|>assistant
{% endif %}
"#
}