1use crate::extensions::types::ToolFn;
9use futures::stream::{self, StreamExt};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::future::Future;
13use std::path::{Path, PathBuf};
14use std::process::Stdio;
15use tokio::io::{AsyncBufReadExt, BufReader};
16use tokio::process::Command;
17use tokio::sync::mpsc;
18
19use crate::agents::AgentDef;
20use crate::commands::spawn::terminal::{find_harness_binary, Harness};
21
22pub struct ExtensionRunner {
28 registered_tools: HashMap<String, ToolFn>,
29}
30
31impl ExtensionRunner {
32 pub fn new() -> Self {
34 ExtensionRunner {
35 registered_tools: HashMap::new(),
36 }
37 }
38
39 pub fn register_tool(&mut self, name: String, tool_fn: ToolFn) {
41 self.registered_tools.insert(name, tool_fn);
42 }
43
44 pub fn execute_tool(&self, name: &str, args: &[Value]) -> Result<Value, ExtensionRunnerError> {
46 let tool_fn = self
47 .registered_tools
48 .get(name)
49 .ok_or_else(|| ExtensionRunnerError::ToolNotFound(name.to_string()))?;
50
51 tool_fn(args).map_err(ExtensionRunnerError::ExecutionError)
52 }
53
54 pub fn has_tool(&self, name: &str) -> bool {
56 self.registered_tools.contains_key(name)
57 }
58
59 pub fn list_tools(&self) -> Vec<String> {
61 self.registered_tools.keys().cloned().collect()
62 }
63
64 pub fn on_tool_call(
77 &self,
78 tool_name: &str,
79 arguments: Value,
80 ) -> Result<ToolCallResult, ExtensionRunnerError> {
81 let args = match arguments {
83 Value::Array(arr) => arr,
84 Value::Object(_) => vec![arguments],
85 Value::Null => vec![],
86 other => vec![other],
87 };
88
89 let result = self.execute_tool(tool_name, &args)?;
90
91 Ok(ToolCallResult {
92 tool_name: tool_name.to_string(),
93 output: result,
94 success: true,
95 })
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct ToolCallResult {
102 pub tool_name: String,
104 pub output: Value,
106 pub success: bool,
108}
109
110impl Default for ExtensionRunner {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116#[derive(Debug, thiserror::Error)]
118pub enum ExtensionRunnerError {
119 #[error("Tool not found: {0}")]
120 ToolNotFound(String),
121
122 #[error("Tool execution error: {0}")]
123 ExecutionError(Box<dyn std::error::Error + Send + Sync>),
124}
125
126#[derive(Debug, Clone)]
132pub struct AgentResult {
133 pub task_id: String,
135 pub success: bool,
137 pub exit_code: Option<i32>,
139 pub output: String,
141 pub duration_ms: u64,
143}
144
145#[derive(Debug, Clone)]
147pub enum AgentEvent {
148 Started { task_id: String },
150 Output { task_id: String, line: String },
152 Completed { result: AgentResult },
154 SpawnFailed { task_id: String, error: String },
156}
157
158#[derive(Debug, Clone)]
160pub struct SpawnConfig {
161 pub task_id: String,
163 pub prompt: String,
165 pub working_dir: PathBuf,
167 pub harness: Harness,
169 pub model: Option<String>,
171}
172
173pub async fn spawn_agent(
177 config: SpawnConfig,
178 event_tx: mpsc::Sender<AgentEvent>,
179) -> Result<tokio::task::JoinHandle<AgentResult>, anyhow::Error> {
180 let binary_path = find_harness_binary(config.harness)?;
181 let task_id = config.task_id.clone();
182
183 let mut cmd = match config.harness {
185 Harness::Claude => {
186 let mut c = Command::new(binary_path);
187 c.arg(&config.prompt);
188 c.arg("--dangerously-skip-permissions");
189 if let Some(ref model) = config.model {
190 c.arg("--model").arg(model);
191 }
192 c
193 }
194 Harness::OpenCode => {
195 let mut c = Command::new(binary_path);
196 c.arg("run");
197 c.arg("--variant").arg("minimal");
198 if let Some(ref model) = config.model {
199 c.arg("--model").arg(model);
200 }
201 c.arg(&config.prompt);
202 c
203 }
204 Harness::Cursor => {
205 let mut c = Command::new(binary_path);
206 c.arg("-p");
207 if let Some(ref model) = config.model {
208 c.arg("--model").arg(model);
209 }
210 c.arg(&config.prompt);
211 c
212 }
213 #[cfg(feature = "direct-api")]
214 Harness::DirectApi => {
215 let mut c = Command::new(binary_path);
216 c.arg("agent-exec");
217 c.arg("--prompt").arg(&config.prompt);
218 if let Some(ref model) = config.model {
219 c.arg("--model").arg(model);
220 }
221 c
222 }
223 };
224
225 cmd.current_dir(&config.working_dir);
227 cmd.env("SCUD_TASK_ID", &config.task_id);
228 cmd.stdout(Stdio::piped());
229 cmd.stderr(Stdio::piped());
230
231 let start_time = std::time::Instant::now();
232
233 let mut child = cmd.spawn().map_err(|e| {
235 anyhow::anyhow!(
236 "Failed to spawn {} for task {}: {}",
237 config.harness.name(),
238 config.task_id,
239 e
240 )
241 })?;
242
243 let _ = event_tx
245 .send(AgentEvent::Started {
246 task_id: task_id.clone(),
247 })
248 .await;
249
250 let stdout = child.stdout.take();
252 let stderr = child.stderr.take();
253 let event_tx_clone = event_tx.clone();
254 let task_id_clone = task_id.clone();
255
256 let handle = tokio::spawn(async move {
257 let mut output_buffer = String::new();
258
259 if let Some(stdout) = stdout {
261 let reader = BufReader::new(stdout);
262 let mut lines = reader.lines();
263 while let Ok(Some(line)) = lines.next_line().await {
264 output_buffer.push_str(&line);
265 output_buffer.push('\n');
266 let _ = event_tx_clone
267 .send(AgentEvent::Output {
268 task_id: task_id_clone.clone(),
269 line: line.clone(),
270 })
271 .await;
272 }
273 }
274
275 if let Some(stderr) = stderr {
277 let reader = BufReader::new(stderr);
278 let mut lines = reader.lines();
279 while let Ok(Some(line)) = lines.next_line().await {
280 output_buffer.push_str("[stderr] ");
281 output_buffer.push_str(&line);
282 output_buffer.push('\n');
283 }
284 }
285
286 let status = child.wait().await;
288 let duration_ms = start_time.elapsed().as_millis() as u64;
289
290 let (success, exit_code) = match status {
291 Ok(s) => (s.success(), s.code()),
292 Err(_) => (false, None),
293 };
294
295 let result = AgentResult {
296 task_id: task_id_clone.clone(),
297 success,
298 exit_code,
299 output: output_buffer,
300 duration_ms,
301 };
302
303 let _ = event_tx_clone
304 .send(AgentEvent::Completed {
305 result: result.clone(),
306 })
307 .await;
308
309 result
310 });
311
312 Ok(handle)
313}
314
315pub fn load_agent_config(
317 agent_type: Option<&str>,
318 default_harness: Harness,
319 default_model: Option<&str>,
320 working_dir: &Path,
321) -> (Harness, Option<String>) {
322 if let Some(agent_name) = agent_type {
323 if let Some(agent_def) = AgentDef::try_load(agent_name, working_dir) {
324 let harness = agent_def.harness().unwrap_or(default_harness);
325 let model = agent_def
326 .model()
327 .map(String::from)
328 .or_else(|| default_model.map(String::from));
329 return (harness, model);
330 }
331 }
332
333 (default_harness, default_model.map(String::from))
335}
336
337pub struct AgentRunner {
339 event_tx: mpsc::Sender<AgentEvent>,
341 event_rx: mpsc::Receiver<AgentEvent>,
343 handles: Vec<tokio::task::JoinHandle<AgentResult>>,
345}
346
347impl AgentRunner {
348 pub fn new(capacity: usize) -> Self {
350 let (event_tx, event_rx) = mpsc::channel(capacity);
351 Self {
352 event_tx,
353 event_rx,
354 handles: Vec::new(),
355 }
356 }
357
358 pub fn event_sender(&self) -> mpsc::Sender<AgentEvent> {
360 self.event_tx.clone()
361 }
362
363 pub async fn spawn(&mut self, config: SpawnConfig) -> anyhow::Result<()> {
365 let handle = spawn_agent(config, self.event_tx.clone()).await?;
366 self.handles.push(handle);
367 Ok(())
368 }
369
370 pub async fn recv_event(&mut self) -> Option<AgentEvent> {
372 self.event_rx.recv().await
373 }
374
375 pub fn try_recv_event(&mut self) -> Option<AgentEvent> {
377 self.event_rx.try_recv().ok()
378 }
379
380 pub async fn wait_all(&mut self) -> Vec<AgentResult> {
382 let handles = std::mem::take(&mut self.handles);
383 let mut results = Vec::new();
384
385 for handle in handles {
386 if let Ok(result) = handle.await {
387 results.push(result);
388 }
389 }
390
391 results
392 }
393
394 pub fn active_count(&self) -> usize {
396 self.handles.iter().filter(|h| !h.is_finished()).count()
397 }
398}
399
400pub async fn map_with_concurrency_limit<T, F, Fut, R>(
432 items: impl IntoIterator<Item = T>,
433 concurrency: usize,
434 f: F,
435) -> Vec<R>
436where
437 F: Fn(T) -> Fut,
438 Fut: Future<Output = R>,
439{
440 stream::iter(items)
441 .map(f)
442 .buffer_unordered(concurrency)
443 .collect()
444 .await
445}
446
447pub async fn map_with_concurrency_limit_ordered<T, F, Fut, R>(
460 items: impl IntoIterator<Item = T>,
461 concurrency: usize,
462 f: F,
463) -> Vec<R>
464where
465 F: Fn(T) -> Fut,
466 Fut: Future<Output = R>,
467{
468 stream::iter(items)
469 .map(f)
470 .buffered(concurrency)
471 .collect()
472 .await
473}
474
475pub async fn spawn_agents_with_limit(
488 configs: impl IntoIterator<Item = SpawnConfig>,
489 concurrency: usize,
490 event_tx: mpsc::Sender<AgentEvent>,
491) -> Vec<Result<AgentResult, anyhow::Error>> {
492 let configs: Vec<_> = configs.into_iter().collect();
493
494 map_with_concurrency_limit(configs, concurrency, |config| {
495 let tx = event_tx.clone();
496 async move {
497 match spawn_agent(config, tx).await {
498 Ok(handle) => handle.await.map_err(|e| anyhow::anyhow!("Join error: {}", e)),
499 Err(e) => Err(e),
500 }
501 }
502 })
503 .await
504}
505
506#[derive(Debug, Clone)]
508pub struct ConcurrentSpawnConfig {
509 pub max_concurrent: usize,
511 pub timeout_ms: u64,
513 pub fail_fast: bool,
515}
516
517impl Default for ConcurrentSpawnConfig {
518 fn default() -> Self {
519 Self {
520 max_concurrent: 5,
521 timeout_ms: 0,
522 fail_fast: false,
523 }
524 }
525}
526
527#[derive(Debug)]
529pub struct ConcurrentSpawnResult {
530 pub successes: Vec<AgentResult>,
532 pub failures: Vec<(String, String)>,
534 pub all_succeeded: bool,
536}
537
538pub async fn spawn_agents_concurrent(
551 configs: Vec<SpawnConfig>,
552 spawn_config: ConcurrentSpawnConfig,
553 event_tx: mpsc::Sender<AgentEvent>,
554) -> ConcurrentSpawnResult {
555 let mut successes = Vec::new();
556 let mut failures = Vec::new();
557
558 let results = if spawn_config.timeout_ms > 0 {
559 let timeout_duration = std::time::Duration::from_millis(spawn_config.timeout_ms);
561
562 map_with_concurrency_limit(configs, spawn_config.max_concurrent, |config| {
563 let tx = event_tx.clone();
564 let task_id = config.task_id.clone();
565 async move {
566 let result = tokio::time::timeout(timeout_duration, async {
567 match spawn_agent(config, tx).await {
568 Ok(handle) => handle
569 .await
570 .map_err(|e| anyhow::anyhow!("Join error: {}", e)),
571 Err(e) => Err(e),
572 }
573 })
574 .await;
575
576 match result {
577 Ok(Ok(agent_result)) => Ok(agent_result),
578 Ok(Err(e)) => Err((task_id, e.to_string())),
579 Err(_) => Err((task_id, "Timeout".to_string())),
580 }
581 }
582 })
583 .await
584 } else {
585 map_with_concurrency_limit(configs, spawn_config.max_concurrent, |config| {
587 let tx = event_tx.clone();
588 let task_id = config.task_id.clone();
589 async move {
590 match spawn_agent(config, tx).await {
591 Ok(handle) => handle
592 .await
593 .map_err(|e| (task_id, format!("Join error: {}", e))),
594 Err(e) => Err((task_id, e.to_string())),
595 }
596 }
597 })
598 .await
599 };
600
601 for result in results {
602 match result {
603 Ok(agent_result) => successes.push(agent_result),
604 Err((task_id, error)) => failures.push((task_id, error)),
605 }
606 }
607
608 let all_succeeded = failures.is_empty();
609
610 ConcurrentSpawnResult {
611 successes,
612 failures,
613 all_succeeded,
614 }
615}
616
617pub async fn spawn_subagent(
633 task_id: String,
634 prompt: String,
635 working_dir: PathBuf,
636 harness: Harness,
637 model: Option<String>,
638) -> Result<AgentResult, anyhow::Error> {
639 let (tx, _rx) = mpsc::channel(10);
641
642 let config = SpawnConfig {
643 task_id,
644 prompt,
645 working_dir,
646 harness,
647 model,
648 };
649
650 let handle = spawn_agent(config, tx).await?;
651 handle
652 .await
653 .map_err(|e| anyhow::anyhow!("Subagent join error: {}", e))
654}
655
656#[cfg(test)]
657mod tests {
658 use super::*;
659
660 #[test]
661 fn test_extension_runner_new() {
662 let runner = ExtensionRunner::new();
663 assert!(runner.list_tools().is_empty());
664 }
665
666 #[test]
667 fn test_agent_result_debug() {
668 let result = AgentResult {
669 task_id: "test:1".to_string(),
670 success: true,
671 exit_code: Some(0),
672 output: "test output".to_string(),
673 duration_ms: 1000,
674 };
675
676 assert!(result.success);
677 assert_eq!(result.exit_code, Some(0));
678 assert_eq!(result.task_id, "test:1");
679 }
680
681 #[test]
682 fn test_spawn_config_debug() {
683 let config = SpawnConfig {
684 task_id: "test:1".to_string(),
685 prompt: "do something".to_string(),
686 working_dir: PathBuf::from("/tmp"),
687 harness: Harness::Claude,
688 model: Some("opus".to_string()),
689 };
690
691 assert_eq!(config.task_id, "test:1");
692 assert_eq!(config.harness, Harness::Claude);
693 }
694
695 #[tokio::test]
696 async fn test_agent_runner_new() {
697 let runner = AgentRunner::new(100);
698 assert_eq!(runner.active_count(), 0);
699 }
700
701 #[test]
702 fn test_tool_call_result() {
703 let result = ToolCallResult {
704 tool_name: "my_tool".to_string(),
705 output: serde_json::json!({"key": "value"}),
706 success: true,
707 };
708
709 assert_eq!(result.tool_name, "my_tool");
710 assert!(result.success);
711 assert_eq!(result.output["key"], "value");
712 }
713
714 #[test]
715 fn test_on_tool_call_not_found() {
716 let runner = ExtensionRunner::new();
717 let result = runner.on_tool_call("nonexistent", serde_json::json!({}));
718
719 assert!(result.is_err());
720 match result {
721 Err(ExtensionRunnerError::ToolNotFound(name)) => {
722 assert_eq!(name, "nonexistent");
723 }
724 _ => panic!("Expected ToolNotFound error"),
725 }
726 }
727
728 #[test]
729 fn test_on_tool_call_with_registered_tool() {
730 let mut runner = ExtensionRunner::new();
731
732 fn echo_tool(
734 args: &[Value],
735 ) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
736 Ok(args.first().cloned().unwrap_or(Value::Null))
737 }
738
739 runner.register_tool("echo".to_string(), echo_tool);
740
741 let result = runner
743 .on_tool_call("echo", serde_json::json!({"test": 123}))
744 .unwrap();
745
746 assert_eq!(result.tool_name, "echo");
747 assert!(result.success);
748 assert_eq!(result.output["test"], 123);
749 }
750
751 #[test]
752 fn test_on_tool_call_argument_conversion() {
753 let mut runner = ExtensionRunner::new();
754
755 fn count_args(
757 args: &[Value],
758 ) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
759 Ok(serde_json::json!(args.len()))
760 }
761
762 runner.register_tool("count".to_string(), count_args);
763
764 let result = runner
766 .on_tool_call("count", serde_json::json!([1, 2, 3]))
767 .unwrap();
768 assert_eq!(result.output, 3);
769
770 let result = runner
772 .on_tool_call("count", serde_json::json!({"a": 1}))
773 .unwrap();
774 assert_eq!(result.output, 1);
775
776 let result = runner.on_tool_call("count", Value::Null).unwrap();
778 assert_eq!(result.output, 0);
779
780 let result = runner.on_tool_call("count", serde_json::json!(42)).unwrap();
782 assert_eq!(result.output, 1);
783 }
784
785 #[tokio::test]
786 async fn test_map_with_concurrency_limit() {
787 use std::sync::atomic::{AtomicUsize, Ordering};
788 use std::sync::Arc;
789
790 let items: Vec<i32> = (0..10).collect();
791 let counter = Arc::new(AtomicUsize::new(0));
792 let max_concurrent = Arc::new(AtomicUsize::new(0));
793
794 let results = map_with_concurrency_limit(items, 3, |n| {
795 let counter = Arc::clone(&counter);
796 let max_concurrent = Arc::clone(&max_concurrent);
797 async move {
798 let current = counter.fetch_add(1, Ordering::SeqCst) + 1;
800
801 let mut max = max_concurrent.load(Ordering::SeqCst);
803 while current > max {
804 match max_concurrent.compare_exchange_weak(
805 max,
806 current,
807 Ordering::SeqCst,
808 Ordering::SeqCst,
809 ) {
810 Ok(_) => break,
811 Err(new_max) => max = new_max,
812 }
813 }
814
815 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
817
818 counter.fetch_sub(1, Ordering::SeqCst);
820
821 n * 2
822 }
823 })
824 .await;
825
826 assert_eq!(results.len(), 10);
828
829 let mut sorted: Vec<i32> = results;
831 sorted.sort();
832 assert_eq!(sorted, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
833
834 assert!(max_concurrent.load(Ordering::SeqCst) <= 3);
836 }
837
838 #[tokio::test]
839 async fn test_map_with_concurrency_limit_ordered() {
840 let items: Vec<i32> = vec![1, 2, 3, 4, 5];
841
842 let results = map_with_concurrency_limit_ordered(items, 2, |n| async move {
843 tokio::time::sleep(std::time::Duration::from_millis((5 - n) as u64 * 5)).await;
845 n * 10
846 })
847 .await;
848
849 assert_eq!(results, vec![10, 20, 30, 40, 50]);
851 }
852
853 #[test]
854 fn test_concurrent_spawn_config_default() {
855 let config = ConcurrentSpawnConfig::default();
856
857 assert_eq!(config.max_concurrent, 5);
858 assert_eq!(config.timeout_ms, 0);
859 assert!(!config.fail_fast);
860 }
861
862 #[test]
863 fn test_concurrent_spawn_result() {
864 let result = ConcurrentSpawnResult {
865 successes: vec![AgentResult {
866 task_id: "1".to_string(),
867 success: true,
868 exit_code: Some(0),
869 output: "done".to_string(),
870 duration_ms: 100,
871 }],
872 failures: vec![],
873 all_succeeded: true,
874 };
875
876 assert!(result.all_succeeded);
877 assert_eq!(result.successes.len(), 1);
878 assert!(result.failures.is_empty());
879 }
880
881 #[test]
882 fn test_concurrent_spawn_result_with_failures() {
883 let result = ConcurrentSpawnResult {
884 successes: vec![],
885 failures: vec![("task1".to_string(), "error msg".to_string())],
886 all_succeeded: false,
887 };
888
889 assert!(!result.all_succeeded);
890 assert!(result.successes.is_empty());
891 assert_eq!(result.failures.len(), 1);
892 assert_eq!(result.failures[0].0, "task1");
893 }
894}