Skip to main content

batuta/agent/driver/
chat_template.rs

1//! Chat template formatting for local LLM inference.
2//!
3//! Different model families require different prompt formats.
4//! Auto-detects the template from the model filename:
5//! - Qwen (2.5, 3.x), DeepSeek, Yi → ChatML
6//! - Llama → Llama 3.x
7//! - Unknown → ChatML (most widely supported)
8//!
9//! Qwen3 uses ChatML with native `<tool_call>` support. Thinking mode
10//! (`<think>...</think>`) is controlled by generation params, not template.
11//! PMAT-179: Default model is Qwen3 1.7B (0.960 tool-calling score).
12//!
13//! See: apr-code.md §5.1
14
15use super::{CompletionRequest, Message, ToolDefinition};
16
17/// Chat template family, auto-detected from model filename.
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum ChatTemplate {
20    /// ChatML: `<|im_start|>role\ncontent<|im_end|>` (Qwen, Yi, Deepseek)
21    ChatMl,
22    /// Llama 3.x: `<|start_header_id|>role<|end_header_id|>\ncontent<|eot_id|>`
23    Llama3,
24    /// Generic: `<|system|>\ncontent\n<|end|>` (fallback)
25    Generic,
26}
27
28impl ChatTemplate {
29    /// Detect template from model filename.
30    pub fn from_model_path(path: &std::path::Path) -> Self {
31        let name = path.file_stem().map(|s| s.to_string_lossy().to_lowercase()).unwrap_or_default();
32
33        if name.contains("qwen") || name.contains("deepseek") || name.contains("yi-") {
34            Self::ChatMl
35        } else if name.contains("llama") {
36            Self::Llama3
37        } else {
38            Self::ChatMl
39        }
40    }
41}
42
43/// Format messages using a specific chat template.
44///
45/// For local models, tool definitions from `request.tools` are injected
46/// into the system prompt so the model knows what tools exist and how
47/// to invoke them via `<tool_call>` blocks. API-based drivers handle
48/// tools natively; local models need this explicit injection.
49pub fn format_prompt_with_template(request: &CompletionRequest, template: ChatTemplate) -> String {
50    // Build enriched system prompt with tool definitions
51    let enriched_system = build_enriched_system(&request.system, &request.tools);
52    let enriched_request = CompletionRequest {
53        system: Some(enriched_system),
54        model: request.model.clone(),
55        messages: request.messages.clone(),
56        tools: request.tools.clone(),
57        max_tokens: request.max_tokens,
58        temperature: request.temperature,
59    };
60
61    match template {
62        ChatTemplate::ChatMl => format_chatml(&enriched_request),
63        ChatTemplate::Llama3 => format_llama3(&enriched_request),
64        ChatTemplate::Generic => format_generic(&enriched_request),
65    }
66}
67
68/// Build an enriched system prompt with tool definitions appended.
69///
70/// Local models need explicit tool definitions in text form — unlike
71/// API models (Anthropic/OpenAI) which accept tools as structured params.
72/// The format teaches the model to emit `<tool_call>` blocks that
73/// `parse_tool_calls()` in realizar.rs can extract.
74fn build_enriched_system(base_system: &Option<String>, tools: &[ToolDefinition]) -> String {
75    let mut system = base_system.clone().unwrap_or_default();
76
77    if tools.is_empty() {
78        return system;
79    }
80
81    // Append tool definitions
82    system.push_str("\n\n## Available Tools\n\n");
83    system.push_str(
84        "To use a tool, output a <tool_call> block with JSON inside. \
85         You will receive the result in a <tool_result> block.\n\n",
86    );
87    system.push_str("Format:\n```\n<tool_call>\n{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}\n</tool_call>\n```\n\n");
88
89    for tool in tools {
90        system.push_str(&format!("### {}\n{}\n", tool.name, tool.description));
91        // Compact JSON schema — only include properties, not full schema boilerplate
92        if let Some(props) = tool.input_schema.get("properties") {
93            system.push_str(&format!("Parameters: {}\n\n", compact_schema(props)));
94        } else {
95            system.push('\n');
96        }
97    }
98
99    system.push_str(
100        "After receiving a <tool_result>, analyze it and either use another tool or respond to the user.\n",
101    );
102
103    system
104}
105
106/// Compact a JSON schema properties object into a readable summary.
107fn compact_schema(props: &serde_json::Value) -> String {
108    if let Some(obj) = props.as_object() {
109        let params: Vec<String> = obj
110            .iter()
111            .map(|(k, v)| {
112                let typ = v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
113                let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
114                if desc.is_empty() {
115                    format!("{k}: {typ}")
116                } else {
117                    format!("{k} ({typ}): {desc}")
118                }
119            })
120            .collect();
121        format!("{{{}}}", params.join(", "))
122    } else {
123        props.to_string()
124    }
125}
126
127/// ChatML format (Qwen, DeepSeek, Yi).
128fn format_chatml(request: &CompletionRequest) -> String {
129    let mut prompt = String::new();
130
131    if let Some(ref system) = request.system {
132        prompt.push_str(&format!("<|im_start|>system\n{system}<|im_end|>\n"));
133    }
134
135    for msg in &request.messages {
136        match msg {
137            Message::System(s) => {
138                prompt.push_str(&format!("<|im_start|>system\n{s}<|im_end|>\n"));
139            }
140            Message::User(s) => {
141                prompt.push_str(&format!("<|im_start|>user\n{s}<|im_end|>\n"));
142            }
143            Message::Assistant(s) => {
144                prompt.push_str(&format!("<|im_start|>assistant\n{s}<|im_end|>\n"));
145            }
146            Message::AssistantToolUse(call) => {
147                prompt.push_str(&format!(
148                    "<|im_start|>assistant\n<tool_call>\n{}\n</tool_call><|im_end|>\n",
149                    serde_json::json!({"name": call.name, "input": call.input})
150                ));
151            }
152            Message::ToolResult(result) => {
153                prompt.push_str(&format!(
154                    "<|im_start|>user\n<tool_result>{}</tool_result><|im_end|>\n",
155                    result.content
156                ));
157            }
158        }
159    }
160
161    prompt.push_str("<|im_start|>assistant\n");
162    prompt
163}
164
165/// Llama 3.x format.
166fn format_llama3(request: &CompletionRequest) -> String {
167    let mut prompt = String::new();
168    prompt.push_str("<|begin_of_text|>");
169
170    if let Some(ref system) = request.system {
171        prompt
172            .push_str(&format!("<|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>"));
173    }
174
175    for msg in &request.messages {
176        match msg {
177            Message::System(s) => {
178                prompt.push_str(&format!(
179                    "<|start_header_id|>system<|end_header_id|>\n\n{s}<|eot_id|>"
180                ));
181            }
182            Message::User(s) => {
183                prompt.push_str(&format!(
184                    "<|start_header_id|>user<|end_header_id|>\n\n{s}<|eot_id|>"
185                ));
186            }
187            Message::Assistant(s) => {
188                prompt.push_str(&format!(
189                    "<|start_header_id|>assistant<|end_header_id|>\n\n{s}<|eot_id|>"
190                ));
191            }
192            Message::AssistantToolUse(call) => {
193                prompt.push_str(&format!(
194                    "<|start_header_id|>assistant<|end_header_id|>\n\n<tool_call>\n{}\n</tool_call><|eot_id|>",
195                    serde_json::json!({"name": call.name, "input": call.input})
196                ));
197            }
198            Message::ToolResult(result) => {
199                prompt.push_str(&format!(
200                    "<|start_header_id|>user<|end_header_id|>\n\n<tool_result>{}</tool_result><|eot_id|>",
201                    result.content
202                ));
203            }
204        }
205    }
206
207    prompt.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
208    prompt
209}
210
211/// Generic fallback format.
212fn format_generic(request: &CompletionRequest) -> String {
213    let mut prompt = String::new();
214
215    if let Some(ref system) = request.system {
216        prompt.push_str(&format!("<|system|>\n{system}\n<|end|>\n"));
217    }
218
219    for msg in &request.messages {
220        match msg {
221            Message::System(s) => {
222                prompt.push_str(&format!("<|system|>\n{s}\n<|end|>\n"));
223            }
224            Message::User(s) => {
225                prompt.push_str(&format!("<|user|>\n{s}\n<|end|>\n"));
226            }
227            Message::Assistant(s) => {
228                prompt.push_str(&format!("<|assistant|>\n{s}\n<|end|>\n"));
229            }
230            Message::AssistantToolUse(call) => {
231                prompt.push_str(&format!(
232                    "<|assistant|>\n<tool_call>\n{}\n</tool_call>\n<|end|>\n",
233                    serde_json::json!({"name": call.name, "input": call.input})
234                ));
235            }
236            Message::ToolResult(result) => {
237                prompt.push_str(&format!(
238                    "<|user|>\n<tool_result>{}</tool_result>\n<|end|>\n",
239                    result.content
240                ));
241            }
242        }
243    }
244
245    prompt.push_str("<|assistant|>\n");
246    prompt
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use crate::agent::driver::ToolCall;
253
254    fn sample_tools() -> Vec<ToolDefinition> {
255        vec![
256            ToolDefinition {
257                name: "file_read".into(),
258                description: "Read file contents".into(),
259                input_schema: serde_json::json!({
260                    "type": "object",
261                    "properties": {
262                        "path": {"type": "string", "description": "File path to read"}
263                    }
264                }),
265            },
266            ToolDefinition {
267                name: "shell".into(),
268                description: "Execute shell command".into(),
269                input_schema: serde_json::json!({
270                    "type": "object",
271                    "properties": {
272                        "command": {"type": "string", "description": "Command to run"}
273                    }
274                }),
275            },
276        ]
277    }
278
279    #[test]
280    fn test_tool_definitions_injected_into_system() {
281        let request = CompletionRequest {
282            model: "test".into(),
283            messages: vec![Message::User("Hello".into())],
284            tools: sample_tools(),
285            max_tokens: 100,
286            temperature: 0.5,
287            system: Some("You are helpful".into()),
288        };
289        let prompt = format_prompt_with_template(&request, ChatTemplate::ChatMl);
290        assert!(prompt.contains("file_read"), "tool name missing");
291        assert!(prompt.contains("Read file contents"), "tool description missing");
292        assert!(prompt.contains("shell"), "second tool missing");
293        assert!(prompt.contains("<tool_call>"), "tool call format missing");
294        assert!(prompt.contains("tool_result"), "tool result format missing");
295        assert!(prompt.contains("path (string): File path to read"), "schema missing");
296    }
297
298    #[test]
299    fn test_no_tools_no_injection() {
300        let request = CompletionRequest {
301            model: "test".into(),
302            messages: vec![Message::User("Hello".into())],
303            tools: vec![],
304            max_tokens: 100,
305            temperature: 0.5,
306            system: Some("You are helpful".into()),
307        };
308        let prompt = format_prompt_with_template(&request, ChatTemplate::ChatMl);
309        assert!(prompt.contains("You are helpful"));
310        assert!(!prompt.contains("Available Tools"), "no tools = no injection");
311    }
312
313    #[test]
314    fn test_compact_schema() {
315        let props = serde_json::json!({
316            "path": {"type": "string", "description": "File to read"},
317            "limit": {"type": "integer"}
318        });
319        let result = compact_schema(&props);
320        assert!(result.contains("path (string): File to read"));
321        assert!(result.contains("limit: integer"));
322    }
323
324    #[test]
325    fn test_format_prompt_chatml() {
326        let request = CompletionRequest {
327            model: "test".into(),
328            messages: vec![Message::User("Hello".into())],
329            tools: vec![],
330            max_tokens: 100,
331            temperature: 0.5,
332            system: Some("You are helpful".into()),
333        };
334        let prompt = format_chatml(&request);
335        assert!(prompt.contains("<|im_start|>system"));
336        assert!(prompt.contains("You are helpful"));
337        assert!(prompt.contains("<|im_start|>user"));
338        assert!(prompt.contains("Hello"));
339        assert!(prompt.ends_with("<|im_start|>assistant\n"));
340    }
341
342    #[test]
343    fn test_format_prompt_llama3() {
344        let request = CompletionRequest {
345            model: "test".into(),
346            messages: vec![Message::User("Hello".into())],
347            tools: vec![],
348            max_tokens: 100,
349            temperature: 0.5,
350            system: Some("Be helpful".into()),
351        };
352        let prompt = format_llama3(&request);
353        assert!(prompt.starts_with("<|begin_of_text|>"));
354        assert!(prompt.contains("<|start_header_id|>system<|end_header_id|>"));
355        assert!(prompt.contains("Be helpful"));
356        assert!(prompt.contains("<|start_header_id|>user<|end_header_id|>"));
357        assert!(prompt.contains("Hello"));
358        assert!(prompt.ends_with("<|start_header_id|>assistant<|end_header_id|>\n\n"));
359    }
360
361    #[test]
362    fn test_format_prompt_generic_fallback() {
363        let request = CompletionRequest {
364            model: "test".into(),
365            messages: vec![Message::User("Hello".into())],
366            tools: vec![],
367            max_tokens: 100,
368            temperature: 0.5,
369            system: Some("You are helpful".into()),
370        };
371        let prompt = format_generic(&request);
372        assert!(prompt.contains("<|system|>"));
373        assert!(prompt.contains("<|user|>"));
374        assert!(prompt.ends_with("<|assistant|>\n"));
375    }
376
377    #[test]
378    fn test_format_prompt_tool_messages() {
379        let request = CompletionRequest {
380            model: "test".into(),
381            messages: vec![
382                Message::AssistantToolUse(ToolCall {
383                    id: "1".into(),
384                    name: "rag".into(),
385                    input: serde_json::json!({"query": "test"}),
386                }),
387                Message::ToolResult(crate::agent::driver::ToolResultMsg {
388                    tool_use_id: "1".into(),
389                    content: "result data".into(),
390                    is_error: false,
391                }),
392            ],
393            tools: vec![],
394            max_tokens: 100,
395            temperature: 0.5,
396            system: None,
397        };
398        for template in [ChatTemplate::ChatMl, ChatTemplate::Llama3, ChatTemplate::Generic] {
399            let prompt = format_prompt_with_template(&request, template);
400            assert!(prompt.contains("<tool_call>"), "missing tool_call in {template:?}");
401            assert!(prompt.contains("<tool_result>"), "missing tool_result in {template:?}");
402            assert!(prompt.contains("result data"), "missing result data in {template:?}");
403        }
404    }
405
406    #[test]
407    fn test_chat_template_detection() {
408        use std::path::Path;
409        assert_eq!(
410            ChatTemplate::from_model_path(Path::new("qwen2.5-coder-7b.gguf")),
411            ChatTemplate::ChatMl
412        );
413        assert_eq!(
414            ChatTemplate::from_model_path(Path::new("Qwen3-8B-Q4K.apr")),
415            ChatTemplate::ChatMl
416        );
417        assert_eq!(
418            ChatTemplate::from_model_path(Path::new("deepseek-coder-v2.gguf")),
419            ChatTemplate::ChatMl
420        );
421        assert_eq!(
422            ChatTemplate::from_model_path(Path::new("llama-3.2-3b.gguf")),
423            ChatTemplate::Llama3
424        );
425        assert_eq!(
426            ChatTemplate::from_model_path(Path::new("Meta-Llama-3-8B.apr")),
427            ChatTemplate::Llama3
428        );
429        assert_eq!(ChatTemplate::from_model_path(Path::new("yi-34b.gguf")), ChatTemplate::ChatMl);
430        assert_eq!(
431            ChatTemplate::from_model_path(Path::new("custom-model.gguf")),
432            ChatTemplate::ChatMl
433        );
434    }
435}
436
437#[cfg(test)]
438#[path = "chat_template_contract_tests.rs"]
439mod contract_tests;