use std::collections::HashMap;
use std::future::IntoFuture;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use tokio::sync::Mutex;
use claude_wrapper::Claude;
use crate::cli_parsing::extract_failure_details;
use crate::error::{Error, Result};
use crate::messaging::MessageBus;
use crate::store::PoolStore;
use crate::types::*;
use crate::utils::new_id;
pub(crate) struct PoolInner<S: PoolStore> {
pub(crate) claude: Claude,
pub(crate) config: PoolConfig,
pub(crate) store: S,
pub(crate) total_spend: AtomicU64,
pub(crate) shutdown: AtomicBool,
pub(crate) context: dashmap::DashMap<String, String>,
pub(crate) assignment_lock: Mutex<()>,
pub(crate) worktree_manager: Option<crate::worktree::WorktreeManager>,
pub(crate) chain_progress: dashmap::DashMap<String, crate::chain::ChainProgress>,
pub(crate) message_bus: MessageBus,
pub(crate) created_at_ms: u64,
}
pub struct Pool<S: PoolStore> {
inner: Arc<PoolInner<S>>,
}
impl<S: PoolStore> Clone for Pool<S> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
pub struct PoolBuilder<S: PoolStore> {
claude: Claude,
slot_count: usize,
config: PoolConfig,
store: S,
slot_configs: Vec<SlotConfig>,
}
impl<S: PoolStore + 'static> PoolBuilder<S> {
pub fn slots(mut self, count: usize) -> Self {
self.slot_count = count;
self
}
pub fn config(mut self, config: PoolConfig) -> Self {
self.config = config;
self
}
pub fn slot_config(mut self, config: SlotConfig) -> Self {
self.slot_configs.push(config);
self
}
pub async fn build(self) -> Result<Pool<S>> {
let repo_dir = self
.claude
.working_dir()
.map(|p| p.to_path_buf())
.unwrap_or_else(|| std::env::current_dir().unwrap_or_default());
let worktree_base = self
.config
.worktree_base_dir
.clone()
.unwrap_or_else(|| repo_dir.join(".claude").join("pool-worktrees"));
let worktree_manager = match crate::worktree::WorktreeManager::new_validated(
&repo_dir,
Some(worktree_base),
)
.await
{
Ok(mgr) => Some(mgr),
Err(e) => {
if self.config.worktree_isolation {
return Err(e);
}
tracing::warn!(
repo_dir = %repo_dir.display(),
error = %e,
"worktree manager unavailable; per-chain worktree isolation will fall back to shared CWD"
);
None
}
};
let inner = Arc::new(PoolInner {
claude: self.claude,
config: self.config,
store: self.store,
total_spend: AtomicU64::new(0),
shutdown: AtomicBool::new(false),
context: dashmap::DashMap::new(),
assignment_lock: Mutex::new(()),
worktree_manager,
chain_progress: dashmap::DashMap::new(),
message_bus: MessageBus::default(),
created_at_ms: now_ms(),
});
for i in 0..self.slot_count {
let slot_config = self.slot_configs.get(i).cloned().unwrap_or_default();
let slot_id = SlotId(format!("slot-{i}"));
let worktree_path = if inner.config.worktree_isolation {
if let Some(ref mgr) = inner.worktree_manager {
let path = mgr.create(&slot_id).await?;
Some(path.to_string_lossy().into_owned())
} else {
None
}
} else {
None
};
let record = SlotRecord {
id: slot_id,
state: SlotState::Idle,
config: slot_config,
current_task: None,
session_id: None,
tasks_completed: 0,
cost_microdollars: 0,
restart_count: 0,
worktree_path,
mcp_config_path: None,
};
inner.store.put_slot(record).await?;
}
Ok(Pool { inner })
}
}
impl Pool<crate::store::InMemoryStore> {
pub fn builder(claude: Claude) -> PoolBuilder<crate::store::InMemoryStore> {
PoolBuilder {
claude,
slot_count: 1,
config: PoolConfig::default(),
store: crate::store::InMemoryStore::new(),
slot_configs: Vec::new(),
}
}
}
pub struct RunOptions<'pool, S: PoolStore + 'static> {
pool: &'pool Pool<S>,
prompt: String,
config: Option<TaskOverrides>,
working_dir: Option<std::path::PathBuf>,
on_output: Option<crate::chain::OnOutputChunk>,
}
impl<'pool, S: PoolStore + 'static> RunOptions<'pool, S> {
pub fn config(mut self, config: TaskOverrides) -> Self {
self.config = Some(config);
self
}
pub fn working_dir(mut self, dir: impl Into<std::path::PathBuf>) -> Self {
self.working_dir = Some(dir.into());
self
}
pub fn on_output(mut self, f: impl Fn(&str) + Send + Sync + 'static) -> Self {
self.on_output = Some(Arc::new(f));
self
}
}
impl<'pool, S: PoolStore + 'static> IntoFuture for RunOptions<'pool, S> {
type Output = Result<TaskResult>;
type IntoFuture = Pin<Box<dyn std::future::Future<Output = Result<TaskResult>> + Send + 'pool>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
self.pool
.run_with_config_streaming(
&self.prompt,
self.config,
self.on_output,
self.working_dir,
)
.await
})
}
}
impl<S: PoolStore + 'static> Pool<S> {
pub fn builder_with_store(claude: Claude, store: S) -> PoolBuilder<S> {
PoolBuilder {
claude,
slot_count: 1,
config: PoolConfig::default(),
store,
slot_configs: Vec::new(),
}
}
pub fn run<'pool>(&'pool self, prompt: impl Into<String>) -> RunOptions<'pool, S> {
RunOptions {
pool: self,
prompt: prompt.into(),
config: None,
working_dir: None,
on_output: None,
}
}
#[deprecated(since = "0.1.0", note = "use pool.run(prompt).config(config).await")]
pub async fn run_with_config(
&self,
prompt: &str,
task_config: Option<TaskOverrides>,
) -> Result<TaskResult> {
let mut builder = self.run(prompt);
if let Some(cfg) = task_config {
builder = builder.config(cfg);
}
builder.await
}
#[deprecated(
since = "0.1.0",
note = "use pool.run(prompt).config(config).working_dir(dir).await"
)]
pub async fn run_with_config_and_dir(
&self,
prompt: &str,
task_config: Option<TaskOverrides>,
working_dir: Option<std::path::PathBuf>,
) -> Result<TaskResult> {
let mut builder = self.run(prompt);
if let Some(cfg) = task_config {
builder = builder.config(cfg);
}
if let Some(dir) = working_dir {
builder = builder.working_dir(dir);
}
builder.await
}
pub(crate) async fn run_with_config_streaming(
&self,
prompt: &str,
task_config: Option<TaskOverrides>,
on_output: Option<crate::chain::OnOutputChunk>,
working_dir: Option<std::path::PathBuf>,
) -> Result<TaskResult> {
self.check_shutdown()?;
self.check_budget()?;
self.check_task_budget(task_config.as_ref())?;
let task_id = TaskId(format!("task-{}", new_id()));
let record = TaskRecord::new_pending(task_id.clone(), prompt).with_config(task_config);
self.inner.store.put_task(record).await?;
let (slot_id, slot_config) = self.assign_slot(&task_id).await?;
let result = crate::executor::execute_task_streaming(
&self.inner,
&task_id,
prompt,
&slot_id,
&slot_config,
on_output,
working_dir.as_deref(),
)
.await;
self.release_slot(&slot_id, &task_id, &result).await?;
let task_result = result?;
let mut task = self
.inner
.store
.get_task(&task_id)
.await?
.ok_or_else(|| Error::TaskNotFound(task_id.0.clone()))?;
task.transition_to(TaskState::Completed);
task.result = Some(task_result.clone());
self.inner.store.put_task(task).await?;
Ok(task_result)
}
pub async fn submit(&self, prompt: &str) -> Result<TaskId> {
self.submit_with_config(prompt, None, vec![]).await
}
pub async fn submit_with_config(
&self,
prompt: &str,
task_config: Option<TaskOverrides>,
tags: Vec<String>,
) -> Result<TaskId> {
self.check_shutdown()?;
self.check_budget()?;
self.check_task_budget(task_config.as_ref())?;
let task_id = TaskId(format!("task-{}", new_id()));
let prompt = prompt.to_string();
let record = TaskRecord::new_pending(task_id.clone(), prompt.clone())
.with_tags(tags)
.with_config(task_config);
self.inner.store.put_task(record).await?;
let pool = self.clone();
let tid = task_id.clone();
tokio::spawn(async move {
let task = match pool.inner.store.get_task(&tid).await {
Ok(Some(t)) => t,
_ => return,
};
match pool.assign_slot(&tid).await {
Ok((slot_id, slot_config)) => {
let result = crate::executor::execute_task(
&pool.inner,
&tid,
&prompt,
&slot_id,
&slot_config,
None,
)
.await;
let _ = pool.release_slot(&slot_id, &tid, &result).await;
let mut updated = task;
match result {
Ok(task_result) => {
updated.transition_to(TaskState::Completed);
updated.result = Some(task_result);
}
Err(e) => {
let details = extract_failure_details(&e);
updated.transition_to(TaskState::Failed);
updated.result =
Some(TaskResult::failure(e.to_string()).with_failure_details(
details.failed_command,
details.exit_code,
details.stderr,
));
}
}
let _ = pool.inner.store.put_task(updated).await;
}
Err(e) => {
let mut updated = task;
updated.transition_to(TaskState::Failed);
updated.result = Some(TaskResult::failure(e.to_string()));
let _ = pool.inner.store.put_task(updated).await;
}
}
});
Ok(task_id)
}
pub async fn submit_with_review(
&self,
prompt: &str,
task_config: Option<TaskOverrides>,
tags: Vec<String>,
max_rejections: Option<u32>,
) -> Result<TaskId> {
self.check_shutdown()?;
self.check_budget()?;
self.check_task_budget(task_config.as_ref())?;
let task_id = TaskId(format!("task-{}", new_id()));
let prompt = prompt.to_string();
let max_rej = max_rejections.unwrap_or(3);
let record = TaskRecord::new_pending(task_id.clone(), prompt.clone())
.with_tags(tags)
.with_config(task_config)
.with_review(max_rej);
self.inner.store.put_task(record).await?;
let pool = self.clone();
let tid = task_id.clone();
tokio::spawn(async move {
let task = match pool.inner.store.get_task(&tid).await {
Ok(Some(t)) => t,
_ => return,
};
match pool.assign_slot(&tid).await {
Ok((slot_id, slot_config)) => {
let result = crate::executor::execute_task(
&pool.inner,
&tid,
&task.prompt,
&slot_id,
&slot_config,
None,
)
.await;
let _ = pool.release_slot(&slot_id, &tid, &result).await;
let mut updated = task;
match result {
Ok(task_result) => {
if updated.review_required {
updated.transition_to(TaskState::PendingReview);
} else {
updated.transition_to(TaskState::Completed);
}
updated.result = Some(task_result);
}
Err(e) => {
let details = extract_failure_details(&e);
updated.transition_to(TaskState::Failed);
updated.result =
Some(TaskResult::failure(e.to_string()).with_failure_details(
details.failed_command,
details.exit_code,
details.stderr,
));
}
}
let _ = pool.inner.store.put_task(updated).await;
}
Err(e) => {
let mut updated = task;
updated.transition_to(TaskState::Failed);
updated.result = Some(TaskResult::failure(e.to_string()));
let _ = pool.inner.store.put_task(updated).await;
}
}
});
Ok(task_id)
}
pub async fn approve_result(&self, task_id: &TaskId) -> Result<()> {
let mut task = self
.inner
.store
.get_task(task_id)
.await?
.ok_or_else(|| Error::TaskNotFound(task_id.0.clone()))?;
if task.state != TaskState::PendingReview {
return Err(Error::Store(format!(
"task {} is not pending review (state: {:?})",
task_id.0, task.state
)));
}
task.transition_to(TaskState::Completed);
self.inner.store.put_task(task).await
}
pub async fn reject_result(&self, task_id: &TaskId, feedback: &str) -> Result<()> {
let mut task = self
.inner
.store
.get_task(task_id)
.await?
.ok_or_else(|| Error::TaskNotFound(task_id.0.clone()))?;
if task.state != TaskState::PendingReview {
return Err(Error::Store(format!(
"task {} is not pending review (state: {:?})",
task_id.0, task.state
)));
}
task.rejection_count += 1;
if task.rejection_count >= task.max_rejections {
task.transition_to(TaskState::Failed);
task.result = Some(TaskResult::failure(format!(
"task rejected {} times (max: {}). Last feedback: {}",
task.rejection_count, task.max_rejections, feedback
)));
self.inner.store.put_task(task).await?;
return Ok(());
}
let original = task
.original_prompt
.clone()
.unwrap_or_else(|| task.prompt.clone());
task.prompt = format!(
"{}\n\n--- Rejection feedback (attempt {}/{}) ---\n{}",
original, task.rejection_count, task.max_rejections, feedback
);
task.transition_to(TaskState::Pending);
task.slot_id = None;
task.result = None;
self.inner.store.put_task(task.clone()).await?;
let pool = self.clone();
let tid = task_id.clone();
tokio::spawn(async move {
let task = match pool.inner.store.get_task(&tid).await {
Ok(Some(t)) => t,
_ => return,
};
match pool.assign_slot(&tid).await {
Ok((slot_id, slot_config)) => {
let result = crate::executor::execute_task(
&pool.inner,
&tid,
&task.prompt,
&slot_id,
&slot_config,
None,
)
.await;
let _ = pool.release_slot(&slot_id, &tid, &result).await;
let mut updated = task;
match result {
Ok(task_result) => {
if updated.review_required {
updated.transition_to(TaskState::PendingReview);
} else {
updated.transition_to(TaskState::Completed);
}
updated.result = Some(task_result);
}
Err(e) => {
let details = extract_failure_details(&e);
updated.transition_to(TaskState::Failed);
updated.result =
Some(TaskResult::failure(e.to_string()).with_failure_details(
details.failed_command,
details.exit_code,
details.stderr,
));
}
}
let _ = pool.inner.store.put_task(updated).await;
}
Err(e) => {
let mut updated = task;
updated.transition_to(TaskState::Failed);
updated.result = Some(TaskResult::failure(e.to_string()));
let _ = pool.inner.store.put_task(updated).await;
}
}
});
Ok(())
}
pub async fn result(&self, task_id: &TaskId) -> Result<Option<TaskResult>> {
let task = self
.inner
.store
.get_task(task_id)
.await?
.ok_or_else(|| Error::TaskNotFound(task_id.0.clone()))?;
match task.state {
TaskState::Completed | TaskState::Failed | TaskState::PendingReview => Ok(task.result),
_ => Ok(None),
}
}
pub async fn cancel(&self, task_id: &TaskId) -> Result<()> {
let mut task = self
.inner
.store
.get_task(task_id)
.await?
.ok_or_else(|| Error::TaskNotFound(task_id.0.clone()))?;
match task.state {
TaskState::Pending | TaskState::PendingReview => {
task.transition_to(TaskState::Cancelled);
self.inner.store.put_task(task).await?;
Ok(())
}
TaskState::Running => {
task.transition_to(TaskState::Cancelled);
self.inner.store.put_task(task).await?;
Ok(())
}
_ => Ok(()), }
}
pub async fn claim(&self, slot_id: &SlotId) -> Result<Option<TaskId>> {
self.check_shutdown()?;
let slot = self
.inner
.store
.get_slot(slot_id)
.await?
.ok_or_else(|| Error::SlotNotFound(slot_id.0.clone()))?;
if slot.state != SlotState::Idle {
return Ok(None);
}
let pending = self
.inner
.store
.list_tasks(&TaskFilter {
state: Some(TaskState::Pending),
..Default::default()
})
.await?;
let task = match pending.into_iter().find(|t| t.slot_id.is_none()) {
Some(t) => t,
None => return Ok(None),
};
let task_id = task.id.clone();
let prompt = task.prompt.clone();
let slot_config = slot.config.clone();
let mut updated_task = task;
updated_task.transition_to(TaskState::Running);
updated_task.slot_id = Some(slot_id.clone());
self.inner.store.put_task(updated_task.clone()).await?;
let mut updated_slot = slot;
updated_slot.state = SlotState::Busy;
updated_slot.current_task = Some(task_id.clone());
self.inner.store.put_slot(updated_slot).await?;
let pool = self.clone();
let tid = task_id.clone();
let sid = slot_id.clone();
tokio::spawn(async move {
let result =
crate::executor::execute_task(&pool.inner, &tid, &prompt, &sid, &slot_config, None)
.await;
let _ = pool.release_slot(&sid, &tid, &result).await;
if let Ok(Some(mut task)) = pool.inner.store.get_task(&tid).await {
match result {
Ok(task_result) => {
task.transition_to(TaskState::Completed);
task.result = Some(task_result);
}
Err(e) => {
let details = extract_failure_details(&e);
task.transition_to(TaskState::Failed);
task.result =
Some(TaskResult::failure(e.to_string()).with_failure_details(
details.failed_command,
details.exit_code,
details.stderr,
));
}
}
let _ = pool.inner.store.put_task(task).await;
}
});
Ok(Some(task_id))
}
pub async fn cancel_chain(&self, task_id: &TaskId) -> Result<()> {
let mut task = self
.inner
.store
.get_task(task_id)
.await?
.ok_or_else(|| Error::TaskNotFound(task_id.0.clone()))?;
match task.state {
TaskState::Running | TaskState::Pending => {
task.transition_to(TaskState::Cancelled);
self.inner.store.put_task(task).await?;
if let Some(mut progress) = self.inner.chain_progress.get_mut(&task_id.0) {
progress.status = crate::chain::ChainStatus::Cancelled;
}
Ok(())
}
_ => Ok(()), }
}
pub async fn fan_out(&self, prompts: &[&str]) -> Result<Vec<TaskResult>> {
self.check_shutdown()?;
self.check_budget()?;
let mut handles = Vec::with_capacity(prompts.len());
for prompt in prompts {
let pool = self.clone();
let prompt = prompt.to_string();
handles.push(tokio::spawn(async move { pool.run(&prompt).await }));
}
let mut results = Vec::with_capacity(handles.len());
for handle in handles {
results.push(
handle
.await
.map_err(|e| Error::Store(format!("task join error: {e}")))?,
);
}
results.into_iter().collect()
}
pub async fn submit_chain(
&self,
steps: Vec<crate::chain::ChainStep>,
options: crate::chain::ChainOptions,
) -> Result<TaskId> {
self.check_shutdown()?;
self.check_budget()?;
let task_id = TaskId(format!("chain-{}", new_id()));
let isolation = options.isolation;
let record =
TaskRecord::new_pending(task_id.clone(), format!("chain: {} steps", steps.len()))
.with_tags(options.tags);
self.inner.store.put_task(record).await?;
let progress = crate::chain::ChainProgress {
total_steps: steps.len(),
current_step: None,
current_step_name: None,
current_step_partial_output: None,
current_step_started_at: None,
completed_steps: vec![],
status: crate::chain::ChainStatus::Running,
};
self.inner
.chain_progress
.insert(task_id.0.clone(), progress);
if let Some(mut task) = self.inner.store.get_task(&task_id).await? {
task.transition_to(TaskState::Running);
self.inner.store.put_task(task).await?;
}
let chain_working_dir = match isolation {
crate::chain::ChainIsolation::Worktree => {
if let Some(ref mgr) = self.inner.worktree_manager {
match mgr.create_for_chain(&task_id).await {
Ok(path) => Some(path),
Err(e) => {
tracing::warn!(
task_id = %task_id.0,
error = %e,
"failed to create chain worktree, falling back to slot dir"
);
None
}
}
} else {
None
}
}
crate::chain::ChainIsolation::Clone => {
if let Some(ref mgr) = self.inner.worktree_manager {
match mgr.create_clone_for_chain(&task_id).await {
Ok(path) => Some(path),
Err(e) => {
tracing::warn!(
task_id = %task_id.0,
error = %e,
"failed to create chain clone, falling back to slot dir"
);
None
}
}
} else {
None
}
}
crate::chain::ChainIsolation::None => None,
};
let pool = self.clone();
let tid = task_id.clone();
tokio::spawn(async move {
let result = crate::chain::execute_chain_with_progress(
&pool,
&steps,
Some(&tid),
chain_working_dir.as_deref(),
)
.await;
if chain_working_dir.is_some()
&& let Some(ref mgr) = pool.inner.worktree_manager
{
match isolation {
crate::chain::ChainIsolation::Worktree => {
if let Err(e) = mgr.remove_chain(&tid).await {
tracing::warn!(
task_id = %tid.0,
error = %e,
"failed to clean up chain worktree"
);
}
}
crate::chain::ChainIsolation::Clone => {
if let Err(e) = mgr.remove_clone(&tid).await {
tracing::warn!(
task_id = %tid.0,
error = %e,
"failed to clean up chain clone"
);
}
}
crate::chain::ChainIsolation::None => {}
}
}
if let Some(mut task) = pool.inner.store.get_task(&tid).await.ok().flatten() {
match result {
Ok(chain_result) => {
let success = chain_result.success;
if success {
task.transition_to(TaskState::Completed);
} else {
task.transition_to(TaskState::Failed);
}
let output = serde_json::to_string(&chain_result).unwrap_or_default();
task.result = Some(if success {
TaskResult::success(output, chain_result.total_cost_microdollars, 0)
} else {
let mut r = TaskResult::failure(output);
r.cost_microdollars = chain_result.total_cost_microdollars;
r
});
}
Err(e) => {
let details = extract_failure_details(&e);
task.transition_to(TaskState::Failed);
task.result =
Some(TaskResult::failure(e.to_string()).with_failure_details(
details.failed_command,
details.exit_code,
details.stderr,
));
}
}
let _ = pool.inner.store.put_task(task).await;
}
});
Ok(task_id)
}
pub async fn fan_out_chains(
&self,
chains: Vec<Vec<crate::chain::ChainStep>>,
options: crate::chain::ChainOptions,
) -> Result<Vec<TaskId>> {
self.check_shutdown()?;
self.check_budget()?;
let mut handles = Vec::with_capacity(chains.len());
for chain_steps in chains {
let pool = self.clone();
let options = options.clone();
handles.push(tokio::spawn(async move {
pool.submit_chain(chain_steps, options).await
}));
}
let mut task_ids = Vec::with_capacity(handles.len());
for handle in handles {
match handle.await {
Ok(Ok(task_id)) => task_ids.push(task_id),
Ok(Err(e)) => {
tracing::warn!("failed to submit chain: {}", e);
}
Err(e) => {
tracing::warn!("chain submission task panicked: {}", e);
}
}
}
Ok(task_ids)
}
pub fn chain_progress(&self, task_id: &TaskId) -> Option<crate::chain::ChainProgress> {
self.inner
.chain_progress
.get(&task_id.0)
.map(|v| v.value().clone())
}
pub fn list_chain_progress(&self) -> Vec<(TaskId, crate::chain::ChainProgress)> {
self.inner
.chain_progress
.iter()
.map(|entry| (TaskId(entry.key().clone()), entry.value().clone()))
.collect()
}
pub(crate) async fn set_chain_progress(
&self,
task_id: &TaskId,
progress: crate::chain::ChainProgress,
) {
self.inner
.chain_progress
.insert(task_id.0.clone(), progress);
}
pub(crate) fn append_chain_partial_output(&self, task_id: &TaskId, chunk: &str) {
if let Some(mut progress) = self.inner.chain_progress.get_mut(&task_id.0)
&& let Some(ref mut partial) = progress.current_step_partial_output
{
partial.push_str(chunk);
}
}
pub fn set_context(&self, key: impl Into<String>, value: impl Into<String>) {
self.inner.context.insert(key.into(), value.into());
}
pub fn get_context(&self, key: &str) -> Option<String> {
self.inner.context.get(key).map(|v| v.value().clone())
}
pub fn delete_context(&self, key: &str) -> Option<String> {
self.inner.context.remove(key).map(|(_, v)| v)
}
pub fn list_context(&self) -> Vec<(String, String)> {
self.inner
.context
.iter()
.map(|r| (r.key().clone(), r.value().clone()))
.collect()
}
pub fn send_message(&self, from: SlotId, to: SlotId, content: String) -> String {
self.inner.message_bus.send(from, to, content)
}
pub async fn broadcast_message(&self, from: SlotId, content: String) -> Result<Vec<String>> {
let slots = self.inner.store.list_slots().await?;
let recipients: Vec<SlotId> = slots.into_iter().map(|s| s.id).collect();
Ok(self.inner.message_bus.broadcast(from, &recipients, content))
}
pub async fn find_slots(
&self,
name: Option<&str>,
role: Option<&str>,
state: Option<SlotState>,
) -> Result<Vec<SlotRecord>> {
let slots = self.inner.store.list_slots().await?;
Ok(slots
.into_iter()
.filter(|s| {
if let Some(n) = name
&& s.config.name.as_deref() != Some(n)
{
return false;
}
if let Some(r) = role
&& s.config.role.as_deref() != Some(r)
{
return false;
}
if let Some(st) = state
&& s.state != st
{
return false;
}
true
})
.collect())
}
pub fn read_messages(&self, slot_id: &SlotId) -> Vec<crate::messaging::Message> {
self.inner.message_bus.read(slot_id)
}
pub fn peek_messages(&self, slot_id: &SlotId) -> Vec<crate::messaging::Message> {
self.inner.message_bus.peek(slot_id)
}
pub fn message_count(&self, slot_id: &SlotId) -> usize {
self.inner.message_bus.count(slot_id)
}
pub async fn drain(&self) -> Result<DrainSummary> {
self.inner.shutdown.store(true, Ordering::SeqCst);
loop {
let running = self
.inner
.store
.list_tasks(&TaskFilter {
state: Some(TaskState::Running),
..Default::default()
})
.await?;
if running.is_empty() {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
let slots = self.inner.store.list_slots().await?;
let mut total_cost = 0u64;
let mut total_tasks = 0u64;
let slot_ids: Vec<_> = slots.iter().map(|w| w.id.clone()).collect();
for mut slot in slots {
total_cost += slot.cost_microdollars;
total_tasks += slot.tasks_completed;
slot.state = SlotState::Stopped;
self.inner.store.put_slot(slot).await?;
}
if let Some(ref mgr) = self.inner.worktree_manager {
mgr.cleanup_all(&slot_ids).await?;
}
for slot_id in &slot_ids {
if let Some(slot) = self.inner.store.get_slot(slot_id).await?
&& let Some(ref path) = slot.mcp_config_path
&& let Err(e) = std::fs::remove_file(path)
{
tracing::warn!(
slot_id = %slot_id.0,
path = %path.display(),
error = %e,
"failed to clean up slot MCP config"
);
}
}
Ok(DrainSummary {
total_cost_microdollars: total_cost,
total_tasks_completed: total_tasks,
})
}
pub async fn status(&self) -> Result<PoolStatus> {
let slots = self.inner.store.list_slots().await?;
let idle = slots.iter().filter(|w| w.state == SlotState::Idle).count();
let busy = slots.iter().filter(|w| w.state == SlotState::Busy).count();
let all_tasks = self.inner.store.list_tasks(&TaskFilter::default()).await?;
let running_tasks = all_tasks
.iter()
.filter(|t| t.state == TaskState::Running)
.count();
let pending_tasks = all_tasks
.iter()
.filter(|t| t.state == TaskState::Pending)
.count();
let pending_review_tasks = all_tasks
.iter()
.filter(|t| t.state == TaskState::PendingReview)
.count();
let completed_tasks = all_tasks
.iter()
.filter(|t| t.state == TaskState::Completed)
.count();
let failed_tasks = all_tasks
.iter()
.filter(|t| t.state == TaskState::Failed)
.count();
let cancelled_tasks = all_tasks
.iter()
.filter(|t| t.state == TaskState::Cancelled)
.count();
Ok(PoolStatus {
total_slots: slots.len(),
idle_slots: idle,
busy_slots: busy,
running_tasks,
pending_tasks,
pending_review_tasks,
completed_tasks,
failed_tasks,
cancelled_tasks,
total_spend_microdollars: self.inner.total_spend.load(Ordering::Relaxed),
budget_microdollars: self.inner.config.budget_microdollars,
shutdown: self.inner.shutdown.load(Ordering::Relaxed),
})
}
pub fn store(&self) -> &S {
&self.inner.store
}
pub fn config(&self) -> &PoolConfig {
&self.inner.config
}
pub fn claude(&self) -> &Claude {
&self.inner.claude
}
pub async fn session_metrics(&self, filter: &MetricsFilter) -> Result<SessionMetrics> {
let all_tasks = self.inner.store.list_tasks(&TaskFilter::default()).await?;
let filtered: Vec<&TaskRecord> = all_tasks
.iter()
.filter(|t| {
if let Some(since) = filter.since_ms
&& t.created_at_ms.unwrap_or(0) < since
{
return false;
}
if let Some(until) = filter.until_ms
&& t.created_at_ms.unwrap_or(0) > until
{
return false;
}
if let Some(ref tags) = filter.tags
&& !tags.iter().any(|tag| t.tags.contains(tag))
{
return false;
}
if let Some(ref model) = filter.model {
match t.result {
Some(ref result) if result.model.as_deref() == Some(model) => {}
_ => return false,
}
}
true
})
.collect();
let mut metrics = SessionMetrics {
session_start_ms: self.inner.created_at_ms,
session_duration_ms: now_ms().saturating_sub(self.inner.created_at_ms),
total_tasks: filtered.len() as u64,
..Default::default()
};
let mut elapsed_values: Vec<u64> = Vec::new();
let mut total_turns: u64 = 0;
let mut completed_count: u64 = 0;
let mut model_accum: HashMap<String, (u64, u64, u64, u64)> = HashMap::new();
for task in &filtered {
match task.state {
TaskState::Pending => metrics.pending_tasks += 1,
TaskState::Running => metrics.running_tasks += 1,
TaskState::Completed | TaskState::PendingReview => metrics.completed_tasks += 1,
TaskState::Failed => metrics.failed_tasks += 1,
TaskState::Cancelled => metrics.cancelled_tasks += 1,
}
if let Some(ref result) = task.result {
metrics.total_spend_microdollars += result.cost_microdollars;
if result.cost_microdollars > metrics.max_cost_microdollars {
metrics.max_cost_microdollars = result.cost_microdollars;
}
if task.state == TaskState::Completed || task.state == TaskState::PendingReview {
completed_count += 1;
total_turns += result.turns_used as u64;
if result.elapsed_ms > 0 {
elapsed_values.push(result.elapsed_ms);
}
if result.elapsed_ms > metrics.max_elapsed_ms {
metrics.max_elapsed_ms = result.elapsed_ms;
}
}
if let Some(ref model) = result.model {
*metrics.tasks_by_model.entry(model.clone()).or_insert(0) += 1;
let acc = model_accum.entry(model.clone()).or_default();
acc.0 += 1;
acc.1 += result.cost_microdollars;
acc.2 += result.elapsed_ms;
acc.3 += result.turns_used as u64;
}
}
}
if completed_count > 0 {
metrics.avg_cost_microdollars = metrics.total_spend_microdollars / completed_count;
metrics.avg_turns = total_turns as f64 / completed_count as f64;
}
if !elapsed_values.is_empty() {
let sum: u64 = elapsed_values.iter().sum();
metrics.avg_elapsed_ms = sum / elapsed_values.len() as u64;
metrics.min_elapsed_ms = elapsed_values.iter().copied().min().unwrap_or(0);
elapsed_values.sort_unstable();
let mid = elapsed_values.len() / 2;
metrics.median_elapsed_ms = if elapsed_values.len().is_multiple_of(2) && mid > 0 {
(elapsed_values[mid - 1] + elapsed_values[mid]) / 2
} else {
elapsed_values[mid]
};
}
metrics.model_breakdown = model_accum
.into_iter()
.map(|(model, (count, cost, elapsed, turns))| ModelMetrics {
model,
task_count: count,
total_cost_microdollars: cost,
avg_cost_microdollars: if count > 0 { cost / count } else { 0 },
avg_elapsed_ms: if count > 0 { elapsed / count } else { 0 },
total_turns: turns,
})
.collect();
metrics
.model_breakdown
.sort_by(|a, b| b.total_cost_microdollars.cmp(&a.total_cost_microdollars));
Ok(metrics)
}
pub fn start_supervisor(&self) -> Option<crate::supervisor::SupervisorHandle> {
if !self.inner.config.supervisor_enabled {
return None;
}
Some(crate::supervisor::spawn_supervisor(
self.clone(),
self.inner.config.supervisor_interval_secs,
))
}
pub async fn scale_up(&self, count: usize) -> Result<usize> {
if count == 0 {
return Ok(self.inner.store.list_slots().await?.len());
}
let current_slots = self.inner.store.list_slots().await?;
let current_count = current_slots.len();
let new_count = current_count + count;
if new_count > self.inner.config.scaling.max_slots {
return Err(Error::Store(format!(
"cannot scale up to {} slots: exceeds max_slots ({})",
new_count, self.inner.config.scaling.max_slots
)));
}
let existing_ids: Vec<usize> = current_slots
.iter()
.filter_map(|w| w.id.0.strip_prefix("slot-").and_then(|s| s.parse().ok()))
.collect();
let mut next_id = existing_ids.iter().max().unwrap_or(&0) + 1;
for _ in 0..count {
let slot_id = SlotId(format!("slot-{next_id}"));
next_id += 1;
let worktree_path = if self.inner.config.worktree_isolation {
if let Some(ref mgr) = self.inner.worktree_manager {
let path = mgr.create(&slot_id).await?;
Some(path.to_string_lossy().into_owned())
} else {
None
}
} else {
None
};
let record = SlotRecord {
id: slot_id,
state: SlotState::Idle,
config: SlotConfig::default(),
current_task: None,
session_id: None,
tasks_completed: 0,
cost_microdollars: 0,
restart_count: 0,
worktree_path,
mcp_config_path: None,
};
self.inner.store.put_slot(record).await?;
}
Ok(new_count)
}
pub async fn scale_down(&self, count: usize) -> Result<usize> {
if count == 0 {
return Ok(self.inner.store.list_slots().await?.len());
}
let mut slots = self.inner.store.list_slots().await?;
let current_count = slots.len();
let new_count = current_count.saturating_sub(count);
if new_count < self.inner.config.scaling.min_slots {
return Err(Error::Store(format!(
"cannot scale down to {} slots: below min_slots ({})",
new_count, self.inner.config.scaling.min_slots
)));
}
slots.sort_by_key(|w| std::cmp::Reverse(w.tasks_completed));
let slots_to_remove = &slots[..count];
let timeout = std::time::Duration::from_secs(30);
for slot in slots_to_remove {
let deadline = std::time::Instant::now() + timeout;
loop {
if let Some(w) = self.inner.store.get_slot(&slot.id).await? {
if w.state != SlotState::Busy {
break;
}
if std::time::Instant::now() >= deadline {
break;
}
} else {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
if let Some(ref mgr) = self.inner.worktree_manager
&& slot.worktree_path.is_some()
{
let _ = mgr.cleanup_all(std::slice::from_ref(&slot.id)).await;
}
self.inner.store.delete_slot(&slot.id).await?;
}
Ok(new_count)
}
pub async fn set_target_slots(&self, target: usize) -> Result<usize> {
let current = self.inner.store.list_slots().await?.len();
if target > current {
self.scale_up(target - current).await
} else if target < current {
self.scale_down(current - target).await
} else {
Ok(current)
}
}
fn check_shutdown(&self) -> Result<()> {
if self.inner.shutdown.load(Ordering::SeqCst) {
Err(Error::PoolShutdown)
} else {
Ok(())
}
}
fn check_budget(&self) -> Result<()> {
if let Some(limit) = self.inner.config.budget_microdollars {
let spent = self.inner.total_spend.load(Ordering::Relaxed);
if spent >= limit {
return Err(Error::BudgetExhausted {
spent_microdollars: spent,
limit_microdollars: limit,
});
}
}
Ok(())
}
fn check_task_budget(&self, task_config: Option<&TaskOverrides>) -> Result<()> {
let task_budget_usd = task_config.and_then(|t| t.max_budget_usd);
let pool_limit = self.inner.config.budget_microdollars;
if let (Some(task_budget), Some(limit)) = (task_budget_usd, pool_limit) {
let spent = self.inner.total_spend.load(Ordering::Relaxed);
let remaining = limit.saturating_sub(spent);
let task_microdollars = (task_budget * 1_000_000.0) as u64;
if task_microdollars > remaining {
return Err(Error::TaskBudgetExceedsRemaining {
task_budget_usd: task_budget,
remaining_usd: remaining as f64 / 1_000_000.0,
});
}
}
Ok(())
}
async fn wait_for_idle_slot_with_timeout(&self, timeout_secs: u64) -> Result<SlotRecord> {
use std::time::{Duration, Instant};
let deadline = Instant::now() + Duration::from_secs(timeout_secs);
let mut backoff_ms = 10u64;
const MAX_BACKOFF_MS: u64 = 500;
loop {
self.check_shutdown()?;
let slots = self.inner.store.list_slots().await?;
for slot in slots {
if slot.state == SlotState::Idle {
return Ok(slot);
}
}
if Instant::now() >= deadline {
return Err(Error::NoSlotAvailable { timeout_secs });
}
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = std::cmp::min((backoff_ms as f64 * 1.5) as u64, MAX_BACKOFF_MS);
}
}
async fn assign_slot(&self, task_id: &TaskId) -> Result<(SlotId, SlotConfig)> {
let _lock = self.inner.assignment_lock.lock().await;
let timeout = self.inner.config.slot_assignment_timeout_secs;
let mut slot = self.wait_for_idle_slot_with_timeout(timeout).await?;
let config = slot.config.clone();
slot.state = SlotState::Busy;
slot.current_task = Some(task_id.clone());
self.inner.store.put_slot(slot.clone()).await?;
if let Some(mut task) = self.inner.store.get_task(task_id).await? {
task.transition_to(TaskState::Running);
task.slot_id = Some(slot.id.clone());
self.inner.store.put_task(task).await?;
}
Ok((slot.id, config))
}
async fn release_slot(
&self,
slot_id: &SlotId,
task_id: &TaskId,
result: &std::result::Result<TaskResult, Error>,
) -> Result<()> {
if let Some(mut slot) = self.inner.store.get_slot(slot_id).await? {
slot.state = SlotState::Idle;
slot.current_task = None;
if let Ok(task_result) = result {
slot.tasks_completed += 1;
slot.cost_microdollars += task_result.cost_microdollars;
slot.session_id = task_result.session_id.clone();
self.inner
.total_spend
.fetch_add(task_result.cost_microdollars, Ordering::Relaxed);
if let Some(task_record) = self.inner.store.get_task(task_id).await?
&& let Some(ref config) = task_record.config
&& let Some(max_budget_usd) = config.max_budget_usd
{
let max_microdollars = (max_budget_usd * 1_000_000.0) as u64;
if task_result.cost_microdollars > max_microdollars {
tracing::warn!(
task_id = %task_id.0,
cost_microdollars = task_result.cost_microdollars,
budget_microdollars = max_microdollars,
"task exceeded its per-task budget cap"
);
let mut updated_task = task_record;
if let Some(ref mut r) = updated_task.result {
r.budget_exceeded = true;
}
self.inner.store.put_task(updated_task).await?;
}
}
}
self.inner.store.put_slot(slot).await?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DrainSummary {
pub total_cost_microdollars: u64,
pub total_tasks_completed: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PoolStatus {
pub total_slots: usize,
pub idle_slots: usize,
pub busy_slots: usize,
pub running_tasks: usize,
pub pending_tasks: usize,
pub pending_review_tasks: usize,
pub completed_tasks: usize,
pub failed_tasks: usize,
pub cancelled_tasks: usize,
pub total_spend_microdollars: u64,
pub budget_microdollars: Option<u64>,
pub shutdown: bool,
}
use serde::{Deserialize, Serialize};
#[cfg(test)]
mod tests {
use super::*;
use crate::cli_parsing::{
detect_permission_prompt, extract_failure_details, extract_tool_name,
};
fn mock_claude() -> Claude {
Claude::builder().binary("/usr/bin/false").build().unwrap()
}
#[tokio::test]
async fn build_pool_registers_slots() {
let pool = Pool::builder(mock_claude()).slots(3).build().await.unwrap();
let slots = pool.store().list_slots().await.unwrap();
assert_eq!(slots.len(), 3);
for slot in &slots {
assert_eq!(slot.state, SlotState::Idle);
}
}
#[tokio::test]
async fn pool_with_slot_configs() {
let pool = Pool::builder(mock_claude())
.slots(2)
.slot_config(SlotConfig {
model: Some("opus".into()),
role: Some("reviewer".into()),
..Default::default()
})
.build()
.await
.unwrap();
let slots = pool.store().list_slots().await.unwrap();
let w0 = slots.iter().find(|w| w.id.0 == "slot-0").unwrap();
let w1 = slots.iter().find(|w| w.id.0 == "slot-1").unwrap();
assert_eq!(w0.config.model.as_deref(), Some("opus"));
assert_eq!(w0.config.role.as_deref(), Some("reviewer"));
assert!(w1.config.model.is_none());
}
#[tokio::test]
async fn context_operations() {
let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
pool.set_context("repo", "claude-wrapper");
pool.set_context("branch", "main");
assert_eq!(pool.get_context("repo").as_deref(), Some("claude-wrapper"));
assert_eq!(pool.list_context().len(), 2);
pool.delete_context("branch");
assert!(pool.get_context("branch").is_none());
}
#[tokio::test]
async fn drain_marks_slots_stopped() {
let pool = Pool::builder(mock_claude()).slots(2).build().await.unwrap();
let summary = pool.drain().await.unwrap();
assert_eq!(summary.total_tasks_completed, 0);
let slots = pool.store().list_slots().await.unwrap();
for w in &slots {
assert_eq!(w.state, SlotState::Stopped);
}
assert!(pool.run("hello").await.is_err());
}
#[tokio::test]
async fn budget_enforcement() {
let pool = Pool::builder(mock_claude())
.slots(1)
.config(PoolConfig {
budget_microdollars: Some(100),
..Default::default()
})
.build()
.await
.unwrap();
pool.inner.total_spend.store(100, Ordering::Relaxed);
let err = pool.run("hello").await.unwrap_err();
assert!(matches!(err, Error::BudgetExhausted { .. }));
}
#[tokio::test]
async fn status_snapshot() {
let pool = Pool::builder(mock_claude())
.slots(3)
.config(PoolConfig {
budget_microdollars: Some(1_000_000),
..Default::default()
})
.build()
.await
.unwrap();
let status = pool.status().await.unwrap();
assert_eq!(status.total_slots, 3);
assert_eq!(status.idle_slots, 3);
assert_eq!(status.busy_slots, 0);
assert_eq!(status.budget_microdollars, Some(1_000_000));
assert!(!status.shutdown);
}
#[tokio::test]
async fn no_idle_slots_timeout() {
let pool = Pool::builder(mock_claude())
.slots(1)
.config(PoolConfig {
slot_assignment_timeout_secs: 1,
..Default::default()
})
.build()
.await
.unwrap();
let mut slots = pool.store().list_slots().await.unwrap();
slots[0].state = SlotState::Busy;
pool.store().put_slot(slots[0].clone()).await.unwrap();
let err = pool.run("hello").await.unwrap_err();
assert!(matches!(err, Error::NoSlotAvailable { timeout_secs: 1 }));
}
#[tokio::test]
async fn fan_out_with_excess_prompts() {
let pool = Pool::builder(mock_claude()).slots(2).build().await.unwrap();
let prompts = vec!["prompt1", "prompt2", "prompt3", "prompt4"];
let results = pool.fan_out(&prompts).await;
match results {
Ok(_) | Err(_) => {
}
}
}
#[tokio::test]
async fn slot_identity_fields_persisted() {
let pool = Pool::builder(mock_claude())
.slots(1)
.slot_config(SlotConfig {
name: Some("reviewer".into()),
role: Some("code_review".into()),
description: Some("Reviews PRs for correctness and style".into()),
..Default::default()
})
.build()
.await
.unwrap();
let slots = pool.store().list_slots().await.unwrap();
let slot = slots.iter().find(|w| w.id.0 == "slot-0").unwrap();
assert_eq!(slot.config.name.as_deref(), Some("reviewer"));
assert_eq!(slot.config.role.as_deref(), Some("code_review"));
assert_eq!(
slot.config.description.as_deref(),
Some("Reviews PRs for correctness and style")
);
}
#[tokio::test]
async fn find_slots_filters_by_name_role_state() {
let pool = Pool::builder(mock_claude())
.slots(1)
.slot_config(SlotConfig {
name: Some("reviewer".into()),
role: Some("code_review".into()),
..Default::default()
})
.build()
.await
.unwrap();
pool.scale_up(1).await.unwrap();
let mut slots = pool.store().list_slots().await.unwrap();
if let Some(s) = slots.iter_mut().find(|s| s.id.0 == "slot-1") {
s.config.name = Some("writer".into());
s.config.role = Some("implementation".into());
pool.store().put_slot(s.clone()).await.unwrap();
}
let found = pool.find_slots(Some("reviewer"), None, None).await.unwrap();
assert_eq!(found.len(), 1);
assert_eq!(found[0].id.0, "slot-0");
let found = pool
.find_slots(None, Some("implementation"), None)
.await
.unwrap();
assert_eq!(found.len(), 1);
assert_eq!(found[0].id.0, "slot-1");
let found = pool
.find_slots(None, None, Some(SlotState::Idle))
.await
.unwrap();
assert_eq!(found.len(), 2);
let found = pool.find_slots(None, None, None).await.unwrap();
assert_eq!(found.len(), 2);
let found = pool
.find_slots(Some("nonexistent"), None, None)
.await
.unwrap();
assert!(found.is_empty());
}
#[tokio::test]
async fn broadcast_sends_to_all_except_sender() {
let pool = Pool::builder(mock_claude()).slots(3).build().await.unwrap();
let from = SlotId("slot-0".into());
let ids = pool
.broadcast_message(from.clone(), "hello everyone".into())
.await
.unwrap();
assert_eq!(ids.len(), 2);
assert_eq!(pool.message_count(&SlotId("slot-1".into())), 1);
assert_eq!(pool.message_count(&SlotId("slot-2".into())), 1);
assert_eq!(pool.message_count(&from), 0); }
#[tokio::test]
async fn scale_up_increases_slot_count() {
let pool = Pool::builder(mock_claude()).slots(2).build().await.unwrap();
let initial_count = pool.store().list_slots().await.unwrap().len();
assert_eq!(initial_count, 2);
let new_count = pool.scale_up(3).await.unwrap();
assert_eq!(new_count, 5);
let slots = pool.store().list_slots().await.unwrap();
assert_eq!(slots.len(), 5);
for slot in slots.iter().skip(2) {
assert_eq!(slot.state, SlotState::Idle);
}
}
#[tokio::test]
async fn scale_up_respects_max_slots() {
let mut config = PoolConfig::default();
config.scaling.max_slots = 4;
let pool = Pool::builder(mock_claude())
.slots(2)
.config(config)
.build()
.await
.unwrap();
let result = pool.scale_up(5).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("exceeds max_slots")
);
assert_eq!(pool.store().list_slots().await.unwrap().len(), 2);
}
#[tokio::test]
async fn scale_down_reduces_slot_count() {
let pool = Pool::builder(mock_claude()).slots(4).build().await.unwrap();
let initial = pool.store().list_slots().await.unwrap().len();
assert_eq!(initial, 4);
let new_count = pool.scale_down(2).await.unwrap();
assert_eq!(new_count, 2);
assert_eq!(pool.store().list_slots().await.unwrap().len(), 2);
}
#[tokio::test]
async fn scale_down_respects_min_slots() {
let mut config = PoolConfig::default();
config.scaling.min_slots = 2;
let pool = Pool::builder(mock_claude())
.slots(3)
.config(config)
.build()
.await
.unwrap();
let result = pool.scale_down(2).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("below min_slots"));
assert_eq!(pool.store().list_slots().await.unwrap().len(), 3);
}
#[tokio::test]
async fn set_target_slots_scales_up() {
let pool = Pool::builder(mock_claude()).slots(2).build().await.unwrap();
let new_count = pool.set_target_slots(5).await.unwrap();
assert_eq!(new_count, 5);
assert_eq!(pool.store().list_slots().await.unwrap().len(), 5);
}
#[tokio::test]
async fn set_target_slots_scales_down() {
let pool = Pool::builder(mock_claude()).slots(5).build().await.unwrap();
let new_count = pool.set_target_slots(2).await.unwrap();
assert_eq!(new_count, 2);
assert_eq!(pool.store().list_slots().await.unwrap().len(), 2);
}
#[tokio::test]
async fn set_target_slots_no_op_when_equal() {
let pool = Pool::builder(mock_claude()).slots(3).build().await.unwrap();
let new_count = pool.set_target_slots(3).await.unwrap();
assert_eq!(new_count, 3);
}
#[tokio::test]
async fn fan_out_chains_submits_all_chains() {
let pool = Pool::builder(mock_claude()).slots(2).build().await.unwrap();
let options = crate::chain::ChainOptions {
tags: vec![],
..Default::default()
};
let chain1 = vec![crate::chain::ChainStep {
name: "step1".into(),
action: crate::chain::StepAction::Prompt {
prompt: "prompt 1".into(),
},
config: None,
failure_policy: crate::chain::StepFailurePolicy {
retries: 0,
recovery_prompt: None,
},
output_vars: Default::default(),
}];
let chain2 = vec![crate::chain::ChainStep {
name: "step1".into(),
action: crate::chain::StepAction::Prompt {
prompt: "prompt 2".into(),
},
config: None,
failure_policy: crate::chain::StepFailurePolicy {
retries: 0,
recovery_prompt: None,
},
output_vars: Default::default(),
}];
let chains = vec![chain1, chain2];
let task_ids = pool.fan_out_chains(chains, options).await.unwrap();
assert_eq!(task_ids.len(), 2);
assert_ne!(task_ids[0].0, task_ids[1].0);
for task_id in &task_ids {
let task = pool.store().get_task(task_id).await.unwrap();
assert!(task.is_some());
}
}
#[test]
fn detect_allow_bash_in_stderr() {
let err = claude_wrapper::Error::CommandFailed {
command: "claude --print".into(),
exit_code: 1,
stdout: String::new(),
stderr: "Allow Bash tool? (y/n)".into(),
working_dir: None,
};
let result = detect_permission_prompt(&err, "slot-1");
assert!(result.is_some());
let err = result.unwrap();
match err {
Error::PermissionPromptDetected {
tool_name, slot_id, ..
} => {
assert_eq!(tool_name, "Bash");
assert_eq!(slot_id, "slot-1");
}
other => panic!("expected PermissionPromptDetected, got: {other}"),
}
}
#[test]
fn detect_wants_to_use_pattern() {
let err = claude_wrapper::Error::CommandFailed {
command: "claude --print".into(),
exit_code: 1,
stdout: String::new(),
stderr: "Claude wants to use Edit tool.".into(),
working_dir: None,
};
let result = detect_permission_prompt(&err, "slot-2");
assert!(result.is_some());
match result.unwrap() {
Error::PermissionPromptDetected { tool_name, .. } => {
assert_eq!(tool_name, "Edit");
}
other => panic!("expected PermissionPromptDetected, got: {other}"),
}
}
#[test]
fn no_detection_on_clean_stderr() {
let err = claude_wrapper::Error::CommandFailed {
command: "claude --print".into(),
exit_code: 1,
stdout: String::new(),
stderr: "some unrelated error output".into(),
working_dir: None,
};
assert!(detect_permission_prompt(&err, "slot-1").is_none());
}
#[test]
fn no_detection_on_empty_stderr() {
let err = claude_wrapper::Error::CommandFailed {
command: "claude --print".into(),
exit_code: 1,
stdout: String::new(),
stderr: String::new(),
working_dir: None,
};
assert!(detect_permission_prompt(&err, "slot-1").is_none());
}
#[test]
fn no_detection_on_timeout() {
let err = claude_wrapper::Error::Timeout {
timeout_seconds: 30,
};
assert!(detect_permission_prompt(&err, "slot-1").is_none());
}
#[test]
fn extract_tool_name_unknown_fallback() {
assert_eq!(extract_tool_name("some random text"), "unknown");
}
#[test]
fn extract_tool_name_allow_prefix() {
assert_eq!(extract_tool_name("Allow Write tool?"), "Write");
}
#[test]
fn extract_tool_name_wants_to_use() {
assert_eq!(
extract_tool_name("Claude wants to use Bash, proceed?"),
"Bash"
);
}
#[test]
fn extract_details_from_command_failed() {
let err = Error::Wrapper(claude_wrapper::Error::CommandFailed {
command: "claude --print -p test".into(),
exit_code: 1,
stdout: String::new(),
stderr: "error: something went wrong".into(),
working_dir: None,
});
let details = extract_failure_details(&err);
assert_eq!(
details.failed_command.as_deref(),
Some("claude --print -p test")
);
assert_eq!(details.exit_code, Some(1));
assert_eq!(
details.stderr.as_deref(),
Some("error: something went wrong")
);
}
#[test]
fn extract_details_from_non_command_error() {
let err = Error::TaskNotFound("task-123".into());
let details = extract_failure_details(&err);
assert!(details.failed_command.is_none());
assert!(details.exit_code.is_none());
assert!(details.stderr.is_none());
}
#[test]
fn extract_details_empty_stderr_is_none() {
let err = Error::Wrapper(claude_wrapper::Error::CommandFailed {
command: "claude --print".into(),
exit_code: 2,
stdout: String::new(),
stderr: String::new(),
working_dir: None,
});
let details = extract_failure_details(&err);
assert_eq!(details.failed_command.as_deref(), Some("claude --print"));
assert_eq!(details.exit_code, Some(2));
assert!(details.stderr.is_none());
}
#[tokio::test]
async fn cancel_chain_marks_task_cancelled() {
let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
let task_id = TaskId("chain-test-1".into());
let record = TaskRecord {
id: task_id.clone(),
prompt: "chain: 3 steps".into(),
state: TaskState::Running,
slot_id: None,
result: None,
tags: vec![],
config: None,
review_required: false,
max_rejections: 3,
rejection_count: 0,
original_prompt: None,
created_at_ms: None,
started_at_ms: None,
completed_at_ms: None,
};
pool.store().put_task(record).await.unwrap();
pool.set_chain_progress(
&task_id,
crate::chain::ChainProgress {
total_steps: 3,
current_step: Some(1),
current_step_name: Some("implement".into()),
current_step_partial_output: None,
current_step_started_at: None,
completed_steps: vec![],
status: crate::chain::ChainStatus::Running,
},
)
.await;
pool.cancel_chain(&task_id).await.unwrap();
let task = pool.store().get_task(&task_id).await.unwrap().unwrap();
assert_eq!(task.state, TaskState::Cancelled);
let progress = pool.chain_progress(&task_id).unwrap();
assert_eq!(progress.status, crate::chain::ChainStatus::Cancelled);
}
#[tokio::test]
async fn cancel_chain_noop_for_completed() {
let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
let task_id = TaskId("chain-done".into());
let record = TaskRecord {
id: task_id.clone(),
prompt: "chain: 1 steps".into(),
state: TaskState::Completed,
slot_id: None,
result: Some(TaskResult {
output: "done".into(),
success: true,
cost_microdollars: 100,
turns_used: 0,
elapsed_ms: 0,
model: None,
session_id: None,
failed_command: None,
exit_code: None,
stderr: None,
budget_exceeded: false,
}),
tags: vec![],
config: None,
review_required: false,
max_rejections: 3,
rejection_count: 0,
original_prompt: None,
created_at_ms: None,
started_at_ms: None,
completed_at_ms: None,
};
pool.store().put_task(record).await.unwrap();
pool.cancel_chain(&task_id).await.unwrap();
let task = pool.store().get_task(&task_id).await.unwrap().unwrap();
assert_eq!(task.state, TaskState::Completed);
}
#[tokio::test]
async fn cancel_chain_not_found() {
let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
let result = pool.cancel_chain(&TaskId("nonexistent".into())).await;
assert!(matches!(result, Err(Error::TaskNotFound(_))));
}
#[tokio::test]
async fn append_chain_partial_output_accumulates() {
let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
let task_id = TaskId("chain-test".into());
let progress = crate::chain::ChainProgress {
total_steps: 2,
current_step: Some(0),
current_step_name: Some("plan".into()),
current_step_partial_output: Some(String::new()),
current_step_started_at: Some(1700000000),
completed_steps: vec![],
status: crate::chain::ChainStatus::Running,
};
pool.set_chain_progress(&task_id, progress).await;
pool.append_chain_partial_output(&task_id, "hello ");
pool.append_chain_partial_output(&task_id, "world");
let progress = pool.chain_progress(&task_id).unwrap();
assert_eq!(
progress.current_step_partial_output.as_deref(),
Some("hello world")
);
}
#[tokio::test]
async fn append_chain_partial_output_noop_when_none() {
let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
let task_id = TaskId("chain-test-2".into());
let progress = crate::chain::ChainProgress {
total_steps: 1,
current_step: None,
current_step_name: None,
current_step_partial_output: None,
current_step_started_at: None,
completed_steps: vec![],
status: crate::chain::ChainStatus::Completed,
};
pool.set_chain_progress(&task_id, progress).await;
pool.append_chain_partial_output(&task_id, "ignored");
let progress = pool.chain_progress(&task_id).unwrap();
assert!(progress.current_step_partial_output.is_none());
}
#[tokio::test]
async fn append_chain_partial_output_noop_for_missing_task() {
let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
let task_id = TaskId("nonexistent".into());
pool.append_chain_partial_output(&task_id, "ignored");
}
#[tokio::test]
async fn task_budget_exceeds_remaining_pool_budget() {
let pool = Pool::builder(mock_claude())
.slots(1)
.config(PoolConfig {
budget_microdollars: Some(1_000_000), ..Default::default()
})
.build()
.await
.unwrap();
pool.inner.total_spend.store(800_000, Ordering::Relaxed);
let task_config = TaskOverrides {
max_budget_usd: Some(0.50),
..Default::default()
};
let err = pool
.submit_with_config("expensive task", Some(task_config), vec![])
.await
.unwrap_err();
assert!(matches!(err, Error::TaskBudgetExceedsRemaining { .. }));
}
#[tokio::test]
async fn task_budget_within_remaining_pool_budget() {
let pool = Pool::builder(mock_claude())
.slots(1)
.config(PoolConfig {
budget_microdollars: Some(1_000_000), ..Default::default()
})
.build()
.await
.unwrap();
pool.inner.total_spend.store(400_000, Ordering::Relaxed);
let task_config = TaskOverrides {
max_budget_usd: Some(0.50),
..Default::default()
};
let result = pool
.submit_with_config("task", Some(task_config), vec![])
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn task_budget_check_skipped_without_pool_budget() {
let pool = Pool::builder(mock_claude())
.slots(1)
.config(PoolConfig {
budget_microdollars: None, ..Default::default()
})
.build()
.await
.unwrap();
let task_config = TaskOverrides {
max_budget_usd: Some(100.0),
..Default::default()
};
let result = pool
.submit_with_config("task", Some(task_config), vec![])
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn budget_exceeded_flag_set_on_result() {
let result = TaskResult::success("done", 500_000, 3);
assert!(!result.budget_exceeded);
let mut result_with_flag = result;
result_with_flag.budget_exceeded = true;
assert!(result_with_flag.budget_exceeded);
}
#[tokio::test]
async fn budget_exceeded_serde_roundtrip() {
let mut result = TaskResult::success("done", 500_000, 3);
result.budget_exceeded = true;
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("budget_exceeded"));
let parsed: TaskResult = serde_json::from_str(&json).unwrap();
assert!(parsed.budget_exceeded);
let result_ok = TaskResult::success("done", 100, 1);
let json_ok = serde_json::to_string(&result_ok).unwrap();
assert!(!json_ok.contains("budget_exceeded"));
}
#[tokio::test]
async fn task_budget_error_message() {
let err = Error::TaskBudgetExceedsRemaining {
task_budget_usd: 0.50,
remaining_usd: 0.20,
};
let msg = err.to_string();
assert!(msg.contains("0.50"));
assert!(msg.contains("0.20"));
}
}