batuta/agent/driver/
apr_serve.rs1use async_trait::async_trait;
16use std::path::PathBuf;
17use std::process::{Child, Command, Stdio};
18
19use super::{CompletionRequest, CompletionResponse, LlmDriver, Message, ToolCall};
20use crate::agent::result::{AgentError, DriverError, StopReason, TokenUsage};
21use crate::serve::backends::PrivacyTier;
22
23pub struct AprServeDriver {
25 base_url: String,
27 model_name: String,
29 _child: Child,
31 context_window_size: usize,
33}
34
35impl Drop for AprServeDriver {
36 fn drop(&mut self) {
38 let pid = self._child.id();
39
40 #[cfg(unix)]
42 {
43 let _ = Command::new("kill")
44 .args(["-TERM", &pid.to_string()])
45 .stdout(Stdio::null())
46 .stderr(Stdio::null())
47 .status();
48
49 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
51 loop {
52 match self._child.try_wait() {
53 Ok(Some(_)) => return, Ok(None) if std::time::Instant::now() < deadline => {
55 std::thread::sleep(std::time::Duration::from_millis(100));
56 }
57 _ => break, }
59 }
60 }
61
62 let _ = self._child.kill();
64 let _ = self._child.wait();
65 }
66}
67
68impl AprServeDriver {
69 pub fn launch(model_path: PathBuf, context_window: Option<usize>) -> Result<Self, AgentError> {
75 let apr_path = find_apr_binary()?;
76
77 let port = 19384 + (std::process::id() % 1000) as u16;
79 let base_url = format!("http://127.0.0.1:{port}");
80
81 let model_name = model_path
82 .file_stem()
83 .map(|s| s.to_string_lossy().to_string())
84 .unwrap_or_else(|| "local".to_string());
85
86 let mut cmd = Command::new(&apr_path);
91 cmd.args([
92 "serve",
93 "run",
94 &model_path.to_string_lossy(),
95 "--port",
96 &port.to_string(),
97 "--host",
98 "127.0.0.1",
99 "--gpu",
100 ])
101 .env("BATCHED_PREFILL", "0")
102 .stdout(Stdio::piped())
103 .stderr(Stdio::piped());
104
105 configure_parent_death_signal(&mut cmd);
111
112 let child = cmd.spawn().map_err(|e| {
113 AgentError::Driver(DriverError::InferenceFailed(format!(
114 "failed to spawn apr serve: {e}"
115 )))
116 })?;
117
118 eprintln!("Launched apr serve on port {port} (pid {})", child.id());
119
120 let mut driver = Self {
121 base_url,
122 model_name,
123 _child: child,
124 context_window_size: context_window.unwrap_or(4096),
125 };
126
127 driver.wait_for_ready()?;
129
130 Ok(driver)
131 }
132
133 fn wait_for_ready(&mut self) -> Result<(), AgentError> {
138 let addr = self.base_url.trim_start_matches("http://").to_string();
139 let sock_addr: std::net::SocketAddr =
140 addr.parse().unwrap_or_else(|_| std::net::SocketAddr::from(([127, 0, 0, 1], 19384)));
141
142 let start = std::time::Instant::now();
143 let timeout = std::time::Duration::from_secs(30);
144
145 loop {
146 if start.elapsed() > timeout {
147 let stderr = self.drain_stderr();
148 let mut msg = "apr serve did not become ready within 30s".to_string();
149 if !stderr.is_empty() {
150 msg.push_str(&format!("\nsubprocess stderr:\n{stderr}"));
151 }
152 msg.push_str(&format!(
153 "\nDebug manually: apr serve run <model> --port {} --host 127.0.0.1",
154 addr.rsplit(':').next().unwrap_or("19384")
155 ));
156 return Err(AgentError::Driver(DriverError::InferenceFailed(msg)));
157 }
158
159 if let Ok(Some(status)) = self._child.try_wait() {
161 let stderr = self.drain_stderr();
162 let mut msg = format!("apr serve exited with {status} during startup");
163 if !stderr.is_empty() {
164 msg.push_str(&format!("\nsubprocess stderr:\n{stderr}"));
165 }
166 return Err(AgentError::Driver(DriverError::InferenceFailed(msg)));
167 }
168
169 if std::net::TcpStream::connect_timeout(
170 &sock_addr,
171 std::time::Duration::from_millis(200),
172 )
173 .is_ok()
174 {
175 eprintln!("apr serve ready ({:.1}s)", start.elapsed().as_secs_f64());
176 return Ok(());
177 }
178
179 std::thread::sleep(std::time::Duration::from_millis(500));
180 }
181 }
182
183 fn drain_stderr(&mut self) -> String {
185 use std::io::Read;
186 let Some(stderr) = self._child.stderr.as_mut() else {
187 return String::new();
188 };
189 let mut buf = vec![0u8; 2048];
190 let n = stderr.read(&mut buf).unwrap_or(0);
191 let text = String::from_utf8_lossy(&buf[..n]).to_string();
192 let lines: Vec<&str> = text.lines().collect();
194 if lines.len() > 10 {
195 lines[lines.len() - 10..].join("\n")
196 } else {
197 text
198 }
199 }
200
201 fn build_openai_body(&self, request: &CompletionRequest) -> serde_json::Value {
208 let mut messages = Vec::new();
209
210 if let Some(ref system) = request.system {
211 let compact_system = system
215 .find("\n\n## Available Tools")
216 .map(|i| &system[..i])
217 .unwrap_or(system)
218 .to_string();
219
220 messages.push(serde_json::json!({
221 "role": "system",
222 "content": compact_system
223 }));
224 }
225
226 for msg in &request.messages {
227 match msg {
228 Message::User(text) => messages.push(serde_json::json!({
229 "role": "user",
230 "content": text
231 })),
232 Message::Assistant(text) => messages.push(serde_json::json!({
233 "role": "assistant",
234 "content": text
235 })),
236 Message::AssistantToolUse(call) => messages.push(serde_json::json!({
237 "role": "assistant",
238 "content": format!("<tool_call>\n{}\n</tool_call>",
239 serde_json::json!({"name": call.name, "input": call.input}))
240 })),
241 Message::ToolResult(result) => messages.push(serde_json::json!({
242 "role": "user",
243 "content": format!("<tool_result>\n{}\n</tool_result>", result.content)
244 })),
245 _ => {}
246 }
247 }
248
249 let max_tokens = request.max_tokens.min(1024);
256
257 serde_json::json!({
258 "model": self.model_name,
259 "messages": messages,
260 "max_tokens": max_tokens,
261 "temperature": request.temperature,
262 "stream": false
263 })
264 }
265}
266
267#[async_trait]
268impl LlmDriver for AprServeDriver {
269 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError> {
270 let url = format!("{}/v1/chat/completions", self.base_url);
271 let body = self.build_openai_body(&request);
272
273 let client = reqwest::Client::builder()
274 .timeout(std::time::Duration::from_secs(120))
275 .build()
276 .map_err(|e| AgentError::Driver(DriverError::Network(format!("http client: {e}"))))?;
277 let response = client
278 .post(&url)
279 .header("content-type", "application/json")
280 .json(&body)
281 .send()
282 .await
283 .map_err(|e| AgentError::Driver(DriverError::Network(format!("apr serve: {e}"))))?;
284
285 if !response.status().is_success() {
286 let status = response.status().as_u16();
287 let text = response.text().await.unwrap_or_default();
288 return Err(AgentError::Driver(DriverError::Network(format!(
289 "apr serve HTTP {status}: {text}"
290 ))));
291 }
292
293 let json: serde_json::Value = response
294 .json()
295 .await
296 .map_err(|e| AgentError::Driver(DriverError::InferenceFailed(format!("parse: {e}"))))?;
297
298 let raw_text = json["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
300
301 let text = strip_thinking_blocks(&raw_text);
305
306 let usage = json.get("usage").cloned().unwrap_or(serde_json::json!({}));
307 let input_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0);
308 let output_tokens = usage["completion_tokens"].as_u64().unwrap_or(0);
309
310 let (clean_text, tool_calls) = super::realizar::parse_tool_calls_pub(&text);
312
313 let stop_reason =
314 if tool_calls.is_empty() { StopReason::EndTurn } else { StopReason::ToolUse };
315
316 Ok(CompletionResponse {
317 text: clean_text,
318 stop_reason,
319 tool_calls,
320 usage: TokenUsage { input_tokens, output_tokens },
321 })
322 }
323
324 fn context_window(&self) -> usize {
325 self.context_window_size
326 }
327
328 fn privacy_tier(&self) -> PrivacyTier {
329 PrivacyTier::Sovereign
331 }
332}
333
334fn strip_thinking_blocks(text: &str) -> String {
336 let mut result = text.to_string();
337 while let Some(start) = result.find("<think>") {
339 if let Some(end) = result[start..].find("</think>") {
340 result.replace_range(start..start + end + "</think>".len(), "");
341 } else {
342 result.truncate(start);
344 break;
345 }
346 }
347 result = result.replace("</think>", "");
349 result.trim().to_string()
350}
351
352#[cfg(unix)]
363#[allow(unsafe_code)] fn configure_parent_death_signal(cmd: &mut Command) {
365 use std::os::unix::process::CommandExt;
366 unsafe {
369 cmd.pre_exec(|| {
370 if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM, 0, 0, 0) == -1 {
371 return Err(std::io::Error::last_os_error());
372 }
373 if libc::getppid() == 1 {
374 return Err(std::io::Error::other(
375 "parent died before PR_SET_PDEATHSIG took effect",
376 ));
377 }
378 Ok(())
379 });
380 }
381}
382
383#[cfg(not(unix))]
384fn configure_parent_death_signal(_cmd: &mut Command) {
385 }
387
388fn find_apr_binary() -> Result<PathBuf, AgentError> {
390 which::which("apr").map_err(|_| {
391 AgentError::Driver(DriverError::InferenceFailed(
392 "apr binary not found on PATH. Install: cargo install apr-cli".into(),
393 ))
394 })
395}
396
397#[cfg(test)]
398#[path = "apr_serve_tests.rs"]
399mod tests;