1use super::{CompletionRequest, Message, ToolDefinition};
16
17#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum ChatTemplate {
20 ChatMl,
22 Llama3,
24 Generic,
26}
27
28impl ChatTemplate {
29 pub fn from_model_path(path: &std::path::Path) -> Self {
31 let name = path.file_stem().map(|s| s.to_string_lossy().to_lowercase()).unwrap_or_default();
32
33 if name.contains("qwen") || name.contains("deepseek") || name.contains("yi-") {
34 Self::ChatMl
35 } else if name.contains("llama") {
36 Self::Llama3
37 } else {
38 Self::ChatMl
39 }
40 }
41}
42
43pub fn format_prompt_with_template(request: &CompletionRequest, template: ChatTemplate) -> String {
50 let enriched_system = build_enriched_system(&request.system, &request.tools);
52 let enriched_request = CompletionRequest {
53 system: Some(enriched_system),
54 model: request.model.clone(),
55 messages: request.messages.clone(),
56 tools: request.tools.clone(),
57 max_tokens: request.max_tokens,
58 temperature: request.temperature,
59 };
60
61 match template {
62 ChatTemplate::ChatMl => format_chatml(&enriched_request),
63 ChatTemplate::Llama3 => format_llama3(&enriched_request),
64 ChatTemplate::Generic => format_generic(&enriched_request),
65 }
66}
67
68fn build_enriched_system(base_system: &Option<String>, tools: &[ToolDefinition]) -> String {
75 let mut system = base_system.clone().unwrap_or_default();
76
77 if tools.is_empty() {
78 return system;
79 }
80
81 system.push_str("\n\n## Available Tools\n\n");
83 system.push_str(
84 "To use a tool, output a <tool_call> block with JSON inside. \
85 You will receive the result in a <tool_result> block.\n\n",
86 );
87 system.push_str("Format:\n```\n<tool_call>\n{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}\n</tool_call>\n```\n\n");
88
89 for tool in tools {
90 system.push_str(&format!("### {}\n{}\n", tool.name, tool.description));
91 if let Some(props) = tool.input_schema.get("properties") {
93 system.push_str(&format!("Parameters: {}\n\n", compact_schema(props)));
94 } else {
95 system.push('\n');
96 }
97 }
98
99 system.push_str(
100 "After receiving a <tool_result>, analyze it and either use another tool or respond to the user.\n",
101 );
102
103 system
104}
105
106fn compact_schema(props: &serde_json::Value) -> String {
108 if let Some(obj) = props.as_object() {
109 let params: Vec<String> = obj
110 .iter()
111 .map(|(k, v)| {
112 let typ = v.get("type").and_then(|t| t.as_str()).unwrap_or("string");
113 let desc = v.get("description").and_then(|d| d.as_str()).unwrap_or("");
114 if desc.is_empty() {
115 format!("{k}: {typ}")
116 } else {
117 format!("{k} ({typ}): {desc}")
118 }
119 })
120 .collect();
121 format!("{{{}}}", params.join(", "))
122 } else {
123 props.to_string()
124 }
125}
126
127fn format_chatml(request: &CompletionRequest) -> String {
129 let mut prompt = String::new();
130
131 if let Some(ref system) = request.system {
132 prompt.push_str(&format!("<|im_start|>system\n{system}<|im_end|>\n"));
133 }
134
135 for msg in &request.messages {
136 match msg {
137 Message::System(s) => {
138 prompt.push_str(&format!("<|im_start|>system\n{s}<|im_end|>\n"));
139 }
140 Message::User(s) => {
141 prompt.push_str(&format!("<|im_start|>user\n{s}<|im_end|>\n"));
142 }
143 Message::Assistant(s) => {
144 prompt.push_str(&format!("<|im_start|>assistant\n{s}<|im_end|>\n"));
145 }
146 Message::AssistantToolUse(call) => {
147 prompt.push_str(&format!(
148 "<|im_start|>assistant\n<tool_call>\n{}\n</tool_call><|im_end|>\n",
149 serde_json::json!({"name": call.name, "input": call.input})
150 ));
151 }
152 Message::ToolResult(result) => {
153 prompt.push_str(&format!(
154 "<|im_start|>user\n<tool_result>{}</tool_result><|im_end|>\n",
155 result.content
156 ));
157 }
158 }
159 }
160
161 prompt.push_str("<|im_start|>assistant\n");
162 prompt
163}
164
165fn format_llama3(request: &CompletionRequest) -> String {
167 let mut prompt = String::new();
168 prompt.push_str("<|begin_of_text|>");
169
170 if let Some(ref system) = request.system {
171 prompt
172 .push_str(&format!("<|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>"));
173 }
174
175 for msg in &request.messages {
176 match msg {
177 Message::System(s) => {
178 prompt.push_str(&format!(
179 "<|start_header_id|>system<|end_header_id|>\n\n{s}<|eot_id|>"
180 ));
181 }
182 Message::User(s) => {
183 prompt.push_str(&format!(
184 "<|start_header_id|>user<|end_header_id|>\n\n{s}<|eot_id|>"
185 ));
186 }
187 Message::Assistant(s) => {
188 prompt.push_str(&format!(
189 "<|start_header_id|>assistant<|end_header_id|>\n\n{s}<|eot_id|>"
190 ));
191 }
192 Message::AssistantToolUse(call) => {
193 prompt.push_str(&format!(
194 "<|start_header_id|>assistant<|end_header_id|>\n\n<tool_call>\n{}\n</tool_call><|eot_id|>",
195 serde_json::json!({"name": call.name, "input": call.input})
196 ));
197 }
198 Message::ToolResult(result) => {
199 prompt.push_str(&format!(
200 "<|start_header_id|>user<|end_header_id|>\n\n<tool_result>{}</tool_result><|eot_id|>",
201 result.content
202 ));
203 }
204 }
205 }
206
207 prompt.push_str("<|start_header_id|>assistant<|end_header_id|>\n\n");
208 prompt
209}
210
211fn format_generic(request: &CompletionRequest) -> String {
213 let mut prompt = String::new();
214
215 if let Some(ref system) = request.system {
216 prompt.push_str(&format!("<|system|>\n{system}\n<|end|>\n"));
217 }
218
219 for msg in &request.messages {
220 match msg {
221 Message::System(s) => {
222 prompt.push_str(&format!("<|system|>\n{s}\n<|end|>\n"));
223 }
224 Message::User(s) => {
225 prompt.push_str(&format!("<|user|>\n{s}\n<|end|>\n"));
226 }
227 Message::Assistant(s) => {
228 prompt.push_str(&format!("<|assistant|>\n{s}\n<|end|>\n"));
229 }
230 Message::AssistantToolUse(call) => {
231 prompt.push_str(&format!(
232 "<|assistant|>\n<tool_call>\n{}\n</tool_call>\n<|end|>\n",
233 serde_json::json!({"name": call.name, "input": call.input})
234 ));
235 }
236 Message::ToolResult(result) => {
237 prompt.push_str(&format!(
238 "<|user|>\n<tool_result>{}</tool_result>\n<|end|>\n",
239 result.content
240 ));
241 }
242 }
243 }
244
245 prompt.push_str("<|assistant|>\n");
246 prompt
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use crate::agent::driver::ToolCall;
253
254 fn sample_tools() -> Vec<ToolDefinition> {
255 vec![
256 ToolDefinition {
257 name: "file_read".into(),
258 description: "Read file contents".into(),
259 input_schema: serde_json::json!({
260 "type": "object",
261 "properties": {
262 "path": {"type": "string", "description": "File path to read"}
263 }
264 }),
265 },
266 ToolDefinition {
267 name: "shell".into(),
268 description: "Execute shell command".into(),
269 input_schema: serde_json::json!({
270 "type": "object",
271 "properties": {
272 "command": {"type": "string", "description": "Command to run"}
273 }
274 }),
275 },
276 ]
277 }
278
279 #[test]
280 fn test_tool_definitions_injected_into_system() {
281 let request = CompletionRequest {
282 model: "test".into(),
283 messages: vec![Message::User("Hello".into())],
284 tools: sample_tools(),
285 max_tokens: 100,
286 temperature: 0.5,
287 system: Some("You are helpful".into()),
288 };
289 let prompt = format_prompt_with_template(&request, ChatTemplate::ChatMl);
290 assert!(prompt.contains("file_read"), "tool name missing");
291 assert!(prompt.contains("Read file contents"), "tool description missing");
292 assert!(prompt.contains("shell"), "second tool missing");
293 assert!(prompt.contains("<tool_call>"), "tool call format missing");
294 assert!(prompt.contains("tool_result"), "tool result format missing");
295 assert!(prompt.contains("path (string): File path to read"), "schema missing");
296 }
297
298 #[test]
299 fn test_no_tools_no_injection() {
300 let request = CompletionRequest {
301 model: "test".into(),
302 messages: vec![Message::User("Hello".into())],
303 tools: vec![],
304 max_tokens: 100,
305 temperature: 0.5,
306 system: Some("You are helpful".into()),
307 };
308 let prompt = format_prompt_with_template(&request, ChatTemplate::ChatMl);
309 assert!(prompt.contains("You are helpful"));
310 assert!(!prompt.contains("Available Tools"), "no tools = no injection");
311 }
312
313 #[test]
314 fn test_compact_schema() {
315 let props = serde_json::json!({
316 "path": {"type": "string", "description": "File to read"},
317 "limit": {"type": "integer"}
318 });
319 let result = compact_schema(&props);
320 assert!(result.contains("path (string): File to read"));
321 assert!(result.contains("limit: integer"));
322 }
323
324 #[test]
325 fn test_format_prompt_chatml() {
326 let request = CompletionRequest {
327 model: "test".into(),
328 messages: vec![Message::User("Hello".into())],
329 tools: vec![],
330 max_tokens: 100,
331 temperature: 0.5,
332 system: Some("You are helpful".into()),
333 };
334 let prompt = format_chatml(&request);
335 assert!(prompt.contains("<|im_start|>system"));
336 assert!(prompt.contains("You are helpful"));
337 assert!(prompt.contains("<|im_start|>user"));
338 assert!(prompt.contains("Hello"));
339 assert!(prompt.ends_with("<|im_start|>assistant\n"));
340 }
341
342 #[test]
343 fn test_format_prompt_llama3() {
344 let request = CompletionRequest {
345 model: "test".into(),
346 messages: vec![Message::User("Hello".into())],
347 tools: vec![],
348 max_tokens: 100,
349 temperature: 0.5,
350 system: Some("Be helpful".into()),
351 };
352 let prompt = format_llama3(&request);
353 assert!(prompt.starts_with("<|begin_of_text|>"));
354 assert!(prompt.contains("<|start_header_id|>system<|end_header_id|>"));
355 assert!(prompt.contains("Be helpful"));
356 assert!(prompt.contains("<|start_header_id|>user<|end_header_id|>"));
357 assert!(prompt.contains("Hello"));
358 assert!(prompt.ends_with("<|start_header_id|>assistant<|end_header_id|>\n\n"));
359 }
360
361 #[test]
362 fn test_format_prompt_generic_fallback() {
363 let request = CompletionRequest {
364 model: "test".into(),
365 messages: vec![Message::User("Hello".into())],
366 tools: vec![],
367 max_tokens: 100,
368 temperature: 0.5,
369 system: Some("You are helpful".into()),
370 };
371 let prompt = format_generic(&request);
372 assert!(prompt.contains("<|system|>"));
373 assert!(prompt.contains("<|user|>"));
374 assert!(prompt.ends_with("<|assistant|>\n"));
375 }
376
377 #[test]
378 fn test_format_prompt_tool_messages() {
379 let request = CompletionRequest {
380 model: "test".into(),
381 messages: vec![
382 Message::AssistantToolUse(ToolCall {
383 id: "1".into(),
384 name: "rag".into(),
385 input: serde_json::json!({"query": "test"}),
386 }),
387 Message::ToolResult(crate::agent::driver::ToolResultMsg {
388 tool_use_id: "1".into(),
389 content: "result data".into(),
390 is_error: false,
391 }),
392 ],
393 tools: vec![],
394 max_tokens: 100,
395 temperature: 0.5,
396 system: None,
397 };
398 for template in [ChatTemplate::ChatMl, ChatTemplate::Llama3, ChatTemplate::Generic] {
399 let prompt = format_prompt_with_template(&request, template);
400 assert!(prompt.contains("<tool_call>"), "missing tool_call in {template:?}");
401 assert!(prompt.contains("<tool_result>"), "missing tool_result in {template:?}");
402 assert!(prompt.contains("result data"), "missing result data in {template:?}");
403 }
404 }
405
406 #[test]
407 fn test_chat_template_detection() {
408 use std::path::Path;
409 assert_eq!(
410 ChatTemplate::from_model_path(Path::new("qwen2.5-coder-7b.gguf")),
411 ChatTemplate::ChatMl
412 );
413 assert_eq!(
414 ChatTemplate::from_model_path(Path::new("Qwen3-8B-Q4K.apr")),
415 ChatTemplate::ChatMl
416 );
417 assert_eq!(
418 ChatTemplate::from_model_path(Path::new("deepseek-coder-v2.gguf")),
419 ChatTemplate::ChatMl
420 );
421 assert_eq!(
422 ChatTemplate::from_model_path(Path::new("llama-3.2-3b.gguf")),
423 ChatTemplate::Llama3
424 );
425 assert_eq!(
426 ChatTemplate::from_model_path(Path::new("Meta-Llama-3-8B.apr")),
427 ChatTemplate::Llama3
428 );
429 assert_eq!(ChatTemplate::from_model_path(Path::new("yi-34b.gguf")), ChatTemplate::ChatMl);
430 assert_eq!(
431 ChatTemplate::from_model_path(Path::new("custom-model.gguf")),
432 ChatTemplate::ChatMl
433 );
434 }
435}
436
437#[cfg(test)]
438#[path = "chat_template_contract_tests.rs"]
439mod contract_tests;