1use async_trait::async_trait;
12use std::path::PathBuf;
13
14use super::chat_template::{format_prompt_with_template, ChatTemplate};
15use super::validate::validate_model_file;
16use super::{CompletionRequest, CompletionResponse, LlmDriver, ToolCall};
17use crate::agent::result::{AgentError, DriverError, StopReason, TokenUsage};
18use crate::serve::backends::PrivacyTier;
19
20pub struct RealizarDriver {
22 model_path: PathBuf,
24 context_window_size: usize,
26 template: ChatTemplate,
28}
29
30impl RealizarDriver {
31 pub fn new(model_path: PathBuf, context_window: Option<usize>) -> Result<Self, AgentError> {
43 if !model_path.exists() {
44 return Err(AgentError::Driver(DriverError::InferenceFailed(format!(
45 "model not found: {}",
46 model_path.display()
47 ))));
48 }
49
50 validate_model_file(&model_path)?;
52
53 let context_window_size = context_window.unwrap_or(4096);
54 let template = ChatTemplate::from_model_path(&model_path);
55 Ok(Self { model_path, context_window_size, template })
56 }
57}
58
59#[async_trait]
60impl LlmDriver for RealizarDriver {
61 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError> {
62 let prompt = format_prompt_with_template(&request, self.template);
64
65 let config = realizar::infer::InferenceConfig {
67 model_path: self.model_path.clone(),
68 prompt: Some(prompt),
69 input_tokens: None,
70 max_tokens: request.max_tokens as usize,
71 temperature: request.temperature,
72 top_k: 0,
73 no_gpu: self.model_path.extension().is_some_and(|e| e == "apr"),
76 trace: false,
77 trace_verbose: false,
78 trace_output: None,
79 trace_steps: None,
80 verbose: false,
81 use_mock_backend: false,
82 stop_tokens: vec![],
83 };
84
85 let result = tokio::task::spawn_blocking(move || realizar::infer::run_inference(&config))
87 .await
88 .map_err(|e| {
89 AgentError::Driver(DriverError::InferenceFailed(format!("spawn_blocking: {e}")))
90 })?
91 .map_err(|e| AgentError::Driver(DriverError::InferenceFailed(e.to_string())))?;
92
93 let (raw_text, tool_calls) = parse_tool_calls(&result.text);
95
96 let text = sanitize_output(&raw_text, request.system.as_deref());
98
99 let stop_reason =
100 if tool_calls.is_empty() { StopReason::EndTurn } else { StopReason::ToolUse };
101
102 Ok(CompletionResponse {
103 text,
104 stop_reason,
105 tool_calls,
106 usage: TokenUsage {
107 input_tokens: result.input_token_count as u64,
108 output_tokens: result.generated_token_count as u64,
109 },
110 })
111 }
112
113 fn context_window(&self) -> usize {
114 self.context_window_size
115 }
116
117 fn privacy_tier(&self) -> PrivacyTier {
118 PrivacyTier::Sovereign
119 }
120}
121
122pub fn parse_tool_calls_pub(text: &str) -> (String, Vec<ToolCall>) {
133 parse_tool_calls(text)
134}
135
136fn parse_tool_calls(text: &str) -> (String, Vec<ToolCall>) {
137 let mut tool_calls = Vec::new();
138 let mut remaining = String::new();
139 let mut call_counter = 0u32;
140
141 let mut cursor = text;
142 loop {
143 let xml_pos = cursor.find("<tool_call>");
145 let md_pos = cursor.find("```json");
146
147 let (start, tag_len, is_markdown) = match (xml_pos, md_pos) {
148 (Some(x), Some(m)) if x <= m => (x, "<tool_call>".len(), false),
149 (Some(x), None) => (x, "<tool_call>".len(), false),
150 (_, Some(m)) => (m, "```json".len(), true),
151 (None, None) => {
152 remaining.push_str(cursor);
153 break;
154 }
155 };
156
157 remaining.push_str(&cursor[..start]);
158 let after_tag = &cursor[start + tag_len..];
159
160 let (json_str, advance_past) = if is_markdown {
162 if let Some(end) = after_tag.find("```") {
164 (&after_tag[..end], &after_tag[end + "```".len()..])
165 } else {
166 (after_tag, "")
167 }
168 } else if let Some(end) = after_tag.find("</tool_call>") {
169 (&after_tag[..end], &after_tag[end + "</tool_call>".len()..])
170 } else {
171 (after_tag, "")
173 };
174 let json_str = json_str.trim();
175
176 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
177 if let Some(name) = parsed.get("name").and_then(|n| n.as_str()) {
179 let name = name.to_string();
180 let input = parsed.get("input").cloned().unwrap_or(serde_json::json!({}));
181 call_counter += 1;
182 tool_calls.push(ToolCall { id: format!("local-{call_counter}"), name, input });
183 } else {
184 remaining.push_str(&cursor[start..]);
185 break;
186 }
187 } else {
188 remaining.push_str(&cursor[start..]);
189 break;
190 }
191
192 cursor = advance_past;
193 if cursor.is_empty() {
194 break;
195 }
196 }
197
198 (remaining.trim().to_string(), tool_calls)
199}
200
201fn sanitize_output(text: &str, system_prompt: Option<&str>) -> String {
207 let mut cleaned = text.to_string();
208
209 if let Some(sys) = system_prompt {
211 let sys_prefix = &sys[..sys.len().min(80)];
213 if cleaned.starts_with(sys_prefix) {
214 cleaned = cleaned[sys.len().min(cleaned.len())..].to_string();
216 }
217 }
218
219 for marker in &[
221 "<|im_start|>",
222 "<|im_end|>",
223 "<|start_header_id|>",
224 "<|end_header_id|>",
225 "<|eot_id|>",
226 "<|system|>",
227 "<|user|>",
228 "<|assistant|>",
229 "<|end|>",
230 ] {
231 cleaned = cleaned.replace(marker, "");
232 }
233
234 let cleaned = cleaned.trim();
236 let cleaned = cleaned.strip_prefix("system\n").unwrap_or(cleaned);
237 let cleaned = cleaned.strip_prefix("assistant\n").unwrap_or(cleaned);
238 cleaned.trim().to_string()
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_parse_no_tool_calls() {
247 let (text, calls) = parse_tool_calls("Hello world");
248 assert_eq!(text, "Hello world");
249 assert!(calls.is_empty());
250 }
251
252 #[test]
253 fn test_parse_single_tool_call() {
254 let input = r#"Before text
255<tool_call>
256{"name": "rag", "input": {"query": "SIMD"}}
257</tool_call>
258After text"#;
259 let (text, calls) = parse_tool_calls(input);
260 assert_eq!(text, "Before text\n\nAfter text");
261 assert_eq!(calls.len(), 1);
262 assert_eq!(calls[0].name, "rag");
263 assert_eq!(calls[0].id, "local-1");
264 assert_eq!(calls[0].input, serde_json::json!({"query": "SIMD"}));
265 }
266
267 #[test]
268 fn test_parse_multiple_tool_calls() {
269 let input = r#"<tool_call>
270{"name": "rag", "input": {"query": "a"}}
271</tool_call>
272Middle
273<tool_call>
274{"name": "memory", "input": {"action": "recall", "query": "b"}}
275</tool_call>"#;
276 let (text, calls) = parse_tool_calls(input);
277 assert_eq!(text, "Middle");
278 assert_eq!(calls.len(), 2);
279 assert_eq!(calls[0].name, "rag");
280 assert_eq!(calls[0].id, "local-1");
281 assert_eq!(calls[1].name, "memory");
282 assert_eq!(calls[1].id, "local-2");
283 }
284
285 #[test]
286 fn test_parse_malformed_json() {
287 let input = r#"<tool_call>
288not valid json
289</tool_call>"#;
290 let (_text, calls) = parse_tool_calls(input);
291 assert!(calls.is_empty());
292 }
293
294 #[test]
295 fn test_parse_missing_close_tag_with_valid_json() {
296 let input =
298 "<tool_call>\n{\"name\": \"file_read\", \"input\": {\"path\": \"src/main.rs\"}}";
299 let (text, calls) = parse_tool_calls(input);
300 assert_eq!(calls.len(), 1, "should extract tool call without closing tag");
301 assert_eq!(calls[0].name, "file_read");
302 assert!(text.is_empty(), "no remaining text expected");
303 }
304
305 #[test]
306 fn test_parse_missing_close_tag_with_trailing_text() {
307 let input =
309 "Let me read that.\n<tool_call> {\"name\": \"file_read\", \"input\": {\"path\": \"foo.rs\"}}";
310 let (text, calls) = parse_tool_calls(input);
311 assert_eq!(calls.len(), 1);
312 assert_eq!(calls[0].name, "file_read");
313 assert!(text.contains("Let me read that"));
314 }
315
316 #[test]
317 fn test_parse_missing_close_tag_invalid_json() {
318 let input = "<tool_call>\nnot valid json at all";
320 let (text, calls) = parse_tool_calls(input);
321 assert!(calls.is_empty(), "invalid JSON should not produce tool call");
322 assert!(text.contains("<tool_call>"));
323 }
324
325 #[test]
326 fn test_parse_markdown_code_block() {
327 let input = "Let me read that file.\n```json\n{\"name\": \"file_read\", \"input\": {\"path\": \"src/main.rs\"}}\n```";
329 let (text, calls) = parse_tool_calls(input);
330 assert_eq!(calls.len(), 1, "should extract tool call from markdown block");
331 assert_eq!(calls[0].name, "file_read");
332 assert_eq!(calls[0].input["path"], "src/main.rs");
333 assert!(text.contains("Let me read that"));
334 }
335
336 #[test]
337 fn test_parse_markdown_code_block_not_tool_call() {
338 let input = "Here's an example:\n```json\n{\"key\": \"value\"}\n```";
340 let (text, calls) = parse_tool_calls(input);
341 assert!(calls.is_empty(), "JSON without name field should not be a tool call");
342 assert!(text.contains("example"));
343 }
344
345 #[test]
346 fn test_parse_missing_name() {
347 let input = r#"<tool_call>
348{"input": {"query": "test"}}
349</tool_call>"#;
350 let (_, calls) = parse_tool_calls(input);
351 assert!(calls.is_empty(), "JSON without name should not be extracted");
352 }
353
354 #[test]
355 fn test_privacy_tier_always_sovereign() {
356 assert_eq!(PrivacyTier::Sovereign, PrivacyTier::Sovereign);
357 }
358
359 #[test]
362 fn test_sanitize_strips_echoed_system_prompt() {
363 let sys = "You are apr code, a sovereign AI coding assistant.";
364 let output = format!("{sys} And then the model continues here.");
365 let cleaned = sanitize_output(&output, Some(sys));
366 assert!(!cleaned.contains("sovereign AI coding assistant"));
367 assert!(cleaned.contains("continues here"));
368 }
369
370 #[test]
371 fn test_sanitize_strips_chat_markers() {
372 let output = "<|im_start|>assistant\nHello world<|im_end|>";
373 let cleaned = sanitize_output(output, None);
374 assert_eq!(cleaned, "Hello world");
375 }
376
377 #[test]
378 fn test_sanitize_preserves_clean_output() {
379 let output = "The answer is 42.";
380 let cleaned = sanitize_output(output, Some("You are helpful."));
381 assert_eq!(cleaned, "The answer is 42.");
382 }
383
384 #[test]
385 fn test_sanitize_strips_role_prefix() {
386 let output = "assistant\nHere is my response.";
387 let cleaned = sanitize_output(output, None);
388 assert_eq!(cleaned, "Here is my response.");
389 }
390}