1use std::sync::Arc;
9
10use serde::{Deserialize, Serialize};
11use tokio::task::JoinSet;
12
13use crate::error::Error;
14use crate::llm::LlmProvider;
15use crate::llm::types::TokenUsage;
16
17use super::AgentOutput;
18use super::AgentRunner;
19use super::dag::DagAgent;
20use super::debate::DebateAgent;
21use super::mixture::MixtureOfAgentsAgent;
22use super::voting::VotingAgent;
23
24type StopCondition = Box<dyn Fn(&str) -> bool + Send + Sync>;
26
27pub struct SequentialAgent<P: LlmProvider> {
35 agents: Vec<AgentRunner<P>>,
36}
37
38impl<P: LlmProvider> std::fmt::Debug for SequentialAgent<P> {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("SequentialAgent")
41 .field("agent_count", &self.agents.len())
42 .finish()
43 }
44}
45
46pub struct SequentialAgentBuilder<P: LlmProvider> {
48 agents: Vec<AgentRunner<P>>,
49}
50
51impl<P: LlmProvider> SequentialAgent<P> {
52 pub fn builder() -> SequentialAgentBuilder<P> {
86 SequentialAgentBuilder { agents: Vec::new() }
87 }
88
89 pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
92 let mut current_input = task.to_string();
93 let mut total_usage = TokenUsage::default();
94 let mut total_tool_calls = 0usize;
95 let mut total_cost: Option<f64> = None;
96 let mut last_output: Option<AgentOutput> = None;
97
98 for agent in &self.agents {
99 let result = agent
100 .execute(¤t_input)
101 .await
102 .map_err(|e| e.accumulate_usage(total_usage))?;
103 result.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
104 current_input = result.result.clone();
105 last_output = Some(result);
106 }
107
108 let mut output = last_output.expect("at least one agent");
110 output.tokens_used = total_usage;
111 output.tool_calls_made = total_tool_calls;
112 output.estimated_cost_usd = total_cost;
113 Ok(output)
114 }
115}
116
117impl<P: LlmProvider> SequentialAgentBuilder<P> {
118 pub fn agent(mut self, agent: AgentRunner<P>) -> Self {
120 self.agents.push(agent);
121 self
122 }
123
124 pub fn agents(mut self, agents: Vec<AgentRunner<P>>) -> Self {
126 self.agents.extend(agents);
127 self
128 }
129
130 pub fn build(self) -> Result<SequentialAgent<P>, Error> {
132 if self.agents.is_empty() {
133 return Err(Error::Config(
134 "SequentialAgent requires at least one agent".into(),
135 ));
136 }
137 Ok(SequentialAgent {
138 agents: self.agents,
139 })
140 }
141}
142
143pub struct ParallelAgent<P: LlmProvider + 'static> {
151 agents: Vec<Arc<AgentRunner<P>>>,
152}
153
154impl<P: LlmProvider + 'static> std::fmt::Debug for ParallelAgent<P> {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 f.debug_struct("ParallelAgent")
157 .field("agent_count", &self.agents.len())
158 .finish()
159 }
160}
161
162pub struct ParallelAgentBuilder<P: LlmProvider + 'static> {
164 agents: Vec<Arc<AgentRunner<P>>>,
165}
166
167impl<P: LlmProvider + 'static> ParallelAgent<P> {
168 pub fn builder() -> ParallelAgentBuilder<P> {
170 ParallelAgentBuilder { agents: Vec::new() }
171 }
172
173 pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
175 let mut set = JoinSet::new();
176
177 for agent in &self.agents {
178 let agent = Arc::clone(agent);
179 let task = task.to_string();
180 set.spawn(async move {
181 let name = agent.name().to_string();
182 let result = agent.execute(&task).await;
183 (name, result)
184 });
185 }
186
187 let mut results: Vec<(String, AgentOutput)> = Vec::with_capacity(self.agents.len());
188 let mut total_usage = TokenUsage::default();
189 let mut total_tool_calls = 0usize;
190 let mut total_cost: Option<f64> = None;
191
192 while let Some(join_result) = set.join_next().await {
193 let (name, agent_result) = join_result
194 .map_err(|e| Error::Agent(format!("parallel agent task panicked: {e}")))?;
195 let output = agent_result.map_err(|e| e.accumulate_usage(total_usage))?;
196 output.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
197 results.push((name, output));
198 }
199
200 results.sort_by(|a, b| a.0.cmp(&b.0));
202
203 let merged_text = results
204 .iter()
205 .map(|(name, output)| format!("## {name}\n{}", output.result))
206 .collect::<Vec<_>>()
207 .join("\n\n");
208
209 Ok(AgentOutput {
210 result: merged_text,
211 tool_calls_made: total_tool_calls,
212 tokens_used: total_usage,
213 structured: None,
214 estimated_cost_usd: total_cost,
215 model_name: None,
216 })
217 }
218}
219
220impl<P: LlmProvider + 'static> ParallelAgentBuilder<P> {
221 pub fn agent(mut self, agent: AgentRunner<P>) -> Self {
223 self.agents.push(Arc::new(agent));
224 self
225 }
226
227 pub fn agents(mut self, agents: Vec<AgentRunner<P>>) -> Self {
229 self.agents.extend(agents.into_iter().map(Arc::new));
230 self
231 }
232
233 pub fn build(self) -> Result<ParallelAgent<P>, Error> {
235 if self.agents.is_empty() {
236 return Err(Error::Config(
237 "ParallelAgent requires at least one agent".into(),
238 ));
239 }
240 Ok(ParallelAgent {
241 agents: self.agents,
242 })
243 }
244}
245
246pub struct LoopAgent<P: LlmProvider> {
254 agent: AgentRunner<P>,
255 max_iterations: usize,
256 should_stop: StopCondition,
257}
258
259impl<P: LlmProvider> std::fmt::Debug for LoopAgent<P> {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 f.debug_struct("LoopAgent")
262 .field("max_iterations", &self.max_iterations)
263 .finish()
264 }
265}
266
267pub struct LoopAgentBuilder<P: LlmProvider> {
269 agent: Option<AgentRunner<P>>,
270 max_iterations: Option<usize>,
271 should_stop: Option<StopCondition>,
272}
273
274impl<P: LlmProvider> LoopAgent<P> {
275 pub fn builder() -> LoopAgentBuilder<P> {
277 LoopAgentBuilder {
278 agent: None,
279 max_iterations: None,
280 should_stop: None,
281 }
282 }
283
284 pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
286 let mut current_input = task.to_string();
287 let mut total_usage = TokenUsage::default();
288 let mut total_tool_calls = 0usize;
289 let mut total_cost: Option<f64> = None;
290 let mut last_output: Option<AgentOutput> = None;
291
292 for _ in 0..self.max_iterations {
293 let result = self
294 .agent
295 .execute(¤t_input)
296 .await
297 .map_err(|e| e.accumulate_usage(total_usage))?;
298 result.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
299 current_input = result.result.clone();
300 let should_stop = (self.should_stop)(&result.result);
301 last_output = Some(result);
302 if should_stop {
303 break;
304 }
305 }
306
307 let mut output = last_output.expect("at least one iteration");
309 output.tokens_used = total_usage;
310 output.tool_calls_made = total_tool_calls;
311 output.estimated_cost_usd = total_cost;
312 Ok(output)
313 }
314}
315
316impl<P: LlmProvider> LoopAgentBuilder<P> {
317 pub fn agent(mut self, agent: AgentRunner<P>) -> Self {
319 self.agent = Some(agent);
320 self
321 }
322
323 pub fn max_iterations(mut self, n: usize) -> Self {
325 self.max_iterations = Some(n);
326 self
327 }
328
329 pub fn should_stop(mut self, f: impl Fn(&str) -> bool + Send + Sync + 'static) -> Self {
332 self.should_stop = Some(Box::new(f));
333 self
334 }
335
336 pub fn build(self) -> Result<LoopAgent<P>, Error> {
338 let agent = self
339 .agent
340 .ok_or_else(|| Error::Config("LoopAgent requires an agent".into()))?;
341 let max_iterations = self
342 .max_iterations
343 .ok_or_else(|| Error::Config("LoopAgent requires max_iterations".into()))?;
344 if max_iterations == 0 {
345 return Err(Error::Config(
346 "LoopAgent max_iterations must be at least 1".into(),
347 ));
348 }
349 let should_stop = self
350 .should_stop
351 .ok_or_else(|| Error::Config("LoopAgent requires a should_stop condition".into()))?;
352 Ok(LoopAgent {
353 agent,
354 max_iterations,
355 should_stop,
356 })
357 }
358}
359
360#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
366#[serde(rename_all = "snake_case")]
367pub enum WorkflowType {
368 Sequential,
370 Parallel,
372 Loop,
374 Dag,
376 Debate,
378 Voting,
380 Mixture,
382}
383
384pub enum WorkflowRouter<P: LlmProvider + 'static> {
391 Sequential(Box<SequentialAgent<P>>),
393 Parallel(Box<ParallelAgent<P>>),
395 Loop(Box<LoopAgent<P>>),
397 Dag(Box<DagAgent<P>>),
399 Debate(Box<DebateAgent<P>>),
401 Voting(Box<VotingAgent<P>>),
403 Mixture(Box<MixtureOfAgentsAgent<P>>),
405}
406
407impl<P: LlmProvider + 'static> WorkflowRouter<P> {
408 pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
414 match self {
415 Self::Sequential(a) => a.execute(task).await,
416 Self::Parallel(a) => a.execute(task).await,
417 Self::Loop(a) => a.execute(task).await,
418 Self::Dag(a) => a.execute(task).await,
419 Self::Debate(a) => a.execute(task).await,
420 Self::Mixture(a) => a.execute(task).await,
421 Self::Voting(a) => a.execute(task).await.map(|vr| vr.output),
422 }
423 }
424
425 pub fn workflow_type(&self) -> WorkflowType {
427 match self {
428 Self::Sequential(_) => WorkflowType::Sequential,
429 Self::Parallel(_) => WorkflowType::Parallel,
430 Self::Loop(_) => WorkflowType::Loop,
431 Self::Dag(_) => WorkflowType::Dag,
432 Self::Debate(_) => WorkflowType::Debate,
433 Self::Voting(_) => WorkflowType::Voting,
434 Self::Mixture(_) => WorkflowType::Mixture,
435 }
436 }
437}
438
439impl<P: LlmProvider + 'static> std::fmt::Debug for WorkflowRouter<P> {
440 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441 f.debug_tuple("WorkflowRouter")
442 .field(&self.workflow_type())
443 .finish()
444 }
445}
446
447impl<P: LlmProvider + 'static> From<SequentialAgent<P>> for WorkflowRouter<P> {
448 fn from(agent: SequentialAgent<P>) -> Self {
449 Self::Sequential(Box::new(agent))
450 }
451}
452
453impl<P: LlmProvider + 'static> From<ParallelAgent<P>> for WorkflowRouter<P> {
454 fn from(agent: ParallelAgent<P>) -> Self {
455 Self::Parallel(Box::new(agent))
456 }
457}
458
459impl<P: LlmProvider + 'static> From<LoopAgent<P>> for WorkflowRouter<P> {
460 fn from(agent: LoopAgent<P>) -> Self {
461 Self::Loop(Box::new(agent))
462 }
463}
464
465impl<P: LlmProvider + 'static> From<DagAgent<P>> for WorkflowRouter<P> {
466 fn from(agent: DagAgent<P>) -> Self {
467 Self::Dag(Box::new(agent))
468 }
469}
470
471impl<P: LlmProvider + 'static> From<DebateAgent<P>> for WorkflowRouter<P> {
472 fn from(agent: DebateAgent<P>) -> Self {
473 Self::Debate(Box::new(agent))
474 }
475}
476
477impl<P: LlmProvider + 'static> From<VotingAgent<P>> for WorkflowRouter<P> {
478 fn from(agent: VotingAgent<P>) -> Self {
479 Self::Voting(Box::new(agent))
480 }
481}
482
483impl<P: LlmProvider + 'static> From<MixtureOfAgentsAgent<P>> for WorkflowRouter<P> {
484 fn from(agent: MixtureOfAgentsAgent<P>) -> Self {
485 Self::Mixture(Box::new(agent))
486 }
487}
488
489#[cfg(test)]
494mod tests {
495 use super::*;
496 use crate::agent::test_helpers::{MockProvider, make_agent};
497
498 #[test]
503 fn sequential_builder_rejects_empty_agents() {
504 let result = SequentialAgent::<MockProvider>::builder().build();
505 assert!(result.is_err());
506 assert!(
507 result
508 .unwrap_err()
509 .to_string()
510 .contains("at least one agent")
511 );
512 }
513
514 #[test]
515 fn sequential_builder_accepts_one_agent() {
516 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
517 "done", 10, 5,
518 )]));
519 let agent = make_agent(provider, "a");
520 let seq = SequentialAgent::builder().agent(agent).build();
521 assert!(seq.is_ok());
522 }
523
524 #[tokio::test]
529 async fn sequential_single_agent() {
530 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
531 "hello world",
532 100,
533 50,
534 )]));
535 let agent = make_agent(provider, "step1");
536 let seq = SequentialAgent::builder().agent(agent).build().unwrap();
537
538 let output = seq.execute("start").await.unwrap();
539 assert_eq!(output.result, "hello world");
540 assert_eq!(output.tokens_used.input_tokens, 100);
541 assert_eq!(output.tokens_used.output_tokens, 50);
542 }
543
544 #[tokio::test]
545 async fn sequential_chains_output_as_input() {
546 let provider_a = Arc::new(MockProvider::new(vec![MockProvider::text_response(
549 "step-a-output",
550 100,
551 50,
552 )]));
553 let provider_b = Arc::new(MockProvider::new(vec![MockProvider::text_response(
554 "step-b-output",
555 200,
556 80,
557 )]));
558
559 let agent_a = make_agent(provider_a, "agent-a");
560 let agent_b = make_agent(provider_b, "agent-b");
561
562 let seq = SequentialAgent::builder()
563 .agent(agent_a)
564 .agent(agent_b)
565 .build()
566 .unwrap();
567
568 let output = seq.execute("initial task").await.unwrap();
569 assert_eq!(output.result, "step-b-output");
570 assert_eq!(output.tokens_used.input_tokens, 300);
572 assert_eq!(output.tokens_used.output_tokens, 130);
573 }
574
575 #[tokio::test]
576 async fn sequential_three_agents_accumulates_usage() {
577 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
578 "out1", 10, 5,
579 )]));
580 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
581 "out2", 20, 10,
582 )]));
583 let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
584 "out3", 30, 15,
585 )]));
586
587 let seq = SequentialAgent::builder()
588 .agent(make_agent(p1, "a"))
589 .agent(make_agent(p2, "b"))
590 .agent(make_agent(p3, "c"))
591 .build()
592 .unwrap();
593
594 let output = seq.execute("go").await.unwrap();
595 assert_eq!(output.result, "out3");
596 assert_eq!(output.tokens_used.input_tokens, 60);
597 assert_eq!(output.tokens_used.output_tokens, 30);
598 }
599
600 #[tokio::test]
601 async fn sequential_error_carries_partial_usage() {
602 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
603 "ok", 100, 50,
604 )]));
605 let p2 = Arc::new(MockProvider::new(vec![]));
607
608 let seq = SequentialAgent::builder()
609 .agent(make_agent(p1, "good"))
610 .agent(make_agent(p2, "bad"))
611 .build()
612 .unwrap();
613
614 let err = seq.execute("task").await.unwrap_err();
615 let partial = err.partial_usage();
616 assert!(partial.input_tokens >= 100);
618 }
619
620 #[test]
625 fn parallel_builder_rejects_empty_agents() {
626 let result = ParallelAgent::<MockProvider>::builder().build();
627 assert!(result.is_err());
628 assert!(
629 result
630 .unwrap_err()
631 .to_string()
632 .contains("at least one agent")
633 );
634 }
635
636 #[test]
637 fn parallel_builder_accepts_one_agent() {
638 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
639 "ok", 10, 5,
640 )]));
641 let agent = make_agent(provider, "a");
642 let par = ParallelAgent::builder().agent(agent).build();
643 assert!(par.is_ok());
644 }
645
646 #[tokio::test]
651 async fn parallel_single_agent() {
652 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
653 "result-a", 100, 50,
654 )]));
655 let agent = make_agent(provider, "agent-a");
656 let par = ParallelAgent::builder().agent(agent).build().unwrap();
657
658 let output = par.execute("task").await.unwrap();
659 assert!(output.result.contains("agent-a"));
660 assert!(output.result.contains("result-a"));
661 assert_eq!(output.tokens_used.input_tokens, 100);
662 assert_eq!(output.tokens_used.output_tokens, 50);
663 }
664
665 #[tokio::test]
666 async fn parallel_multiple_agents_accumulates_usage() {
667 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
668 "out-a", 100, 50,
669 )]));
670 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
671 "out-b", 200, 80,
672 )]));
673
674 let par = ParallelAgent::builder()
675 .agent(make_agent(p1, "alpha"))
676 .agent(make_agent(p2, "beta"))
677 .build()
678 .unwrap();
679
680 let output = par.execute("same task").await.unwrap();
681 assert!(output.result.contains("out-a"));
683 assert!(output.result.contains("out-b"));
684 assert!(output.result.contains("## alpha"));
686 assert!(output.result.contains("## beta"));
687 assert_eq!(output.tokens_used.input_tokens, 300);
689 assert_eq!(output.tokens_used.output_tokens, 130);
690 }
691
692 #[tokio::test]
693 async fn parallel_output_sorted_by_name() {
694 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
695 "out-z", 10, 5,
696 )]));
697 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
698 "out-a", 10, 5,
699 )]));
700
701 let par = ParallelAgent::builder()
702 .agent(make_agent(p1, "zebra"))
703 .agent(make_agent(p2, "alpha"))
704 .build()
705 .unwrap();
706
707 let output = par.execute("task").await.unwrap();
708 let alpha_pos = output.result.find("## alpha").unwrap();
710 let zebra_pos = output.result.find("## zebra").unwrap();
711 assert!(alpha_pos < zebra_pos);
712 }
713
714 #[tokio::test]
715 async fn parallel_error_fails_fast() {
716 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
717 "ok", 100, 50,
718 )]));
719 let p2 = Arc::new(MockProvider::new(vec![]));
721
722 let par = ParallelAgent::builder()
723 .agent(make_agent(p1, "good"))
724 .agent(make_agent(p2, "bad"))
725 .build()
726 .unwrap();
727
728 let result = par.execute("task").await;
729 assert!(result.is_err());
730 }
731
732 #[test]
737 fn loop_builder_rejects_missing_agent() {
738 let result = LoopAgent::<MockProvider>::builder()
739 .max_iterations(3)
740 .should_stop(|_| true)
741 .build();
742 assert!(result.is_err());
743 assert!(
744 result
745 .unwrap_err()
746 .to_string()
747 .contains("requires an agent")
748 );
749 }
750
751 #[test]
752 fn loop_builder_rejects_missing_max_iterations() {
753 let provider = Arc::new(MockProvider::new(vec![]));
754 let agent = make_agent(provider, "a");
755 let result = LoopAgent::builder()
756 .agent(agent)
757 .should_stop(|_| true)
758 .build();
759 assert!(result.is_err());
760 assert!(
761 result
762 .unwrap_err()
763 .to_string()
764 .contains("requires max_iterations")
765 );
766 }
767
768 #[test]
769 fn loop_builder_rejects_zero_max_iterations() {
770 let provider = Arc::new(MockProvider::new(vec![]));
771 let agent = make_agent(provider, "a");
772 let result = LoopAgent::builder()
773 .agent(agent)
774 .max_iterations(0)
775 .should_stop(|_| true)
776 .build();
777 assert!(result.is_err());
778 assert!(result.unwrap_err().to_string().contains("at least 1"));
779 }
780
781 #[test]
782 fn loop_builder_rejects_missing_should_stop() {
783 let provider = Arc::new(MockProvider::new(vec![]));
784 let agent = make_agent(provider, "a");
785 let result = LoopAgent::builder().agent(agent).max_iterations(3).build();
786 assert!(result.is_err());
787 assert!(
788 result
789 .unwrap_err()
790 .to_string()
791 .contains("requires a should_stop")
792 );
793 }
794
795 #[test]
796 fn loop_builder_accepts_valid_config() {
797 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
798 "x", 1, 1,
799 )]));
800 let agent = make_agent(provider, "a");
801 let result = LoopAgent::builder()
802 .agent(agent)
803 .max_iterations(5)
804 .should_stop(|_| true)
805 .build();
806 assert!(result.is_ok());
807 }
808
809 #[tokio::test]
814 async fn loop_stops_on_condition() {
815 let provider = Arc::new(MockProvider::new(vec![
817 MockProvider::text_response("working...", 10, 5),
818 MockProvider::text_response("DONE", 10, 5),
819 MockProvider::text_response("should not reach", 10, 5),
820 ]));
821 let agent = make_agent(provider, "worker");
822
823 let loop_agent = LoopAgent::builder()
824 .agent(agent)
825 .max_iterations(10)
826 .should_stop(|text| text.contains("DONE"))
827 .build()
828 .unwrap();
829
830 let output = loop_agent.execute("start").await.unwrap();
831 assert_eq!(output.result, "DONE");
832 assert_eq!(output.tokens_used.input_tokens, 20);
834 assert_eq!(output.tokens_used.output_tokens, 10);
835 }
836
837 #[tokio::test]
838 async fn loop_stops_at_max_iterations() {
839 let provider = Arc::new(MockProvider::new(vec![
840 MockProvider::text_response("iter1", 10, 5),
841 MockProvider::text_response("iter2", 10, 5),
842 MockProvider::text_response("iter3", 10, 5),
843 ]));
844 let agent = make_agent(provider, "worker");
845
846 let loop_agent = LoopAgent::builder()
847 .agent(agent)
848 .max_iterations(3)
849 .should_stop(|_| false) .build()
851 .unwrap();
852
853 let output = loop_agent.execute("start").await.unwrap();
854 assert_eq!(output.result, "iter3");
855 assert_eq!(output.tokens_used.input_tokens, 30);
856 assert_eq!(output.tokens_used.output_tokens, 15);
857 }
858
859 #[tokio::test]
860 async fn loop_single_iteration() {
861 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
862 "once", 50, 25,
863 )]));
864 let agent = make_agent(provider, "worker");
865
866 let loop_agent = LoopAgent::builder()
867 .agent(agent)
868 .max_iterations(1)
869 .should_stop(|_| false)
870 .build()
871 .unwrap();
872
873 let output = loop_agent.execute("go").await.unwrap();
874 assert_eq!(output.result, "once");
875 assert_eq!(output.tokens_used.input_tokens, 50);
876 }
877
878 #[tokio::test]
879 async fn loop_error_carries_partial_usage() {
880 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
882 "ok", 100, 50,
883 )]));
884 let agent = make_agent(provider, "worker");
885
886 let loop_agent = LoopAgent::builder()
887 .agent(agent)
888 .max_iterations(5)
889 .should_stop(|_| false) .build()
891 .unwrap();
892
893 let err = loop_agent.execute("go").await.unwrap_err();
894 let partial = err.partial_usage();
895 assert!(partial.input_tokens >= 100);
896 }
897
898 #[test]
903 fn sequential_builder_agents_method() {
904 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
905 "a", 1, 1,
906 )]));
907 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
908 "b", 1, 1,
909 )]));
910 let agents = vec![make_agent(p1, "x"), make_agent(p2, "y")];
911 let seq = SequentialAgent::builder().agents(agents).build();
912 assert!(seq.is_ok());
913 }
914
915 #[test]
920 fn parallel_builder_agents_method() {
921 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
922 "a", 1, 1,
923 )]));
924 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
925 "b", 1, 1,
926 )]));
927 let agents = vec![make_agent(p1, "x"), make_agent(p2, "y")];
928 let par = ParallelAgent::builder().agents(agents).build();
929 assert!(par.is_ok());
930 }
931
932 #[test]
937 fn agent_runner_name_getter() {
938 let provider = Arc::new(MockProvider::new(vec![]));
939 let agent = make_agent(provider, "test-agent");
940 assert_eq!(agent.name(), "test-agent");
941 }
942
943 #[test]
948 fn workflow_type_serde_roundtrip() {
949 for wt in [
950 WorkflowType::Sequential,
951 WorkflowType::Parallel,
952 WorkflowType::Loop,
953 WorkflowType::Dag,
954 WorkflowType::Debate,
955 WorkflowType::Voting,
956 WorkflowType::Mixture,
957 ] {
958 let json = serde_json::to_string(&wt).unwrap();
959 let back: WorkflowType = serde_json::from_str(&json).unwrap();
960 assert_eq!(wt, back);
961 }
962 }
963
964 #[test]
965 fn workflow_type_snake_case() {
966 assert_eq!(
967 serde_json::to_string(&WorkflowType::Sequential).unwrap(),
968 "\"sequential\""
969 );
970 assert_eq!(
971 serde_json::to_string(&WorkflowType::Parallel).unwrap(),
972 "\"parallel\""
973 );
974 assert_eq!(
975 serde_json::to_string(&WorkflowType::Loop).unwrap(),
976 "\"loop\""
977 );
978 assert_eq!(
979 serde_json::to_string(&WorkflowType::Dag).unwrap(),
980 "\"dag\""
981 );
982 assert_eq!(
983 serde_json::to_string(&WorkflowType::Debate).unwrap(),
984 "\"debate\""
985 );
986 assert_eq!(
987 serde_json::to_string(&WorkflowType::Voting).unwrap(),
988 "\"voting\""
989 );
990 assert_eq!(
991 serde_json::to_string(&WorkflowType::Mixture).unwrap(),
992 "\"mixture\""
993 );
994 }
995
996 #[tokio::test]
1001 async fn router_sequential() {
1002 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1003 "seq-out", 10, 5,
1004 )]));
1005 let seq = SequentialAgent::builder()
1006 .agent(make_agent(provider, "s"))
1007 .build()
1008 .unwrap();
1009 let router = WorkflowRouter::Sequential(Box::new(seq));
1010 assert_eq!(router.workflow_type(), WorkflowType::Sequential);
1011 let output = router.execute("task").await.unwrap();
1012 assert_eq!(output.result, "seq-out");
1013 }
1014
1015 #[tokio::test]
1016 async fn router_parallel() {
1017 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1018 "par-out", 10, 5,
1019 )]));
1020 let par = ParallelAgent::builder()
1021 .agent(make_agent(provider, "p"))
1022 .build()
1023 .unwrap();
1024 let router = WorkflowRouter::Parallel(Box::new(par));
1025 assert_eq!(router.workflow_type(), WorkflowType::Parallel);
1026 let output = router.execute("task").await.unwrap();
1027 assert!(output.result.contains("par-out"));
1028 }
1029
1030 #[tokio::test]
1031 async fn router_loop() {
1032 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1033 "loop-out", 10, 5,
1034 )]));
1035 let lp = LoopAgent::builder()
1036 .agent(make_agent(provider, "l"))
1037 .max_iterations(1)
1038 .should_stop(|_| true)
1039 .build()
1040 .unwrap();
1041 let router = WorkflowRouter::Loop(Box::new(lp));
1042 assert_eq!(router.workflow_type(), WorkflowType::Loop);
1043 let output = router.execute("task").await.unwrap();
1044 assert_eq!(output.result, "loop-out");
1045 }
1046
1047 #[tokio::test]
1048 async fn router_dag() {
1049 use crate::agent::dag::DagAgent;
1050
1051 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1052 "dag-out", 10, 5,
1053 )]));
1054 let dag = DagAgent::builder()
1055 .node("A", make_agent(provider, "A"))
1056 .build()
1057 .unwrap();
1058 let router = WorkflowRouter::Dag(Box::new(dag));
1059 assert_eq!(router.workflow_type(), WorkflowType::Dag);
1060 let output = router.execute("task").await.unwrap();
1061 assert_eq!(output.result, "dag-out");
1062 }
1063
1064 #[test]
1065 fn router_from_sequential() {
1066 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1067 "x", 1, 1,
1068 )]));
1069 let seq = SequentialAgent::builder()
1070 .agent(make_agent(provider, "s"))
1071 .build()
1072 .unwrap();
1073 let router: WorkflowRouter<MockProvider> = seq.into();
1074 assert_eq!(router.workflow_type(), WorkflowType::Sequential);
1075 }
1076
1077 #[test]
1078 fn router_from_dag() {
1079 use crate::agent::dag::DagAgent;
1080
1081 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1082 "x", 1, 1,
1083 )]));
1084 let dag = DagAgent::builder()
1085 .node("A", make_agent(provider, "A"))
1086 .build()
1087 .unwrap();
1088 let router: WorkflowRouter<MockProvider> = dag.into();
1089 assert_eq!(router.workflow_type(), WorkflowType::Dag);
1090 }
1091
1092 #[test]
1093 fn router_debug() {
1094 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1095 "x", 1, 1,
1096 )]));
1097 let seq = SequentialAgent::builder()
1098 .agent(make_agent(provider, "s"))
1099 .build()
1100 .unwrap();
1101 let router = WorkflowRouter::Sequential(Box::new(seq));
1102 let debug = format!("{router:?}");
1103 assert!(debug.contains("WorkflowRouter"));
1104 assert!(debug.contains("Sequential"));
1105 }
1106}