1use super::AgentChain;
11use super::AgentInvocation;
12use super::ExecutionPlan;
13use super::ExecutionStep;
14use super::InvocationRequest;
15use super::SubagentContext;
16use super::SubagentError;
17use super::SubagentExecution;
18use super::SubagentRegistry;
19use super::SubagentStatus;
20use crate::modes::OperatingMode;
21use serde::Deserialize;
22use serde::Serialize;
23use std::collections::HashMap;
24use std::path::PathBuf;
25use std::sync::Arc;
26use std::sync::atomic::AtomicBool;
27use std::sync::atomic::AtomicU32;
28use std::sync::atomic::AtomicUsize;
29use std::sync::atomic::Ordering;
30use std::time::Duration;
31use std::time::SystemTime;
32use tokio::sync::Mutex;
33use tokio::sync::RwLock;
34use tokio::sync::Semaphore;
35use tokio::sync::mpsc;
36use tokio::time::sleep;
37use tokio::time::timeout;
38use tracing::debug;
39use tracing::info;
40use tracing::warn;
41use uuid::Uuid;
42
43const DEFAULT_MAX_CONCURRENCY: usize = 8;
45
46const DEFAULT_AGENT_TIMEOUT: Duration = Duration::from_secs(300); const MAX_RETRIES: u32 = 3;
51
52const RETRY_BACKOFF: Duration = Duration::from_secs(2);
54
55const CIRCUIT_BREAKER_THRESHOLD: u32 = 5;
57
58const CIRCUIT_BREAKER_RESET: Duration = Duration::from_secs(60);
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct OrchestratorConfig {
64 pub max_concurrency: usize,
66
67 pub agent_timeout: Duration,
69
70 pub enable_retries: bool,
72
73 pub max_retries: u32,
75
76 pub retry_backoff: Duration,
78
79 pub enable_circuit_breaker: bool,
81
82 pub circuit_breaker_threshold: u32,
84
85 pub circuit_breaker_reset: Duration,
87
88 pub monitor_memory: bool,
90
91 pub memory_threshold_mb: usize,
93}
94
95impl Default for OrchestratorConfig {
96 fn default() -> Self {
97 Self {
98 max_concurrency: DEFAULT_MAX_CONCURRENCY,
99 agent_timeout: DEFAULT_AGENT_TIMEOUT,
100 enable_retries: true,
101 max_retries: MAX_RETRIES,
102 retry_backoff: RETRY_BACKOFF,
103 enable_circuit_breaker: true,
104 circuit_breaker_threshold: CIRCUIT_BREAKER_THRESHOLD,
105 circuit_breaker_reset: CIRCUIT_BREAKER_RESET,
106 monitor_memory: true,
107 memory_threshold_mb: 2048, }
109 }
110}
111
112#[derive(Debug, Clone)]
114pub struct SharedContext {
115 data: Arc<RwLock<HashMap<String, serde_json::Value>>>,
117
118 previous_outputs: Arc<RwLock<Vec<String>>>,
120
121 modified_files: Arc<RwLock<Vec<PathBuf>>>,
123
124 errors: Arc<RwLock<Vec<String>>>,
126}
127
128impl Default for SharedContext {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl SharedContext {
135 pub fn new() -> Self {
137 Self {
138 data: Arc::new(RwLock::new(HashMap::new())),
139 previous_outputs: Arc::new(RwLock::new(Vec::new())),
140 modified_files: Arc::new(RwLock::new(Vec::new())),
141 errors: Arc::new(RwLock::new(Vec::new())),
142 }
143 }
144
145 pub async fn set(&self, key: String, value: serde_json::Value) {
147 self.data.write().await.insert(key, value);
148 }
149
150 pub async fn get(&self, key: &str) -> Option<serde_json::Value> {
152 self.data.read().await.get(key).cloned()
153 }
154
155 pub async fn add_output(&self, output: String) {
157 self.previous_outputs.write().await.push(output);
158 }
159
160 pub async fn last_output(&self) -> Option<String> {
162 self.previous_outputs.read().await.last().cloned()
163 }
164
165 pub async fn all_outputs(&self) -> Vec<String> {
167 self.previous_outputs.read().await.clone()
168 }
169
170 pub async fn add_modified_files(&self, files: Vec<PathBuf>) {
172 self.modified_files.write().await.extend(files);
173 }
174
175 pub async fn modified_files(&self) -> Vec<PathBuf> {
177 self.modified_files.read().await.clone()
178 }
179
180 pub async fn add_error(&self, error: String) {
182 self.errors.write().await.push(error);
183 }
184
185 pub async fn errors(&self) -> Vec<String> {
187 self.errors.read().await.clone()
188 }
189
190 pub async fn snapshot(&self) -> ContextSnapshot {
192 ContextSnapshot {
193 data: self.data.read().await.clone(),
194 previous_outputs: self.previous_outputs.read().await.clone(),
195 modified_files: self.modified_files.read().await.clone(),
196 errors: self.errors.read().await.clone(),
197 timestamp: SystemTime::now(),
198 }
199 }
200
201 pub async fn restore(&self, snapshot: ContextSnapshot) {
203 *self.data.write().await = snapshot.data;
204 *self.previous_outputs.write().await = snapshot.previous_outputs;
205 *self.modified_files.write().await = snapshot.modified_files;
206 *self.errors.write().await = snapshot.errors;
207 }
208
209 pub async fn merge(&self, other: &SharedContext) {
211 let other_data = other.data.read().await;
213 for (key, value) in other_data.iter() {
214 self.data.write().await.insert(key.clone(), value.clone());
215 }
216
217 let other_outputs = other.previous_outputs.read().await;
219 self.previous_outputs
220 .write()
221 .await
222 .extend(other_outputs.clone());
223
224 let other_files = other.modified_files.read().await;
226 self.modified_files
227 .write()
228 .await
229 .extend(other_files.clone());
230
231 let other_errors = other.errors.read().await;
233 self.errors.write().await.extend(other_errors.clone());
234 }
235}
236
237#[derive(Debug, Clone)]
239pub struct ContextSnapshot {
240 pub data: HashMap<String, serde_json::Value>,
241 pub previous_outputs: Vec<String>,
242 pub modified_files: Vec<PathBuf>,
243 pub errors: Vec<String>,
244 pub timestamp: SystemTime,
245}
246
247#[derive(Debug, Clone)]
249pub struct ProgressUpdate {
250 pub execution_id: Uuid,
251 pub agent_name: String,
252 pub status: SubagentStatus,
253 pub message: Option<String>,
254 pub progress_percentage: Option<u8>,
255 pub timestamp: SystemTime,
256}
257
258#[derive(Debug)]
260pub struct OrchestratorResult {
261 pub request: InvocationRequest,
263
264 pub executions: Vec<SubagentExecution>,
266
267 pub context: SharedContext,
269
270 pub total_duration: Duration,
272
273 pub success: bool,
275
276 pub partial_success: bool,
278}
279
280#[derive(Debug)]
282struct CircuitBreaker {
283 failure_count: AtomicU32,
284 is_open: AtomicBool,
285 last_failure_time: Arc<Mutex<Option<SystemTime>>>,
286 threshold: u32,
287 reset_duration: Duration,
288}
289
290impl CircuitBreaker {
291 fn new(threshold: u32, reset_duration: Duration) -> Self {
292 Self {
293 failure_count: AtomicU32::new(0),
294 is_open: AtomicBool::new(false),
295 last_failure_time: Arc::new(Mutex::new(None)),
296 threshold,
297 reset_duration,
298 }
299 }
300
301 async fn record_success(&self) {
302 self.failure_count.store(0, Ordering::SeqCst);
303 self.is_open.store(false, Ordering::SeqCst);
304 *self.last_failure_time.lock().await = None;
305 }
306
307 async fn record_failure(&self) {
308 let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
309 *self.last_failure_time.lock().await = Some(SystemTime::now());
310
311 if count >= self.threshold {
312 self.is_open.store(true, Ordering::SeqCst);
313 warn!(
314 "Circuit breaker opened after {} consecutive failures",
315 count
316 );
317 }
318 }
319
320 async fn is_open(&self) -> bool {
321 if !self.is_open.load(Ordering::SeqCst) {
322 return false;
323 }
324
325 if let Some(last_failure) = *self.last_failure_time.lock().await
327 && let Ok(elapsed) = SystemTime::now().duration_since(last_failure)
328 && elapsed > self.reset_duration
329 {
330 self.is_open.store(false, Ordering::SeqCst);
331 self.failure_count.store(0, Ordering::SeqCst);
332 info!("Circuit breaker reset after {:?}", elapsed);
333 return false;
334 }
335
336 true
337 }
338}
339
340#[derive(Debug)]
342pub struct AgentOrchestrator {
343 config: OrchestratorConfig,
345
346 registry: Arc<SubagentRegistry>,
348
349 concurrency_limiter: Arc<Semaphore>,
351
352 circuit_breakers: Arc<RwLock<HashMap<String, Arc<CircuitBreaker>>>>,
354
355 progress_tx: mpsc::UnboundedSender<ProgressUpdate>,
357
358 progress_rx: Arc<Mutex<mpsc::UnboundedReceiver<ProgressUpdate>>>,
360
361 cancelled: Arc<AtomicBool>,
363
364 active_executions: Arc<AtomicUsize>,
366
367 operating_mode: OperatingMode,
369}
370
371impl AgentOrchestrator {
372 pub fn new(
374 registry: Arc<SubagentRegistry>,
375 config: OrchestratorConfig,
376 operating_mode: OperatingMode,
377 ) -> Self {
378 let (progress_tx, progress_rx) = mpsc::unbounded_channel();
379
380 Self {
381 config: config.clone(),
382 registry,
383 concurrency_limiter: Arc::new(Semaphore::new(config.max_concurrency)),
384 circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
385 progress_tx,
386 progress_rx: Arc::new(Mutex::new(progress_rx)),
387 cancelled: Arc::new(AtomicBool::new(false)),
388 active_executions: Arc::new(AtomicUsize::new(0)),
389 operating_mode,
390 }
391 }
392
393 pub async fn execute_plan(
395 &self,
396 request: InvocationRequest,
397 ) -> Result<OrchestratorResult, SubagentError> {
398 let start_time = SystemTime::now();
399 info!("Starting orchestrator execution for request {}", request.id);
400
401 if self.cancelled.load(Ordering::SeqCst) {
403 return Err(SubagentError::ExecutionFailed(
404 "Execution cancelled".to_string(),
405 ));
406 }
407
408 if self.config.monitor_memory {
410 self.check_memory_pressure().await?;
411 }
412
413 let context = SharedContext::new();
415
416 let executions = match &request.execution_plan {
418 ExecutionPlan::Single(invocation) => {
419 vec![self.execute_single(invocation.clone(), &context).await?]
420 }
421 ExecutionPlan::Sequential(chain) => {
422 self.execute_sequential(chain.clone(), &context).await?
423 }
424 ExecutionPlan::Parallel(invocations) => {
425 self.execute_parallel(invocations.clone(), &context).await?
426 }
427 ExecutionPlan::Mixed(steps) => self.execute_mixed(steps.clone(), &context).await?,
428 ExecutionPlan::Conditional(cond) => {
429 let mut results = Vec::new();
431 for agent in &cond.agents {
432 results.push(self.execute_single(agent.clone(), &context).await?);
434 }
435 results
436 }
437 };
438
439 let total_duration = SystemTime::now()
441 .duration_since(start_time)
442 .unwrap_or_default();
443
444 let success = executions
446 .iter()
447 .all(|e| e.status == SubagentStatus::Completed);
448 let partial_success = executions
449 .iter()
450 .any(|e| e.status == SubagentStatus::Completed);
451
452 info!(
453 "Orchestrator execution completed for request {} in {:?} (success: {}, partial: {})",
454 request.id, total_duration, success, partial_success
455 );
456
457 Ok(OrchestratorResult {
458 request,
459 executions,
460 context,
461 total_duration,
462 success,
463 partial_success,
464 })
465 }
466
467 pub async fn execute_single(
469 &self,
470 invocation: AgentInvocation,
471 shared_context: &SharedContext,
472 ) -> Result<SubagentExecution, SubagentError> {
473 if self.config.enable_circuit_breaker {
475 let breaker = self
476 .get_or_create_circuit_breaker(&invocation.agent_name)
477 .await;
478 if breaker.is_open().await {
479 return Err(SubagentError::ExecutionFailed(format!(
480 "Circuit breaker open for agent {}",
481 invocation.agent_name
482 )));
483 }
484 }
485
486 let _permit = self.concurrency_limiter.acquire().await.map_err(|e| {
488 SubagentError::ExecutionFailed(format!("Failed to acquire permit: {}", e))
489 })?;
490
491 self.active_executions.fetch_add(1, Ordering::SeqCst);
492 let result = self.execute_with_retry(invocation, shared_context).await;
493 self.active_executions.fetch_sub(1, Ordering::SeqCst);
494
495 if self.config.enable_circuit_breaker {
497 let breaker = self
498 .get_or_create_circuit_breaker(
499 &result
500 .as_ref()
501 .map(|e| e.agent_name.clone())
502 .unwrap_or_default(),
503 )
504 .await;
505
506 match &result {
507 Ok(execution) if execution.status == SubagentStatus::Completed => {
508 breaker.record_success().await;
509 }
510 _ => {
511 breaker.record_failure().await;
512 }
513 }
514 }
515
516 result
517 }
518
519 pub async fn execute_sequential(
521 &self,
522 chain: AgentChain,
523 shared_context: &SharedContext,
524 ) -> Result<Vec<SubagentExecution>, SubagentError> {
525 let mut executions = Vec::new();
526
527 for invocation in chain.agents {
528 if self.cancelled.load(Ordering::SeqCst) {
530 warn!("Sequential execution cancelled");
531 break;
532 }
533
534 let execution = self.execute_single(invocation, shared_context).await?;
535
536 if chain.pass_output
538 && let Some(output) = &execution.output
539 {
540 shared_context.add_output(output.clone()).await;
541 }
542
543 shared_context
545 .add_modified_files(execution.modified_files.clone())
546 .await;
547
548 executions.push(execution);
549 }
550
551 Ok(executions)
552 }
553
554 pub async fn execute_parallel(
556 &self,
557 invocations: Vec<AgentInvocation>,
558 shared_context: &SharedContext,
559 ) -> Result<Vec<SubagentExecution>, SubagentError> {
560 let mut tasks = Vec::new();
561
562 for invocation in invocations {
563 let self_clone = self.clone_for_task();
564 let context_clone = shared_context.clone();
565
566 let task =
567 tokio::spawn(
568 async move { self_clone.execute_single(invocation, &context_clone).await },
569 );
570
571 tasks.push(task);
572 }
573
574 let mut executions = Vec::new();
576 let mut errors = Vec::new();
577
578 for task in tasks {
579 match task.await {
580 Ok(Ok(execution)) => {
581 shared_context
583 .add_modified_files(execution.modified_files.clone())
584 .await;
585 executions.push(execution);
586 }
587 Ok(Err(e)) => {
588 errors.push(e.to_string());
589 shared_context.add_error(e.to_string()).await;
590 }
591 Err(e) => {
592 let error = format!("Task join error: {}", e);
593 errors.push(error.clone());
594 shared_context.add_error(error).await;
595 }
596 }
597 }
598
599 if !executions.is_empty() {
601 Ok(executions)
602 } else if !errors.is_empty() {
603 Err(SubagentError::ExecutionFailed(format!(
604 "All parallel executions failed: {}",
605 errors.join(", ")
606 )))
607 } else {
608 Err(SubagentError::ExecutionFailed(
609 "No executions completed".to_string(),
610 ))
611 }
612 }
613
614 pub async fn execute_mixed(
616 &self,
617 steps: Vec<ExecutionStep>,
618 shared_context: &SharedContext,
619 ) -> Result<Vec<SubagentExecution>, SubagentError> {
620 let mut all_executions = Vec::new();
621
622 for step in steps {
623 if self.cancelled.load(Ordering::SeqCst) {
625 warn!("Mixed execution cancelled");
626 break;
627 }
628
629 match step {
630 ExecutionStep::Single(invocation) => {
631 let execution = self.execute_single(invocation, shared_context).await?;
632 all_executions.push(execution);
633 }
634 ExecutionStep::Parallel(invocations) => {
635 let executions = self.execute_parallel(invocations, shared_context).await?;
636 all_executions.extend(executions);
637 }
638 ExecutionStep::Conditional(cond) => {
639 for agent in cond.agents {
640 let execution = self.execute_single(agent, shared_context).await?;
641 all_executions.push(execution);
642 }
643 }
644 ExecutionStep::Barrier => {
645 while self.active_executions.load(Ordering::SeqCst) > 0 {
647 sleep(Duration::from_millis(100)).await;
648 }
649 debug!("Barrier: All previous executions completed");
650 }
651 }
652 }
653
654 Ok(all_executions)
655 }
656
657 async fn execute_with_retry(
659 &self,
660 invocation: AgentInvocation,
661 shared_context: &SharedContext,
662 ) -> Result<SubagentExecution, SubagentError> {
663 let mut attempts = 0;
664 let max_attempts = if self.config.enable_retries {
665 self.config.max_retries + 1
666 } else {
667 1
668 };
669
670 loop {
671 attempts += 1;
672
673 let mut execution = SubagentExecution::new(invocation.agent_name.clone());
675
676 self.send_progress(ProgressUpdate {
678 execution_id: execution.id,
679 agent_name: invocation.agent_name.clone(),
680 status: SubagentStatus::Running,
681 message: Some(format!(
682 "Starting execution (attempt {}/{})",
683 attempts, max_attempts
684 )),
685 progress_percentage: Some(0),
686 timestamp: SystemTime::now(),
687 })
688 .await;
689
690 let result = timeout(
692 self.config.agent_timeout,
693 self.execute_agent_internal(&invocation, shared_context, &mut execution),
694 )
695 .await;
696
697 match result {
698 Ok(Ok(())) => {
699 self.send_progress(ProgressUpdate {
701 execution_id: execution.id,
702 agent_name: invocation.agent_name.clone(),
703 status: SubagentStatus::Completed,
704 message: Some("Execution completed successfully".to_string()),
705 progress_percentage: Some(100),
706 timestamp: SystemTime::now(),
707 })
708 .await;
709
710 return Ok(execution);
711 }
712 Ok(Err(e)) if attempts < max_attempts && self.is_retriable_error(&e) => {
713 warn!(
715 "Agent {} failed with retriable error (attempt {}/{}): {}",
716 invocation.agent_name, attempts, max_attempts, e
717 );
718
719 self.send_progress(ProgressUpdate {
720 execution_id: execution.id,
721 agent_name: invocation.agent_name.clone(),
722 status: SubagentStatus::Running,
723 message: Some(format!("Retrying after error: {}", e)),
724 progress_percentage: None,
725 timestamp: SystemTime::now(),
726 })
727 .await;
728
729 sleep(self.config.retry_backoff * attempts).await;
731 continue;
732 }
733 Ok(Err(e)) => {
734 execution.fail(e.to_string());
736
737 self.send_progress(ProgressUpdate {
738 execution_id: execution.id,
739 agent_name: invocation.agent_name.clone(),
740 status: SubagentStatus::Failed(e.to_string()),
741 message: Some(format!("Execution failed: {}", e)),
742 progress_percentage: None,
743 timestamp: SystemTime::now(),
744 })
745 .await;
746
747 return Err(e);
748 }
749 Err(_) => {
750 let error = format!(
752 "Agent {} timed out after {:?}",
753 invocation.agent_name, self.config.agent_timeout
754 );
755 execution.fail(error.clone());
756
757 self.send_progress(ProgressUpdate {
758 execution_id: execution.id,
759 agent_name: invocation.agent_name.clone(),
760 status: SubagentStatus::Failed(error.clone()),
761 message: Some(error.clone()),
762 progress_percentage: None,
763 timestamp: SystemTime::now(),
764 })
765 .await;
766
767 if attempts < max_attempts {
768 warn!(
769 "Agent {} timed out (attempt {}/{})",
770 invocation.agent_name, attempts, max_attempts
771 );
772 sleep(self.config.retry_backoff * attempts).await;
773 continue;
774 }
775
776 return Err(SubagentError::Timeout {
777 name: invocation.agent_name,
778 });
779 }
780 }
781 }
782 }
783
784 async fn execute_agent_internal(
786 &self,
787 invocation: &AgentInvocation,
788 shared_context: &SharedContext,
789 execution: &mut SubagentExecution,
790 ) -> Result<(), SubagentError> {
791 execution.start();
792
793 let _agent = self
795 .registry
796 .get_agent(&invocation.agent_name)
797 .ok_or_else(|| SubagentError::AgentNotFound {
798 name: invocation.agent_name.clone(),
799 })?;
800
801 let agent_context = SubagentContext {
804 execution_id: execution.id,
805 mode: self.operating_mode, available_tools: vec![], conversation_context: shared_context.last_output().await.unwrap_or_default(),
808 working_directory: std::env::current_dir().unwrap_or_default(),
809 parameters: invocation.parameters.clone(),
810 metadata: HashMap::new(),
811 };
812
813 info!(
815 "Executing agent {} with context: {:?}",
816 invocation.agent_name, agent_context
817 );
818
819 for i in 1..=10 {
821 if self.cancelled.load(Ordering::SeqCst) {
822 return Err(SubagentError::ExecutionFailed(
823 "Execution cancelled".to_string(),
824 ));
825 }
826
827 self.send_progress(ProgressUpdate {
828 execution_id: execution.id,
829 agent_name: invocation.agent_name.clone(),
830 status: SubagentStatus::Running,
831 message: Some(format!("Processing step {}/10", i)),
832 progress_percentage: Some((i * 10) as u8),
833 timestamp: SystemTime::now(),
834 })
835 .await;
836
837 sleep(Duration::from_millis(100)).await;
838 }
839
840 let output = format!(
842 "Agent {} completed successfully with parameters: {:?}",
843 invocation.agent_name, invocation.parameters
844 );
845
846 execution.complete(output.clone(), vec![]);
847
848 Ok(())
849 }
850
851 const fn is_retriable_error(&self, error: &SubagentError) -> bool {
853 matches!(
854 error,
855 SubagentError::Timeout { .. }
856 | SubagentError::ExecutionFailed(_)
857 | SubagentError::Io(_)
858 )
859 }
860
861 async fn get_or_create_circuit_breaker(&self, agent_name: &str) -> Arc<CircuitBreaker> {
863 let mut breakers = self.circuit_breakers.write().await;
864
865 breakers
866 .entry(agent_name.to_string())
867 .or_insert_with(|| {
868 Arc::new(CircuitBreaker::new(
869 self.config.circuit_breaker_threshold,
870 self.config.circuit_breaker_reset,
871 ))
872 })
873 .clone()
874 }
875
876 async fn check_memory_pressure(&self) -> Result<(), SubagentError> {
878 let memory_usage_mb = 500; if memory_usage_mb > self.config.memory_threshold_mb {
883 Err(SubagentError::ExecutionFailed(format!(
884 "Memory pressure too high: {}MB > {}MB threshold",
885 memory_usage_mb, self.config.memory_threshold_mb
886 )))
887 } else {
888 Ok(())
889 }
890 }
891
892 async fn send_progress(&self, update: ProgressUpdate) {
894 if let Err(e) = self.progress_tx.send(update) {
895 warn!("Failed to send progress update: {}", e);
896 }
897 }
898
899 pub async fn progress_receiver(&self) -> mpsc::UnboundedReceiver<ProgressUpdate> {
901 let (_tx, rx) = mpsc::unbounded_channel();
904 rx
905 }
906
907 pub fn cancel(&self) {
909 self.cancelled.store(true, Ordering::SeqCst);
910 info!("Orchestrator execution cancelled");
911 }
912
913 pub fn reset_cancellation(&self) {
915 self.cancelled.store(false, Ordering::SeqCst);
916 }
917
918 pub fn active_count(&self) -> usize {
920 self.active_executions.load(Ordering::SeqCst)
921 }
922
923 fn clone_for_task(&self) -> Arc<Self> {
925 Arc::new(Self {
928 config: self.config.clone(),
929 registry: self.registry.clone(),
930 concurrency_limiter: self.concurrency_limiter.clone(),
931 circuit_breakers: self.circuit_breakers.clone(),
932 progress_tx: self.progress_tx.clone(),
933 progress_rx: self.progress_rx.clone(),
934 cancelled: self.cancelled.clone(),
935 active_executions: self.active_executions.clone(),
936 operating_mode: self.operating_mode,
937 })
938 }
939}
940
941impl AgentOrchestrator {
943 pub async fn execute_conditional<F>(
945 &self,
946 invocation: AgentInvocation,
947 shared_context: &SharedContext,
948 condition: F,
949 ) -> Result<Option<SubagentExecution>, SubagentError>
950 where
951 F: Fn(&SharedContext) -> futures::future::BoxFuture<'_, bool> + Send + Sync,
952 {
953 if condition(shared_context).await {
955 Ok(Some(self.execute_single(invocation, shared_context).await?))
956 } else {
957 info!("Skipping agent {} due to condition", invocation.agent_name);
958 Ok(None)
959 }
960 }
961
962 pub async fn execute_with_dependencies(
964 &self,
965 invocation: AgentInvocation,
966 dependencies: Vec<String>,
967 shared_context: &SharedContext,
968 ) -> Result<SubagentExecution, SubagentError> {
969 let outputs = shared_context.all_outputs().await;
971
972 for dep in dependencies {
973 if !outputs.iter().any(|o| o.contains(&dep)) {
974 return Err(SubagentError::ExecutionFailed(format!(
975 "Dependency {} not satisfied for agent {}",
976 dep, invocation.agent_name
977 )));
978 }
979 }
980
981 self.execute_single(invocation, shared_context).await
982 }
983}
984
985#[cfg(test)]
986mod tests {
987 use super::*;
988 use crate::subagents::config::SubagentConfig;
989
990 async fn create_test_orchestrator() -> AgentOrchestrator {
991 use tempfile::TempDir;
992
993 let temp_dir = TempDir::new().unwrap();
995 let home_dir = temp_dir.path().join(".agcodex");
996 std::fs::create_dir_all(&home_dir).unwrap();
997
998 unsafe {
1000 std::env::set_var("HOME", temp_dir.path());
1001 }
1002
1003 let registry = Arc::new(SubagentRegistry::new().unwrap());
1005
1006 let test_agent = SubagentConfig {
1008 name: "test-agent".to_string(),
1009 description: "Test agent".to_string(),
1010 mode_override: None,
1011 intelligence: crate::subagents::config::IntelligenceLevel::Medium,
1012 tools: std::collections::HashMap::new(),
1013 prompt: "Test prompt".to_string(),
1014 parameters: vec![],
1015 template: None,
1016 timeout_seconds: 10,
1017 chainable: true,
1018 parallelizable: true,
1019 metadata: std::collections::HashMap::new(),
1020 file_patterns: vec![],
1021 tags: vec![], };
1023
1024 let global_agents_dir = temp_dir
1026 .path()
1027 .join(".agcodex")
1028 .join("agents")
1029 .join("global");
1030 std::fs::create_dir_all(&global_agents_dir).unwrap();
1031 let config_path = global_agents_dir.join("test-agent.toml");
1032 test_agent.to_file(&config_path).unwrap();
1033
1034 registry.load_all().unwrap();
1036
1037 AgentOrchestrator::new(
1038 registry,
1039 OrchestratorConfig::default(),
1040 OperatingMode::Build,
1041 )
1042 }
1043
1044 #[tokio::test]
1045 async fn test_single_agent_execution() {
1046 let orchestrator = create_test_orchestrator().await;
1047
1048 let invocation = AgentInvocation {
1049 agent_name: "test-agent".to_string(),
1050 parameters: HashMap::new(),
1051 raw_parameters: String::new(),
1052 position: 0,
1053 intelligence_override: None,
1054 mode_override: None,
1055 };
1056
1057 let context = SharedContext::new();
1058 let result = orchestrator.execute_single(invocation, &context).await;
1059
1060 assert!(result.is_ok());
1061 let execution = result.unwrap();
1062 assert_eq!(execution.status, SubagentStatus::Completed);
1063 }
1064
1065 #[tokio::test]
1066 async fn test_sequential_execution() {
1067 let orchestrator = create_test_orchestrator().await;
1068
1069 let chain = AgentChain {
1070 agents: vec![
1071 AgentInvocation {
1072 agent_name: "test-agent".to_string(),
1073 parameters: HashMap::new(),
1074 raw_parameters: String::new(),
1075 position: 0,
1076 intelligence_override: None,
1077 mode_override: None,
1078 },
1079 AgentInvocation {
1080 agent_name: "test-agent".to_string(),
1081 parameters: HashMap::new(),
1082 raw_parameters: String::new(),
1083 position: 1,
1084 intelligence_override: None,
1085 mode_override: None,
1086 },
1087 ],
1088 pass_output: true,
1089 };
1090
1091 let context = SharedContext::new();
1092 let result = orchestrator.execute_sequential(chain, &context).await;
1093
1094 assert!(result.is_ok());
1095 let executions = result.unwrap();
1096 assert_eq!(executions.len(), 2);
1097 assert!(
1098 executions
1099 .iter()
1100 .all(|e| e.status == SubagentStatus::Completed)
1101 );
1102 }
1103
1104 #[tokio::test]
1105 async fn test_parallel_execution() {
1106 let orchestrator = create_test_orchestrator().await;
1107
1108 let invocations = vec![
1109 AgentInvocation {
1110 agent_name: "test-agent".to_string(),
1111 parameters: HashMap::new(),
1112 raw_parameters: String::new(),
1113 position: 0,
1114 intelligence_override: None,
1115 mode_override: None,
1116 },
1117 AgentInvocation {
1118 agent_name: "test-agent".to_string(),
1119 parameters: HashMap::new(),
1120 raw_parameters: String::new(),
1121 position: 1,
1122 intelligence_override: None,
1123 mode_override: None,
1124 },
1125 ];
1126
1127 let context = SharedContext::new();
1128 let result = orchestrator.execute_parallel(invocations, &context).await;
1129
1130 assert!(result.is_ok());
1131 let executions = result.unwrap();
1132 assert_eq!(executions.len(), 2);
1133 }
1134
1135 #[tokio::test]
1136 async fn test_context_sharing() {
1137 let context = SharedContext::new();
1138
1139 context
1141 .set("key1".to_string(), serde_json::json!("value1"))
1142 .await;
1143 let value = context.get("key1").await;
1144 assert_eq!(value, Some(serde_json::json!("value1")));
1145
1146 context.add_output("output1".to_string()).await;
1148 context.add_output("output2".to_string()).await;
1149 assert_eq!(context.last_output().await, Some("output2".to_string()));
1150 assert_eq!(context.all_outputs().await.len(), 2);
1151
1152 let snapshot = context.snapshot().await;
1154 context
1155 .set("key2".to_string(), serde_json::json!("value2"))
1156 .await;
1157 assert!(context.get("key2").await.is_some());
1158
1159 context.restore(snapshot).await;
1160 assert!(context.get("key2").await.is_none());
1161 assert_eq!(context.get("key1").await, Some(serde_json::json!("value1")));
1162 }
1163
1164 #[tokio::test]
1165 async fn test_circuit_breaker() {
1166 let breaker = CircuitBreaker::new(3, Duration::from_secs(1));
1167
1168 for _ in 0..3 {
1170 breaker.record_failure().await;
1171 }
1172
1173 assert!(breaker.is_open().await);
1175
1176 tokio::time::sleep(Duration::from_secs(2)).await;
1178
1179 assert!(!breaker.is_open().await);
1181
1182 breaker.record_success().await;
1184 assert!(!breaker.is_open().await);
1185 }
1186
1187 #[tokio::test]
1188 async fn test_cancellation() {
1189 let orchestrator = create_test_orchestrator().await;
1190
1191 orchestrator.cancel();
1192
1193 let invocation = AgentInvocation {
1194 agent_name: "test-agent".to_string(),
1195 parameters: HashMap::new(),
1196 raw_parameters: String::new(),
1197 position: 0,
1198 intelligence_override: None,
1199 mode_override: None,
1200 };
1201
1202 let context = SharedContext::new();
1203 let result = orchestrator.execute_single(invocation, &context).await;
1204
1205 assert!(result.is_err());
1206 assert!(matches!(
1207 result.unwrap_err(),
1208 SubagentError::ExecutionFailed(_)
1209 ));
1210 }
1211}