1use crate::extensions::types::ToolFn;
9use futures::stream::{self, StreamExt};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::path::{Path, PathBuf};
14use std::process::Stdio;
15use tokio::io::{AsyncBufReadExt, BufReader};
16use tokio::process::Command;
17use tokio::sync::mpsc;
18
19use crate::agents::AgentDef;
20use crate::commands::spawn::terminal::{find_harness_binary, Harness};
21
22pub struct ExtensionRunner {
28 registered_tools: HashMap<String, ToolFn>,
29}
30
31impl ExtensionRunner {
32 pub fn new() -> Self {
34 ExtensionRunner {
35 registered_tools: HashMap::new(),
36 }
37 }
38
39 pub fn register_tool(&mut self, name: String, tool_fn: ToolFn) {
41 self.registered_tools.insert(name, tool_fn);
42 }
43
44 pub fn execute_tool(&self, name: &str, args: &[Value]) -> Result<Value, ExtensionRunnerError> {
46 let tool_fn = self
47 .registered_tools
48 .get(name)
49 .ok_or_else(|| ExtensionRunnerError::ToolNotFound(name.to_string()))?;
50
51 tool_fn(args).map_err(ExtensionRunnerError::ExecutionError)
52 }
53
54 pub fn has_tool(&self, name: &str) -> bool {
56 self.registered_tools.contains_key(name)
57 }
58
59 pub fn list_tools(&self) -> Vec<String> {
61 self.registered_tools.keys().cloned().collect()
62 }
63
64 pub fn on_tool_call(
77 &self,
78 tool_name: &str,
79 arguments: Value,
80 ) -> Result<ToolCallResult, ExtensionRunnerError> {
81 let args = match arguments {
83 Value::Array(arr) => arr,
84 Value::Object(_) => vec![arguments],
85 Value::Null => vec![],
86 other => vec![other],
87 };
88
89 let result = self.execute_tool(tool_name, &args)?;
90
91 Ok(ToolCallResult {
92 tool_name: tool_name.to_string(),
93 output: result,
94 success: true,
95 })
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct ToolCallResult {
102 pub tool_name: String,
104 pub output: Value,
106 pub success: bool,
108}
109
110impl Default for ExtensionRunner {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116#[derive(Debug, thiserror::Error)]
118pub enum ExtensionRunnerError {
119 #[error("Tool not found: {0}")]
120 ToolNotFound(String),
121
122 #[error("Tool execution error: {0}")]
123 ExecutionError(Box<dyn std::error::Error + Send + Sync>),
124}
125
126#[derive(Debug, Clone)]
132pub struct AgentResult {
133 pub task_id: String,
135 pub success: bool,
137 pub exit_code: Option<i32>,
139 pub output: String,
141 pub duration_ms: u64,
143}
144
145#[derive(Debug, Clone)]
147pub enum AgentEvent {
148 Started { task_id: String },
150 Output { task_id: String, line: String },
152 Completed { result: AgentResult },
154 SpawnFailed { task_id: String, error: String },
156}
157
158#[derive(Debug, Clone)]
160pub struct SpawnConfig {
161 pub task_id: String,
163 pub prompt: String,
165 pub working_dir: PathBuf,
167 pub harness: Harness,
169 pub model: Option<String>,
171}
172
173pub async fn spawn_agent(
177 config: SpawnConfig,
178 event_tx: mpsc::Sender<AgentEvent>,
179) -> Result<tokio::task::JoinHandle<AgentResult>, anyhow::Error> {
180 let binary_path = find_harness_binary(config.harness)?;
181 let task_id = config.task_id.clone();
182
183 let mut cmd = match config.harness {
185 Harness::Claude => {
186 let mut c = Command::new(binary_path);
187 c.arg(&config.prompt);
188 c.arg("--dangerously-skip-permissions");
189 if let Some(ref model) = config.model {
190 c.arg("--model").arg(model);
191 }
192 c
193 }
194 Harness::OpenCode => {
195 let mut c = Command::new(binary_path);
196 c.arg("run");
197 c.arg("--variant").arg("minimal");
198 if let Some(ref model) = config.model {
199 c.arg("--model").arg(model);
200 }
201 c.arg(&config.prompt);
202 c
203 }
204 Harness::Cursor => {
205 let mut c = Command::new(binary_path);
206 c.arg("-p");
207 if let Some(ref model) = config.model {
208 c.arg("--model").arg(model);
209 }
210 c.arg(&config.prompt);
211 c
212 }
213 Harness::Rho => {
214 let mut c = Command::new(binary_path);
215 if let Some(ref model) = config.model {
216 c.arg("--model").arg(model);
217 }
218 c.arg(&config.prompt);
219 c
220 }
221 #[cfg(feature = "direct-api")]
222 Harness::DirectApi => {
223 let mut c = Command::new(binary_path);
224 c.arg("agent-exec");
225 c.arg("--prompt").arg(&config.prompt);
226 if let Some(ref model) = config.model {
227 c.arg("--model").arg(model);
228 }
229 c
230 }
231 };
232
233 cmd.current_dir(&config.working_dir);
235 cmd.env("SCUD_TASK_ID", &config.task_id);
236 cmd.stdout(Stdio::piped());
237 cmd.stderr(Stdio::piped());
238
239 let start_time = std::time::Instant::now();
240
241 let mut child = cmd.spawn().map_err(|e| {
243 anyhow::anyhow!(
244 "Failed to spawn {} for task {}: {}",
245 config.harness.name(),
246 config.task_id,
247 e
248 )
249 })?;
250
251 let _ = event_tx
253 .send(AgentEvent::Started {
254 task_id: task_id.clone(),
255 })
256 .await;
257
258 let stdout = child.stdout.take();
260 let stderr = child.stderr.take();
261 let event_tx_clone = event_tx.clone();
262 let task_id_clone = task_id.clone();
263
264 let handle = tokio::spawn(async move {
265 let mut output_buffer = String::new();
266
267 if let Some(stdout) = stdout {
269 let reader = BufReader::new(stdout);
270 let mut lines = reader.lines();
271 while let Ok(Some(line)) = lines.next_line().await {
272 output_buffer.push_str(&line);
273 output_buffer.push('\n');
274 let _ = event_tx_clone
275 .send(AgentEvent::Output {
276 task_id: task_id_clone.clone(),
277 line: line.clone(),
278 })
279 .await;
280 }
281 }
282
283 if let Some(stderr) = stderr {
285 let reader = BufReader::new(stderr);
286 let mut lines = reader.lines();
287 while let Ok(Some(line)) = lines.next_line().await {
288 output_buffer.push_str("[stderr] ");
289 output_buffer.push_str(&line);
290 output_buffer.push('\n');
291 }
292 }
293
294 let status = child.wait().await;
296 let duration_ms = start_time.elapsed().as_millis() as u64;
297
298 let (success, exit_code) = match status {
299 Ok(s) => (s.success(), s.code()),
300 Err(_) => (false, None),
301 };
302
303 let result = AgentResult {
304 task_id: task_id_clone.clone(),
305 success,
306 exit_code,
307 output: output_buffer,
308 duration_ms,
309 };
310
311 let _ = event_tx_clone
312 .send(AgentEvent::Completed {
313 result: result.clone(),
314 })
315 .await;
316
317 result
318 });
319
320 Ok(handle)
321}
322
323pub fn load_agent_config(
325 agent_type: Option<&str>,
326 default_harness: Harness,
327 default_model: Option<&str>,
328 working_dir: &Path,
329) -> (Harness, Option<String>) {
330 if let Some(agent_name) = agent_type {
331 if let Some(agent_def) = AgentDef::try_load(agent_name, working_dir) {
332 let harness = agent_def.harness().unwrap_or(default_harness);
333 let model = agent_def
334 .model()
335 .map(String::from)
336 .or_else(|| default_model.map(String::from));
337 return (harness, model);
338 }
339 }
340
341 (default_harness, default_model.map(String::from))
343}
344
345pub struct AgentRunner {
347 event_tx: mpsc::Sender<AgentEvent>,
349 event_rx: mpsc::Receiver<AgentEvent>,
351 handles: Vec<tokio::task::JoinHandle<AgentResult>>,
353}
354
355impl AgentRunner {
356 pub fn new(capacity: usize) -> Self {
358 let (event_tx, event_rx) = mpsc::channel(capacity);
359 Self {
360 event_tx,
361 event_rx,
362 handles: Vec::new(),
363 }
364 }
365
366 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
368 self.event_tx.clone()
369 }
370
371 pub async fn spawn(&mut self, config: SpawnConfig) -> anyhow::Result<()> {
373 let handle = spawn_agent(config, self.event_tx.clone()).await?;
374 self.handles.push(handle);
375 Ok(())
376 }
377
378 pub async fn recv_event(&mut self) -> Option<AgentEvent> {
380 self.event_rx.recv().await
381 }
382
383 pub fn try_recv_event(&mut self) -> Option<AgentEvent> {
385 self.event_rx.try_recv().ok()
386 }
387
388 pub async fn wait_all(&mut self) -> Vec<AgentResult> {
390 let handles = std::mem::take(&mut self.handles);
391 let mut results = Vec::new();
392
393 for handle in handles {
394 if let Ok(result) = handle.await {
395 results.push(result);
396 }
397 }
398
399 results
400 }
401
402 pub fn active_count(&self) -> usize {
404 self.handles.iter().filter(|h| !h.is_finished()).count()
405 }
406}
407
408pub async fn map_with_concurrency_limit<T, F, Fut, R>(
440 items: impl IntoIterator<Item = T>,
441 concurrency: usize,
442 f: F,
443) -> Vec<R>
444where
445 F: Fn(T) -> Fut,
446 Fut: Future<Output = R>,
447{
448 stream::iter(items)
449 .map(f)
450 .buffer_unordered(concurrency)
451 .collect()
452 .await
453}
454
455pub async fn map_with_concurrency_limit_ordered<T, F, Fut, R>(
468 items: impl IntoIterator<Item = T>,
469 concurrency: usize,
470 f: F,
471) -> Vec<R>
472where
473 F: Fn(T) -> Fut,
474 Fut: Future<Output = R>,
475{
476 stream::iter(items)
477 .map(f)
478 .buffered(concurrency)
479 .collect()
480 .await
481}
482
483pub async fn spawn_agents_with_limit(
496 configs: impl IntoIterator<Item = SpawnConfig>,
497 concurrency: usize,
498 event_tx: mpsc::Sender<AgentEvent>,
499) -> Vec<Result<AgentResult, anyhow::Error>> {
500 let configs: Vec<_> = configs.into_iter().collect();
501
502 map_with_concurrency_limit(configs, concurrency, |config| {
503 let tx = event_tx.clone();
504 async move {
505 match spawn_agent(config, tx).await {
506 Ok(handle) => handle.await.map_err(|e| anyhow::anyhow!("Join error: {}", e)),
507 Err(e) => Err(e),
508 }
509 }
510 })
511 .await
512}
513
514#[derive(Debug, Clone)]
516pub struct ConcurrentSpawnConfig {
517 pub max_concurrent: usize,
519 pub timeout_ms: u64,
521 pub fail_fast: bool,
523}
524
525impl Default for ConcurrentSpawnConfig {
526 fn default() -> Self {
527 Self {
528 max_concurrent: 5,
529 timeout_ms: 0,
530 fail_fast: false,
531 }
532 }
533}
534
535#[derive(Debug)]
537pub struct ConcurrentSpawnResult {
538 pub successes: Vec<AgentResult>,
540 pub failures: Vec<(String, String)>,
542 pub all_succeeded: bool,
544}
545
546pub async fn spawn_agents_concurrent(
559 configs: Vec<SpawnConfig>,
560 spawn_config: ConcurrentSpawnConfig,
561 event_tx: mpsc::Sender<AgentEvent>,
562) -> ConcurrentSpawnResult {
563 let mut successes = Vec::new();
564 let mut failures = Vec::new();
565
566 let results = if spawn_config.timeout_ms > 0 {
567 let timeout_duration = std::time::Duration::from_millis(spawn_config.timeout_ms);
569
570 map_with_concurrency_limit(configs, spawn_config.max_concurrent, |config| {
571 let tx = event_tx.clone();
572 let task_id = config.task_id.clone();
573 async move {
574 let result = tokio::time::timeout(timeout_duration, async {
575 match spawn_agent(config, tx).await {
576 Ok(handle) => handle
577 .await
578 .map_err(|e| anyhow::anyhow!("Join error: {}", e)),
579 Err(e) => Err(e),
580 }
581 })
582 .await;
583
584 match result {
585 Ok(Ok(agent_result)) => Ok(agent_result),
586 Ok(Err(e)) => Err((task_id, e.to_string())),
587 Err(_) => Err((task_id, "Timeout".to_string())),
588 }
589 }
590 })
591 .await
592 } else {
593 map_with_concurrency_limit(configs, spawn_config.max_concurrent, |config| {
595 let tx = event_tx.clone();
596 let task_id = config.task_id.clone();
597 async move {
598 match spawn_agent(config, tx).await {
599 Ok(handle) => handle
600 .await
601 .map_err(|e| (task_id, format!("Join error: {}", e))),
602 Err(e) => Err((task_id, e.to_string())),
603 }
604 }
605 })
606 .await
607 };
608
609 for result in results {
610 match result {
611 Ok(agent_result) => successes.push(agent_result),
612 Err((task_id, error)) => failures.push((task_id, error)),
613 }
614 }
615
616 let all_succeeded = failures.is_empty();
617
618 ConcurrentSpawnResult {
619 successes,
620 failures,
621 all_succeeded,
622 }
623}
624
625pub async fn spawn_subagent(
641 task_id: String,
642 prompt: String,
643 working_dir: PathBuf,
644 harness: Harness,
645 model: Option<String>,
646) -> Result<AgentResult, anyhow::Error> {
647 let (tx, _rx) = mpsc::channel(10);
649
650 let config = SpawnConfig {
651 task_id,
652 prompt,
653 working_dir,
654 harness,
655 model,
656 };
657
658 let handle = spawn_agent(config, tx).await?;
659 handle
660 .await
661 .map_err(|e| anyhow::anyhow!("Subagent join error: {}", e))
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667
668 #[test]
669 fn test_extension_runner_new() {
670 let runner = ExtensionRunner::new();
671 assert!(runner.list_tools().is_empty());
672 }
673
674 #[test]
675 fn test_agent_result_debug() {
676 let result = AgentResult {
677 task_id: "test:1".to_string(),
678 success: true,
679 exit_code: Some(0),
680 output: "test output".to_string(),
681 duration_ms: 1000,
682 };
683
684 assert!(result.success);
685 assert_eq!(result.exit_code, Some(0));
686 assert_eq!(result.task_id, "test:1");
687 }
688
689 #[test]
690 fn test_spawn_config_debug() {
691 let config = SpawnConfig {
692 task_id: "test:1".to_string(),
693 prompt: "do something".to_string(),
694 working_dir: PathBuf::from("/tmp"),
695 harness: Harness::Claude,
696 model: Some("opus".to_string()),
697 };
698
699 assert_eq!(config.task_id, "test:1");
700 assert_eq!(config.harness, Harness::Claude);
701 }
702
703 #[tokio::test]
704 async fn test_agent_runner_new() {
705 let runner = AgentRunner::new(100);
706 assert_eq!(runner.active_count(), 0);
707 }
708
709 #[test]
710 fn test_tool_call_result() {
711 let result = ToolCallResult {
712 tool_name: "my_tool".to_string(),
713 output: serde_json::json!({"key": "value"}),
714 success: true,
715 };
716
717 assert_eq!(result.tool_name, "my_tool");
718 assert!(result.success);
719 assert_eq!(result.output["key"], "value");
720 }
721
722 #[test]
723 fn test_on_tool_call_not_found() {
724 let runner = ExtensionRunner::new();
725 let result = runner.on_tool_call("nonexistent", serde_json::json!({}));
726
727 assert!(result.is_err());
728 match result {
729 Err(ExtensionRunnerError::ToolNotFound(name)) => {
730 assert_eq!(name, "nonexistent");
731 }
732 _ => panic!("Expected ToolNotFound error"),
733 }
734 }
735
736 #[test]
737 fn test_on_tool_call_with_registered_tool() {
738 let mut runner = ExtensionRunner::new();
739
740 fn echo_tool(
742 args: &[Value],
743 ) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
744 Ok(args.first().cloned().unwrap_or(Value::Null))
745 }
746
747 runner.register_tool("echo".to_string(), echo_tool);
748
749 let result = runner
751 .on_tool_call("echo", serde_json::json!({"test": 123}))
752 .unwrap();
753
754 assert_eq!(result.tool_name, "echo");
755 assert!(result.success);
756 assert_eq!(result.output["test"], 123);
757 }
758
759 #[test]
760 fn test_on_tool_call_argument_conversion() {
761 let mut runner = ExtensionRunner::new();
762
763 fn count_args(
765 args: &[Value],
766 ) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
767 Ok(serde_json::json!(args.len()))
768 }
769
770 runner.register_tool("count".to_string(), count_args);
771
772 let result = runner
774 .on_tool_call("count", serde_json::json!([1, 2, 3]))
775 .unwrap();
776 assert_eq!(result.output, 3);
777
778 let result = runner
780 .on_tool_call("count", serde_json::json!({"a": 1}))
781 .unwrap();
782 assert_eq!(result.output, 1);
783
784 let result = runner.on_tool_call("count", Value::Null).unwrap();
786 assert_eq!(result.output, 0);
787
788 let result = runner.on_tool_call("count", serde_json::json!(42)).unwrap();
790 assert_eq!(result.output, 1);
791 }
792
793 #[tokio::test]
794 async fn test_map_with_concurrency_limit() {
795 use std::sync::atomic::{AtomicUsize, Ordering};
796 use std::sync::Arc;
797
798 let items: Vec<i32> = (0..10).collect();
799 let counter = Arc::new(AtomicUsize::new(0));
800 let max_concurrent = Arc::new(AtomicUsize::new(0));
801
802 let results = map_with_concurrency_limit(items, 3, |n| {
803 let counter = Arc::clone(&counter);
804 let max_concurrent = Arc::clone(&max_concurrent);
805 async move {
806 let current = counter.fetch_add(1, Ordering::SeqCst) + 1;
808
809 let mut max = max_concurrent.load(Ordering::SeqCst);
811 while current > max {
812 match max_concurrent.compare_exchange_weak(
813 max,
814 current,
815 Ordering::SeqCst,
816 Ordering::SeqCst,
817 ) {
818 Ok(_) => break,
819 Err(new_max) => max = new_max,
820 }
821 }
822
823 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
825
826 counter.fetch_sub(1, Ordering::SeqCst);
828
829 n * 2
830 }
831 })
832 .await;
833
834 assert_eq!(results.len(), 10);
836
837 let mut sorted: Vec<i32> = results;
839 sorted.sort();
840 assert_eq!(sorted, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
841
842 assert!(max_concurrent.load(Ordering::SeqCst) <= 3);
844 }
845
846 #[tokio::test]
847 async fn test_map_with_concurrency_limit_ordered() {
848 let items: Vec<i32> = vec![1, 2, 3, 4, 5];
849
850 let results = map_with_concurrency_limit_ordered(items, 2, |n| async move {
851 tokio::time::sleep(std::time::Duration::from_millis((5 - n) as u64 * 5)).await;
853 n * 10
854 })
855 .await;
856
857 assert_eq!(results, vec![10, 20, 30, 40, 50]);
859 }
860
861 #[test]
862 fn test_concurrent_spawn_config_default() {
863 let config = ConcurrentSpawnConfig::default();
864
865 assert_eq!(config.max_concurrent, 5);
866 assert_eq!(config.timeout_ms, 0);
867 assert!(!config.fail_fast);
868 }
869
870 #[test]
871 fn test_concurrent_spawn_result() {
872 let result = ConcurrentSpawnResult {
873 successes: vec![AgentResult {
874 task_id: "1".to_string(),
875 success: true,
876 exit_code: Some(0),
877 output: "done".to_string(),
878 duration_ms: 100,
879 }],
880 failures: vec![],
881 all_succeeded: true,
882 };
883
884 assert!(result.all_succeeded);
885 assert_eq!(result.successes.len(), 1);
886 assert!(result.failures.is_empty());
887 }
888
889 #[test]
890 fn test_concurrent_spawn_result_with_failures() {
891 let result = ConcurrentSpawnResult {
892 successes: vec![],
893 failures: vec![("task1".to_string(), "error msg".to_string())],
894 all_succeeded: false,
895 };
896
897 assert!(!result.all_succeeded);
898 assert!(result.successes.is_empty());
899 assert_eq!(result.failures.len(), 1);
900 assert_eq!(result.failures[0].0, "task1");
901 }
902}