Skip to main content

batuta/agent/driver/
apr_serve.rs

1//! AprServeDriver — first-class inference via `apr serve` subprocess.
2//!
3//! Spawns `apr serve run <model>` as a child process with CUDA/GPU support,
4//! then connects via OpenAI-compatible HTTP API. This is the **preferred**
5//! inference path for `batuta code` / `apr code`:
6//!
7//! - Full CUDA/GPU acceleration (apr-cli has all features)
8//! - APR and GGUF format support (prefers APR)
9//! - No feature flag issues (batuta doesn't need `cuda` feature)
10//! - Sovereign: localhost only, no data egress
11//!
12//! PMAT-160: Replaces embedded RealizarDriver as primary inference.
13//! RealizarDriver remains as fallback when `apr` binary is not on PATH.
14
15use async_trait::async_trait;
16use std::path::PathBuf;
17use std::process::{Child, Command, Stdio};
18
19use super::{CompletionRequest, CompletionResponse, LlmDriver, Message, ToolCall};
20use crate::agent::result::{AgentError, DriverError, StopReason, TokenUsage};
21use crate::serve::backends::PrivacyTier;
22
23/// Driver that uses `apr serve` subprocess for inference.
24pub struct AprServeDriver {
25    /// Base URL for the local server (e.g., `http://127.0.0.1:19384`)
26    base_url: String,
27    /// Model name for OpenAI API requests
28    model_name: String,
29    /// Child process handle (killed on drop)
30    _child: Child,
31    /// Context window size
32    context_window_size: usize,
33    /// Model file size in bytes (used to scale the startup-ready timeout
34    /// for large MoE GGUFs). `None` if stat failed at launch time.
35    model_size_bytes: Option<u64>,
36}
37
38impl Drop for AprServeDriver {
39    /// PMAT-166: Graceful shutdown — SIGTERM first, SIGKILL after 2s timeout.
40    fn drop(&mut self) {
41        let pid = self._child.id();
42
43        // Try graceful shutdown first (SIGTERM on Unix via kill command)
44        #[cfg(unix)]
45        {
46            let _ = Command::new("kill")
47                .args(["-TERM", &pid.to_string()])
48                .stdout(Stdio::null())
49                .stderr(Stdio::null())
50                .status();
51
52            // Wait up to 2s for graceful exit
53            let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
54            loop {
55                match self._child.try_wait() {
56                    Ok(Some(_)) => return, // Exited cleanly
57                    Ok(None) if std::time::Instant::now() < deadline => {
58                        std::thread::sleep(std::time::Duration::from_millis(100));
59                    }
60                    _ => break, // Timeout or error — force kill
61                }
62            }
63        }
64
65        // Fallback: force kill (always runs on Windows, or after SIGTERM timeout)
66        let _ = self._child.kill();
67        let _ = self._child.wait();
68    }
69}
70
71impl AprServeDriver {
72    /// Launch `apr serve run` and wait for readiness.
73    ///
74    /// Picks a random port, spawns the subprocess, polls the health
75    /// endpoint until ready (max 30s). Returns error if `apr` not
76    /// found or server fails to start.
77    pub fn launch(model_path: PathBuf, context_window: Option<usize>) -> Result<Self, AgentError> {
78        let apr_path = find_apr_binary()?;
79
80        // Pick a random high port to avoid conflicts
81        let port = 19384 + (std::process::id() % 1000) as u16;
82        let base_url = format!("http://127.0.0.1:{port}");
83
84        let model_name = model_path
85            .file_stem()
86            .map(|s| s.to_string_lossy().to_string())
87            .unwrap_or_else(|| "local".to_string());
88
89        // PMAT-181: Enable GPU with serial prefill. The FP8 batched prefill produces
90        // wrong output for Qwen3 (Q6K→FP8 requantization bug). Serial prefill uses
91        // Q4K/Q6K GEMV kernels which produce correct output. BATCHED_PREFILL=0 disables
92        // the FP8 path while keeping CUDA acceleration for decode tokens.
93        let mut cmd = Command::new(&apr_path);
94        cmd.args([
95            "serve",
96            "run",
97            &model_path.to_string_lossy(),
98            "--port",
99            &port.to_string(),
100            "--host",
101            "127.0.0.1",
102            "--gpu",
103        ])
104        .env("BATCHED_PREFILL", "0")
105        .stdout(Stdio::piped())
106        .stderr(Stdio::piped());
107
108        // Issue #1712: kernel-enforced reaping. Drop on AprServeDriver only fires
109        // for graceful Rust exit — if `apr code` is killed by `timeout`, SIGTERM,
110        // SIGKILL, or an uncaught panic, the Drop never runs and `apr serve`
111        // orphans (~3 GB RSS each). PR_SET_PDEATHSIG asks the kernel to SIGTERM
112        // the child the moment its parent dies, independent of cleanup paths.
113        configure_parent_death_signal(&mut cmd);
114
115        let child = cmd.spawn().map_err(|e| {
116            AgentError::Driver(DriverError::InferenceFailed(format!(
117                "failed to spawn apr serve: {e}"
118            )))
119        })?;
120
121        eprintln!("Launched apr serve on port {port} (pid {})", child.id());
122
123        let model_size_bytes = std::fs::metadata(&model_path).ok().map(|m| m.len());
124
125        let mut driver = Self {
126            base_url,
127            model_name,
128            _child: child,
129            context_window_size: context_window.unwrap_or(4096),
130            model_size_bytes,
131        };
132
133        // Wait for server to be ready
134        driver.wait_for_ready()?;
135
136        Ok(driver)
137    }
138
139    /// Poll health endpoint until server is ready.
140    ///
141    /// Default budget is 30 seconds — fine for small models that mmap and
142    /// validate in <5s. Large MoE GGUFs (e.g., Qwen3-Coder-30B at 18.5 GB)
143    /// can exceed 30s on cold-cache loads, so the budget is overridable
144    /// via `APR_SERVE_READY_TIMEOUT_S` (an integer count of seconds).
145    /// The default also auto-scales by model file size when the model
146    /// is known: +1 second per 500 MB above 2 GB (e.g., an 18 GB model
147    /// gets ~62s budget; a 1 GB model gets the 30s baseline).
148    ///
149    /// PMAT-171: Detects subprocess death during startup. On timeout or crash,
150    /// reads stderr from the child process for actionable debug output.
151    ///
152    /// PR #1781: Configurable timeout + size-aware default.
153    fn wait_for_ready(&mut self) -> Result<(), AgentError> {
154        let addr = self.base_url.trim_start_matches("http://").to_string();
155        let sock_addr: std::net::SocketAddr =
156            addr.parse().unwrap_or_else(|_| std::net::SocketAddr::from(([127, 0, 0, 1], 19384)));
157
158        let timeout_secs = self.resolve_ready_timeout_secs();
159        let start = std::time::Instant::now();
160        let timeout = std::time::Duration::from_secs(timeout_secs);
161
162        loop {
163            if start.elapsed() > timeout {
164                let stderr = self.drain_stderr();
165                let mut msg = format!(
166                    "apr serve did not become ready within {timeout_secs}s (override via APR_SERVE_READY_TIMEOUT_S)"
167                );
168                if !stderr.is_empty() {
169                    msg.push_str(&format!("\nsubprocess stderr:\n{stderr}"));
170                }
171                msg.push_str(&format!(
172                    "\nDebug manually: apr serve run <model> --port {} --host 127.0.0.1",
173                    addr.rsplit(':').next().unwrap_or("19384")
174                ));
175                return Err(AgentError::Driver(DriverError::InferenceFailed(msg)));
176            }
177
178            // Check if subprocess died
179            if let Ok(Some(status)) = self._child.try_wait() {
180                let stderr = self.drain_stderr();
181                let mut msg = format!("apr serve exited with {status} during startup");
182                if !stderr.is_empty() {
183                    msg.push_str(&format!("\nsubprocess stderr:\n{stderr}"));
184                }
185                return Err(AgentError::Driver(DriverError::InferenceFailed(msg)));
186            }
187
188            if std::net::TcpStream::connect_timeout(
189                &sock_addr,
190                std::time::Duration::from_millis(200),
191            )
192            .is_ok()
193            {
194                eprintln!("apr serve ready ({:.1}s)", start.elapsed().as_secs_f64());
195                return Ok(());
196            }
197
198            std::thread::sleep(std::time::Duration::from_millis(500));
199        }
200    }
201
202    /// Resolve the startup-ready timeout in seconds.
203    ///
204    /// Reads `APR_SERVE_READY_TIMEOUT_S` from the env (operator override)
205    /// and falls back to a size-aware default. See
206    /// [`compute_ready_timeout_secs`] for the resolution rules + unit tests.
207    fn resolve_ready_timeout_secs(&self) -> u64 {
208        let env_override = std::env::var("APR_SERVE_READY_TIMEOUT_S").ok();
209        compute_ready_timeout_secs(self.model_size_bytes, env_override.as_deref())
210    }
211
212    /// Read available stderr from the child process (non-blocking, last 2KB).
213    fn drain_stderr(&mut self) -> String {
214        use std::io::Read;
215        let Some(stderr) = self._child.stderr.as_mut() else {
216            return String::new();
217        };
218        let mut buf = vec![0u8; 2048];
219        let n = stderr.read(&mut buf).unwrap_or(0);
220        let text = String::from_utf8_lossy(&buf[..n]).to_string();
221        // Return last few lines for concise output
222        let lines: Vec<&str> = text.lines().collect();
223        if lines.len() > 10 {
224            lines[lines.len() - 10..].join("\n")
225        } else {
226            text
227        }
228    }
229
230    /// Build OpenAI-compatible request body.
231    ///
232    /// PMAT-176: Only strips the verbose `## Available Tools` section injected
233    /// by `build_enriched_system()` (full JSON schemas ~2000 tokens). Preserves
234    /// the compact `## Tools` table from `CODE_SYSTEM_PROMPT` — that table has
235    /// tool names, use cases, and example inputs designed for 1.5B-7B models.
236    fn build_openai_body(&self, request: &CompletionRequest) -> serde_json::Value {
237        let mut messages = Vec::new();
238
239        if let Some(ref system) = request.system {
240            // PMAT-176: Only strip the verbose enriched section (full JSON schemas).
241            // Keep the compact "## Tools" table from CODE_SYSTEM_PROMPT — it has
242            // descriptions and examples that small models need for tool discovery.
243            let compact_system = system
244                .find("\n\n## Available Tools")
245                .map(|i| &system[..i])
246                .unwrap_or(system)
247                .to_string();
248
249            messages.push(serde_json::json!({
250                "role": "system",
251                "content": compact_system
252            }));
253        }
254
255        for msg in &request.messages {
256            match msg {
257                Message::User(text) => messages.push(serde_json::json!({
258                    "role": "user",
259                    "content": text
260                })),
261                Message::Assistant(text) => messages.push(serde_json::json!({
262                    "role": "assistant",
263                    "content": text
264                })),
265                Message::AssistantToolUse(call) => messages.push(serde_json::json!({
266                    "role": "assistant",
267                    "content": format!("<tool_call>\n{}\n</tool_call>",
268                        serde_json::json!({"name": call.name, "input": call.input}))
269                })),
270                Message::ToolResult(result) => messages.push(serde_json::json!({
271                    "role": "user",
272                    "content": format!("<tool_result>\n{}\n</tool_result>", result.content)
273                })),
274                _ => {}
275            }
276        }
277
278        // PMAT-170: Cap max_tokens for HTTP path. The manifest default (4096)
279        // causes very long generation on local models. 1024 accommodates:
280        // - Tool call JSON (~100-200 tokens each)
281        // - File edit content (multi-line diffs)
282        // - Explanation text alongside tool calls
283        // Previous 512 cap truncated complex edits mid-output.
284        //
285        // aprender#1789 follow-up: env-var override for large MoE models
286        // without KV cache. At ~0.5 tok/s (30B-MoE-no-KV), 1024 tokens
287        // takes ~34 min — exceeds reasonable per-turn budgets. Allow the
288        // operator (or bench harness) to dial down for slow models.
289        let max_tokens_cap = std::env::var("APR_AGENT_MAX_TOKENS_CAP")
290            .ok()
291            .and_then(|v| v.parse::<u32>().ok())
292            .unwrap_or(1024);
293        let max_tokens = request.max_tokens.min(max_tokens_cap);
294
295        // 3-knob toolkit (qwen3-moe-sampling-v1 + qwen3-moe-repetition-penalty-v1):
296        // operator env-var overrides for sampling parameters. When set, these
297        // flow from apr code → HTTP body → apr serve's try_qwen3_moe_backend
298        // → QuantizedGenerateConfig → run_qwen3_moe_generate → sample_from_logits.
299        // When UNSET, the request still uses temperature from CompletionRequest
300        // (existing behavior); other fields default to the
301        // QuantizedGenerateConfig defaults (greedy).
302        let temperature = std::env::var("APR_AGENT_TEMPERATURE")
303            .ok()
304            .and_then(|v| v.parse::<f32>().ok())
305            .unwrap_or(request.temperature);
306        let top_k = std::env::var("APR_AGENT_TOP_K").ok().and_then(|v| v.parse::<usize>().ok());
307        let top_p = std::env::var("APR_AGENT_TOP_P").ok().and_then(|v| v.parse::<f32>().ok());
308        let repeat_penalty =
309            std::env::var("APR_AGENT_REPEAT_PENALTY").ok().and_then(|v| v.parse::<f32>().ok());
310        let repeat_last_n =
311            std::env::var("APR_AGENT_REPEAT_LAST_N").ok().and_then(|v| v.parse::<usize>().ok());
312        let seed = std::env::var("APR_AGENT_SEED").ok().and_then(|v| v.parse::<u64>().ok());
313
314        let mut body = serde_json::json!({
315            "model": self.model_name,
316            "messages": messages,
317            "max_tokens": max_tokens,
318            "temperature": temperature,
319            "stream": false,
320        });
321        if let Some(v) = top_k {
322            body["top_k"] = serde_json::json!(v);
323        }
324        if let Some(v) = top_p {
325            body["top_p"] = serde_json::json!(v);
326        }
327        if let Some(v) = repeat_penalty {
328            body["repeat_penalty"] = serde_json::json!(v);
329        }
330        if let Some(v) = repeat_last_n {
331            body["repeat_last_n"] = serde_json::json!(v);
332        }
333        if let Some(v) = seed {
334            body["seed"] = serde_json::json!(v);
335        }
336        body
337    }
338}
339
340/// Compute the startup-ready timeout in seconds for `apr serve`.
341///
342/// Resolution order:
343/// 1. If `env_override` parses as a `u64`, use it verbatim (operator
344///    override; minimum 1s clamp via `.max(1)`).
345/// 2. Otherwise compute a size-aware default: 30s baseline + 1s per
346///    500 MB above 2 GB. A 1 GB model gets 30s; a 4 GB model gets ~34s;
347///    an 18 GB model gets ~62s; a 30 GB model gets ~86s.
348/// 3. If model size is unknown (stat failed at launch), fall back to
349///    30s baseline.
350///
351/// Always returns at least `MIN_TIMEOUT_S = 1` to avoid the pathological
352/// 0-second budget case.
353///
354/// Extracted as a free function so the resolution logic is unit-testable
355/// without spawning a subprocess. Called from
356/// [`AprServeDriver::resolve_ready_timeout_secs`] with the live env.
357#[must_use]
358pub fn compute_ready_timeout_secs(
359    model_size_bytes: Option<u64>,
360    env_override: Option<&str>,
361) -> u64 {
362    const MIN_TIMEOUT_S: u64 = 1;
363    const BASELINE_S: u64 = 30;
364    const SIZE_FREE_BYTES: u64 = 2 * 1024 * 1024 * 1024; // 2 GB
365    const BYTES_PER_EXTRA_SECOND: u64 = 500 * 1024 * 1024; // 500 MB
366
367    if let Some(raw) = env_override {
368        if let Ok(n) = raw.parse::<u64>() {
369            return n.max(MIN_TIMEOUT_S);
370        }
371    }
372    let Some(bytes) = model_size_bytes else {
373        return BASELINE_S;
374    };
375    let extra_bytes = bytes.saturating_sub(SIZE_FREE_BYTES);
376    let extra_secs = extra_bytes / BYTES_PER_EXTRA_SECOND;
377    BASELINE_S.saturating_add(extra_secs)
378}
379
380#[async_trait]
381impl LlmDriver for AprServeDriver {
382    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError> {
383        let url = format!("{}/v1/chat/completions", self.base_url);
384        let body = self.build_openai_body(&request);
385
386        // aprender#1789 follow-up: 120s default is too short for large MoE
387        // models without KV cache (each token is full-prefill; a 256-token
388        // generation at ~30B params can exceed 30 minutes wall). Empirical
389        // evidence: paiml/claude-code-parity-apr Phase 6 bench against
390        // Qwen3-Coder-30B-A3B saw every fixture die with "error sending
391        // request" at exactly the 120s mark. Same root-cause class as
392        // aprender#1782 (configurable + size-aware default).
393        //
394        // Override via `APR_AGENT_HTTP_TIMEOUT_S` env var. Default of 1800s
395        // (30 min) matches the bench's per-turn-timeout ceiling + leaves
396        // headroom for large MoE inference until M32d KV cache lands.
397        let http_timeout_secs = std::env::var("APR_AGENT_HTTP_TIMEOUT_S")
398            .ok()
399            .and_then(|v| v.parse::<u64>().ok())
400            .unwrap_or(1800);
401        let client = reqwest::Client::builder()
402            .timeout(std::time::Duration::from_secs(http_timeout_secs))
403            .build()
404            .map_err(|e| AgentError::Driver(DriverError::Network(format!("http client: {e}"))))?;
405        let response = client
406            .post(&url)
407            .header("content-type", "application/json")
408            .json(&body)
409            .send()
410            .await
411            .map_err(|e| AgentError::Driver(DriverError::Network(format!("apr serve: {e}"))))?;
412
413        if !response.status().is_success() {
414            let status = response.status().as_u16();
415            let text = response.text().await.unwrap_or_default();
416            return Err(AgentError::Driver(DriverError::Network(format!(
417                "apr serve HTTP {status}: {text}"
418            ))));
419        }
420
421        let json: serde_json::Value = response
422            .json()
423            .await
424            .map_err(|e| AgentError::Driver(DriverError::InferenceFailed(format!("parse: {e}"))))?;
425
426        // Extract response from OpenAI format
427        let raw_text = json["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
428
429        // PMAT-180: Strip Qwen3 thinking blocks. The model may emit
430        // <think>...</think> or bare </think> tokens. Remove them before
431        // parsing tool calls — thinking content is internal reasoning.
432        let text = strip_thinking_blocks(&raw_text);
433
434        let usage = json.get("usage").cloned().unwrap_or(serde_json::json!({}));
435        let input_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0);
436        let output_tokens = usage["completion_tokens"].as_u64().unwrap_or(0);
437
438        // Parse tool calls from text (same parser as RealizarDriver)
439        let (clean_text, tool_calls) = super::realizar::parse_tool_calls_pub(&text);
440
441        let stop_reason =
442            if tool_calls.is_empty() { StopReason::EndTurn } else { StopReason::ToolUse };
443
444        Ok(CompletionResponse {
445            text: clean_text,
446            stop_reason,
447            tool_calls,
448            usage: TokenUsage { input_tokens, output_tokens },
449        })
450    }
451
452    fn context_window(&self) -> usize {
453        self.context_window_size
454    }
455
456    fn privacy_tier(&self) -> PrivacyTier {
457        // Sovereign: apr serve runs on localhost, zero network egress
458        PrivacyTier::Sovereign
459    }
460}
461
462/// Strip Qwen3 thinking blocks (`<think>...</think>`) and bare `</think>` tags.
463fn strip_thinking_blocks(text: &str) -> String {
464    let mut result = text.to_string();
465    // Strip <think>...</think> blocks (may span multiple lines)
466    while let Some(start) = result.find("<think>") {
467        if let Some(end) = result[start..].find("</think>") {
468            result.replace_range(start..start + end + "</think>".len(), "");
469        } else {
470            // Unclosed <think> — strip to end
471            result.truncate(start);
472            break;
473        }
474    }
475    // Strip bare </think> tags (model sometimes emits just closing tags)
476    result = result.replace("</think>", "");
477    result.trim().to_string()
478}
479
480/// Issue #1712: ask the kernel to SIGTERM the child when the parent dies.
481///
482/// On Linux/Unix this uses `PR_SET_PDEATHSIG` via `pre_exec` so the child
483/// receives SIGTERM the instant its parent exits — whether the parent died
484/// gracefully, was SIGKILLed by `timeout`, or was terminated by the OOM
485/// killer. Without this, `apr serve` orphans hold ~3 GB RSS each.
486///
487/// A `getppid()==1` check immediately after `prctl` closes the small race
488/// where the parent dies between fork and prctl (in which case the death
489/// signal has already missed its window).
490#[cfg(unix)]
491#[allow(unsafe_code)] // pre_exec is unsafe-by-API; body uses only async-signal-safe calls
492fn configure_parent_death_signal(cmd: &mut Command) {
493    use std::os::unix::process::CommandExt;
494    // SAFETY: `prctl` and `getppid` are async-signal-safe; `pre_exec` runs
495    // between fork and exec where only async-signal-safe calls are allowed.
496    unsafe {
497        cmd.pre_exec(|| {
498            if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM, 0, 0, 0) == -1 {
499                return Err(std::io::Error::last_os_error());
500            }
501            if libc::getppid() == 1 {
502                return Err(std::io::Error::other(
503                    "parent died before PR_SET_PDEATHSIG took effect",
504                ));
505            }
506            Ok(())
507        });
508    }
509}
510
511#[cfg(not(unix))]
512fn configure_parent_death_signal(_cmd: &mut Command) {
513    // Windows: no equivalent — orphans on parent death still possible.
514}
515
516/// Find the `apr` binary on PATH.
517fn find_apr_binary() -> Result<PathBuf, AgentError> {
518    which::which("apr").map_err(|_| {
519        AgentError::Driver(DriverError::InferenceFailed(
520            "apr binary not found on PATH. Install: cargo install apr-cli".into(),
521        ))
522    })
523}
524
525#[cfg(test)]
526#[path = "apr_serve_tests.rs"]
527mod tests;