Skip to main content

batuta/agent/driver/
realizar.rs

1//! RealizarDriver — sovereign local inference via GGUF/APR models.
2//!
3//! Uses the `realizar` crate for local LLM inference. All data
4//! stays on-device (Sovereign privacy tier, Genchi Genbutsu).
5//!
6//! Tool call parsing: local models output `<tool_call>` JSON blocks
7//! in their text. The driver extracts these into `ToolCall` structs.
8//!
9//! Feature-gated behind `inference`.
10
11use async_trait::async_trait;
12use std::path::PathBuf;
13
14use super::chat_template::{format_prompt_with_template, ChatTemplate};
15use super::validate::validate_model_file;
16use super::{CompletionRequest, CompletionResponse, LlmDriver, ToolCall};
17use crate::agent::result::{AgentError, DriverError, StopReason, TokenUsage};
18use crate::serve::backends::PrivacyTier;
19
20/// Local inference driver using realizar (GGUF/APR/SafeTensors).
21pub struct RealizarDriver {
22    /// Path to model file.
23    model_path: PathBuf,
24    /// Context window size.
25    context_window_size: usize,
26    /// Auto-detected chat template.
27    template: ChatTemplate,
28}
29
30impl RealizarDriver {
31    /// Create a new RealizarDriver from a model path.
32    ///
33    /// **Contract: `apr_model_validity` (apr-code-v1.yaml)**
34    ///
35    /// Preconditions enforced at the load boundary (Jidoka):
36    /// - File must exist
37    /// - APR files: must have embedded tokenizer (checked via header)
38    /// - GGUF files: must have valid magic bytes
39    ///
40    /// Violation → actionable error with re-conversion instructions.
41    /// No broken model ever reaches the inference loop.
42    pub fn new(model_path: PathBuf, context_window: Option<usize>) -> Result<Self, AgentError> {
43        if !model_path.exists() {
44            return Err(AgentError::Driver(DriverError::InferenceFailed(format!(
45                "model not found: {}",
46                model_path.display()
47            ))));
48        }
49
50        // ═══ CONTRACT: apr_model_validity — Jidoka boundary check ═══
51        validate_model_file(&model_path)?;
52
53        let context_window_size = context_window.unwrap_or(4096);
54        let template = ChatTemplate::from_model_path(&model_path);
55        Ok(Self { model_path, context_window_size, template })
56    }
57}
58
59#[async_trait]
60impl LlmDriver for RealizarDriver {
61    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError> {
62        // Format messages using auto-detected chat template
63        let prompt = format_prompt_with_template(&request, self.template);
64
65        // Build inference config (explicit fields — no Default impl)
66        let config = realizar::infer::InferenceConfig {
67            model_path: self.model_path.clone(),
68            prompt: Some(prompt),
69            input_tokens: None,
70            max_tokens: request.max_tokens as usize,
71            temperature: request.temperature,
72            top_k: 0,
73            // PMAT-156/158: Disable GPU only for APR models (wgpu shader bug).
74            // GGUF models work fine with CUDA — keep GPU enabled for them.
75            no_gpu: self.model_path.extension().is_some_and(|e| e == "apr"),
76            trace: false,
77            trace_verbose: false,
78            trace_output: None,
79            trace_steps: None,
80            verbose: false,
81            use_mock_backend: false,
82            stop_tokens: vec![],
83        };
84
85        // Run inference in blocking thread (realizar is sync)
86        let result = tokio::task::spawn_blocking(move || realizar::infer::run_inference(&config))
87            .await
88            .map_err(|e| {
89                AgentError::Driver(DriverError::InferenceFailed(format!("spawn_blocking: {e}")))
90            })?
91            .map_err(|e| AgentError::Driver(DriverError::InferenceFailed(e.to_string())))?;
92
93        // Parse tool calls from text output
94        let (raw_text, tool_calls) = parse_tool_calls(&result.text);
95
96        // Sanitize output: strip echoed system prompt and chat template markers
97        let text = sanitize_output(&raw_text, request.system.as_deref());
98
99        let stop_reason =
100            if tool_calls.is_empty() { StopReason::EndTurn } else { StopReason::ToolUse };
101
102        Ok(CompletionResponse {
103            text,
104            stop_reason,
105            tool_calls,
106            usage: TokenUsage {
107                input_tokens: result.input_token_count as u64,
108                output_tokens: result.generated_token_count as u64,
109            },
110        })
111    }
112
113    fn context_window(&self) -> usize {
114        self.context_window_size
115    }
116
117    fn privacy_tier(&self) -> PrivacyTier {
118        PrivacyTier::Sovereign
119    }
120}
121
122/// Parse tool calls from model output text.
123///
124/// Supports multiple formats (PMAT-158):
125/// 1. `<tool_call>{"name":...}</tool_call>` — custom XML tags
126/// 2. `<tool_call>{"name":...}` — unclosed XML (small model fallback)
127/// 3. `` ```json\n{"name":...}\n``` `` — markdown code block (Qwen native)
128///
129/// Returns the remaining text (with tool call blocks removed)
130/// and the extracted tool calls.
131/// Public wrapper for tool call parsing (used by AprServeDriver).
132pub fn parse_tool_calls_pub(text: &str) -> (String, Vec<ToolCall>) {
133    parse_tool_calls(text)
134}
135
136fn parse_tool_calls(text: &str) -> (String, Vec<ToolCall>) {
137    let mut tool_calls = Vec::new();
138    let mut remaining = String::new();
139    let mut call_counter = 0u32;
140
141    let mut cursor = text;
142    loop {
143        // Find next tool call start — try <tool_call> first, then ```json
144        let xml_pos = cursor.find("<tool_call>");
145        let md_pos = cursor.find("```json");
146
147        let (start, tag_len, is_markdown) = match (xml_pos, md_pos) {
148            (Some(x), Some(m)) if x <= m => (x, "<tool_call>".len(), false),
149            (Some(x), None) => (x, "<tool_call>".len(), false),
150            (_, Some(m)) => (m, "```json".len(), true),
151            (None, None) => {
152                remaining.push_str(cursor);
153                break;
154            }
155        };
156
157        remaining.push_str(&cursor[..start]);
158        let after_tag = &cursor[start + tag_len..];
159
160        // Find closing tag and extract JSON
161        let (json_str, advance_past) = if is_markdown {
162            // Markdown: ```json\n...\n```
163            if let Some(end) = after_tag.find("```") {
164                (&after_tag[..end], &after_tag[end + "```".len()..])
165            } else {
166                (after_tag, "")
167            }
168        } else if let Some(end) = after_tag.find("</tool_call>") {
169            (&after_tag[..end], &after_tag[end + "</tool_call>".len()..])
170        } else {
171            // PMAT-158: No closing tag — try parsing to end-of-string
172            (after_tag, "")
173        };
174        let json_str = json_str.trim();
175
176        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
177            // Must have "name" field to be a tool call (not just any JSON)
178            if let Some(name) = parsed.get("name").and_then(|n| n.as_str()) {
179                let name = name.to_string();
180                let input = parsed.get("input").cloned().unwrap_or(serde_json::json!({}));
181                call_counter += 1;
182                tool_calls.push(ToolCall { id: format!("local-{call_counter}"), name, input });
183            } else {
184                remaining.push_str(&cursor[start..]);
185                break;
186            }
187        } else {
188            remaining.push_str(&cursor[start..]);
189            break;
190        }
191
192        cursor = advance_past;
193        if cursor.is_empty() {
194            break;
195        }
196    }
197
198    (remaining.trim().to_string(), tool_calls)
199}
200
201/// Sanitize model output: strip echoed system prompt and chat template markers.
202///
203/// Small models (<3B) often echo the system prompt or leak chat template
204/// tokens into their response. This strips those artifacts so the agent
205/// loop sees clean assistant text.
206fn sanitize_output(text: &str, system_prompt: Option<&str>) -> String {
207    let mut cleaned = text.to_string();
208
209    // Strip echoed system prompt (common with small models)
210    if let Some(sys) = system_prompt {
211        // Check if output starts with a significant prefix of the system prompt
212        let sys_prefix = &sys[..sys.len().min(80)];
213        if cleaned.starts_with(sys_prefix) {
214            // The model regurgitated the system prompt — strip it
215            cleaned = cleaned[sys.len().min(cleaned.len())..].to_string();
216        }
217    }
218
219    // Strip leaked chat template markers
220    for marker in &[
221        "<|im_start|>",
222        "<|im_end|>",
223        "<|start_header_id|>",
224        "<|end_header_id|>",
225        "<|eot_id|>",
226        "<|system|>",
227        "<|user|>",
228        "<|assistant|>",
229        "<|end|>",
230    ] {
231        cleaned = cleaned.replace(marker, "");
232    }
233
234    // Strip leading/trailing whitespace and role labels
235    let cleaned = cleaned.trim();
236    let cleaned = cleaned.strip_prefix("system\n").unwrap_or(cleaned);
237    let cleaned = cleaned.strip_prefix("assistant\n").unwrap_or(cleaned);
238    cleaned.trim().to_string()
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_parse_no_tool_calls() {
247        let (text, calls) = parse_tool_calls("Hello world");
248        assert_eq!(text, "Hello world");
249        assert!(calls.is_empty());
250    }
251
252    #[test]
253    fn test_parse_single_tool_call() {
254        let input = r#"Before text
255<tool_call>
256{"name": "rag", "input": {"query": "SIMD"}}
257</tool_call>
258After text"#;
259        let (text, calls) = parse_tool_calls(input);
260        assert_eq!(text, "Before text\n\nAfter text");
261        assert_eq!(calls.len(), 1);
262        assert_eq!(calls[0].name, "rag");
263        assert_eq!(calls[0].id, "local-1");
264        assert_eq!(calls[0].input, serde_json::json!({"query": "SIMD"}));
265    }
266
267    #[test]
268    fn test_parse_multiple_tool_calls() {
269        let input = r#"<tool_call>
270{"name": "rag", "input": {"query": "a"}}
271</tool_call>
272Middle
273<tool_call>
274{"name": "memory", "input": {"action": "recall", "query": "b"}}
275</tool_call>"#;
276        let (text, calls) = parse_tool_calls(input);
277        assert_eq!(text, "Middle");
278        assert_eq!(calls.len(), 2);
279        assert_eq!(calls[0].name, "rag");
280        assert_eq!(calls[0].id, "local-1");
281        assert_eq!(calls[1].name, "memory");
282        assert_eq!(calls[1].id, "local-2");
283    }
284
285    #[test]
286    fn test_parse_malformed_json() {
287        let input = r#"<tool_call>
288not valid json
289</tool_call>"#;
290        let (_text, calls) = parse_tool_calls(input);
291        assert!(calls.is_empty());
292    }
293
294    #[test]
295    fn test_parse_missing_close_tag_with_valid_json() {
296        // PMAT-158: Small models omit </tool_call>. Parser should still extract.
297        let input =
298            "<tool_call>\n{\"name\": \"file_read\", \"input\": {\"path\": \"src/main.rs\"}}";
299        let (text, calls) = parse_tool_calls(input);
300        assert_eq!(calls.len(), 1, "should extract tool call without closing tag");
301        assert_eq!(calls[0].name, "file_read");
302        assert!(text.is_empty(), "no remaining text expected");
303    }
304
305    #[test]
306    fn test_parse_missing_close_tag_with_trailing_text() {
307        // Unclosed tag with text before it — text preserved, tool call extracted
308        let input =
309            "Let me read that.\n<tool_call> {\"name\": \"file_read\", \"input\": {\"path\": \"foo.rs\"}}";
310        let (text, calls) = parse_tool_calls(input);
311        assert_eq!(calls.len(), 1);
312        assert_eq!(calls[0].name, "file_read");
313        assert!(text.contains("Let me read that"));
314    }
315
316    #[test]
317    fn test_parse_missing_close_tag_invalid_json() {
318        // Unclosed tag with invalid JSON — treated as plain text
319        let input = "<tool_call>\nnot valid json at all";
320        let (text, calls) = parse_tool_calls(input);
321        assert!(calls.is_empty(), "invalid JSON should not produce tool call");
322        assert!(text.contains("<tool_call>"));
323    }
324
325    #[test]
326    fn test_parse_markdown_code_block() {
327        // PMAT-158: Qwen2.5-Coder native format — ```json blocks
328        let input = "Let me read that file.\n```json\n{\"name\": \"file_read\", \"input\": {\"path\": \"src/main.rs\"}}\n```";
329        let (text, calls) = parse_tool_calls(input);
330        assert_eq!(calls.len(), 1, "should extract tool call from markdown block");
331        assert_eq!(calls[0].name, "file_read");
332        assert_eq!(calls[0].input["path"], "src/main.rs");
333        assert!(text.contains("Let me read that"));
334    }
335
336    #[test]
337    fn test_parse_markdown_code_block_not_tool_call() {
338        // JSON in code block without "name" field — not a tool call
339        let input = "Here's an example:\n```json\n{\"key\": \"value\"}\n```";
340        let (text, calls) = parse_tool_calls(input);
341        assert!(calls.is_empty(), "JSON without name field should not be a tool call");
342        assert!(text.contains("example"));
343    }
344
345    #[test]
346    fn test_parse_missing_name() {
347        let input = r#"<tool_call>
348{"input": {"query": "test"}}
349</tool_call>"#;
350        let (_, calls) = parse_tool_calls(input);
351        assert!(calls.is_empty(), "JSON without name should not be extracted");
352    }
353
354    #[test]
355    fn test_privacy_tier_always_sovereign() {
356        assert_eq!(PrivacyTier::Sovereign, PrivacyTier::Sovereign);
357    }
358
359    // ── Output sanitization tests ──
360
361    #[test]
362    fn test_sanitize_strips_echoed_system_prompt() {
363        let sys = "You are apr code, a sovereign AI coding assistant.";
364        let output = format!("{sys} And then the model continues here.");
365        let cleaned = sanitize_output(&output, Some(sys));
366        assert!(!cleaned.contains("sovereign AI coding assistant"));
367        assert!(cleaned.contains("continues here"));
368    }
369
370    #[test]
371    fn test_sanitize_strips_chat_markers() {
372        let output = "<|im_start|>assistant\nHello world<|im_end|>";
373        let cleaned = sanitize_output(output, None);
374        assert_eq!(cleaned, "Hello world");
375    }
376
377    #[test]
378    fn test_sanitize_preserves_clean_output() {
379        let output = "The answer is 42.";
380        let cleaned = sanitize_output(output, Some("You are helpful."));
381        assert_eq!(cleaned, "The answer is 42.");
382    }
383
384    #[test]
385    fn test_sanitize_strips_role_prefix() {
386        let output = "assistant\nHere is my response.";
387        let cleaned = sanitize_output(output, None);
388        assert_eq!(cleaned, "Here is my response.");
389    }
390}