Skip to main content

heartbit_core/agent/
mixture.rs

1//! Mixture-of-Agents (MoA) workflow agent.
2//!
3//! Orchestrates ensemble reasoning with refinement:
4//! 1. N proposer agents run in parallel on the same task
5//! 2. Their outputs are collected and formatted as a combined proposal document
6//! 3. A synthesizer agent refines the combined proposals into a final output
7//! 4. Optionally repeats for multiple layers (rounds of refinement)
8
9use std::sync::Arc;
10
11use tokio::task::JoinSet;
12
13use crate::error::Error;
14use crate::llm::LlmProvider;
15use crate::llm::types::TokenUsage;
16
17use super::{AgentOutput, AgentRunner};
18
19// ---------------------------------------------------------------------------
20// MixtureOfAgentsAgent
21// ---------------------------------------------------------------------------
22
23/// Runs N proposer agents in parallel, collects their outputs, and feeds
24/// a combined proposal document to a synthesizer agent for refinement.
25/// Supports multiple layers where each layer's synthesis feeds the next
26/// layer's proposers.
27pub struct MixtureOfAgentsAgent<P: LlmProvider + 'static> {
28    proposers: Vec<Arc<AgentRunner<P>>>,
29    synthesizer: Arc<AgentRunner<P>>,
30    layers: usize,
31}
32
33impl<P: LlmProvider + 'static> std::fmt::Debug for MixtureOfAgentsAgent<P> {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("MixtureOfAgentsAgent")
36            .field("proposer_count", &self.proposers.len())
37            .field("layers", &self.layers)
38            .finish()
39    }
40}
41
42/// Builder for [`MixtureOfAgentsAgent`].
43pub struct MixtureOfAgentsAgentBuilder<P: LlmProvider + 'static> {
44    proposers: Vec<Arc<AgentRunner<P>>>,
45    synthesizer: Option<Arc<AgentRunner<P>>>,
46    layers: Option<usize>,
47}
48
49impl<P: LlmProvider + 'static> MixtureOfAgentsAgent<P> {
50    /// Create a new [`MixtureOfAgentsAgentBuilder`].
51    pub fn builder() -> MixtureOfAgentsAgentBuilder<P> {
52        MixtureOfAgentsAgentBuilder {
53            proposers: Vec::new(),
54            synthesizer: None,
55            layers: None,
56        }
57    }
58
59    /// Execute the mixture-of-agents pipeline.
60    ///
61    /// For each layer, all proposers run in parallel on the current input,
62    /// their outputs are merged into a proposal document, and the synthesizer
63    /// produces a refined output. Multi-layer mode feeds each synthesis back
64    /// as input to the next layer's proposers.
65    pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
66        let mut current_input = task.to_string();
67        let mut total_usage = TokenUsage::default();
68        let mut total_tool_calls = 0usize;
69        let mut total_cost: Option<f64> = None;
70        let mut last_structured: Option<serde_json::Value> = None;
71        let mut last_model_name: Option<String> = None;
72
73        for _ in 0..self.layers {
74            // --- Proposer phase: run all proposers in parallel ---
75            let mut set = JoinSet::new();
76            for proposer in &self.proposers {
77                let proposer = Arc::clone(proposer);
78                let input = current_input.clone();
79                set.spawn(async move {
80                    let name = proposer.name().to_string();
81                    let result = proposer.execute(&input).await;
82                    (name, result)
83                });
84            }
85
86            let mut proposals: Vec<(String, AgentOutput)> =
87                Vec::with_capacity(self.proposers.len());
88
89            while let Some(join_result) = set.join_next().await {
90                let (name, agent_result) = join_result
91                    .map_err(|e| Error::Agent(format!("proposer task panicked: {e}")))?;
92                let output = agent_result.map_err(|e| e.accumulate_usage(total_usage))?;
93                output.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
94                proposals.push((name, output));
95            }
96
97            // Sort by proposer name for deterministic output
98            proposals.sort_by(|a, b| a.0.cmp(&b.0));
99
100            let proposal_text = proposals
101                .iter()
102                .map(|(name, output)| format!("## {name}\n{}", output.result))
103                .collect::<Vec<_>>()
104                .join("\n\n");
105
106            // --- Synthesizer phase ---
107            let synth_output = self
108                .synthesizer
109                .execute(&proposal_text)
110                .await
111                .map_err(|e| e.accumulate_usage(total_usage))?;
112
113            synth_output.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
114
115            last_structured = synth_output.structured;
116            last_model_name = synth_output.model_name;
117            current_input = synth_output.result;
118        }
119
120        Ok(AgentOutput {
121            result: current_input,
122            tool_calls_made: total_tool_calls,
123            tokens_used: total_usage,
124            structured: last_structured,
125            estimated_cost_usd: total_cost,
126            model_name: last_model_name,
127        })
128    }
129}
130
131impl<P: LlmProvider + 'static> MixtureOfAgentsAgentBuilder<P> {
132    /// Add a proposer agent. Wraps it in `Arc` for concurrent sharing.
133    pub fn proposer(mut self, agent: AgentRunner<P>) -> Self {
134        self.proposers.push(Arc::new(agent));
135        self
136    }
137
138    /// Add multiple proposer agents.
139    pub fn proposers(mut self, agents: Vec<AgentRunner<P>>) -> Self {
140        self.proposers.extend(agents.into_iter().map(Arc::new));
141        self
142    }
143
144    /// Set the synthesizer agent.
145    pub fn synthesizer(mut self, agent: AgentRunner<P>) -> Self {
146        self.synthesizer = Some(Arc::new(agent));
147        self
148    }
149
150    /// Set the number of layers (rounds of proposal + synthesis). Defaults to 1.
151    pub fn layers(mut self, n: usize) -> Self {
152        self.layers = Some(n);
153        self
154    }
155
156    /// Build the [`MixtureOfAgentsAgent`].
157    pub fn build(self) -> Result<MixtureOfAgentsAgent<P>, Error> {
158        if self.proposers.len() < 2 {
159            return Err(Error::Config(
160                "MixtureOfAgentsAgent requires at least 2 proposers".into(),
161            ));
162        }
163        let synthesizer = self
164            .synthesizer
165            .ok_or_else(|| Error::Config("MixtureOfAgentsAgent requires a synthesizer".into()))?;
166        let layers = self.layers.unwrap_or(1);
167        if layers == 0 {
168            return Err(Error::Config(
169                "MixtureOfAgentsAgent layers must be at least 1".into(),
170            ));
171        }
172        Ok(MixtureOfAgentsAgent {
173            proposers: self.proposers,
174            synthesizer,
175            layers,
176        })
177    }
178}
179
180// ===========================================================================
181// Tests
182// ===========================================================================
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use crate::agent::test_helpers::{MockProvider, make_agent};
188
189    // -----------------------------------------------------------------------
190    // Builder validation tests
191    // -----------------------------------------------------------------------
192
193    #[test]
194    fn builder_rejects_fewer_than_two_proposers() {
195        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
196            "x", 1, 1,
197        )]));
198        let synth = make_agent(Arc::clone(&provider), "synth");
199
200        // Zero proposers
201        let result = MixtureOfAgentsAgent::<MockProvider>::builder()
202            .synthesizer(synth)
203            .build();
204        assert!(result.is_err());
205        assert!(
206            result
207                .unwrap_err()
208                .to_string()
209                .contains("at least 2 proposers")
210        );
211
212        // One proposer
213        let synth2 = make_agent(Arc::clone(&provider), "synth2");
214        let p1 = make_agent(provider, "p1");
215        let result = MixtureOfAgentsAgent::builder()
216            .proposer(p1)
217            .synthesizer(synth2)
218            .build();
219        assert!(result.is_err());
220        assert!(
221            result
222                .unwrap_err()
223                .to_string()
224                .contains("at least 2 proposers")
225        );
226    }
227
228    #[test]
229    fn builder_rejects_missing_synthesizer() {
230        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
231            "x", 1, 1,
232        )]));
233        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
234            "x", 1, 1,
235        )]));
236
237        let result = MixtureOfAgentsAgent::builder()
238            .proposer(make_agent(p1, "a"))
239            .proposer(make_agent(p2, "b"))
240            .build();
241        assert!(result.is_err());
242        assert!(
243            result
244                .unwrap_err()
245                .to_string()
246                .contains("requires a synthesizer")
247        );
248    }
249
250    #[test]
251    fn builder_rejects_zero_layers() {
252        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
253            "x", 1, 1,
254        )]));
255        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
256            "x", 1, 1,
257        )]));
258        let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
259            "x", 1, 1,
260        )]));
261
262        let result = MixtureOfAgentsAgent::builder()
263            .proposer(make_agent(p1, "a"))
264            .proposer(make_agent(p2, "b"))
265            .synthesizer(make_agent(synth, "synth"))
266            .layers(0)
267            .build();
268        assert!(result.is_err());
269        assert!(
270            result
271                .unwrap_err()
272                .to_string()
273                .contains("layers must be at least 1")
274        );
275    }
276
277    #[test]
278    fn builder_accepts_valid_config_default_layers() {
279        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
280            "x", 1, 1,
281        )]));
282        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
283            "x", 1, 1,
284        )]));
285        let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
286            "x", 1, 1,
287        )]));
288
289        let result = MixtureOfAgentsAgent::builder()
290            .proposer(make_agent(p1, "a"))
291            .proposer(make_agent(p2, "b"))
292            .synthesizer(make_agent(synth, "synth"))
293            .build();
294        assert!(result.is_ok());
295    }
296
297    #[test]
298    fn builder_accepts_valid_config_explicit_layers() {
299        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
300            "x", 1, 1,
301        )]));
302        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
303            "x", 1, 1,
304        )]));
305        let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
306            "x", 1, 1,
307        )]));
308
309        let result = MixtureOfAgentsAgent::builder()
310            .proposer(make_agent(p1, "a"))
311            .proposer(make_agent(p2, "b"))
312            .synthesizer(make_agent(synth, "synth"))
313            .layers(3)
314            .build();
315        assert!(result.is_ok());
316    }
317
318    // -----------------------------------------------------------------------
319    // Debug impl test
320    // -----------------------------------------------------------------------
321
322    #[test]
323    fn debug_impl_shows_proposer_count_and_layers() {
324        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
325            "x", 1, 1,
326        )]));
327        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
328            "x", 1, 1,
329        )]));
330        let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
331            "x", 1, 1,
332        )]));
333        let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
334            "x", 1, 1,
335        )]));
336
337        let moa = MixtureOfAgentsAgent::builder()
338            .proposer(make_agent(p1, "a"))
339            .proposer(make_agent(p2, "b"))
340            .proposer(make_agent(p3, "c"))
341            .synthesizer(make_agent(synth, "synth"))
342            .layers(2)
343            .build()
344            .unwrap();
345
346        let debug = format!("{moa:?}");
347        assert!(debug.contains("MixtureOfAgentsAgent"));
348        assert!(debug.contains("proposer_count: 3"));
349        assert!(debug.contains("layers: 2"));
350    }
351
352    // -----------------------------------------------------------------------
353    // Execution tests
354    // -----------------------------------------------------------------------
355
356    #[test]
357    fn builder_proposers_bulk_method() {
358        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
359            "x", 1, 1,
360        )]));
361        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
362            "x", 1, 1,
363        )]));
364        let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
365            "x", 1, 1,
366        )]));
367        let agents = vec![make_agent(p1, "a"), make_agent(p2, "b")];
368        let result = MixtureOfAgentsAgent::builder()
369            .proposers(agents)
370            .synthesizer(make_agent(synth, "synth"))
371            .build();
372        assert!(result.is_ok());
373    }
374
375    #[tokio::test]
376    async fn single_layer_execution() {
377        // 2 proposers + 1 synthesizer
378        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
379            "proposal from alpha",
380            100,
381            50,
382        )]));
383        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
384            "proposal from beta",
385            120,
386            60,
387        )]));
388        let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
389            "synthesized result",
390            200,
391            100,
392        )]));
393
394        let moa = MixtureOfAgentsAgent::builder()
395            .proposer(make_agent(p1, "alpha"))
396            .proposer(make_agent(p2, "beta"))
397            .synthesizer(make_agent(synth, "synth"))
398            .build()
399            .unwrap();
400
401        let output = moa.execute("analyze this").await.unwrap();
402        assert_eq!(output.result, "synthesized result");
403    }
404
405    #[tokio::test]
406    async fn token_usage_accumulated() {
407        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
408            "p1-out", 100, 50,
409        )]));
410        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
411            "p2-out", 120, 60,
412        )]));
413        let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
414            "final", 200, 100,
415        )]));
416
417        let moa = MixtureOfAgentsAgent::builder()
418            .proposer(make_agent(p1, "a"))
419            .proposer(make_agent(p2, "b"))
420            .synthesizer(make_agent(synth, "synth"))
421            .build()
422            .unwrap();
423
424        let output = moa.execute("task").await.unwrap();
425        // 100 + 120 + 200 = 420 input tokens
426        assert_eq!(output.tokens_used.input_tokens, 420);
427        // 50 + 60 + 100 = 210 output tokens
428        assert_eq!(output.tokens_used.output_tokens, 210);
429    }
430
431    #[tokio::test]
432    async fn multi_layer_execution() {
433        // Layer 1: 2 proposers + synthesizer
434        // Layer 2: same 2 proposers (re-run on synthesis) + synthesizer
435        // Each proposer needs responses for both layers (2 responses each)
436        let p1 = Arc::new(MockProvider::new(vec![
437            MockProvider::text_response("p1-layer1", 10, 5),
438            MockProvider::text_response("p1-layer2", 10, 5),
439        ]));
440        let p2 = Arc::new(MockProvider::new(vec![
441            MockProvider::text_response("p2-layer1", 10, 5),
442            MockProvider::text_response("p2-layer2", 10, 5),
443        ]));
444        let synth = Arc::new(MockProvider::new(vec![
445            MockProvider::text_response("synth-layer1", 20, 10),
446            MockProvider::text_response("synth-layer2-final", 20, 10),
447        ]));
448
449        let moa = MixtureOfAgentsAgent::builder()
450            .proposer(make_agent(p1, "a"))
451            .proposer(make_agent(p2, "b"))
452            .synthesizer(make_agent(synth, "synth"))
453            .layers(2)
454            .build()
455            .unwrap();
456
457        let output = moa.execute("task").await.unwrap();
458        assert_eq!(output.result, "synth-layer2-final");
459        // 2 layers * (10 + 10 proposer + 20 synth) = 2 * 40 = 80 input
460        assert_eq!(output.tokens_used.input_tokens, 80);
461        // 2 layers * (5 + 5 proposer + 10 synth) = 2 * 20 = 40 output
462        assert_eq!(output.tokens_used.output_tokens, 40);
463    }
464
465    #[tokio::test]
466    async fn proposer_error_carries_partial_usage() {
467        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
468            "ok", 100, 50,
469        )]));
470        // p2 has no responses -> will error
471        let p2 = Arc::new(MockProvider::new(vec![]));
472        let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
473            "final", 10, 5,
474        )]));
475
476        let moa = MixtureOfAgentsAgent::builder()
477            .proposer(make_agent(p1, "good"))
478            .proposer(make_agent(p2, "bad"))
479            .synthesizer(make_agent(synth, "synth"))
480            .build()
481            .unwrap();
482
483        let err = moa.execute("task").await.unwrap_err();
484        let partial = err.partial_usage();
485        // JoinSet ordering is non-deterministic: the successful proposer may
486        // or may not finish before the error is collected.
487        assert!(
488            partial.input_tokens == 0 || partial.input_tokens >= 100,
489            "partial usage should be zero or include completed proposer"
490        );
491    }
492
493    #[tokio::test]
494    async fn synthesizer_error_carries_partial_usage_from_proposers() {
495        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
496            "ok1", 100, 50,
497        )]));
498        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
499            "ok2", 120, 60,
500        )]));
501        // Synthesizer has no responses -> will error
502        let synth = Arc::new(MockProvider::new(vec![]));
503
504        let moa = MixtureOfAgentsAgent::builder()
505            .proposer(make_agent(p1, "a"))
506            .proposer(make_agent(p2, "b"))
507            .synthesizer(make_agent(synth, "synth"))
508            .build()
509            .unwrap();
510
511        let err = moa.execute("task").await.unwrap_err();
512        let partial = err.partial_usage();
513        // Both proposers succeeded: 100 + 120 = 220
514        assert!(partial.input_tokens >= 220);
515    }
516
517    #[tokio::test]
518    async fn synthesizer_receives_sorted_proposal_document() {
519        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
520            "output-c", 10, 5,
521        )]));
522        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
523            "output-a", 10, 5,
524        )]));
525        let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
526            "output-b", 10, 5,
527        )]));
528        let synth_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
529            "final-synthesis",
530            10,
531            5,
532        )]));
533
534        let moa = MixtureOfAgentsAgent::builder()
535            .proposer(make_agent(Arc::clone(&p1), "charlie"))
536            .proposer(make_agent(Arc::clone(&p2), "alpha"))
537            .proposer(make_agent(Arc::clone(&p3), "beta"))
538            .synthesizer(make_agent(Arc::clone(&synth_p), "synth"))
539            .build()
540            .unwrap();
541
542        let output = moa.execute("task").await.unwrap();
543        assert_eq!(output.result, "final-synthesis");
544
545        // Inspect what the synthesizer received via captured_requests
546        let synth_requests = synth_p.captured_requests.lock().unwrap();
547        assert_eq!(synth_requests.len(), 1);
548        let synth_input = &synth_requests[0].messages[0];
549        let input_text = match &synth_input.content[0] {
550            crate::llm::types::ContentBlock::Text { text } => text.as_str(),
551            _ => panic!("expected text content"),
552        };
553        // Proposers should be sorted alphabetically: alpha, beta, charlie
554        let alpha_pos = input_text
555            .find("## alpha")
556            .expect("should contain ## alpha");
557        let beta_pos = input_text.find("## beta").expect("should contain ## beta");
558        let charlie_pos = input_text
559            .find("## charlie")
560            .expect("should contain ## charlie");
561        assert!(alpha_pos < beta_pos, "alpha should come before beta");
562        assert!(beta_pos < charlie_pos, "beta should come before charlie");
563        // Each proposer's output should appear
564        assert!(input_text.contains("output-a"));
565        assert!(input_text.contains("output-b"));
566        assert!(input_text.contains("output-c"));
567    }
568}