1use std::sync::Arc;
27use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
28
29use tokio::sync::Mutex;
30
31use claude_wrapper::Claude;
32use claude_wrapper::types::OutputFormat;
33
34use crate::config::ResolvedConfig;
35use crate::error::{Error, Result};
36use crate::skill::SkillRegistry;
37use crate::store::PoolStore;
38use crate::types::*;
39
40struct PoolInner<S: PoolStore> {
42 claude: Claude,
43 config: PoolConfig,
44 store: S,
45 total_spend: AtomicU64,
46 shutdown: AtomicBool,
47 context: dashmap::DashMap<String, String>,
49 assignment_lock: Mutex<()>,
51 worktree_manager: Option<crate::worktree::WorktreeManager>,
53 chain_progress: dashmap::DashMap<String, crate::chain::ChainProgress>,
55}
56
57pub struct Pool<S: PoolStore> {
62 inner: Arc<PoolInner<S>>,
63}
64
65impl<S: PoolStore> Clone for Pool<S> {
67 fn clone(&self) -> Self {
68 Self {
69 inner: Arc::clone(&self.inner),
70 }
71 }
72}
73
74pub struct PoolBuilder<S: PoolStore> {
76 claude: Claude,
77 slot_count: usize,
78 config: PoolConfig,
79 store: S,
80 slot_configs: Vec<SlotConfig>,
81}
82
83impl<S: PoolStore + 'static> PoolBuilder<S> {
84 pub fn slots(mut self, count: usize) -> Self {
86 self.slot_count = count;
87 self
88 }
89
90 pub fn config(mut self, config: PoolConfig) -> Self {
92 self.config = config;
93 self
94 }
95
96 pub fn slot_config(mut self, config: SlotConfig) -> Self {
102 self.slot_configs.push(config);
103 self
104 }
105
106 pub async fn build(self) -> Result<Pool<S>> {
108 let repo_dir = self
110 .claude
111 .working_dir()
112 .map(|p| p.to_path_buf())
113 .unwrap_or_else(|| std::env::current_dir().unwrap_or_default());
114
115 let worktree_manager = match crate::worktree::WorktreeManager::new_validated(
119 &repo_dir, None,
120 )
121 .await
122 {
123 Ok(mgr) => Some(mgr),
124 Err(e) => {
125 if self.config.worktree_isolation {
126 return Err(e);
127 }
128 tracing::warn!(
129 repo_dir = %repo_dir.display(),
130 error = %e,
131 "worktree manager unavailable; per-chain worktree isolation will fall back to shared CWD"
132 );
133 None
134 }
135 };
136
137 let inner = Arc::new(PoolInner {
138 claude: self.claude,
139 config: self.config,
140 store: self.store,
141 total_spend: AtomicU64::new(0),
142 shutdown: AtomicBool::new(false),
143 context: dashmap::DashMap::new(),
144 assignment_lock: Mutex::new(()),
145 worktree_manager,
146 chain_progress: dashmap::DashMap::new(),
147 });
148
149 for i in 0..self.slot_count {
151 let slot_config = self.slot_configs.get(i).cloned().unwrap_or_default();
152
153 let slot_id = SlotId(format!("slot-{i}"));
154
155 let worktree_path = if inner.config.worktree_isolation {
157 if let Some(ref mgr) = inner.worktree_manager {
158 let path = mgr.create(&slot_id).await?;
159 Some(path.to_string_lossy().into_owned())
160 } else {
161 None
162 }
163 } else {
164 None
165 };
166
167 let record = SlotRecord {
168 id: slot_id,
169 state: SlotState::Idle,
170 config: slot_config,
171 current_task: None,
172 session_id: None,
173 tasks_completed: 0,
174 cost_microdollars: 0,
175 restart_count: 0,
176 worktree_path,
177 mcp_config_path: None,
178 };
179 inner.store.put_slot(record).await?;
180 }
181
182 Ok(Pool { inner })
183 }
184}
185
186impl Pool<crate::store::InMemoryStore> {
187 pub fn builder(claude: Claude) -> PoolBuilder<crate::store::InMemoryStore> {
189 PoolBuilder {
190 claude,
191 slot_count: 1,
192 config: PoolConfig::default(),
193 store: crate::store::InMemoryStore::new(),
194 slot_configs: Vec::new(),
195 }
196 }
197}
198
199impl<S: PoolStore + 'static> Pool<S> {
200 pub fn builder_with_store(claude: Claude, store: S) -> PoolBuilder<S> {
202 PoolBuilder {
203 claude,
204 slot_count: 1,
205 config: PoolConfig::default(),
206 store,
207 slot_configs: Vec::new(),
208 }
209 }
210
211 pub async fn run(&self, prompt: &str) -> Result<TaskResult> {
216 self.run_with_config(prompt, None).await
217 }
218
219 pub async fn run_with_config(
221 &self,
222 prompt: &str,
223 task_config: Option<SlotConfig>,
224 ) -> Result<TaskResult> {
225 self.run_with_config_and_dir(prompt, task_config, None)
226 .await
227 }
228
229 pub async fn run_with_config_and_dir(
231 &self,
232 prompt: &str,
233 task_config: Option<SlotConfig>,
234 working_dir: Option<std::path::PathBuf>,
235 ) -> Result<TaskResult> {
236 self.check_shutdown()?;
237 self.check_budget()?;
238
239 let task_id = TaskId(format!("task-{}", new_id()));
240
241 let record = TaskRecord {
242 id: task_id.clone(),
243 prompt: prompt.to_string(),
244 state: TaskState::Pending,
245 slot_id: None,
246 result: None,
247 tags: vec![],
248 config: task_config,
249 };
250 self.inner.store.put_task(record).await?;
251
252 let (slot_id, slot_config) = self.assign_slot(&task_id).await?;
253 let result = self
254 .execute_task(
255 &task_id,
256 prompt,
257 &slot_id,
258 &slot_config,
259 working_dir.as_deref(),
260 )
261 .await;
262
263 self.release_slot(&slot_id, &task_id, &result).await?;
264
265 let task_result = result?;
266 let mut task = self
268 .inner
269 .store
270 .get_task(&task_id)
271 .await?
272 .ok_or_else(|| Error::TaskNotFound(task_id.0.clone()))?;
273 task.state = TaskState::Completed;
274 task.result = Some(task_result.clone());
275 self.inner.store.put_task(task).await?;
276
277 Ok(task_result)
278 }
279
280 pub(crate) async fn run_with_config_streaming(
287 &self,
288 prompt: &str,
289 task_config: Option<SlotConfig>,
290 on_output: Option<crate::chain::OnOutputChunk>,
291 working_dir: Option<std::path::PathBuf>,
292 ) -> Result<TaskResult> {
293 self.check_shutdown()?;
294 self.check_budget()?;
295
296 let task_id = TaskId(format!("task-{}", new_id()));
297
298 let record = TaskRecord {
299 id: task_id.clone(),
300 prompt: prompt.to_string(),
301 state: TaskState::Pending,
302 slot_id: None,
303 result: None,
304 tags: vec![],
305 config: task_config,
306 };
307 self.inner.store.put_task(record).await?;
308
309 let (slot_id, slot_config) = self.assign_slot(&task_id).await?;
310 let result = self
311 .execute_task_streaming(
312 &task_id,
313 prompt,
314 &slot_id,
315 &slot_config,
316 on_output,
317 working_dir.as_deref(),
318 )
319 .await;
320
321 self.release_slot(&slot_id, &task_id, &result).await?;
322
323 let task_result = result?;
324 let mut task = self
325 .inner
326 .store
327 .get_task(&task_id)
328 .await?
329 .ok_or_else(|| Error::TaskNotFound(task_id.0.clone()))?;
330 task.state = TaskState::Completed;
331 task.result = Some(task_result.clone());
332 self.inner.store.put_task(task).await?;
333
334 Ok(task_result)
335 }
336
337 pub async fn submit(&self, prompt: &str) -> Result<TaskId> {
341 self.submit_with_config(prompt, None, vec![]).await
342 }
343
344 pub async fn submit_with_config(
346 &self,
347 prompt: &str,
348 task_config: Option<SlotConfig>,
349 tags: Vec<String>,
350 ) -> Result<TaskId> {
351 self.check_shutdown()?;
352 self.check_budget()?;
353
354 let task_id = TaskId(format!("task-{}", new_id()));
355 let prompt = prompt.to_string();
356
357 let record = TaskRecord {
358 id: task_id.clone(),
359 prompt: prompt.clone(),
360 state: TaskState::Pending,
361 slot_id: None,
362 result: None,
363 tags,
364 config: task_config,
365 };
366 self.inner.store.put_task(record).await?;
367
368 let pool = self.clone();
370 let tid = task_id.clone();
371 tokio::spawn(async move {
372 let task = match pool.inner.store.get_task(&tid).await {
373 Ok(Some(t)) => t,
374 _ => return,
375 };
376
377 match pool.assign_slot(&tid).await {
378 Ok((slot_id, slot_config)) => {
379 let result = pool
380 .execute_task(&tid, &prompt, &slot_id, &slot_config, None)
381 .await;
382
383 let _ = pool.release_slot(&slot_id, &tid, &result).await;
384
385 let mut updated = task;
386 match result {
387 Ok(task_result) => {
388 updated.state = TaskState::Completed;
389 updated.result = Some(task_result);
390 }
391 Err(e) => {
392 updated.state = TaskState::Failed;
393 updated.result = Some(TaskResult {
394 output: e.to_string(),
395 success: false,
396 cost_microdollars: 0,
397 turns_used: 0,
398 session_id: None,
399 });
400 }
401 }
402 let _ = pool.inner.store.put_task(updated).await;
403 }
404 Err(e) => {
405 let mut updated = task;
406 updated.state = TaskState::Failed;
407 updated.result = Some(TaskResult {
408 output: e.to_string(),
409 success: false,
410 cost_microdollars: 0,
411 turns_used: 0,
412 session_id: None,
413 });
414 let _ = pool.inner.store.put_task(updated).await;
415 }
416 }
417 });
418
419 Ok(task_id)
420 }
421
422 pub async fn result(&self, task_id: &TaskId) -> Result<Option<TaskResult>> {
426 let task = self
427 .inner
428 .store
429 .get_task(task_id)
430 .await?
431 .ok_or_else(|| Error::TaskNotFound(task_id.0.clone()))?;
432
433 match task.state {
434 TaskState::Completed | TaskState::Failed => Ok(task.result),
435 _ => Ok(None),
436 }
437 }
438
439 pub async fn cancel(&self, task_id: &TaskId) -> Result<()> {
441 let mut task = self
442 .inner
443 .store
444 .get_task(task_id)
445 .await?
446 .ok_or_else(|| Error::TaskNotFound(task_id.0.clone()))?;
447
448 match task.state {
449 TaskState::Pending => {
450 task.state = TaskState::Cancelled;
451 self.inner.store.put_task(task).await?;
452 Ok(())
453 }
454 TaskState::Running => {
455 task.state = TaskState::Cancelled;
457 self.inner.store.put_task(task).await?;
458 Ok(())
459 }
460 _ => Ok(()), }
462 }
463
464 pub async fn cancel_chain(&self, task_id: &TaskId) -> Result<()> {
470 let mut task = self
471 .inner
472 .store
473 .get_task(task_id)
474 .await?
475 .ok_or_else(|| Error::TaskNotFound(task_id.0.clone()))?;
476
477 match task.state {
478 TaskState::Running | TaskState::Pending => {
479 task.state = TaskState::Cancelled;
480 self.inner.store.put_task(task).await?;
481 if let Some(mut progress) = self.inner.chain_progress.get_mut(&task_id.0) {
483 progress.status = crate::chain::ChainStatus::Cancelled;
484 }
485 Ok(())
486 }
487 _ => Ok(()), }
489 }
490
491 pub async fn fan_out(&self, prompts: &[&str]) -> Result<Vec<TaskResult>> {
496 self.check_shutdown()?;
497 self.check_budget()?;
498
499 let mut handles = Vec::with_capacity(prompts.len());
500
501 for prompt in prompts {
502 let pool = self.clone();
503 let prompt = prompt.to_string();
504 handles.push(tokio::spawn(async move { pool.run(&prompt).await }));
505 }
506
507 let mut results = Vec::with_capacity(handles.len());
508 for handle in handles {
509 results.push(
510 handle
511 .await
512 .map_err(|e| Error::Store(format!("task join error: {e}")))?,
513 );
514 }
515
516 results.into_iter().collect()
517 }
518
519 pub async fn submit_chain(
525 &self,
526 steps: Vec<crate::chain::ChainStep>,
527 skills: &SkillRegistry,
528 options: crate::chain::ChainOptions,
529 ) -> Result<TaskId> {
530 self.check_shutdown()?;
531 self.check_budget()?;
532
533 let task_id = TaskId(format!("chain-{}", new_id()));
534
535 let isolation = options.isolation;
536
537 let record = TaskRecord {
538 id: task_id.clone(),
539 prompt: format!("chain: {} steps", steps.len()),
540 state: TaskState::Pending,
541 slot_id: None,
542 result: None,
543 tags: options.tags,
544 config: None,
545 };
546 self.inner.store.put_task(record).await?;
547
548 let progress = crate::chain::ChainProgress {
550 total_steps: steps.len(),
551 current_step: None,
552 current_step_name: None,
553 current_step_partial_output: None,
554 current_step_started_at: None,
555 completed_steps: vec![],
556 status: crate::chain::ChainStatus::Running,
557 };
558 self.inner
559 .chain_progress
560 .insert(task_id.0.clone(), progress);
561
562 if let Some(mut task) = self.inner.store.get_task(&task_id).await? {
564 task.state = TaskState::Running;
565 self.inner.store.put_task(task).await?;
566 }
567
568 let chain_working_dir = if isolation == crate::chain::ChainIsolation::Worktree {
570 if let Some(ref mgr) = self.inner.worktree_manager {
571 match mgr.create_for_chain(&task_id).await {
572 Ok(path) => Some(path),
573 Err(e) => {
574 tracing::warn!(
575 task_id = %task_id.0,
576 error = %e,
577 "failed to create chain worktree, falling back to slot dir"
578 );
579 None
580 }
581 }
582 } else {
583 None
584 }
585 } else {
586 None
587 };
588
589 let pool = self.clone();
590 let tid = task_id.clone();
591 let skills = skills.clone();
592 tokio::spawn(async move {
593 let result = crate::chain::execute_chain_with_progress(
594 &pool,
595 &skills,
596 &steps,
597 Some(&tid),
598 chain_working_dir.as_deref(),
599 )
600 .await;
601
602 if chain_working_dir.is_some()
604 && let Some(ref mgr) = pool.inner.worktree_manager
605 && let Err(e) = mgr.remove_chain(&tid).await
606 {
607 tracing::warn!(
608 task_id = %tid.0,
609 error = %e,
610 "failed to clean up chain worktree"
611 );
612 }
613
614 if let Some(mut task) = pool.inner.store.get_task(&tid).await.ok().flatten() {
616 match result {
617 Ok(chain_result) => {
618 let success = chain_result.success;
619 task.state = if success {
620 TaskState::Completed
621 } else {
622 TaskState::Failed
623 };
624 task.result = Some(TaskResult {
625 output: serde_json::to_string(&chain_result).unwrap_or_default(),
626 success,
627 cost_microdollars: chain_result.total_cost_microdollars,
628 turns_used: 0,
629 session_id: None,
630 });
631 }
632 Err(e) => {
633 task.state = TaskState::Failed;
634 task.result = Some(TaskResult {
635 output: e.to_string(),
636 success: false,
637 cost_microdollars: 0,
638 turns_used: 0,
639 session_id: None,
640 });
641 }
642 }
643 let _ = pool.inner.store.put_task(task).await;
644 }
645 });
646
647 Ok(task_id)
648 }
649
650 pub async fn fan_out_chains(
655 &self,
656 chains: Vec<Vec<crate::chain::ChainStep>>,
657 skills: &SkillRegistry,
658 options: crate::chain::ChainOptions,
659 ) -> Result<Vec<TaskId>> {
660 self.check_shutdown()?;
661 self.check_budget()?;
662
663 let mut handles = Vec::with_capacity(chains.len());
664
665 for chain_steps in chains {
666 let pool = self.clone();
667 let skills = skills.clone();
668 let options = options.clone();
669 handles.push(tokio::spawn(async move {
670 pool.submit_chain(chain_steps, &skills, options).await
671 }));
672 }
673
674 let mut task_ids = Vec::with_capacity(handles.len());
675 for handle in handles {
676 match handle.await {
677 Ok(Ok(task_id)) => task_ids.push(task_id),
678 Ok(Err(e)) => {
679 tracing::warn!("failed to submit chain: {}", e);
681 }
682 Err(e) => {
683 tracing::warn!("chain submission task panicked: {}", e);
684 }
685 }
686 }
687
688 Ok(task_ids)
689 }
690
691 pub async fn submit_workflow(
696 &self,
697 workflow_name: &str,
698 arguments: std::collections::HashMap<String, String>,
699 skills: &SkillRegistry,
700 workflows: &crate::workflow::WorkflowRegistry,
701 tags: Vec<String>,
702 ) -> Result<TaskId> {
703 let workflow = workflows
705 .get(workflow_name)
706 .ok_or_else(|| Error::Store(format!("workflow '{}' not found", workflow_name)))?;
707
708 let steps = workflow.instantiate(&arguments)?;
709
710 let options = crate::chain::ChainOptions {
712 tags,
713 ..Default::default()
714 };
715 self.submit_chain(steps, skills, options).await
716 }
717
718 pub fn chain_progress(&self, task_id: &TaskId) -> Option<crate::chain::ChainProgress> {
722 self.inner
723 .chain_progress
724 .get(&task_id.0)
725 .map(|v| v.value().clone())
726 }
727
728 pub(crate) async fn set_chain_progress(
730 &self,
731 task_id: &TaskId,
732 progress: crate::chain::ChainProgress,
733 ) {
734 self.inner
735 .chain_progress
736 .insert(task_id.0.clone(), progress);
737 }
738
739 pub(crate) fn append_chain_partial_output(&self, task_id: &TaskId, chunk: &str) {
744 if let Some(mut progress) = self.inner.chain_progress.get_mut(&task_id.0)
745 && let Some(ref mut partial) = progress.current_step_partial_output
746 {
747 partial.push_str(chunk);
748 }
749 }
750
751 pub fn set_context(&self, key: impl Into<String>, value: impl Into<String>) {
755 self.inner.context.insert(key.into(), value.into());
756 }
757
758 pub fn get_context(&self, key: &str) -> Option<String> {
760 self.inner.context.get(key).map(|v| v.value().clone())
761 }
762
763 pub fn delete_context(&self, key: &str) -> Option<String> {
765 self.inner.context.remove(key).map(|(_, v)| v)
766 }
767
768 pub fn list_context(&self) -> Vec<(String, String)> {
770 self.inner
771 .context
772 .iter()
773 .map(|r| (r.key().clone(), r.value().clone()))
774 .collect()
775 }
776
777 pub async fn drain(&self) -> Result<DrainSummary> {
782 self.inner.shutdown.store(true, Ordering::SeqCst);
783
784 loop {
786 let running = self
787 .inner
788 .store
789 .list_tasks(&TaskFilter {
790 state: Some(TaskState::Running),
791 ..Default::default()
792 })
793 .await?;
794 if running.is_empty() {
795 break;
796 }
797 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
798 }
799
800 let slots = self.inner.store.list_slots().await?;
802 let mut total_cost = 0u64;
803 let mut total_tasks = 0u64;
804 let slot_ids: Vec<_> = slots.iter().map(|w| w.id.clone()).collect();
805
806 for mut slot in slots {
807 total_cost += slot.cost_microdollars;
808 total_tasks += slot.tasks_completed;
809 slot.state = SlotState::Stopped;
810 self.inner.store.put_slot(slot).await?;
811 }
812
813 if let Some(ref mgr) = self.inner.worktree_manager {
815 mgr.cleanup_all(&slot_ids).await?;
816 }
817
818 for slot_id in &slot_ids {
820 if let Some(slot) = self.inner.store.get_slot(slot_id).await?
821 && let Some(ref path) = slot.mcp_config_path
822 && let Err(e) = std::fs::remove_file(path)
823 {
824 tracing::warn!(
825 slot_id = %slot_id.0,
826 path = %path.display(),
827 error = %e,
828 "failed to clean up slot MCP config"
829 );
830 }
831 }
832
833 Ok(DrainSummary {
834 total_cost_microdollars: total_cost,
835 total_tasks_completed: total_tasks,
836 })
837 }
838
839 pub async fn status(&self) -> Result<PoolStatus> {
841 let slots = self.inner.store.list_slots().await?;
842 let idle = slots.iter().filter(|w| w.state == SlotState::Idle).count();
843 let busy = slots.iter().filter(|w| w.state == SlotState::Busy).count();
844
845 let running_tasks = self
846 .inner
847 .store
848 .list_tasks(&TaskFilter {
849 state: Some(TaskState::Running),
850 ..Default::default()
851 })
852 .await?
853 .len();
854
855 let pending_tasks = self
856 .inner
857 .store
858 .list_tasks(&TaskFilter {
859 state: Some(TaskState::Pending),
860 ..Default::default()
861 })
862 .await?
863 .len();
864
865 Ok(PoolStatus {
866 total_slots: slots.len(),
867 idle_slots: idle,
868 busy_slots: busy,
869 running_tasks,
870 pending_tasks,
871 total_spend_microdollars: self.inner.total_spend.load(Ordering::Relaxed),
872 budget_microdollars: self.inner.config.budget_microdollars,
873 shutdown: self.inner.shutdown.load(Ordering::Relaxed),
874 })
875 }
876
877 pub fn store(&self) -> &S {
879 &self.inner.store
880 }
881
882 pub fn config(&self) -> &PoolConfig {
884 &self.inner.config
885 }
886
887 pub fn start_supervisor(&self) -> Option<crate::supervisor::SupervisorHandle> {
897 if !self.inner.config.supervisor_enabled {
898 return None;
899 }
900 Some(crate::supervisor::spawn_supervisor(
901 self.clone(),
902 self.inner.config.supervisor_interval_secs,
903 ))
904 }
905
906 pub async fn scale_up(&self, count: usize) -> Result<usize> {
911 if count == 0 {
912 return Ok(self.inner.store.list_slots().await?.len());
913 }
914
915 let current_slots = self.inner.store.list_slots().await?;
916 let current_count = current_slots.len();
917 let new_count = current_count + count;
918
919 if new_count > self.inner.config.scaling.max_slots {
920 return Err(Error::Store(format!(
921 "cannot scale up to {} slots: exceeds max_slots ({})",
922 new_count, self.inner.config.scaling.max_slots
923 )));
924 }
925
926 let existing_ids: Vec<usize> = current_slots
928 .iter()
929 .filter_map(|w| w.id.0.strip_prefix("slot-").and_then(|s| s.parse().ok()))
930 .collect();
931 let mut next_id = existing_ids.iter().max().unwrap_or(&0) + 1;
932
933 for _ in 0..count {
935 let slot_id = SlotId(format!("slot-{next_id}"));
936 next_id += 1;
937
938 let worktree_path = if self.inner.config.worktree_isolation {
940 if let Some(ref mgr) = self.inner.worktree_manager {
941 let path = mgr.create(&slot_id).await?;
942 Some(path.to_string_lossy().into_owned())
943 } else {
944 None
945 }
946 } else {
947 None
948 };
949
950 let record = SlotRecord {
951 id: slot_id,
952 state: SlotState::Idle,
953 config: SlotConfig::default(),
954 current_task: None,
955 session_id: None,
956 tasks_completed: 0,
957 cost_microdollars: 0,
958 restart_count: 0,
959 worktree_path,
960 mcp_config_path: None,
961 };
962 self.inner.store.put_slot(record).await?;
963 }
964
965 Ok(new_count)
966 }
967
968 pub async fn scale_down(&self, count: usize) -> Result<usize> {
975 if count == 0 {
976 return Ok(self.inner.store.list_slots().await?.len());
977 }
978
979 let mut slots = self.inner.store.list_slots().await?;
980 let current_count = slots.len();
981 let new_count = current_count.saturating_sub(count);
982
983 if new_count < self.inner.config.scaling.min_slots {
984 return Err(Error::Store(format!(
985 "cannot scale down to {} slots: below min_slots ({})",
986 new_count, self.inner.config.scaling.min_slots
987 )));
988 }
989
990 slots.sort_by_key(|w| std::cmp::Reverse(w.tasks_completed));
992
993 let slots_to_remove = &slots[..count];
994 let timeout = std::time::Duration::from_secs(30);
995
996 for slot in slots_to_remove {
997 let deadline = std::time::Instant::now() + timeout;
999 loop {
1000 if let Some(w) = self.inner.store.get_slot(&slot.id).await? {
1001 if w.state != SlotState::Busy {
1002 break;
1003 }
1004 if std::time::Instant::now() >= deadline {
1005 break;
1007 }
1008 } else {
1009 break;
1010 }
1011 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1012 }
1013
1014 if let Some(ref mgr) = self.inner.worktree_manager
1016 && slot.worktree_path.is_some()
1017 {
1018 let _ = mgr.cleanup_all(std::slice::from_ref(&slot.id)).await;
1019 }
1020
1021 self.inner.store.delete_slot(&slot.id).await?;
1023 }
1024
1025 Ok(new_count)
1026 }
1027
1028 pub async fn set_target_slots(&self, target: usize) -> Result<usize> {
1030 let current = self.inner.store.list_slots().await?.len();
1031 if target > current {
1032 self.scale_up(target - current).await
1033 } else if target < current {
1034 self.scale_down(current - target).await
1035 } else {
1036 Ok(current)
1037 }
1038 }
1039
1040 fn check_shutdown(&self) -> Result<()> {
1043 if self.inner.shutdown.load(Ordering::SeqCst) {
1044 Err(Error::PoolShutdown)
1045 } else {
1046 Ok(())
1047 }
1048 }
1049
1050 fn check_budget(&self) -> Result<()> {
1051 if let Some(limit) = self.inner.config.budget_microdollars {
1052 let spent = self.inner.total_spend.load(Ordering::Relaxed);
1053 if spent >= limit {
1054 return Err(Error::BudgetExhausted {
1055 spent_microdollars: spent,
1056 limit_microdollars: limit,
1057 });
1058 }
1059 }
1060 Ok(())
1061 }
1062
1063 async fn wait_for_idle_slot_with_timeout(&self, timeout_secs: u64) -> Result<SlotRecord> {
1065 use std::time::{Duration, Instant};
1066
1067 let deadline = Instant::now() + Duration::from_secs(timeout_secs);
1068 let mut backoff_ms = 10u64;
1069 const MAX_BACKOFF_MS: u64 = 500;
1070
1071 loop {
1072 self.check_shutdown()?;
1073
1074 let slots = self.inner.store.list_slots().await?;
1075 for slot in slots {
1076 if slot.state == SlotState::Idle {
1077 return Ok(slot);
1078 }
1079 }
1080
1081 if Instant::now() >= deadline {
1082 return Err(Error::NoSlotAvailable { timeout_secs });
1083 }
1084
1085 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
1086 backoff_ms = std::cmp::min((backoff_ms as f64 * 1.5) as u64, MAX_BACKOFF_MS);
1087 }
1088 }
1089
1090 async fn assign_slot(&self, task_id: &TaskId) -> Result<(SlotId, SlotConfig)> {
1092 let _lock = self.inner.assignment_lock.lock().await;
1093
1094 let timeout = self.inner.config.slot_assignment_timeout_secs;
1095 let mut slot = self.wait_for_idle_slot_with_timeout(timeout).await?;
1096 let config = slot.config.clone();
1097
1098 slot.state = SlotState::Busy;
1099 slot.current_task = Some(task_id.clone());
1100 self.inner.store.put_slot(slot.clone()).await?;
1101
1102 if let Some(mut task) = self.inner.store.get_task(task_id).await? {
1104 task.state = TaskState::Running;
1105 task.slot_id = Some(slot.id.clone());
1106 self.inner.store.put_task(task).await?;
1107 }
1108
1109 Ok((slot.id, config))
1110 }
1111
1112 async fn release_slot(
1114 &self,
1115 slot_id: &SlotId,
1116 _task_id: &TaskId,
1117 result: &std::result::Result<TaskResult, Error>,
1118 ) -> Result<()> {
1119 if let Some(mut slot) = self.inner.store.get_slot(slot_id).await? {
1120 slot.state = SlotState::Idle;
1121 slot.current_task = None;
1122
1123 if let Ok(task_result) = result {
1124 slot.tasks_completed += 1;
1125 slot.cost_microdollars += task_result.cost_microdollars;
1126 slot.session_id = task_result.session_id.clone();
1127
1128 self.inner
1130 .total_spend
1131 .fetch_add(task_result.cost_microdollars, Ordering::Relaxed);
1132 }
1133
1134 self.inner.store.put_slot(slot).await?;
1135 }
1136 Ok(())
1137 }
1138
1139 async fn ensure_slot_mcp_config(
1144 &self,
1145 slot_id: &SlotId,
1146 servers: &std::collections::HashMap<String, serde_json::Value>,
1147 ) -> Result<std::path::PathBuf> {
1148 if let Some(slot) = self.inner.store.get_slot(slot_id).await?
1150 && let Some(ref path) = slot.mcp_config_path
1151 {
1152 let json = serde_json::to_string_pretty(&serde_json::json!({
1154 "mcpServers": servers
1155 }))?;
1156 std::fs::write(path, json)?;
1157 return Ok(path.clone());
1158 }
1159
1160 use std::io::Write as _;
1162 let json = serde_json::to_string_pretty(&serde_json::json!({
1163 "mcpServers": servers
1164 }))?;
1165 let mut file = tempfile::Builder::new()
1166 .prefix(&format!("claude-pool-{}-", slot_id.0))
1167 .suffix(".mcp.json")
1168 .tempfile()?;
1169 file.write_all(json.as_bytes())?;
1170
1171 let path = file
1173 .into_temp_path()
1174 .keep()
1175 .map_err(std::io::Error::other)?
1176 .to_path_buf();
1177
1178 if let Some(mut slot) = self.inner.store.get_slot(slot_id).await? {
1179 slot.mcp_config_path = Some(path.clone());
1180 self.inner.store.put_slot(slot).await?;
1181 }
1182
1183 tracing::debug!(
1184 slot_id = %slot_id.0,
1185 path = %path.display(),
1186 servers = servers.len(),
1187 "created slot MCP config"
1188 );
1189
1190 Ok(path)
1191 }
1192
1193 async fn execute_task(
1195 &self,
1196 _task_id: &TaskId,
1197 prompt: &str,
1198 slot_id: &SlotId,
1199 slot_config: &SlotConfig,
1200 override_working_dir: Option<&std::path::Path>,
1201 ) -> Result<TaskResult> {
1202 let task_record = self.inner.store.get_task(_task_id).await?;
1203 let task_cfg = task_record.as_ref().and_then(|t| t.config.as_ref());
1204
1205 let resolved = ResolvedConfig::resolve(&self.inner.config, slot_config, task_cfg);
1206
1207 let system_prompt = self.build_system_prompt(&resolved, slot_config);
1209
1210 let mut cmd = claude_wrapper::QueryCommand::new(prompt)
1212 .output_format(OutputFormat::Json)
1213 .permission_mode(resolved.permission_mode);
1214
1215 if resolved.permission_mode == PermissionMode::BypassPermissions {
1216 cmd = cmd.dangerously_skip_permissions();
1217 }
1218
1219 if let Some(ref model) = resolved.model {
1220 cmd = cmd.model(model);
1221 }
1222 if let Some(max_turns) = resolved.max_turns {
1223 cmd = cmd.max_turns(max_turns);
1224 }
1225 if let Some(ref sp) = system_prompt {
1226 cmd = cmd.system_prompt(sp);
1227 }
1228 if let Some(effort) = resolved.effort {
1229 cmd = cmd.effort(effort);
1230 }
1231 if !resolved.allowed_tools.is_empty() {
1232 cmd = cmd.allowed_tools(&resolved.allowed_tools);
1233 }
1234
1235 if !resolved.mcp_servers.is_empty() {
1239 let mcp_path = self
1240 .ensure_slot_mcp_config(slot_id, &resolved.mcp_servers)
1241 .await?;
1242 cmd = cmd.mcp_config(mcp_path.to_string_lossy());
1243 if resolved.strict_mcp_config {
1244 cmd = cmd.strict_mcp_config();
1245 }
1246 }
1247
1248 let claude_instance = if let Some(slot) = self.inner.store.get_slot(slot_id).await? {
1250 if let Some(ref session_id) = slot.session_id {
1252 cmd = cmd.resume(session_id);
1253 }
1254
1255 if let Some(dir) = override_working_dir {
1256 self.inner.claude.with_working_dir(dir)
1257 } else if let Some(ref wt_path) = slot.worktree_path {
1258 self.inner.claude.with_working_dir(wt_path)
1259 } else {
1260 self.inner.claude.clone()
1261 }
1262 } else {
1263 self.inner.claude.clone()
1264 };
1265
1266 tracing::debug!(
1267 slot_id = %slot_id.0,
1268 model = ?resolved.model,
1269 effort = ?resolved.effort,
1270 mcp_servers = resolved.mcp_servers.len(),
1271 "executing task"
1272 );
1273
1274 let query_result = match cmd.execute_json(&claude_instance).await {
1275 Ok(r) => r,
1276 Err(e) if self.inner.config.detect_permission_prompts => {
1277 if let Some(detected) = detect_permission_prompt(&e, &slot_id.0) {
1278 return Err(detected);
1279 }
1280 return Err(e.into());
1281 }
1282 Err(e) => return Err(e.into()),
1283 };
1284
1285 let cost_microdollars = query_result
1286 .cost_usd
1287 .map(|c| (c * 1_000_000.0) as u64)
1288 .unwrap_or(0);
1289
1290 Ok(TaskResult {
1291 output: query_result.result,
1292 success: !query_result.is_error,
1293 cost_microdollars,
1294 turns_used: query_result.num_turns.unwrap_or(0),
1295 session_id: Some(query_result.session_id),
1296 })
1297 }
1298
1299 async fn execute_task_streaming(
1305 &self,
1306 task_id: &TaskId,
1307 prompt: &str,
1308 slot_id: &SlotId,
1309 slot_config: &SlotConfig,
1310 on_output: Option<crate::chain::OnOutputChunk>,
1311 override_working_dir: Option<&std::path::Path>,
1312 ) -> Result<TaskResult> {
1313 let on_output = match on_output {
1315 Some(cb) => cb,
1316 None => {
1317 return self
1318 .execute_task(task_id, prompt, slot_id, slot_config, override_working_dir)
1319 .await;
1320 }
1321 };
1322
1323 let task_record = self.inner.store.get_task(task_id).await?;
1324 let task_cfg = task_record.as_ref().and_then(|t| t.config.as_ref());
1325 let resolved = ResolvedConfig::resolve(&self.inner.config, slot_config, task_cfg);
1326
1327 let system_prompt = self.build_system_prompt(&resolved, slot_config);
1328
1329 let mut cmd = claude_wrapper::QueryCommand::new(prompt)
1331 .output_format(OutputFormat::StreamJson)
1332 .permission_mode(resolved.permission_mode);
1333
1334 if resolved.permission_mode == PermissionMode::BypassPermissions {
1335 cmd = cmd.dangerously_skip_permissions();
1336 }
1337 if let Some(ref model) = resolved.model {
1338 cmd = cmd.model(model);
1339 }
1340 if let Some(max_turns) = resolved.max_turns {
1341 cmd = cmd.max_turns(max_turns);
1342 }
1343 if let Some(ref sp) = system_prompt {
1344 cmd = cmd.system_prompt(sp);
1345 }
1346 if let Some(effort) = resolved.effort {
1347 cmd = cmd.effort(effort);
1348 }
1349 if !resolved.allowed_tools.is_empty() {
1350 cmd = cmd.allowed_tools(&resolved.allowed_tools);
1351 }
1352
1353 if !resolved.mcp_servers.is_empty() {
1354 let mcp_path = self
1355 .ensure_slot_mcp_config(slot_id, &resolved.mcp_servers)
1356 .await?;
1357 cmd = cmd.mcp_config(mcp_path.to_string_lossy());
1358 if resolved.strict_mcp_config {
1359 cmd = cmd.strict_mcp_config();
1360 }
1361 }
1362
1363 let claude_instance = if let Some(slot) = self.inner.store.get_slot(slot_id).await? {
1365 if let Some(ref session_id) = slot.session_id {
1366 cmd = cmd.resume(session_id);
1367 }
1368 if let Some(dir) = override_working_dir {
1369 self.inner.claude.with_working_dir(dir)
1370 } else if let Some(ref wt_path) = slot.worktree_path {
1371 self.inner.claude.with_working_dir(wt_path)
1372 } else {
1373 self.inner.claude.clone()
1374 }
1375 } else {
1376 self.inner.claude.clone()
1377 };
1378
1379 tracing::debug!(
1380 slot_id = %slot_id.0,
1381 model = ?resolved.model,
1382 effort = ?resolved.effort,
1383 mcp_servers = resolved.mcp_servers.len(),
1384 "executing task (streaming)"
1385 );
1386
1387 let mut result_text = String::new();
1389 let mut session_id = String::new();
1390 let mut cost_usd: Option<f64> = None;
1391 let mut is_error = false;
1392
1393 let stream_result = claude_wrapper::streaming::stream_query(
1394 &claude_instance,
1395 &cmd,
1396 |event: claude_wrapper::streaming::StreamEvent| {
1397 match event.event_type() {
1398 Some("result") => {
1399 if let Some(text) = event.result_text() {
1400 result_text = text.to_string();
1401 }
1402 if let Some(sid) = event.session_id() {
1403 session_id = sid.to_string();
1404 }
1405 cost_usd = event.cost_usd();
1406 is_error = event
1407 .data
1408 .get("is_error")
1409 .and_then(|v| v.as_bool())
1410 .unwrap_or(false);
1411 }
1412 Some("assistant") => {
1413 let content_sources = [
1416 event.data.get("content"),
1417 event.data.get("message").and_then(|m| m.get("content")),
1418 ];
1419 for content in content_sources.into_iter().flatten() {
1420 for block in content.as_array().into_iter().flatten() {
1421 if block.get("type").and_then(|t| t.as_str()) == Some("text")
1422 && let Some(text) = block.get("text").and_then(|t| t.as_str())
1423 {
1424 on_output(text);
1425 }
1426 }
1427 }
1428 }
1429 Some("content_block_delta") => {
1430 if let Some(delta) = event.data.get("delta")
1432 && delta.get("type").and_then(|t| t.as_str()) == Some("text_delta")
1433 && let Some(text) = delta.get("text").and_then(|t| t.as_str())
1434 {
1435 on_output(text);
1436 }
1437 }
1438 _ => {}
1439 }
1440 },
1441 )
1442 .await;
1443
1444 match stream_result {
1445 Ok(_) => {}
1446 Err(e) if self.inner.config.detect_permission_prompts => {
1447 if let Some(detected) = detect_permission_prompt(&e, &slot_id.0) {
1448 return Err(detected);
1449 }
1450 return Err(e.into());
1451 }
1452 Err(e) => return Err(e.into()),
1453 }
1454
1455 let cost_microdollars = cost_usd.map(|c| (c * 1_000_000.0) as u64).unwrap_or(0);
1456
1457 Ok(TaskResult {
1458 output: result_text,
1459 success: !is_error,
1460 cost_microdollars,
1461 turns_used: 0,
1462 session_id: Some(session_id),
1463 })
1464 }
1465
1466 fn build_system_prompt(
1468 &self,
1469 resolved: &ResolvedConfig,
1470 slot_config: &SlotConfig,
1471 ) -> Option<String> {
1472 let context_entries: Vec<_> = self.list_context();
1473
1474 let has_identity = slot_config.name.is_some()
1476 || slot_config.role.is_some()
1477 || slot_config.description.is_some();
1478
1479 if resolved.system_prompt.is_none() && context_entries.is_empty() && !has_identity {
1480 return None;
1481 }
1482
1483 let mut parts = Vec::new();
1484
1485 if has_identity {
1487 let mut identity = String::new();
1488 identity.push_str("You are ");
1489
1490 if let Some(ref name) = slot_config.name {
1491 identity.push_str(name);
1492 } else {
1493 identity.push_str("a slot");
1494 }
1495
1496 if let Some(ref role) = slot_config.role {
1497 identity.push_str(", a ");
1498 identity.push_str(role);
1499 }
1500
1501 if let Some(ref description) = slot_config.description {
1502 identity.push_str(". ");
1503 identity.push_str(description);
1504 } else if slot_config.role.is_some() {
1505 identity.push('.');
1506 }
1507
1508 parts.push(identity);
1509 }
1510
1511 if let Some(ref sp) = resolved.system_prompt {
1512 parts.push(sp.clone());
1513 }
1514
1515 if !context_entries.is_empty() {
1516 parts.push("\n\n## Shared Context\n".to_string());
1517 for (key, value) in &context_entries {
1518 parts.push(format!("- **{key}**: {value}"));
1519 }
1520 }
1521
1522 Some(parts.join("\n"))
1523 }
1524}
1525
1526#[derive(Debug, Clone, Serialize, Deserialize)]
1528pub struct DrainSummary {
1529 pub total_cost_microdollars: u64,
1531 pub total_tasks_completed: u64,
1533}
1534
1535#[derive(Debug, Clone, Serialize, Deserialize)]
1537pub struct PoolStatus {
1538 pub total_slots: usize,
1540 pub idle_slots: usize,
1542 pub busy_slots: usize,
1544 pub running_tasks: usize,
1546 pub pending_tasks: usize,
1548 pub total_spend_microdollars: u64,
1550 pub budget_microdollars: Option<u64>,
1552 pub shutdown: bool,
1554}
1555
1556use serde::{Deserialize, Serialize};
1557
1558fn new_id() -> String {
1560 use std::time::{SystemTime, UNIX_EPOCH};
1561 let nanos = SystemTime::now()
1562 .duration_since(UNIX_EPOCH)
1563 .unwrap_or_default()
1564 .as_nanos();
1565 format!("{nanos:x}")
1566}
1567
1568const PERMISSION_PATTERNS: &[&str] = &[
1572 "Allow",
1573 "allow this action",
1574 "approve",
1575 "permission",
1576 "Do you want to",
1577 "tool requires approval",
1578 "wants to use",
1579 "Press Enter",
1580 "y/n",
1581 "Y/n",
1582 "(yes/no)",
1583];
1584
1585fn detect_permission_prompt(err: &claude_wrapper::Error, slot_id: &str) -> Option<Error> {
1590 let stderr = match err {
1591 claude_wrapper::Error::CommandFailed { stderr, .. } => stderr,
1592 _ => return None,
1593 };
1594
1595 if stderr.is_empty() {
1596 return None;
1597 }
1598
1599 for pattern in PERMISSION_PATTERNS {
1600 if stderr.contains(pattern) {
1601 let tool_name = extract_tool_name(stderr);
1602 tracing::warn!(
1603 slot_id,
1604 tool = %tool_name,
1605 "permission prompt detected in slot stderr"
1606 );
1607 return Some(Error::PermissionPromptDetected {
1608 tool_name,
1609 stderr: stderr.clone(),
1610 slot_id: slot_id.to_string(),
1611 });
1612 }
1613 }
1614
1615 None
1616}
1617
1618fn extract_tool_name(stderr: &str) -> String {
1620 for line in stderr.lines() {
1621 let trimmed = line.trim();
1622 if let Some(rest) = trimmed.strip_prefix("Allow ")
1623 && let Some(tool) = rest.split_whitespace().next()
1624 {
1625 return tool.trim_end_matches('?').to_string();
1626 }
1627 if let Some(idx) = trimmed.find("wants to use ") {
1628 let after = &trimmed[idx + "wants to use ".len()..];
1629 if let Some(tool) = after.split_whitespace().next() {
1630 return tool.trim_end_matches(['.', '?', ',']).to_string();
1631 }
1632 }
1633 }
1634 "unknown".to_string()
1635}
1636
1637#[cfg(test)]
1638mod tests {
1639 use super::*;
1640
1641 fn mock_claude() -> Claude {
1642 Claude::builder().binary("/usr/bin/false").build().unwrap()
1645 }
1646
1647 #[tokio::test]
1648 async fn build_pool_registers_slots() {
1649 let pool = Pool::builder(mock_claude()).slots(3).build().await.unwrap();
1650
1651 let slots = pool.store().list_slots().await.unwrap();
1652 assert_eq!(slots.len(), 3);
1653
1654 for slot in &slots {
1655 assert_eq!(slot.state, SlotState::Idle);
1656 }
1657 }
1658
1659 #[tokio::test]
1660 async fn pool_with_slot_configs() {
1661 let pool = Pool::builder(mock_claude())
1662 .slots(2)
1663 .slot_config(SlotConfig {
1664 model: Some("opus".into()),
1665 role: Some("reviewer".into()),
1666 ..Default::default()
1667 })
1668 .build()
1669 .await
1670 .unwrap();
1671
1672 let slots = pool.store().list_slots().await.unwrap();
1673 let w0 = slots.iter().find(|w| w.id.0 == "slot-0").unwrap();
1674 let w1 = slots.iter().find(|w| w.id.0 == "slot-1").unwrap();
1675 assert_eq!(w0.config.model.as_deref(), Some("opus"));
1676 assert_eq!(w0.config.role.as_deref(), Some("reviewer"));
1677 assert!(w1.config.model.is_none());
1679 }
1680
1681 #[tokio::test]
1682 async fn context_operations() {
1683 let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
1684
1685 pool.set_context("repo", "claude-wrapper");
1686 pool.set_context("branch", "main");
1687
1688 assert_eq!(pool.get_context("repo").as_deref(), Some("claude-wrapper"));
1689 assert_eq!(pool.list_context().len(), 2);
1690
1691 pool.delete_context("branch");
1692 assert!(pool.get_context("branch").is_none());
1693 }
1694
1695 #[tokio::test]
1696 async fn drain_marks_slots_stopped() {
1697 let pool = Pool::builder(mock_claude()).slots(2).build().await.unwrap();
1698
1699 let summary = pool.drain().await.unwrap();
1700 assert_eq!(summary.total_tasks_completed, 0);
1701
1702 let slots = pool.store().list_slots().await.unwrap();
1703 for w in &slots {
1704 assert_eq!(w.state, SlotState::Stopped);
1705 }
1706
1707 assert!(pool.run("hello").await.is_err());
1709 }
1710
1711 #[tokio::test]
1712 async fn budget_enforcement() {
1713 let pool = Pool::builder(mock_claude())
1714 .slots(1)
1715 .config(PoolConfig {
1716 budget_microdollars: Some(100),
1717 ..Default::default()
1718 })
1719 .build()
1720 .await
1721 .unwrap();
1722
1723 pool.inner.total_spend.store(100, Ordering::Relaxed);
1725
1726 let err = pool.run("hello").await.unwrap_err();
1727 assert!(matches!(err, Error::BudgetExhausted { .. }));
1728 }
1729
1730 #[tokio::test]
1731 async fn status_snapshot() {
1732 let pool = Pool::builder(mock_claude())
1733 .slots(3)
1734 .config(PoolConfig {
1735 budget_microdollars: Some(1_000_000),
1736 ..Default::default()
1737 })
1738 .build()
1739 .await
1740 .unwrap();
1741
1742 let status = pool.status().await.unwrap();
1743 assert_eq!(status.total_slots, 3);
1744 assert_eq!(status.idle_slots, 3);
1745 assert_eq!(status.busy_slots, 0);
1746 assert_eq!(status.budget_microdollars, Some(1_000_000));
1747 assert!(!status.shutdown);
1748 }
1749
1750 #[tokio::test]
1751 async fn no_idle_slots_timeout() {
1752 let pool = Pool::builder(mock_claude())
1753 .slots(1)
1754 .config(PoolConfig {
1755 slot_assignment_timeout_secs: 1,
1756 ..Default::default()
1757 })
1758 .build()
1759 .await
1760 .unwrap();
1761
1762 let mut slots = pool.store().list_slots().await.unwrap();
1764 slots[0].state = SlotState::Busy;
1765 pool.store().put_slot(slots[0].clone()).await.unwrap();
1766
1767 let err = pool.run("hello").await.unwrap_err();
1768 assert!(matches!(err, Error::NoSlotAvailable { timeout_secs: 1 }));
1769 }
1770
1771 #[tokio::test]
1772 async fn fan_out_with_excess_prompts() {
1773 let pool = Pool::builder(mock_claude()).slots(2).build().await.unwrap();
1778
1779 let prompts = vec!["prompt1", "prompt2", "prompt3", "prompt4"];
1780
1781 let results = pool.fan_out(&prompts).await;
1786
1787 match results {
1790 Ok(_) | Err(_) => {
1791 }
1794 }
1795 }
1796
1797 #[tokio::test]
1798 async fn slot_identity_fields_persisted() {
1799 let pool = Pool::builder(mock_claude())
1800 .slots(1)
1801 .slot_config(SlotConfig {
1802 name: Some("reviewer".into()),
1803 role: Some("code_review".into()),
1804 description: Some("Reviews PRs for correctness and style".into()),
1805 ..Default::default()
1806 })
1807 .build()
1808 .await
1809 .unwrap();
1810
1811 let slots = pool.store().list_slots().await.unwrap();
1812 let slot = slots.iter().find(|w| w.id.0 == "slot-0").unwrap();
1813
1814 assert_eq!(slot.config.name.as_deref(), Some("reviewer"));
1815 assert_eq!(slot.config.role.as_deref(), Some("code_review"));
1816 assert_eq!(
1817 slot.config.description.as_deref(),
1818 Some("Reviews PRs for correctness and style")
1819 );
1820 }
1821
1822 #[tokio::test]
1823 async fn scale_up_increases_slot_count() {
1824 let pool = Pool::builder(mock_claude()).slots(2).build().await.unwrap();
1825
1826 let initial_count = pool.store().list_slots().await.unwrap().len();
1827 assert_eq!(initial_count, 2);
1828
1829 let new_count = pool.scale_up(3).await.unwrap();
1830 assert_eq!(new_count, 5);
1831
1832 let slots = pool.store().list_slots().await.unwrap();
1833 assert_eq!(slots.len(), 5);
1834
1835 for slot in slots.iter().skip(2) {
1837 assert_eq!(slot.state, SlotState::Idle);
1838 }
1839 }
1840
1841 #[tokio::test]
1842 async fn scale_up_respects_max_slots() {
1843 let mut config = PoolConfig::default();
1844 config.scaling.max_slots = 4;
1845
1846 let pool = Pool::builder(mock_claude())
1847 .slots(2)
1848 .config(config)
1849 .build()
1850 .await
1851 .unwrap();
1852
1853 let result = pool.scale_up(5).await;
1855 assert!(result.is_err());
1856 assert!(
1857 result
1858 .unwrap_err()
1859 .to_string()
1860 .contains("exceeds max_slots")
1861 );
1862
1863 assert_eq!(pool.store().list_slots().await.unwrap().len(), 2);
1865 }
1866
1867 #[tokio::test]
1868 async fn scale_down_reduces_slot_count() {
1869 let pool = Pool::builder(mock_claude()).slots(4).build().await.unwrap();
1870
1871 let initial = pool.store().list_slots().await.unwrap().len();
1872 assert_eq!(initial, 4);
1873
1874 let new_count = pool.scale_down(2).await.unwrap();
1875 assert_eq!(new_count, 2);
1876
1877 assert_eq!(pool.store().list_slots().await.unwrap().len(), 2);
1878 }
1879
1880 #[tokio::test]
1881 async fn scale_down_respects_min_slots() {
1882 let mut config = PoolConfig::default();
1883 config.scaling.min_slots = 2;
1884
1885 let pool = Pool::builder(mock_claude())
1886 .slots(3)
1887 .config(config)
1888 .build()
1889 .await
1890 .unwrap();
1891
1892 let result = pool.scale_down(2).await;
1894 assert!(result.is_err());
1895 assert!(result.unwrap_err().to_string().contains("below min_slots"));
1896
1897 assert_eq!(pool.store().list_slots().await.unwrap().len(), 3);
1899 }
1900
1901 #[tokio::test]
1902 async fn set_target_slots_scales_up() {
1903 let pool = Pool::builder(mock_claude()).slots(2).build().await.unwrap();
1904
1905 let new_count = pool.set_target_slots(5).await.unwrap();
1906 assert_eq!(new_count, 5);
1907 assert_eq!(pool.store().list_slots().await.unwrap().len(), 5);
1908 }
1909
1910 #[tokio::test]
1911 async fn set_target_slots_scales_down() {
1912 let pool = Pool::builder(mock_claude()).slots(5).build().await.unwrap();
1913
1914 let new_count = pool.set_target_slots(2).await.unwrap();
1915 assert_eq!(new_count, 2);
1916 assert_eq!(pool.store().list_slots().await.unwrap().len(), 2);
1917 }
1918
1919 #[tokio::test]
1920 async fn set_target_slots_no_op_when_equal() {
1921 let pool = Pool::builder(mock_claude()).slots(3).build().await.unwrap();
1922
1923 let new_count = pool.set_target_slots(3).await.unwrap();
1924 assert_eq!(new_count, 3);
1925 }
1926
1927 #[tokio::test]
1928 async fn fan_out_chains_submits_all_chains() {
1929 let pool = Pool::builder(mock_claude()).slots(2).build().await.unwrap();
1930
1931 let skills = crate::skill::SkillRegistry::new();
1932 let options = crate::chain::ChainOptions {
1933 tags: vec![],
1934 ..Default::default()
1935 };
1936
1937 let chain1 = vec![crate::chain::ChainStep {
1939 name: "step1".into(),
1940 action: crate::chain::StepAction::Prompt {
1941 prompt: "prompt 1".into(),
1942 },
1943 config: None,
1944 failure_policy: crate::chain::StepFailurePolicy {
1945 retries: 0,
1946 recovery_prompt: None,
1947 },
1948 output_vars: Default::default(),
1949 }];
1950
1951 let chain2 = vec![crate::chain::ChainStep {
1952 name: "step1".into(),
1953 action: crate::chain::StepAction::Prompt {
1954 prompt: "prompt 2".into(),
1955 },
1956 config: None,
1957 failure_policy: crate::chain::StepFailurePolicy {
1958 retries: 0,
1959 recovery_prompt: None,
1960 },
1961 output_vars: Default::default(),
1962 }];
1963
1964 let chains = vec![chain1, chain2];
1965
1966 let task_ids = pool.fan_out_chains(chains, &skills, options).await.unwrap();
1968
1969 assert_eq!(task_ids.len(), 2);
1971
1972 assert_ne!(task_ids[0].0, task_ids[1].0);
1974
1975 for task_id in &task_ids {
1977 let task = pool.store().get_task(task_id).await.unwrap();
1978 assert!(task.is_some());
1979 }
1980 }
1981
1982 #[test]
1985 fn detect_allow_bash_in_stderr() {
1986 let err = claude_wrapper::Error::CommandFailed {
1987 command: "claude --print".into(),
1988 exit_code: 1,
1989 stdout: String::new(),
1990 stderr: "Allow Bash tool? (y/n)".into(),
1991 working_dir: None,
1992 };
1993 let result = detect_permission_prompt(&err, "slot-1");
1994 assert!(result.is_some());
1995 let err = result.unwrap();
1996 match err {
1997 Error::PermissionPromptDetected {
1998 tool_name, slot_id, ..
1999 } => {
2000 assert_eq!(tool_name, "Bash");
2001 assert_eq!(slot_id, "slot-1");
2002 }
2003 other => panic!("expected PermissionPromptDetected, got: {other}"),
2004 }
2005 }
2006
2007 #[test]
2008 fn detect_wants_to_use_pattern() {
2009 let err = claude_wrapper::Error::CommandFailed {
2010 command: "claude --print".into(),
2011 exit_code: 1,
2012 stdout: String::new(),
2013 stderr: "Claude wants to use Edit tool.".into(),
2014 working_dir: None,
2015 };
2016 let result = detect_permission_prompt(&err, "slot-2");
2017 assert!(result.is_some());
2018 match result.unwrap() {
2019 Error::PermissionPromptDetected { tool_name, .. } => {
2020 assert_eq!(tool_name, "Edit");
2021 }
2022 other => panic!("expected PermissionPromptDetected, got: {other}"),
2023 }
2024 }
2025
2026 #[test]
2027 fn no_detection_on_clean_stderr() {
2028 let err = claude_wrapper::Error::CommandFailed {
2029 command: "claude --print".into(),
2030 exit_code: 1,
2031 stdout: String::new(),
2032 stderr: "some unrelated error output".into(),
2033 working_dir: None,
2034 };
2035 assert!(detect_permission_prompt(&err, "slot-1").is_none());
2036 }
2037
2038 #[test]
2039 fn no_detection_on_empty_stderr() {
2040 let err = claude_wrapper::Error::CommandFailed {
2041 command: "claude --print".into(),
2042 exit_code: 1,
2043 stdout: String::new(),
2044 stderr: String::new(),
2045 working_dir: None,
2046 };
2047 assert!(detect_permission_prompt(&err, "slot-1").is_none());
2048 }
2049
2050 #[test]
2051 fn no_detection_on_timeout() {
2052 let err = claude_wrapper::Error::Timeout {
2053 timeout_seconds: 30,
2054 };
2055 assert!(detect_permission_prompt(&err, "slot-1").is_none());
2056 }
2057
2058 #[test]
2059 fn extract_tool_name_unknown_fallback() {
2060 assert_eq!(extract_tool_name("some random text"), "unknown");
2061 }
2062
2063 #[test]
2064 fn extract_tool_name_allow_prefix() {
2065 assert_eq!(extract_tool_name("Allow Write tool?"), "Write");
2066 }
2067
2068 #[test]
2069 fn extract_tool_name_wants_to_use() {
2070 assert_eq!(
2071 extract_tool_name("Claude wants to use Bash, proceed?"),
2072 "Bash"
2073 );
2074 }
2075
2076 #[tokio::test]
2079 async fn cancel_chain_marks_task_cancelled() {
2080 let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
2081
2082 let task_id = TaskId("chain-test-1".into());
2084 let record = TaskRecord {
2085 id: task_id.clone(),
2086 prompt: "chain: 3 steps".into(),
2087 state: TaskState::Running,
2088 slot_id: None,
2089 result: None,
2090 tags: vec![],
2091 config: None,
2092 };
2093 pool.store().put_task(record).await.unwrap();
2094
2095 pool.set_chain_progress(
2097 &task_id,
2098 crate::chain::ChainProgress {
2099 total_steps: 3,
2100 current_step: Some(1),
2101 current_step_name: Some("implement".into()),
2102 current_step_partial_output: None,
2103 current_step_started_at: None,
2104 completed_steps: vec![],
2105 status: crate::chain::ChainStatus::Running,
2106 },
2107 )
2108 .await;
2109
2110 pool.cancel_chain(&task_id).await.unwrap();
2112
2113 let task = pool.store().get_task(&task_id).await.unwrap().unwrap();
2115 assert_eq!(task.state, TaskState::Cancelled);
2116
2117 let progress = pool.chain_progress(&task_id).unwrap();
2119 assert_eq!(progress.status, crate::chain::ChainStatus::Cancelled);
2120 }
2121
2122 #[tokio::test]
2123 async fn cancel_chain_noop_for_completed() {
2124 let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
2125
2126 let task_id = TaskId("chain-done".into());
2127 let record = TaskRecord {
2128 id: task_id.clone(),
2129 prompt: "chain: 1 steps".into(),
2130 state: TaskState::Completed,
2131 slot_id: None,
2132 result: Some(TaskResult {
2133 output: "done".into(),
2134 success: true,
2135 cost_microdollars: 100,
2136 turns_used: 0,
2137 session_id: None,
2138 }),
2139 tags: vec![],
2140 config: None,
2141 };
2142 pool.store().put_task(record).await.unwrap();
2143
2144 pool.cancel_chain(&task_id).await.unwrap();
2146 let task = pool.store().get_task(&task_id).await.unwrap().unwrap();
2147 assert_eq!(task.state, TaskState::Completed);
2148 }
2149
2150 #[tokio::test]
2151 async fn cancel_chain_not_found() {
2152 let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
2153 let result = pool.cancel_chain(&TaskId("nonexistent".into())).await;
2154 assert!(matches!(result, Err(Error::TaskNotFound(_))));
2155 }
2156
2157 #[tokio::test]
2160 async fn append_chain_partial_output_accumulates() {
2161 let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
2162
2163 let task_id = TaskId("chain-test".into());
2164 let progress = crate::chain::ChainProgress {
2165 total_steps: 2,
2166 current_step: Some(0),
2167 current_step_name: Some("plan".into()),
2168 current_step_partial_output: Some(String::new()),
2169 current_step_started_at: Some(1700000000),
2170 completed_steps: vec![],
2171 status: crate::chain::ChainStatus::Running,
2172 };
2173 pool.set_chain_progress(&task_id, progress).await;
2174
2175 pool.append_chain_partial_output(&task_id, "hello ");
2176 pool.append_chain_partial_output(&task_id, "world");
2177
2178 let progress = pool.chain_progress(&task_id).unwrap();
2179 assert_eq!(
2180 progress.current_step_partial_output.as_deref(),
2181 Some("hello world")
2182 );
2183 }
2184
2185 #[tokio::test]
2186 async fn append_chain_partial_output_noop_when_none() {
2187 let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
2188
2189 let task_id = TaskId("chain-test-2".into());
2190 let progress = crate::chain::ChainProgress {
2192 total_steps: 1,
2193 current_step: None,
2194 current_step_name: None,
2195 current_step_partial_output: None,
2196 current_step_started_at: None,
2197 completed_steps: vec![],
2198 status: crate::chain::ChainStatus::Completed,
2199 };
2200 pool.set_chain_progress(&task_id, progress).await;
2201
2202 pool.append_chain_partial_output(&task_id, "ignored");
2204
2205 let progress = pool.chain_progress(&task_id).unwrap();
2206 assert!(progress.current_step_partial_output.is_none());
2207 }
2208
2209 #[tokio::test]
2210 async fn append_chain_partial_output_noop_for_missing_task() {
2211 let pool = Pool::builder(mock_claude()).slots(1).build().await.unwrap();
2212
2213 let task_id = TaskId("nonexistent".into());
2215 pool.append_chain_partial_output(&task_id, "ignored");
2216 }
2217}