use async_trait::async_trait;
use std::path::PathBuf;
use std::process::{Child, Command, Stdio};
use super::{CompletionRequest, CompletionResponse, LlmDriver, Message, ToolCall};
use crate::agent::result::{AgentError, DriverError, StopReason, TokenUsage};
use crate::serve::backends::PrivacyTier;
pub struct AprServeDriver {
base_url: String,
model_name: String,
_child: Child,
context_window_size: usize,
model_size_bytes: Option<u64>,
}
impl Drop for AprServeDriver {
fn drop(&mut self) {
let pid = self._child.id();
#[cfg(unix)]
{
let _ = Command::new("kill")
.args(["-TERM", &pid.to_string()])
.stdout(Stdio::null())
.stderr(Stdio::null())
.status();
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
loop {
match self._child.try_wait() {
Ok(Some(_)) => return, Ok(None) if std::time::Instant::now() < deadline => {
std::thread::sleep(std::time::Duration::from_millis(100));
}
_ => break, }
}
}
let _ = self._child.kill();
let _ = self._child.wait();
}
}
impl AprServeDriver {
pub fn launch(model_path: PathBuf, context_window: Option<usize>) -> Result<Self, AgentError> {
let apr_path = find_apr_binary()?;
let port = 19384 + (std::process::id() % 1000) as u16;
let base_url = format!("http://127.0.0.1:{port}");
let model_name = model_path
.file_stem()
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_else(|| "local".to_string());
let mut cmd = Command::new(&apr_path);
cmd.args([
"serve",
"run",
&model_path.to_string_lossy(),
"--port",
&port.to_string(),
"--host",
"127.0.0.1",
"--gpu",
])
.env("BATCHED_PREFILL", "0")
.stdout(Stdio::piped())
.stderr(Stdio::piped());
configure_parent_death_signal(&mut cmd);
let child = cmd.spawn().map_err(|e| {
AgentError::Driver(DriverError::InferenceFailed(format!(
"failed to spawn apr serve: {e}"
)))
})?;
eprintln!("Launched apr serve on port {port} (pid {})", child.id());
let model_size_bytes = std::fs::metadata(&model_path).ok().map(|m| m.len());
let mut driver = Self {
base_url,
model_name,
_child: child,
context_window_size: context_window.unwrap_or(4096),
model_size_bytes,
};
driver.wait_for_ready()?;
Ok(driver)
}
fn wait_for_ready(&mut self) -> Result<(), AgentError> {
let addr = self.base_url.trim_start_matches("http://").to_string();
let sock_addr: std::net::SocketAddr =
addr.parse().unwrap_or_else(|_| std::net::SocketAddr::from(([127, 0, 0, 1], 19384)));
let timeout_secs = self.resolve_ready_timeout_secs();
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_secs(timeout_secs);
loop {
if start.elapsed() > timeout {
let stderr = self.drain_stderr();
let mut msg = format!(
"apr serve did not become ready within {timeout_secs}s (override via APR_SERVE_READY_TIMEOUT_S)"
);
if !stderr.is_empty() {
msg.push_str(&format!("\nsubprocess stderr:\n{stderr}"));
}
msg.push_str(&format!(
"\nDebug manually: apr serve run <model> --port {} --host 127.0.0.1",
addr.rsplit(':').next().unwrap_or("19384")
));
return Err(AgentError::Driver(DriverError::InferenceFailed(msg)));
}
if let Ok(Some(status)) = self._child.try_wait() {
let stderr = self.drain_stderr();
let mut msg = format!("apr serve exited with {status} during startup");
if !stderr.is_empty() {
msg.push_str(&format!("\nsubprocess stderr:\n{stderr}"));
}
return Err(AgentError::Driver(DriverError::InferenceFailed(msg)));
}
if std::net::TcpStream::connect_timeout(
&sock_addr,
std::time::Duration::from_millis(200),
)
.is_ok()
{
eprintln!("apr serve ready ({:.1}s)", start.elapsed().as_secs_f64());
return Ok(());
}
std::thread::sleep(std::time::Duration::from_millis(500));
}
}
fn resolve_ready_timeout_secs(&self) -> u64 {
let env_override = std::env::var("APR_SERVE_READY_TIMEOUT_S").ok();
compute_ready_timeout_secs(self.model_size_bytes, env_override.as_deref())
}
fn drain_stderr(&mut self) -> String {
use std::io::Read;
let Some(stderr) = self._child.stderr.as_mut() else {
return String::new();
};
let mut buf = vec![0u8; 2048];
let n = stderr.read(&mut buf).unwrap_or(0);
let text = String::from_utf8_lossy(&buf[..n]).to_string();
let lines: Vec<&str> = text.lines().collect();
if lines.len() > 10 {
lines[lines.len() - 10..].join("\n")
} else {
text
}
}
fn build_openai_body(&self, request: &CompletionRequest) -> serde_json::Value {
let mut messages = Vec::new();
if let Some(ref system) = request.system {
let compact_system = system
.find("\n\n## Available Tools")
.map(|i| &system[..i])
.unwrap_or(system)
.to_string();
messages.push(serde_json::json!({
"role": "system",
"content": compact_system
}));
}
for msg in &request.messages {
match msg {
Message::User(text) => messages.push(serde_json::json!({
"role": "user",
"content": text
})),
Message::Assistant(text) => messages.push(serde_json::json!({
"role": "assistant",
"content": text
})),
Message::AssistantToolUse(call) => messages.push(serde_json::json!({
"role": "assistant",
"content": format!("<tool_call>\n{}\n</tool_call>",
serde_json::json!({"name": call.name, "input": call.input}))
})),
Message::ToolResult(result) => messages.push(serde_json::json!({
"role": "user",
"content": format!("<tool_result>\n{}\n</tool_result>", result.content)
})),
_ => {}
}
}
let max_tokens_cap = std::env::var("APR_AGENT_MAX_TOKENS_CAP")
.ok()
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(1024);
let max_tokens = request.max_tokens.min(max_tokens_cap);
let temperature = std::env::var("APR_AGENT_TEMPERATURE")
.ok()
.and_then(|v| v.parse::<f32>().ok())
.unwrap_or(request.temperature);
let top_k = std::env::var("APR_AGENT_TOP_K").ok().and_then(|v| v.parse::<usize>().ok());
let top_p = std::env::var("APR_AGENT_TOP_P").ok().and_then(|v| v.parse::<f32>().ok());
let repeat_penalty =
std::env::var("APR_AGENT_REPEAT_PENALTY").ok().and_then(|v| v.parse::<f32>().ok());
let repeat_last_n =
std::env::var("APR_AGENT_REPEAT_LAST_N").ok().and_then(|v| v.parse::<usize>().ok());
let seed = std::env::var("APR_AGENT_SEED").ok().and_then(|v| v.parse::<u64>().ok());
let mut body = serde_json::json!({
"model": self.model_name,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": false,
});
if let Some(v) = top_k {
body["top_k"] = serde_json::json!(v);
}
if let Some(v) = top_p {
body["top_p"] = serde_json::json!(v);
}
if let Some(v) = repeat_penalty {
body["repeat_penalty"] = serde_json::json!(v);
}
if let Some(v) = repeat_last_n {
body["repeat_last_n"] = serde_json::json!(v);
}
if let Some(v) = seed {
body["seed"] = serde_json::json!(v);
}
body
}
}
#[must_use]
pub fn compute_ready_timeout_secs(
model_size_bytes: Option<u64>,
env_override: Option<&str>,
) -> u64 {
const MIN_TIMEOUT_S: u64 = 1;
const BASELINE_S: u64 = 30;
const SIZE_FREE_BYTES: u64 = 2 * 1024 * 1024 * 1024; const BYTES_PER_EXTRA_SECOND: u64 = 500 * 1024 * 1024;
if let Some(raw) = env_override {
if let Ok(n) = raw.parse::<u64>() {
return n.max(MIN_TIMEOUT_S);
}
}
let Some(bytes) = model_size_bytes else {
return BASELINE_S;
};
let extra_bytes = bytes.saturating_sub(SIZE_FREE_BYTES);
let extra_secs = extra_bytes / BYTES_PER_EXTRA_SECOND;
BASELINE_S.saturating_add(extra_secs)
}
#[async_trait]
impl LlmDriver for AprServeDriver {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError> {
let url = format!("{}/v1/chat/completions", self.base_url);
let body = self.build_openai_body(&request);
let http_timeout_secs = std::env::var("APR_AGENT_HTTP_TIMEOUT_S")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(1800);
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(http_timeout_secs))
.build()
.map_err(|e| AgentError::Driver(DriverError::Network(format!("http client: {e}"))))?;
let response = client
.post(&url)
.header("content-type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| AgentError::Driver(DriverError::Network(format!("apr serve: {e}"))))?;
if !response.status().is_success() {
let status = response.status().as_u16();
let text = response.text().await.unwrap_or_default();
return Err(AgentError::Driver(DriverError::Network(format!(
"apr serve HTTP {status}: {text}"
))));
}
let json: serde_json::Value = response
.json()
.await
.map_err(|e| AgentError::Driver(DriverError::InferenceFailed(format!("parse: {e}"))))?;
let raw_text = json["choices"][0]["message"]["content"].as_str().unwrap_or("").to_string();
let text = strip_thinking_blocks(&raw_text);
let usage = json.get("usage").cloned().unwrap_or(serde_json::json!({}));
let input_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0);
let output_tokens = usage["completion_tokens"].as_u64().unwrap_or(0);
let (clean_text, tool_calls) = super::realizar::parse_tool_calls_pub(&text);
let stop_reason =
if tool_calls.is_empty() { StopReason::EndTurn } else { StopReason::ToolUse };
Ok(CompletionResponse {
text: clean_text,
stop_reason,
tool_calls,
usage: TokenUsage { input_tokens, output_tokens },
})
}
fn context_window(&self) -> usize {
self.context_window_size
}
fn privacy_tier(&self) -> PrivacyTier {
PrivacyTier::Sovereign
}
}
fn strip_thinking_blocks(text: &str) -> String {
let mut result = text.to_string();
while let Some(start) = result.find("<think>") {
if let Some(end) = result[start..].find("</think>") {
result.replace_range(start..start + end + "</think>".len(), "");
} else {
result.truncate(start);
break;
}
}
result = result.replace("</think>", "");
result.trim().to_string()
}
#[cfg(unix)]
#[allow(unsafe_code)] fn configure_parent_death_signal(cmd: &mut Command) {
use std::os::unix::process::CommandExt;
unsafe {
cmd.pre_exec(|| {
if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM, 0, 0, 0) == -1 {
return Err(std::io::Error::last_os_error());
}
if libc::getppid() == 1 {
return Err(std::io::Error::other(
"parent died before PR_SET_PDEATHSIG took effect",
));
}
Ok(())
});
}
}
#[cfg(not(unix))]
fn configure_parent_death_signal(_cmd: &mut Command) {
}
fn find_apr_binary() -> Result<PathBuf, AgentError> {
which::which("apr").map_err(|_| {
AgentError::Driver(DriverError::InferenceFailed(
"apr binary not found on PATH. Install: cargo install apr-cli".into(),
))
})
}
#[cfg(test)]
#[path = "apr_serve_tests.rs"]
mod tests;