use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct ExecResult {
pub task_id: String,
pub tool_call_id: String,
pub command: String,
pub exit_code: Option<i32>,
pub stdout: String,
pub stderr: String,
pub started_at: Instant,
pub completed_at: Instant,
}
pub struct ExecPool {
tasks: RwLock<HashMap<String, Instant>>,
pending_results: RwLock<HashMap<String, Vec<ExecResult>>>,
#[allow(dead_code)]
max_concurrent: usize,
}
impl ExecPool {
pub fn new(max_concurrent: usize) -> Arc<Self> {
Arc::new(Self {
tasks: RwLock::new(HashMap::new()),
pending_results: RwLock::new(HashMap::new()),
max_concurrent,
})
}
pub async fn spawn(
self: &Arc<Self>,
task_id: String,
command: String,
cwd: PathBuf,
timeout_secs: u64,
) {
let started_at = Instant::now();
{
let mut tasks = self.tasks.write().await;
tasks.insert(task_id.clone(), started_at);
}
let pool = Arc::clone(self);
let tid = task_id.clone();
let cmd = command;
let cw = cwd;
tokio::spawn(async move {
let completed_at = Instant::now();
let (shell, shell_args) = if cfg!(target_os = "windows") {
("powershell", vec!["-NoProfile", "-Command"])
} else {
("sh", vec!["-c"])
};
let result = tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs),
tokio::process::Command::new(shell)
.args(&shell_args)
.arg(&cmd)
.current_dir(&cw)
.stdin(std::process::Stdio::null())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true)
.output(),
)
.await;
let (exit_code, stdout, stderr) = match result {
Ok(Ok(output)) => {
let exit_code = output.status.code();
let stdout = String::from_utf8_lossy(&output.stdout).into_owned();
let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
(exit_code, stdout, stderr)
}
Ok(Err(e)) => {
tracing::error!(task_id = %tid, "exec background spawn failed: {}", e);
(None, String::new(), format!("spawn error: {}", e))
}
Err(_) => {
tracing::warn!(task_id = %tid, timeout_secs, "exec background timed out");
(None, String::new(), format!("timed out after {} seconds", timeout_secs))
}
};
tracing::info!(
task_id = %tid,
exit_code = ?exit_code,
stdout_len = stdout.len(),
stderr_len = stderr.len(),
"exec background completed"
);
let mut tasks = pool.tasks.write().await;
tasks.remove(&tid);
drop(tasks);
let exec_result = ExecResult {
task_id: tid.clone(),
tool_call_id: String::new(), command: String::new(), exit_code,
stdout,
stderr,
started_at,
completed_at,
};
pool.add_pending_for_task(&tid, exec_result).await;
});
tracing::info!(task_id = %task_id, "exec background spawned");
}
async fn add_pending_for_task(self: &Arc<Self>, task_id: &str, result: ExecResult) {
let mut pending = self.pending_results.write().await;
pending
.entry(format!("task:{}", task_id))
.or_insert_with(Vec::new)
.push(result);
}
pub async fn is_running(&self, task_id: &str) -> bool {
let tasks = self.tasks.read().await;
let is_running = tasks.contains_key(task_id);
tracing::debug!(
task_id = %task_id,
is_running = is_running,
running_count = tasks.len(),
"exec_pool: is_running check"
);
is_running
}
pub async fn try_collect_by_task(&self, task_id: &str) -> Option<ExecResult> {
tracing::info!(task_id = %task_id, "exec_pool: trying to collect result by task_id");
let mut pending = self.pending_results.write().await;
let key = format!("task:{}", task_id);
if let Some(mut results) = pending.remove(&key) {
let result = results.pop();
tracing::info!(
task_id = %task_id,
found = result.is_some(),
remaining_in_list = results.len(),
"exec_pool: result collected from task key"
);
result
} else {
tracing::debug!(
task_id = %task_id,
pending_keys = ?pending.keys().collect::<Vec<_>>(),
"exec_pool: task key not found, showing all pending keys"
);
None
}
}
pub async fn collect_pending_for_session(
self: &Arc<Self>,
session_key: &str,
) -> Vec<ExecResult> {
tracing::info!(
session_key = %session_key,
"exec_pool: collecting pending results for session"
);
let mut pending = self.pending_results.write().await;
let key = format!("session:{}", session_key);
tracing::debug!(
session_key = %session_key,
key = %key,
all_keys = ?pending.keys().collect::<Vec<_>>(),
"exec_pool: checking pending_results keys"
);
if let Some(results) = pending.remove(&key) {
tracing::info!(
session_key = %session_key,
count = results.len(),
task_ids = ?results.iter().map(|r| &r.task_id).collect::<Vec<_>>(),
"exec_pool: collected results for session"
);
results
} else {
tracing::debug!(
session_key = %session_key,
"exec_pool: no results found for session key"
);
Vec::new()
}
}
pub async fn add_pending_for_session(
self: &Arc<Self>,
session_key: String,
result: ExecResult,
) {
tracing::info!(
session_key = %session_key,
task_id = %result.task_id,
tool_call_id = %result.tool_call_id,
exit_code = ?result.exit_code,
"exec_pool: adding pending result for session"
);
let mut pending = self.pending_results.write().await;
let key = format!("session:{}", session_key);
let entry = pending.entry(key).or_insert_with(Vec::new);
let prev_len = entry.len();
entry.push(result);
tracing::debug!(
session_key = %session_key,
prev_len = prev_len,
new_len = entry.len(),
"exec_pool: result added to pending queue"
);
}
pub async fn running_count(&self) -> usize {
let tasks = self.tasks.read().await;
tasks.len()
}
pub async fn pending_count(&self) -> usize {
let pending = self.pending_results.read().await;
pending.values().map(|v| v.len()).sum()
}
}