Skip to main content

batuta/agent/driver/
apr_serve.rs

1//! AprServeDriver — first-class inference via `apr serve` subprocess.
2//!
3//! Spawns `apr serve run <model>` as a child process with CUDA/GPU support,
4//! then connects via OpenAI-compatible HTTP API. This is the **preferred**
5//! inference path for `batuta code` / `apr code`:
6//!
7//! - Full CUDA/GPU acceleration (apr-cli has all features)
8//! - APR and GGUF format support (prefers APR)
9//! - No feature flag issues (batuta doesn't need `cuda` feature)
10//! - Sovereign: localhost only, no data egress
11//!
12//! PMAT-160: Replaces embedded RealizarDriver as primary inference.
13//! RealizarDriver remains as fallback when `apr` binary is not on PATH.
14
15use async_trait::async_trait;
16use std::path::PathBuf;
17use std::process::{Child, Command, Stdio};
18
19use super::{CompletionRequest, CompletionResponse, LlmDriver, Message, ToolCall};
20use crate::agent::result::{AgentError, DriverError, StopReason, TokenUsage};
21use crate::serve::backends::PrivacyTier;
22
23/// Driver that uses `apr serve` subprocess for inference.
24pub struct AprServeDriver {
25    /// Base URL for the local server (e.g., `http://127.0.0.1:19384`)
26    base_url: String,
27    /// Model name for OpenAI API requests
28    model_name: String,
29    /// Child process handle (killed on drop)
30    _child: Child,
31    /// Context window size
32    context_window_size: usize,
33}
34
35impl Drop for AprServeDriver {
36    /// PMAT-166: Graceful shutdown — SIGTERM first, SIGKILL after 2s timeout.
37    fn drop(&mut self) {
38        let pid = self._child.id();
39
40        // Try graceful shutdown first (SIGTERM on Unix via kill command)
41        #[cfg(unix)]
42        {
43            let _ = Command::new("kill")
44                .args(["-TERM", &pid.to_string()])
45                .stdout(Stdio::null())
46                .stderr(Stdio::null())
47                .status();
48
49            // Wait up to 2s for graceful exit
50            let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
51            loop {
52                match self._child.try_wait() {
53                    Ok(Some(_)) => return, // Exited cleanly
54                    Ok(None) if std::time::Instant::now() < deadline => {
55                        std::thread::sleep(std::time::Duration::from_millis(100));
56                    }
57                    _ => break, // Timeout or error — force kill
58                }
59            }
60        }
61
62        // Fallback: force kill (always runs on Windows, or after SIGTERM timeout)
63        let _ = self._child.kill();
64        let _ = self._child.wait();
65    }
66}
67
68impl AprServeDriver {
69    /// Launch `apr serve run` and wait for readiness.
70    ///
71    /// Picks a random port, spawns the subprocess, polls the health
72    /// endpoint until ready (max 30s). Returns error if `apr` not
73    /// found or server fails to start.
74    pub fn launch(model_path: PathBuf, context_window: Option<usize>) -> Result<Self, AgentError> {
75        let apr_path = find_apr_binary()?;
76
77        // Pick a random high port to avoid conflicts
78        let port = 19384 + (std::process::id() % 1000) as u16;
79        let base_url = format!("http://127.0.0.1:{port}");
80
81        let model_name = model_path
82            .file_stem()
83            .map(|s| s.to_string_lossy().to_string())
84            .unwrap_or_else(|| "local".to_string());
85
86        // PMAT-181: Enable GPU with serial prefill. The FP8 batched prefill produces
87        // wrong output for Qwen3 (Q6K→FP8 requantization bug). Serial prefill uses
88        // Q4K/Q6K GEMV kernels which produce correct output. BATCHED_PREFILL=0 disables
89        // the FP8 path while keeping CUDA acceleration for decode tokens.
90        let mut cmd = Command::new(&apr_path);
91        cmd.args([
92            "serve",
93            "run",
94            &model_path.to_string_lossy(),
95            "--port",
96            &port.to_string(),
97            "--host",
98            "127.0.0.1",
99            "--gpu",
100        ])
101        .env("BATCHED_PREFILL", "0")
102        .stdout(Stdio::piped())
103        .stderr(Stdio::piped());
104
105        // Issue #1712: kernel-enforced reaping. Drop on AprServeDriver only fires
106        // for graceful Rust exit — if `apr code` is killed by `timeout`, SIGTERM,
107        // SIGKILL, or an uncaught panic, the Drop never runs and `apr serve`
108        // orphans (~3 GB RSS each). PR_SET_PDEATHSIG asks the kernel to SIGTERM
109        // the child the moment its parent dies, independent of cleanup paths.
110        configure_parent_death_signal(&mut cmd);
111
112        let child = cmd.spawn().map_err(|e| {
113            AgentError::Driver(DriverError::InferenceFailed(format!(
114                "failed to spawn apr serve: {e}"
115            )))
116        })?;
117
118        eprintln!("Launched apr serve on port {port} (pid {})", child.id());
119
120        let mut driver = Self {
121            base_url,
122            model_name,
123            _child: child,
124            context_window_size: context_window.unwrap_or(4096),
125        };
126
127        // Wait for server to be ready
128        driver.wait_for_ready()?;
129
130        Ok(driver)
131    }
132
133    /// Poll health endpoint until server is ready (max 30s).
134    ///
135    /// PMAT-171: Detects subprocess death during startup. On timeout or crash,
136    /// reads stderr from the child process for actionable debug output.
137    fn wait_for_ready(&mut self) -> Result<(), AgentError> {
138        let addr = self.base_url.trim_start_matches("http://").to_string();
139        let sock_addr: std::net::SocketAddr =
140            addr.parse().unwrap_or_else(|_| std::net::SocketAddr::from(([127, 0, 0, 1], 19384)));
141
142        let start = std::time::Instant::now();
143        let timeout = std::time::Duration::from_secs(30);
144
145        loop {
146            if start.elapsed() > timeout {
147                let stderr = self.drain_stderr();
148                let mut msg = "apr serve did not become ready within 30s".to_string();
149                if !stderr.is_empty() {
150                    msg.push_str(&format!("\nsubprocess stderr:\n{stderr}"));
151                }
152                msg.push_str(&format!(
153                    "\nDebug manually: apr serve run <model> --port {} --host 127.0.0.1",
154                    addr.rsplit(':').next().unwrap_or("19384")
155                ));
156                return Err(AgentError::Driver(DriverError::InferenceFailed(msg)));
157            }
158
159            // Check if subprocess died
160            if let Ok(Some(status)) = self._child.try_wait() {
161                let stderr = self.drain_stderr();
162                let mut msg = format!("apr serve exited with {status} during startup");
163                if !stderr.is_empty() {
164                    msg.push_str(&format!("\nsubprocess stderr:\n{stderr}"));
165                }
166                return Err(AgentError::Driver(DriverError::InferenceFailed(msg)));
167            }
168
169            if std::net::TcpStream::connect_timeout(
170                &sock_addr,
171                std::time::Duration::from_millis(200),
172            )
173            .is_ok()
174            {
175                eprintln!("apr serve ready ({:.1}s)", start.elapsed().as_secs_f64());
176                return Ok(());
177            }
178
179            std::thread::sleep(std::time::Duration::from_millis(500));
180        }
181    }
182
183    /// Read available stderr from the child process (non-blocking, last 2KB).
184    fn drain_stderr(&mut self) -> String {
185        use std::io::Read;
186        let Some(stderr) = self._child.stderr.as_mut() else {
187            return String::new();
188        };
189        let mut buf = vec![0u8; 2048];
190        let n = stderr.read(&mut buf).unwrap_or(0);
191        let text = String::from_utf8_lossy(&buf[..n]).to_string();
192        // Return last few lines for concise output
193        let lines: Vec<&str> = text.lines().collect();
194        if lines.len() > 10 {
195            lines[lines.len() - 10..].join("\n")
196        } else {
197            text
198        }
199    }
200
201    /// Build OpenAI-compatible request body.
202    ///
203    /// PMAT-176: Only strips the verbose `## Available Tools` section injected
204    /// by `build_enriched_system()` (full JSON schemas ~2000 tokens). Preserves
205    /// the compact `## Tools` table from `CODE_SYSTEM_PROMPT` — that table has
206    /// tool names, use cases, and example inputs designed for 1.5B-7B models.
207    fn build_openai_body(&self, request: &CompletionRequest) -> serde_json::Value {
208        let mut messages = Vec::new();
209
210        if let Some(ref system) = request.system {
211            // PMAT-176: Only strip the verbose enriched section (full JSON schemas).
212            // Keep the compact "## Tools" table from CODE_SYSTEM_PROMPT — it has
213            // descriptions and examples that small models need for tool discovery.
214            let compact_system = system
215                .find("\n\n## Available Tools")
216                .map(|i| &system[..i])
217                .unwrap_or(system)
218                .to_string();
219
220            messages.push(serde_json::json!({
221                "role": "system",
222                "content": compact_system
223            }));
224        }
225
226        for msg in &request.messages {
227            match msg {
228                Message::User(text) => messages.push(serde_json::json!({
229                    "role": "user",
230                    "content": text
231                })),
232                Message::Assistant(text) => messages.push(serde_json::json!({
233                    "role": "assistant",
234                    "content": text
235                })),
236                Message::AssistantToolUse(call) => messages.push(serde_json::json!({
237                    "role": "assistant",
238                    "content": format!("<tool_call>\n{}\n</tool_call>",
239                        serde_json::json!({"name": call.name, "input": call.input}))
240                })),
241                Message::ToolResult(result) => messages.push(serde_json::json!({
242                    "role": "user",
243                    "content": format!("<tool_result>\n{}\n</tool_result>", result.content)
244                })),
245                _ => {}
246            }
247        }
248
249        // PMAT-170: Cap max_tokens for HTTP path. The manifest default (4096)
250        // causes very long generation on local models. 1024 accommodates:
251        // - Tool call JSON (~100-200 tokens each)
252        // - File edit content (multi-line diffs)
253        // - Explanation text alongside tool calls
254        // Previous 512 cap truncated complex edits mid-output.
255        let max_tokens = request.max_tokens.min(1024);
256
257        serde_json::json!({
258            "model": self.model_name,
259            "messages": messages,
260            "max_tokens": max_tokens,
261            "temperature": request.temperature,
262            "stream": false
263        })
264    }
265}
266
267#[async_trait]
268impl LlmDriver for AprServeDriver {
269    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError> {
270        let url = format!("{}/v1/chat/completions", self.base_url);
271        let body = self.build_openai_body(&request);
272
273        let client = reqwest::Client::builder()
274            .timeout(std::time::Duration::from_secs(120))
275            .build()
276            .map_err(|e| AgentError::Driver(DriverError::Network(format!("http client: {e}"))))?;
277        let response = client
278            .post(&url)
279            .header("content-type", "application/json")
280            .json(&body)
281            .send()
282            .await
283            .map_err(|e| AgentError::Driver(DriverError::Network(format!("apr serve: {e}"))))?;
284
285        if !response.status().is_success() {
286            let status = response.status().as_u16();
287            let text = response.text().await.unwrap_or_default();
288            return Err(AgentError::Driver(DriverError::Network(format!(
289                "apr serve HTTP {status}: {text}"
290            ))));
291        }
292
293        let json: serde_json::Value = response
294            .json()
295            .await
296            .map_err(|e| AgentError::Driver(DriverError::InferenceFailed(format!("parse: {e}"))))?;
297
298        // Extract response from OpenAI format
299        let raw_text = json["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
300
301        // PMAT-180: Strip Qwen3 thinking blocks. The model may emit
302        // <think>...</think> or bare </think> tokens. Remove them before
303        // parsing tool calls — thinking content is internal reasoning.
304        let text = strip_thinking_blocks(&raw_text);
305
306        let usage = json.get("usage").cloned().unwrap_or(serde_json::json!({}));
307        let input_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0);
308        let output_tokens = usage["completion_tokens"].as_u64().unwrap_or(0);
309
310        // Parse tool calls from text (same parser as RealizarDriver)
311        let (clean_text, tool_calls) = super::realizar::parse_tool_calls_pub(&text);
312
313        let stop_reason =
314            if tool_calls.is_empty() { StopReason::EndTurn } else { StopReason::ToolUse };
315
316        Ok(CompletionResponse {
317            text: clean_text,
318            stop_reason,
319            tool_calls,
320            usage: TokenUsage { input_tokens, output_tokens },
321        })
322    }
323
324    fn context_window(&self) -> usize {
325        self.context_window_size
326    }
327
328    fn privacy_tier(&self) -> PrivacyTier {
329        // Sovereign: apr serve runs on localhost, zero network egress
330        PrivacyTier::Sovereign
331    }
332}
333
334/// Strip Qwen3 thinking blocks (`<think>...</think>`) and bare `</think>` tags.
335fn strip_thinking_blocks(text: &str) -> String {
336    let mut result = text.to_string();
337    // Strip <think>...</think> blocks (may span multiple lines)
338    while let Some(start) = result.find("<think>") {
339        if let Some(end) = result[start..].find("</think>") {
340            result.replace_range(start..start + end + "</think>".len(), "");
341        } else {
342            // Unclosed <think> — strip to end
343            result.truncate(start);
344            break;
345        }
346    }
347    // Strip bare </think> tags (model sometimes emits just closing tags)
348    result = result.replace("</think>", "");
349    result.trim().to_string()
350}
351
352/// Issue #1712: ask the kernel to SIGTERM the child when the parent dies.
353///
354/// On Linux/Unix this uses `PR_SET_PDEATHSIG` via `pre_exec` so the child
355/// receives SIGTERM the instant its parent exits — whether the parent died
356/// gracefully, was SIGKILLed by `timeout`, or was terminated by the OOM
357/// killer. Without this, `apr serve` orphans hold ~3 GB RSS each.
358///
359/// A `getppid()==1` check immediately after `prctl` closes the small race
360/// where the parent dies between fork and prctl (in which case the death
361/// signal has already missed its window).
362#[cfg(unix)]
363#[allow(unsafe_code)] // pre_exec is unsafe-by-API; body uses only async-signal-safe calls
364fn configure_parent_death_signal(cmd: &mut Command) {
365    use std::os::unix::process::CommandExt;
366    // SAFETY: `prctl` and `getppid` are async-signal-safe; `pre_exec` runs
367    // between fork and exec where only async-signal-safe calls are allowed.
368    unsafe {
369        cmd.pre_exec(|| {
370            if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM, 0, 0, 0) == -1 {
371                return Err(std::io::Error::last_os_error());
372            }
373            if libc::getppid() == 1 {
374                return Err(std::io::Error::other(
375                    "parent died before PR_SET_PDEATHSIG took effect",
376                ));
377            }
378            Ok(())
379        });
380    }
381}
382
383#[cfg(not(unix))]
384fn configure_parent_death_signal(_cmd: &mut Command) {
385    // Windows: no equivalent — orphans on parent death still possible.
386}
387
388/// Find the `apr` binary on PATH.
389fn find_apr_binary() -> Result<PathBuf, AgentError> {
390    which::which("apr").map_err(|_| {
391        AgentError::Driver(DriverError::InferenceFailed(
392            "apr binary not found on PATH. Install: cargo install apr-cli".into(),
393        ))
394    })
395}
396
397#[cfg(test)]
398#[path = "apr_serve_tests.rs"]
399mod tests;