use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use bitrouter_config::{AgentConfig, AgentSessionConfig, Distribution};
use bitrouter_core::agents::event::{AgentEvent, PermissionRequestId, PermissionResponse};
use bitrouter_core::agents::provider::AgentProvider;
use bitrouter_core::agents::session::AgentSessionInfo;
use bitrouter_core::errors::{BitrouterError, Result};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc};
use super::connection::{HandshakeResult, spawn_agent_thread};
use super::types::AgentCommand;
pub(crate) struct LaunchCommand {
pub binary: PathBuf,
pub args: Vec<String>,
}
struct SessionEntry {
command_tx: mpsc::Sender<AgentCommand>,
_thread_handle: std::thread::JoinHandle<()>,
last_active: Instant,
_permit: OwnedSemaphorePermit,
}
pub struct AcpAgentProvider {
agent_name: String,
config: AgentConfig,
session_config: AgentSessionConfig,
sessions: Mutex<HashMap<String, SessionEntry>>,
connect_semaphore: Arc<Semaphore>,
}
impl AcpAgentProvider {
pub fn new(agent_name: String, config: AgentConfig) -> Self {
let session_config = config.session.as_ref().cloned().unwrap_or_default();
let connect_semaphore = Arc::new(Semaphore::new(session_config.max_concurrent));
Self {
agent_name,
config,
session_config,
sessions: Mutex::new(HashMap::new()),
connect_semaphore,
}
}
pub fn idle_timeout(&self) -> Duration {
Duration::from_secs(self.session_config.idle_timeout_secs)
}
pub fn max_concurrent(&self) -> usize {
self.session_config.max_concurrent
}
pub fn session_count(&self) -> usize {
self.sessions.lock().map(|s| s.len()).unwrap_or(0)
}
pub async fn cleanup_idle_sessions(&self) -> usize {
let idle_timeout = self.idle_timeout();
let now = Instant::now();
let to_cleanup: Vec<mpsc::Sender<AgentCommand>> = {
let mut sessions = match self.sessions.lock() {
Ok(s) => s,
Err(_) => return 0,
};
let mut senders = Vec::new();
let mut ids_to_remove = Vec::new();
for (id, entry) in sessions.iter() {
if now.duration_since(entry.last_active) > idle_timeout {
senders.push(entry.command_tx.clone());
ids_to_remove.push(id.clone());
}
}
for id in &ids_to_remove {
sessions.remove(id);
}
senders
};
let count = to_cleanup.len();
for tx in to_cleanup {
let _ = tx.send(AgentCommand::Disconnect).await;
}
count
}
}
impl AgentProvider for AcpAgentProvider {
fn agent_name(&self) -> &str {
&self.agent_name
}
fn protocol_name(&self) -> &str {
"acp"
}
async fn connect(&self) -> Result<AgentSessionInfo> {
let permit = Arc::clone(&self.connect_semaphore)
.try_acquire_owned()
.map_err(|_| {
let max = self.session_config.max_concurrent;
BitrouterError::transport(
Some(&self.agent_name),
format!("max concurrent sessions ({max}) reached"),
)
})?;
let launch = resolve_launch(&self.config);
let (handshake_tx, handshake_rx) = tokio::sync::oneshot::channel();
let thread_handle = spawn_agent_thread(
self.agent_name.clone(),
launch.binary,
launch.args,
handshake_tx,
);
let handshake = handshake_rx.await.map_err(|_| {
BitrouterError::transport(
Some(&self.agent_name),
"agent thread exited before handshake",
)
})?;
let HandshakeResult {
session_info,
command_tx,
} = handshake.map_err(|msg| BitrouterError::transport(Some(&self.agent_name), msg))?;
{
let mut sessions = self.sessions.lock().map_err(|_| {
BitrouterError::transport(Some(&self.agent_name), "session lock poisoned")
})?;
sessions.insert(
session_info.session_id.clone(),
SessionEntry {
command_tx,
_thread_handle: thread_handle,
last_active: Instant::now(),
_permit: permit,
},
);
}
Ok(session_info)
}
async fn submit(&self, session_id: &str, text: String) -> Result<mpsc::Receiver<AgentEvent>> {
let command_tx = {
let mut sessions = self.sessions.lock().map_err(|_| {
BitrouterError::transport(Some(&self.agent_name), "session lock poisoned")
})?;
match sessions.get_mut(session_id) {
Some(entry) => {
entry.last_active = Instant::now();
entry.command_tx.clone()
}
None => {
return Err(BitrouterError::transport(
Some(&self.agent_name),
format!("session '{session_id}' not found — call connect() first"),
));
}
}
};
let (reply_tx, reply_rx) = mpsc::channel(64);
command_tx
.send(AgentCommand::Prompt { text, reply_tx })
.await
.map_err(|_| {
BitrouterError::transport(Some(&self.agent_name), "agent thread not running")
})?;
Ok(reply_rx)
}
async fn respond_permission(
&self,
session_id: &str,
request_id: PermissionRequestId,
response: PermissionResponse,
) -> Result<()> {
let command_tx = {
let mut sessions = self.sessions.lock().map_err(|_| {
BitrouterError::transport(Some(&self.agent_name), "session lock poisoned")
})?;
match sessions.get_mut(session_id) {
Some(entry) => {
entry.last_active = Instant::now();
entry.command_tx.clone()
}
None => {
return Err(BitrouterError::transport(
Some(&self.agent_name),
format!("session '{session_id}' not found"),
));
}
}
};
command_tx
.send(AgentCommand::RespondPermission {
request_id,
response,
})
.await
.map_err(|_| {
BitrouterError::transport(Some(&self.agent_name), "agent thread not running")
})?;
Ok(())
}
async fn disconnect(&self, session_id: &str) -> Result<()> {
let command_tx = {
let mut sessions = self.sessions.lock().map_err(|_| {
BitrouterError::transport(Some(&self.agent_name), "session lock poisoned")
})?;
sessions.remove(session_id).map(|entry| entry.command_tx)
};
if let Some(tx) = command_tx {
let _ = tx.send(AgentCommand::Disconnect).await;
}
Ok(())
}
}
impl Drop for AcpAgentProvider {
fn drop(&mut self) {
}
}
fn resolve_launch(config: &AgentConfig) -> LaunchCommand {
if let Some(path) = find_on_path(&config.binary) {
return LaunchCommand {
binary: path,
args: config.args.clone(),
};
}
for dist in &config.distribution {
match dist {
Distribution::Npx { package, args } => {
if find_on_path("npx").is_some() {
let mut full_args = vec![package.clone()];
full_args.extend(args.iter().cloned());
return LaunchCommand {
binary: PathBuf::from("npx"),
args: full_args,
};
}
}
Distribution::Uvx { package, args } => {
if find_on_path("uvx").is_some() {
let mut full_args = vec![package.clone()];
full_args.extend(args.iter().cloned());
return LaunchCommand {
binary: PathBuf::from("uvx"),
args: full_args,
};
}
}
Distribution::Binary { .. } => {
continue;
}
}
}
LaunchCommand {
binary: PathBuf::from(&config.binary),
args: config.args.clone(),
}
}
fn find_on_path(name: &str) -> Option<PathBuf> {
let path = PathBuf::from(name);
if path.components().count() > 1 {
return Some(path);
}
let path_var = std::env::var_os("PATH")?;
for dir in std::env::split_paths(&path_var) {
let candidate = dir.join(name);
if candidate.is_file() {
return Some(candidate);
}
}
None
}
const _: () = {
const fn _assert<T: Send + Sync>() {}
_assert::<AcpAgentProvider>();
};
#[cfg(test)]
mod tests {
use super::*;
use bitrouter_config::{AgentConfig, AgentProtocol, AgentSessionConfig};
fn make_config(session: Option<AgentSessionConfig>) -> AgentConfig {
AgentConfig {
protocol: AgentProtocol::Acp,
binary: "nonexistent-agent-binary".to_owned(),
args: Vec::new(),
enabled: true,
distribution: Vec::new(),
session,
a2a: None,
}
}
#[test]
fn provider_defaults_to_single_session() {
let provider = AcpAgentProvider::new("test".to_owned(), make_config(None));
assert_eq!(provider.max_concurrent(), 1);
assert_eq!(provider.idle_timeout(), Duration::from_secs(600));
assert_eq!(provider.session_count(), 0);
}
#[test]
fn provider_respects_session_config() {
let config = make_config(Some(AgentSessionConfig {
idle_timeout_secs: 120,
max_concurrent: 8,
}));
let provider = AcpAgentProvider::new("test".to_owned(), config);
assert_eq!(provider.max_concurrent(), 8);
assert_eq!(provider.idle_timeout(), Duration::from_secs(120));
}
#[test]
fn provider_agent_name() {
let provider = AcpAgentProvider::new("claude-code".to_owned(), make_config(None));
assert_eq!(provider.agent_name(), "claude-code");
assert_eq!(provider.protocol_name(), "acp");
}
#[tokio::test]
async fn submit_without_connect_errors() {
let provider = AcpAgentProvider::new("test".to_owned(), make_config(None));
let result = provider
.submit("nonexistent-session", "hello".to_owned())
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn disconnect_unknown_session_is_noop() {
let provider = AcpAgentProvider::new("test".to_owned(), make_config(None));
let result = provider.disconnect("nonexistent-session").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn cleanup_idle_sessions_empty_pool() {
let provider = AcpAgentProvider::new("test".to_owned(), make_config(None));
let cleaned = provider.cleanup_idle_sessions().await;
assert_eq!(cleaned, 0);
}
}