Skip to main content

heartbit_core/agent/
workflow.rs

1//! Deterministic workflow agent primitives.
2//!
3//! These composable agents orchestrate sub-agents without LLM calls:
4//! - [`SequentialAgent`]: runs agents in order, piping output as input
5//! - [`ParallelAgent`]: runs agents concurrently via `tokio::JoinSet`
6//! - [`LoopAgent`]: repeats a single agent until a condition is met
7
8use std::sync::Arc;
9
10use serde::{Deserialize, Serialize};
11use tokio::task::JoinSet;
12
13use crate::error::Error;
14use crate::llm::LlmProvider;
15use crate::llm::types::TokenUsage;
16
17use super::AgentOutput;
18use super::AgentRunner;
19use super::dag::DagAgent;
20use super::debate::DebateAgent;
21use super::mixture::MixtureOfAgentsAgent;
22use super::voting::VotingAgent;
23
24/// Termination condition for [`LoopAgent`]. Returns `true` to stop the loop.
25type StopCondition = Box<dyn Fn(&str) -> bool + Send + Sync>;
26
27// ---------------------------------------------------------------------------
28// SequentialAgent
29// ---------------------------------------------------------------------------
30
31/// Runs a list of sub-agents in order. Each agent receives the previous
32/// agent's text output as its task input. Returns the final agent's output
33/// with accumulated `TokenUsage`.
34pub struct SequentialAgent<P: LlmProvider> {
35    agents: Vec<AgentRunner<P>>,
36}
37
38impl<P: LlmProvider> std::fmt::Debug for SequentialAgent<P> {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("SequentialAgent")
41            .field("agent_count", &self.agents.len())
42            .finish()
43    }
44}
45
46/// Builder for [`SequentialAgent`].
47pub struct SequentialAgentBuilder<P: LlmProvider> {
48    agents: Vec<AgentRunner<P>>,
49}
50
51impl<P: LlmProvider> SequentialAgent<P> {
52    /// Create a new [`SequentialAgentBuilder`].
53    ///
54    /// Add agents with `.agent(...)` in execution order; each agent receives
55    /// the previous agent's text output as its task input.
56    ///
57    /// # Example
58    ///
59    /// ```rust,no_run
60    /// use std::sync::Arc;
61    /// use heartbit_core::{
62    ///     AgentRunner, AnthropicProvider, BoxedProvider, SequentialAgent,
63    /// };
64    ///
65    /// # async fn run() -> Result<(), heartbit_core::Error> {
66    /// let provider = Arc::new(BoxedProvider::new(AnthropicProvider::new(
67    ///     "sk-...",
68    ///     "claude-sonnet-4-20250514",
69    /// )));
70    /// let researcher = AgentRunner::builder(provider.clone())
71    ///     .system_prompt("Summarize the topic in 3 bullet points.")
72    ///     .build()?;
73    /// let writer = AgentRunner::builder(provider)
74    ///     .system_prompt("Rewrite as a single engaging paragraph.")
75    ///     .build()?;
76    ///
77    /// let pipeline = SequentialAgent::builder()
78    ///     .agent(researcher)
79    ///     .agent(writer)
80    ///     .build()?;
81    /// let output = pipeline.execute("History of Rust").await?;
82    /// println!("{}", output.result);
83    /// # Ok(()) }
84    /// ```
85    pub fn builder() -> SequentialAgentBuilder<P> {
86        SequentialAgentBuilder { agents: Vec::new() }
87    }
88
89    /// Execute the sequential pipeline, feeding each agent's output as the
90    /// next agent's input.
91    pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
92        let mut current_input = task.to_string();
93        let mut total_usage = TokenUsage::default();
94        let mut total_tool_calls = 0usize;
95        let mut total_cost: Option<f64> = None;
96        let mut last_output: Option<AgentOutput> = None;
97
98        for agent in &self.agents {
99            let result = agent
100                .execute(&current_input)
101                .await
102                .map_err(|e| e.accumulate_usage(total_usage))?;
103            result.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
104            current_input = result.result.clone();
105            last_output = Some(result);
106        }
107
108        // Safety: builder guarantees at least one agent
109        let mut output = last_output.expect("at least one agent");
110        output.tokens_used = total_usage;
111        output.tool_calls_made = total_tool_calls;
112        output.estimated_cost_usd = total_cost;
113        Ok(output)
114    }
115}
116
117impl<P: LlmProvider> SequentialAgentBuilder<P> {
118    /// Add an agent to the sequential pipeline.
119    pub fn agent(mut self, agent: AgentRunner<P>) -> Self {
120        self.agents.push(agent);
121        self
122    }
123
124    /// Add multiple agents to the sequential pipeline.
125    pub fn agents(mut self, agents: Vec<AgentRunner<P>>) -> Self {
126        self.agents.extend(agents);
127        self
128    }
129
130    /// Build the [`SequentialAgent`]. Requires at least one agent.
131    pub fn build(self) -> Result<SequentialAgent<P>, Error> {
132        if self.agents.is_empty() {
133            return Err(Error::Config(
134                "SequentialAgent requires at least one agent".into(),
135            ));
136        }
137        Ok(SequentialAgent {
138            agents: self.agents,
139        })
140    }
141}
142
143// ---------------------------------------------------------------------------
144// ParallelAgent
145// ---------------------------------------------------------------------------
146
147/// Runs multiple sub-agents concurrently via `tokio::JoinSet`. All agents
148/// receive the same input task. Returns merged results with accumulated
149/// `TokenUsage`.
150pub struct ParallelAgent<P: LlmProvider + 'static> {
151    agents: Vec<Arc<AgentRunner<P>>>,
152}
153
154impl<P: LlmProvider + 'static> std::fmt::Debug for ParallelAgent<P> {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        f.debug_struct("ParallelAgent")
157            .field("agent_count", &self.agents.len())
158            .finish()
159    }
160}
161
162/// Builder for [`ParallelAgent`].
163pub struct ParallelAgentBuilder<P: LlmProvider + 'static> {
164    agents: Vec<Arc<AgentRunner<P>>>,
165}
166
167impl<P: LlmProvider + 'static> ParallelAgent<P> {
168    /// Create a new [`ParallelAgentBuilder`].
169    pub fn builder() -> ParallelAgentBuilder<P> {
170        ParallelAgentBuilder { agents: Vec::new() }
171    }
172
173    /// Execute all agents concurrently. Fails fast on first error.
174    pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
175        let mut set = JoinSet::new();
176
177        for agent in &self.agents {
178            let agent = Arc::clone(agent);
179            let task = task.to_string();
180            set.spawn(async move {
181                let name = agent.name().to_string();
182                let result = agent.execute(&task).await;
183                (name, result)
184            });
185        }
186
187        let mut results: Vec<(String, AgentOutput)> = Vec::with_capacity(self.agents.len());
188        let mut total_usage = TokenUsage::default();
189        let mut total_tool_calls = 0usize;
190        let mut total_cost: Option<f64> = None;
191
192        while let Some(join_result) = set.join_next().await {
193            let (name, agent_result) = join_result
194                .map_err(|e| Error::Agent(format!("parallel agent task panicked: {e}")))?;
195            let output = agent_result.map_err(|e| e.accumulate_usage(total_usage))?;
196            output.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
197            results.push((name, output));
198        }
199
200        // Sort by agent name for deterministic output ordering
201        results.sort_by(|a, b| a.0.cmp(&b.0));
202
203        let merged_text = results
204            .iter()
205            .map(|(name, output)| format!("## {name}\n{}", output.result))
206            .collect::<Vec<_>>()
207            .join("\n\n");
208
209        Ok(AgentOutput {
210            result: merged_text,
211            tool_calls_made: total_tool_calls,
212            tokens_used: total_usage,
213            structured: None,
214            estimated_cost_usd: total_cost,
215            model_name: None,
216        })
217    }
218}
219
220impl<P: LlmProvider + 'static> ParallelAgentBuilder<P> {
221    /// Add an agent. Wraps it in `Arc` for concurrent sharing.
222    pub fn agent(mut self, agent: AgentRunner<P>) -> Self {
223        self.agents.push(Arc::new(agent));
224        self
225    }
226
227    /// Add multiple agents.
228    pub fn agents(mut self, agents: Vec<AgentRunner<P>>) -> Self {
229        self.agents.extend(agents.into_iter().map(Arc::new));
230        self
231    }
232
233    /// Build the [`ParallelAgent`]. Requires at least one agent.
234    pub fn build(self) -> Result<ParallelAgent<P>, Error> {
235        if self.agents.is_empty() {
236            return Err(Error::Config(
237                "ParallelAgent requires at least one agent".into(),
238            ));
239        }
240        Ok(ParallelAgent {
241            agents: self.agents,
242        })
243    }
244}
245
246// ---------------------------------------------------------------------------
247// LoopAgent
248// ---------------------------------------------------------------------------
249
250/// Runs a single agent in a loop. Stops when `should_stop` returns `true`
251/// on the output text, or when `max_iterations` is reached. Returns the
252/// final iteration's output with accumulated `TokenUsage`.
253pub struct LoopAgent<P: LlmProvider> {
254    agent: AgentRunner<P>,
255    max_iterations: usize,
256    should_stop: StopCondition,
257}
258
259impl<P: LlmProvider> std::fmt::Debug for LoopAgent<P> {
260    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261        f.debug_struct("LoopAgent")
262            .field("max_iterations", &self.max_iterations)
263            .finish()
264    }
265}
266
267/// Builder for [`LoopAgent`].
268pub struct LoopAgentBuilder<P: LlmProvider> {
269    agent: Option<AgentRunner<P>>,
270    max_iterations: Option<usize>,
271    should_stop: Option<StopCondition>,
272}
273
274impl<P: LlmProvider> LoopAgent<P> {
275    /// Create a new [`LoopAgentBuilder`].
276    pub fn builder() -> LoopAgentBuilder<P> {
277        LoopAgentBuilder {
278            agent: None,
279            max_iterations: None,
280            should_stop: None,
281        }
282    }
283
284    /// Execute the loop, feeding each iteration's output as the next input.
285    pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
286        let mut current_input = task.to_string();
287        let mut total_usage = TokenUsage::default();
288        let mut total_tool_calls = 0usize;
289        let mut total_cost: Option<f64> = None;
290        let mut last_output: Option<AgentOutput> = None;
291
292        for _ in 0..self.max_iterations {
293            let result = self
294                .agent
295                .execute(&current_input)
296                .await
297                .map_err(|e| e.accumulate_usage(total_usage))?;
298            result.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
299            current_input = result.result.clone();
300            let should_stop = (self.should_stop)(&result.result);
301            last_output = Some(result);
302            if should_stop {
303                break;
304            }
305        }
306
307        // Safety: max_iterations >= 1 guarantees at least one iteration
308        let mut output = last_output.expect("at least one iteration");
309        output.tokens_used = total_usage;
310        output.tool_calls_made = total_tool_calls;
311        output.estimated_cost_usd = total_cost;
312        Ok(output)
313    }
314}
315
316impl<P: LlmProvider> LoopAgentBuilder<P> {
317    /// Set the agent to loop.
318    pub fn agent(mut self, agent: AgentRunner<P>) -> Self {
319        self.agent = Some(agent);
320        self
321    }
322
323    /// Set the maximum number of iterations (must be >= 1).
324    pub fn max_iterations(mut self, n: usize) -> Self {
325        self.max_iterations = Some(n);
326        self
327    }
328
329    /// Set the termination condition. The closure receives the agent's output
330    /// text and returns `true` to stop the loop.
331    pub fn should_stop(mut self, f: impl Fn(&str) -> bool + Send + Sync + 'static) -> Self {
332        self.should_stop = Some(Box::new(f));
333        self
334    }
335
336    /// Build the [`LoopAgent`].
337    pub fn build(self) -> Result<LoopAgent<P>, Error> {
338        let agent = self
339            .agent
340            .ok_or_else(|| Error::Config("LoopAgent requires an agent".into()))?;
341        let max_iterations = self
342            .max_iterations
343            .ok_or_else(|| Error::Config("LoopAgent requires max_iterations".into()))?;
344        if max_iterations == 0 {
345            return Err(Error::Config(
346                "LoopAgent max_iterations must be at least 1".into(),
347            ));
348        }
349        let should_stop = self
350            .should_stop
351            .ok_or_else(|| Error::Config("LoopAgent requires a should_stop condition".into()))?;
352        Ok(LoopAgent {
353            agent,
354            max_iterations,
355            should_stop,
356        })
357    }
358}
359
360// ---------------------------------------------------------------------------
361// WorkflowType
362// ---------------------------------------------------------------------------
363
364/// Identifies which workflow pattern to use.
365#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
366#[serde(rename_all = "snake_case")]
367pub enum WorkflowType {
368    /// Run agents one after another, piping output to input.
369    Sequential,
370    /// Run agents concurrently, merging results.
371    Parallel,
372    /// Repeat an agent until a stop condition is met.
373    Loop,
374    /// Run agents according to a dependency graph.
375    Dag,
376    /// Run agents as debaters with an optional judge.
377    Debate,
378    /// Run agents as voters and aggregate the result.
379    Voting,
380    /// Run proposer agents and synthesize with a final agent.
381    Mixture,
382}
383
384// ---------------------------------------------------------------------------
385// WorkflowRouter
386// ---------------------------------------------------------------------------
387
388/// Routes execution to one of the workflow agent types.
389/// Allows config-driven workflow selection without hardcoding the type.
390pub enum WorkflowRouter<P: LlmProvider + 'static> {
391    /// A sequential pipeline workflow agent.
392    Sequential(Box<SequentialAgent<P>>),
393    /// A parallel (concurrent) workflow agent.
394    Parallel(Box<ParallelAgent<P>>),
395    /// A looping workflow agent.
396    Loop(Box<LoopAgent<P>>),
397    /// A DAG-based workflow agent.
398    Dag(Box<DagAgent<P>>),
399    /// A debate workflow agent.
400    Debate(Box<DebateAgent<P>>),
401    /// A voting workflow agent.
402    Voting(Box<VotingAgent<P>>),
403    /// A mixture-of-agents workflow agent.
404    Mixture(Box<MixtureOfAgentsAgent<P>>),
405}
406
407impl<P: LlmProvider + 'static> WorkflowRouter<P> {
408    /// Execute the contained workflow agent.
409    ///
410    /// Note: `Voting` returns only the winning voter's `AgentOutput`; the
411    /// `VoteResult` metadata (winner string, tally) is discarded. Use
412    /// `VotingAgent::execute()` directly when you need the full result.
413    pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
414        match self {
415            Self::Sequential(a) => a.execute(task).await,
416            Self::Parallel(a) => a.execute(task).await,
417            Self::Loop(a) => a.execute(task).await,
418            Self::Dag(a) => a.execute(task).await,
419            Self::Debate(a) => a.execute(task).await,
420            Self::Mixture(a) => a.execute(task).await,
421            Self::Voting(a) => a.execute(task).await.map(|vr| vr.output),
422        }
423    }
424
425    /// Returns which workflow type this router contains.
426    pub fn workflow_type(&self) -> WorkflowType {
427        match self {
428            Self::Sequential(_) => WorkflowType::Sequential,
429            Self::Parallel(_) => WorkflowType::Parallel,
430            Self::Loop(_) => WorkflowType::Loop,
431            Self::Dag(_) => WorkflowType::Dag,
432            Self::Debate(_) => WorkflowType::Debate,
433            Self::Voting(_) => WorkflowType::Voting,
434            Self::Mixture(_) => WorkflowType::Mixture,
435        }
436    }
437}
438
439impl<P: LlmProvider + 'static> std::fmt::Debug for WorkflowRouter<P> {
440    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441        f.debug_tuple("WorkflowRouter")
442            .field(&self.workflow_type())
443            .finish()
444    }
445}
446
447impl<P: LlmProvider + 'static> From<SequentialAgent<P>> for WorkflowRouter<P> {
448    fn from(agent: SequentialAgent<P>) -> Self {
449        Self::Sequential(Box::new(agent))
450    }
451}
452
453impl<P: LlmProvider + 'static> From<ParallelAgent<P>> for WorkflowRouter<P> {
454    fn from(agent: ParallelAgent<P>) -> Self {
455        Self::Parallel(Box::new(agent))
456    }
457}
458
459impl<P: LlmProvider + 'static> From<LoopAgent<P>> for WorkflowRouter<P> {
460    fn from(agent: LoopAgent<P>) -> Self {
461        Self::Loop(Box::new(agent))
462    }
463}
464
465impl<P: LlmProvider + 'static> From<DagAgent<P>> for WorkflowRouter<P> {
466    fn from(agent: DagAgent<P>) -> Self {
467        Self::Dag(Box::new(agent))
468    }
469}
470
471impl<P: LlmProvider + 'static> From<DebateAgent<P>> for WorkflowRouter<P> {
472    fn from(agent: DebateAgent<P>) -> Self {
473        Self::Debate(Box::new(agent))
474    }
475}
476
477impl<P: LlmProvider + 'static> From<VotingAgent<P>> for WorkflowRouter<P> {
478    fn from(agent: VotingAgent<P>) -> Self {
479        Self::Voting(Box::new(agent))
480    }
481}
482
483impl<P: LlmProvider + 'static> From<MixtureOfAgentsAgent<P>> for WorkflowRouter<P> {
484    fn from(agent: MixtureOfAgentsAgent<P>) -> Self {
485        Self::Mixture(Box::new(agent))
486    }
487}
488
489// ===========================================================================
490// Tests
491// ===========================================================================
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use crate::agent::test_helpers::{MockProvider, make_agent};
497
498    // -----------------------------------------------------------------------
499    // SequentialAgent builder tests
500    // -----------------------------------------------------------------------
501
502    #[test]
503    fn sequential_builder_rejects_empty_agents() {
504        let result = SequentialAgent::<MockProvider>::builder().build();
505        assert!(result.is_err());
506        assert!(
507            result
508                .unwrap_err()
509                .to_string()
510                .contains("at least one agent")
511        );
512    }
513
514    #[test]
515    fn sequential_builder_accepts_one_agent() {
516        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
517            "done", 10, 5,
518        )]));
519        let agent = make_agent(provider, "a");
520        let seq = SequentialAgent::builder().agent(agent).build();
521        assert!(seq.is_ok());
522    }
523
524    // -----------------------------------------------------------------------
525    // SequentialAgent execution tests
526    // -----------------------------------------------------------------------
527
528    #[tokio::test]
529    async fn sequential_single_agent() {
530        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
531            "hello world",
532            100,
533            50,
534        )]));
535        let agent = make_agent(provider, "step1");
536        let seq = SequentialAgent::builder().agent(agent).build().unwrap();
537
538        let output = seq.execute("start").await.unwrap();
539        assert_eq!(output.result, "hello world");
540        assert_eq!(output.tokens_used.input_tokens, 100);
541        assert_eq!(output.tokens_used.output_tokens, 50);
542    }
543
544    #[tokio::test]
545    async fn sequential_chains_output_as_input() {
546        // Agent A responds with "step-a-output", Agent B responds with "step-b-output".
547        // We verify the second agent runs (and its output is final).
548        let provider_a = Arc::new(MockProvider::new(vec![MockProvider::text_response(
549            "step-a-output",
550            100,
551            50,
552        )]));
553        let provider_b = Arc::new(MockProvider::new(vec![MockProvider::text_response(
554            "step-b-output",
555            200,
556            80,
557        )]));
558
559        let agent_a = make_agent(provider_a, "agent-a");
560        let agent_b = make_agent(provider_b, "agent-b");
561
562        let seq = SequentialAgent::builder()
563            .agent(agent_a)
564            .agent(agent_b)
565            .build()
566            .unwrap();
567
568        let output = seq.execute("initial task").await.unwrap();
569        assert_eq!(output.result, "step-b-output");
570        // Usage should be accumulated
571        assert_eq!(output.tokens_used.input_tokens, 300);
572        assert_eq!(output.tokens_used.output_tokens, 130);
573    }
574
575    #[tokio::test]
576    async fn sequential_three_agents_accumulates_usage() {
577        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
578            "out1", 10, 5,
579        )]));
580        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
581            "out2", 20, 10,
582        )]));
583        let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
584            "out3", 30, 15,
585        )]));
586
587        let seq = SequentialAgent::builder()
588            .agent(make_agent(p1, "a"))
589            .agent(make_agent(p2, "b"))
590            .agent(make_agent(p3, "c"))
591            .build()
592            .unwrap();
593
594        let output = seq.execute("go").await.unwrap();
595        assert_eq!(output.result, "out3");
596        assert_eq!(output.tokens_used.input_tokens, 60);
597        assert_eq!(output.tokens_used.output_tokens, 30);
598    }
599
600    #[tokio::test]
601    async fn sequential_error_carries_partial_usage() {
602        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
603            "ok", 100, 50,
604        )]));
605        // Second provider has no responses -> will error
606        let p2 = Arc::new(MockProvider::new(vec![]));
607
608        let seq = SequentialAgent::builder()
609            .agent(make_agent(p1, "good"))
610            .agent(make_agent(p2, "bad"))
611            .build()
612            .unwrap();
613
614        let err = seq.execute("task").await.unwrap_err();
615        let partial = err.partial_usage();
616        // Should include the first agent's usage
617        assert!(partial.input_tokens >= 100);
618    }
619
620    // -----------------------------------------------------------------------
621    // ParallelAgent builder tests
622    // -----------------------------------------------------------------------
623
624    #[test]
625    fn parallel_builder_rejects_empty_agents() {
626        let result = ParallelAgent::<MockProvider>::builder().build();
627        assert!(result.is_err());
628        assert!(
629            result
630                .unwrap_err()
631                .to_string()
632                .contains("at least one agent")
633        );
634    }
635
636    #[test]
637    fn parallel_builder_accepts_one_agent() {
638        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
639            "ok", 10, 5,
640        )]));
641        let agent = make_agent(provider, "a");
642        let par = ParallelAgent::builder().agent(agent).build();
643        assert!(par.is_ok());
644    }
645
646    // -----------------------------------------------------------------------
647    // ParallelAgent execution tests
648    // -----------------------------------------------------------------------
649
650    #[tokio::test]
651    async fn parallel_single_agent() {
652        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
653            "result-a", 100, 50,
654        )]));
655        let agent = make_agent(provider, "agent-a");
656        let par = ParallelAgent::builder().agent(agent).build().unwrap();
657
658        let output = par.execute("task").await.unwrap();
659        assert!(output.result.contains("agent-a"));
660        assert!(output.result.contains("result-a"));
661        assert_eq!(output.tokens_used.input_tokens, 100);
662        assert_eq!(output.tokens_used.output_tokens, 50);
663    }
664
665    #[tokio::test]
666    async fn parallel_multiple_agents_accumulates_usage() {
667        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
668            "out-a", 100, 50,
669        )]));
670        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
671            "out-b", 200, 80,
672        )]));
673
674        let par = ParallelAgent::builder()
675            .agent(make_agent(p1, "alpha"))
676            .agent(make_agent(p2, "beta"))
677            .build()
678            .unwrap();
679
680        let output = par.execute("same task").await.unwrap();
681        // Both agent outputs should appear
682        assert!(output.result.contains("out-a"));
683        assert!(output.result.contains("out-b"));
684        // Both headers should appear
685        assert!(output.result.contains("## alpha"));
686        assert!(output.result.contains("## beta"));
687        // Usage accumulated
688        assert_eq!(output.tokens_used.input_tokens, 300);
689        assert_eq!(output.tokens_used.output_tokens, 130);
690    }
691
692    #[tokio::test]
693    async fn parallel_output_sorted_by_name() {
694        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
695            "out-z", 10, 5,
696        )]));
697        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
698            "out-a", 10, 5,
699        )]));
700
701        let par = ParallelAgent::builder()
702            .agent(make_agent(p1, "zebra"))
703            .agent(make_agent(p2, "alpha"))
704            .build()
705            .unwrap();
706
707        let output = par.execute("task").await.unwrap();
708        // "alpha" should come before "zebra" in the output
709        let alpha_pos = output.result.find("## alpha").unwrap();
710        let zebra_pos = output.result.find("## zebra").unwrap();
711        assert!(alpha_pos < zebra_pos);
712    }
713
714    #[tokio::test]
715    async fn parallel_error_fails_fast() {
716        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
717            "ok", 100, 50,
718        )]));
719        // Second provider will error
720        let p2 = Arc::new(MockProvider::new(vec![]));
721
722        let par = ParallelAgent::builder()
723            .agent(make_agent(p1, "good"))
724            .agent(make_agent(p2, "bad"))
725            .build()
726            .unwrap();
727
728        let result = par.execute("task").await;
729        assert!(result.is_err());
730    }
731
732    // -----------------------------------------------------------------------
733    // LoopAgent builder tests
734    // -----------------------------------------------------------------------
735
736    #[test]
737    fn loop_builder_rejects_missing_agent() {
738        let result = LoopAgent::<MockProvider>::builder()
739            .max_iterations(3)
740            .should_stop(|_| true)
741            .build();
742        assert!(result.is_err());
743        assert!(
744            result
745                .unwrap_err()
746                .to_string()
747                .contains("requires an agent")
748        );
749    }
750
751    #[test]
752    fn loop_builder_rejects_missing_max_iterations() {
753        let provider = Arc::new(MockProvider::new(vec![]));
754        let agent = make_agent(provider, "a");
755        let result = LoopAgent::builder()
756            .agent(agent)
757            .should_stop(|_| true)
758            .build();
759        assert!(result.is_err());
760        assert!(
761            result
762                .unwrap_err()
763                .to_string()
764                .contains("requires max_iterations")
765        );
766    }
767
768    #[test]
769    fn loop_builder_rejects_zero_max_iterations() {
770        let provider = Arc::new(MockProvider::new(vec![]));
771        let agent = make_agent(provider, "a");
772        let result = LoopAgent::builder()
773            .agent(agent)
774            .max_iterations(0)
775            .should_stop(|_| true)
776            .build();
777        assert!(result.is_err());
778        assert!(result.unwrap_err().to_string().contains("at least 1"));
779    }
780
781    #[test]
782    fn loop_builder_rejects_missing_should_stop() {
783        let provider = Arc::new(MockProvider::new(vec![]));
784        let agent = make_agent(provider, "a");
785        let result = LoopAgent::builder().agent(agent).max_iterations(3).build();
786        assert!(result.is_err());
787        assert!(
788            result
789                .unwrap_err()
790                .to_string()
791                .contains("requires a should_stop")
792        );
793    }
794
795    #[test]
796    fn loop_builder_accepts_valid_config() {
797        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
798            "x", 1, 1,
799        )]));
800        let agent = make_agent(provider, "a");
801        let result = LoopAgent::builder()
802            .agent(agent)
803            .max_iterations(5)
804            .should_stop(|_| true)
805            .build();
806        assert!(result.is_ok());
807    }
808
809    // -----------------------------------------------------------------------
810    // LoopAgent execution tests
811    // -----------------------------------------------------------------------
812
813    #[tokio::test]
814    async fn loop_stops_on_condition() {
815        // Provide 3 responses: the second one contains "DONE"
816        let provider = Arc::new(MockProvider::new(vec![
817            MockProvider::text_response("working...", 10, 5),
818            MockProvider::text_response("DONE", 10, 5),
819            MockProvider::text_response("should not reach", 10, 5),
820        ]));
821        let agent = make_agent(provider, "worker");
822
823        let loop_agent = LoopAgent::builder()
824            .agent(agent)
825            .max_iterations(10)
826            .should_stop(|text| text.contains("DONE"))
827            .build()
828            .unwrap();
829
830        let output = loop_agent.execute("start").await.unwrap();
831        assert_eq!(output.result, "DONE");
832        // Only 2 iterations ran
833        assert_eq!(output.tokens_used.input_tokens, 20);
834        assert_eq!(output.tokens_used.output_tokens, 10);
835    }
836
837    #[tokio::test]
838    async fn loop_stops_at_max_iterations() {
839        let provider = Arc::new(MockProvider::new(vec![
840            MockProvider::text_response("iter1", 10, 5),
841            MockProvider::text_response("iter2", 10, 5),
842            MockProvider::text_response("iter3", 10, 5),
843        ]));
844        let agent = make_agent(provider, "worker");
845
846        let loop_agent = LoopAgent::builder()
847            .agent(agent)
848            .max_iterations(3)
849            .should_stop(|_| false) // never stop
850            .build()
851            .unwrap();
852
853        let output = loop_agent.execute("start").await.unwrap();
854        assert_eq!(output.result, "iter3");
855        assert_eq!(output.tokens_used.input_tokens, 30);
856        assert_eq!(output.tokens_used.output_tokens, 15);
857    }
858
859    #[tokio::test]
860    async fn loop_single_iteration() {
861        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
862            "once", 50, 25,
863        )]));
864        let agent = make_agent(provider, "worker");
865
866        let loop_agent = LoopAgent::builder()
867            .agent(agent)
868            .max_iterations(1)
869            .should_stop(|_| false)
870            .build()
871            .unwrap();
872
873        let output = loop_agent.execute("go").await.unwrap();
874        assert_eq!(output.result, "once");
875        assert_eq!(output.tokens_used.input_tokens, 50);
876    }
877
878    #[tokio::test]
879    async fn loop_error_carries_partial_usage() {
880        // First response succeeds, second errors
881        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
882            "ok", 100, 50,
883        )]));
884        let agent = make_agent(provider, "worker");
885
886        let loop_agent = LoopAgent::builder()
887            .agent(agent)
888            .max_iterations(5)
889            .should_stop(|_| false) // never stop, will error on 2nd iteration
890            .build()
891            .unwrap();
892
893        let err = loop_agent.execute("go").await.unwrap_err();
894        let partial = err.partial_usage();
895        assert!(partial.input_tokens >= 100);
896    }
897
898    // -----------------------------------------------------------------------
899    // SequentialAgent builder .agents() method
900    // -----------------------------------------------------------------------
901
902    #[test]
903    fn sequential_builder_agents_method() {
904        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
905            "a", 1, 1,
906        )]));
907        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
908            "b", 1, 1,
909        )]));
910        let agents = vec![make_agent(p1, "x"), make_agent(p2, "y")];
911        let seq = SequentialAgent::builder().agents(agents).build();
912        assert!(seq.is_ok());
913    }
914
915    // -----------------------------------------------------------------------
916    // ParallelAgent builder .agents() method
917    // -----------------------------------------------------------------------
918
919    #[test]
920    fn parallel_builder_agents_method() {
921        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
922            "a", 1, 1,
923        )]));
924        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
925            "b", 1, 1,
926        )]));
927        let agents = vec![make_agent(p1, "x"), make_agent(p2, "y")];
928        let par = ParallelAgent::builder().agents(agents).build();
929        assert!(par.is_ok());
930    }
931
932    // -----------------------------------------------------------------------
933    // AgentRunner::name() getter test
934    // -----------------------------------------------------------------------
935
936    #[test]
937    fn agent_runner_name_getter() {
938        let provider = Arc::new(MockProvider::new(vec![]));
939        let agent = make_agent(provider, "test-agent");
940        assert_eq!(agent.name(), "test-agent");
941    }
942
943    // -----------------------------------------------------------------------
944    // WorkflowType tests
945    // -----------------------------------------------------------------------
946
947    #[test]
948    fn workflow_type_serde_roundtrip() {
949        for wt in [
950            WorkflowType::Sequential,
951            WorkflowType::Parallel,
952            WorkflowType::Loop,
953            WorkflowType::Dag,
954            WorkflowType::Debate,
955            WorkflowType::Voting,
956            WorkflowType::Mixture,
957        ] {
958            let json = serde_json::to_string(&wt).unwrap();
959            let back: WorkflowType = serde_json::from_str(&json).unwrap();
960            assert_eq!(wt, back);
961        }
962    }
963
964    #[test]
965    fn workflow_type_snake_case() {
966        assert_eq!(
967            serde_json::to_string(&WorkflowType::Sequential).unwrap(),
968            "\"sequential\""
969        );
970        assert_eq!(
971            serde_json::to_string(&WorkflowType::Parallel).unwrap(),
972            "\"parallel\""
973        );
974        assert_eq!(
975            serde_json::to_string(&WorkflowType::Loop).unwrap(),
976            "\"loop\""
977        );
978        assert_eq!(
979            serde_json::to_string(&WorkflowType::Dag).unwrap(),
980            "\"dag\""
981        );
982        assert_eq!(
983            serde_json::to_string(&WorkflowType::Debate).unwrap(),
984            "\"debate\""
985        );
986        assert_eq!(
987            serde_json::to_string(&WorkflowType::Voting).unwrap(),
988            "\"voting\""
989        );
990        assert_eq!(
991            serde_json::to_string(&WorkflowType::Mixture).unwrap(),
992            "\"mixture\""
993        );
994    }
995
996    // -----------------------------------------------------------------------
997    // WorkflowRouter tests
998    // -----------------------------------------------------------------------
999
1000    #[tokio::test]
1001    async fn router_sequential() {
1002        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1003            "seq-out", 10, 5,
1004        )]));
1005        let seq = SequentialAgent::builder()
1006            .agent(make_agent(provider, "s"))
1007            .build()
1008            .unwrap();
1009        let router = WorkflowRouter::Sequential(Box::new(seq));
1010        assert_eq!(router.workflow_type(), WorkflowType::Sequential);
1011        let output = router.execute("task").await.unwrap();
1012        assert_eq!(output.result, "seq-out");
1013    }
1014
1015    #[tokio::test]
1016    async fn router_parallel() {
1017        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1018            "par-out", 10, 5,
1019        )]));
1020        let par = ParallelAgent::builder()
1021            .agent(make_agent(provider, "p"))
1022            .build()
1023            .unwrap();
1024        let router = WorkflowRouter::Parallel(Box::new(par));
1025        assert_eq!(router.workflow_type(), WorkflowType::Parallel);
1026        let output = router.execute("task").await.unwrap();
1027        assert!(output.result.contains("par-out"));
1028    }
1029
1030    #[tokio::test]
1031    async fn router_loop() {
1032        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1033            "loop-out", 10, 5,
1034        )]));
1035        let lp = LoopAgent::builder()
1036            .agent(make_agent(provider, "l"))
1037            .max_iterations(1)
1038            .should_stop(|_| true)
1039            .build()
1040            .unwrap();
1041        let router = WorkflowRouter::Loop(Box::new(lp));
1042        assert_eq!(router.workflow_type(), WorkflowType::Loop);
1043        let output = router.execute("task").await.unwrap();
1044        assert_eq!(output.result, "loop-out");
1045    }
1046
1047    #[tokio::test]
1048    async fn router_dag() {
1049        use crate::agent::dag::DagAgent;
1050
1051        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1052            "dag-out", 10, 5,
1053        )]));
1054        let dag = DagAgent::builder()
1055            .node("A", make_agent(provider, "A"))
1056            .build()
1057            .unwrap();
1058        let router = WorkflowRouter::Dag(Box::new(dag));
1059        assert_eq!(router.workflow_type(), WorkflowType::Dag);
1060        let output = router.execute("task").await.unwrap();
1061        assert_eq!(output.result, "dag-out");
1062    }
1063
1064    #[test]
1065    fn router_from_sequential() {
1066        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1067            "x", 1, 1,
1068        )]));
1069        let seq = SequentialAgent::builder()
1070            .agent(make_agent(provider, "s"))
1071            .build()
1072            .unwrap();
1073        let router: WorkflowRouter<MockProvider> = seq.into();
1074        assert_eq!(router.workflow_type(), WorkflowType::Sequential);
1075    }
1076
1077    #[test]
1078    fn router_from_dag() {
1079        use crate::agent::dag::DagAgent;
1080
1081        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1082            "x", 1, 1,
1083        )]));
1084        let dag = DagAgent::builder()
1085            .node("A", make_agent(provider, "A"))
1086            .build()
1087            .unwrap();
1088        let router: WorkflowRouter<MockProvider> = dag.into();
1089        assert_eq!(router.workflow_type(), WorkflowType::Dag);
1090    }
1091
1092    #[test]
1093    fn router_debug() {
1094        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
1095            "x", 1, 1,
1096        )]));
1097        let seq = SequentialAgent::builder()
1098            .agent(make_agent(provider, "s"))
1099            .build()
1100            .unwrap();
1101        let router = WorkflowRouter::Sequential(Box::new(seq));
1102        let debug = format!("{router:?}");
1103        assert!(debug.contains("WorkflowRouter"));
1104        assert!(debug.contains("Sequential"));
1105    }
1106}