1use std::sync::Arc;
10
11use tokio::task::JoinSet;
12
13use crate::error::Error;
14use crate::llm::LlmProvider;
15use crate::llm::types::TokenUsage;
16
17use super::{AgentOutput, AgentRunner};
18
19pub struct MixtureOfAgentsAgent<P: LlmProvider + 'static> {
28 proposers: Vec<Arc<AgentRunner<P>>>,
29 synthesizer: Arc<AgentRunner<P>>,
30 layers: usize,
31}
32
33impl<P: LlmProvider + 'static> std::fmt::Debug for MixtureOfAgentsAgent<P> {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("MixtureOfAgentsAgent")
36 .field("proposer_count", &self.proposers.len())
37 .field("layers", &self.layers)
38 .finish()
39 }
40}
41
42pub struct MixtureOfAgentsAgentBuilder<P: LlmProvider + 'static> {
44 proposers: Vec<Arc<AgentRunner<P>>>,
45 synthesizer: Option<Arc<AgentRunner<P>>>,
46 layers: Option<usize>,
47}
48
49impl<P: LlmProvider + 'static> MixtureOfAgentsAgent<P> {
50 pub fn builder() -> MixtureOfAgentsAgentBuilder<P> {
52 MixtureOfAgentsAgentBuilder {
53 proposers: Vec::new(),
54 synthesizer: None,
55 layers: None,
56 }
57 }
58
59 pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
66 let mut current_input = task.to_string();
67 let mut total_usage = TokenUsage::default();
68 let mut total_tool_calls = 0usize;
69 let mut total_cost: Option<f64> = None;
70 let mut last_structured: Option<serde_json::Value> = None;
71 let mut last_model_name: Option<String> = None;
72
73 for _ in 0..self.layers {
74 let mut set = JoinSet::new();
76 for proposer in &self.proposers {
77 let proposer = Arc::clone(proposer);
78 let input = current_input.clone();
79 set.spawn(async move {
80 let name = proposer.name().to_string();
81 let result = proposer.execute(&input).await;
82 (name, result)
83 });
84 }
85
86 let mut proposals: Vec<(String, AgentOutput)> =
87 Vec::with_capacity(self.proposers.len());
88
89 while let Some(join_result) = set.join_next().await {
90 let (name, agent_result) = join_result
91 .map_err(|e| Error::Agent(format!("proposer task panicked: {e}")))?;
92 let output = agent_result.map_err(|e| e.accumulate_usage(total_usage))?;
93 output.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
94 proposals.push((name, output));
95 }
96
97 proposals.sort_by(|a, b| a.0.cmp(&b.0));
99
100 let proposal_text = proposals
101 .iter()
102 .map(|(name, output)| format!("## {name}\n{}", output.result))
103 .collect::<Vec<_>>()
104 .join("\n\n");
105
106 let synth_output = self
108 .synthesizer
109 .execute(&proposal_text)
110 .await
111 .map_err(|e| e.accumulate_usage(total_usage))?;
112
113 synth_output.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
114
115 last_structured = synth_output.structured;
116 last_model_name = synth_output.model_name;
117 current_input = synth_output.result;
118 }
119
120 Ok(AgentOutput {
121 result: current_input,
122 tool_calls_made: total_tool_calls,
123 tokens_used: total_usage,
124 structured: last_structured,
125 estimated_cost_usd: total_cost,
126 model_name: last_model_name,
127 })
128 }
129}
130
131impl<P: LlmProvider + 'static> MixtureOfAgentsAgentBuilder<P> {
132 pub fn proposer(mut self, agent: AgentRunner<P>) -> Self {
134 self.proposers.push(Arc::new(agent));
135 self
136 }
137
138 pub fn proposers(mut self, agents: Vec<AgentRunner<P>>) -> Self {
140 self.proposers.extend(agents.into_iter().map(Arc::new));
141 self
142 }
143
144 pub fn synthesizer(mut self, agent: AgentRunner<P>) -> Self {
146 self.synthesizer = Some(Arc::new(agent));
147 self
148 }
149
150 pub fn layers(mut self, n: usize) -> Self {
152 self.layers = Some(n);
153 self
154 }
155
156 pub fn build(self) -> Result<MixtureOfAgentsAgent<P>, Error> {
158 if self.proposers.len() < 2 {
159 return Err(Error::Config(
160 "MixtureOfAgentsAgent requires at least 2 proposers".into(),
161 ));
162 }
163 let synthesizer = self
164 .synthesizer
165 .ok_or_else(|| Error::Config("MixtureOfAgentsAgent requires a synthesizer".into()))?;
166 let layers = self.layers.unwrap_or(1);
167 if layers == 0 {
168 return Err(Error::Config(
169 "MixtureOfAgentsAgent layers must be at least 1".into(),
170 ));
171 }
172 Ok(MixtureOfAgentsAgent {
173 proposers: self.proposers,
174 synthesizer,
175 layers,
176 })
177 }
178}
179
180#[cfg(test)]
185mod tests {
186 use super::*;
187 use crate::agent::test_helpers::{MockProvider, make_agent};
188
189 #[test]
194 fn builder_rejects_fewer_than_two_proposers() {
195 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
196 "x", 1, 1,
197 )]));
198 let synth = make_agent(Arc::clone(&provider), "synth");
199
200 let result = MixtureOfAgentsAgent::<MockProvider>::builder()
202 .synthesizer(synth)
203 .build();
204 assert!(result.is_err());
205 assert!(
206 result
207 .unwrap_err()
208 .to_string()
209 .contains("at least 2 proposers")
210 );
211
212 let synth2 = make_agent(Arc::clone(&provider), "synth2");
214 let p1 = make_agent(provider, "p1");
215 let result = MixtureOfAgentsAgent::builder()
216 .proposer(p1)
217 .synthesizer(synth2)
218 .build();
219 assert!(result.is_err());
220 assert!(
221 result
222 .unwrap_err()
223 .to_string()
224 .contains("at least 2 proposers")
225 );
226 }
227
228 #[test]
229 fn builder_rejects_missing_synthesizer() {
230 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
231 "x", 1, 1,
232 )]));
233 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
234 "x", 1, 1,
235 )]));
236
237 let result = MixtureOfAgentsAgent::builder()
238 .proposer(make_agent(p1, "a"))
239 .proposer(make_agent(p2, "b"))
240 .build();
241 assert!(result.is_err());
242 assert!(
243 result
244 .unwrap_err()
245 .to_string()
246 .contains("requires a synthesizer")
247 );
248 }
249
250 #[test]
251 fn builder_rejects_zero_layers() {
252 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
253 "x", 1, 1,
254 )]));
255 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
256 "x", 1, 1,
257 )]));
258 let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
259 "x", 1, 1,
260 )]));
261
262 let result = MixtureOfAgentsAgent::builder()
263 .proposer(make_agent(p1, "a"))
264 .proposer(make_agent(p2, "b"))
265 .synthesizer(make_agent(synth, "synth"))
266 .layers(0)
267 .build();
268 assert!(result.is_err());
269 assert!(
270 result
271 .unwrap_err()
272 .to_string()
273 .contains("layers must be at least 1")
274 );
275 }
276
277 #[test]
278 fn builder_accepts_valid_config_default_layers() {
279 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
280 "x", 1, 1,
281 )]));
282 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
283 "x", 1, 1,
284 )]));
285 let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
286 "x", 1, 1,
287 )]));
288
289 let result = MixtureOfAgentsAgent::builder()
290 .proposer(make_agent(p1, "a"))
291 .proposer(make_agent(p2, "b"))
292 .synthesizer(make_agent(synth, "synth"))
293 .build();
294 assert!(result.is_ok());
295 }
296
297 #[test]
298 fn builder_accepts_valid_config_explicit_layers() {
299 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
300 "x", 1, 1,
301 )]));
302 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
303 "x", 1, 1,
304 )]));
305 let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
306 "x", 1, 1,
307 )]));
308
309 let result = MixtureOfAgentsAgent::builder()
310 .proposer(make_agent(p1, "a"))
311 .proposer(make_agent(p2, "b"))
312 .synthesizer(make_agent(synth, "synth"))
313 .layers(3)
314 .build();
315 assert!(result.is_ok());
316 }
317
318 #[test]
323 fn debug_impl_shows_proposer_count_and_layers() {
324 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
325 "x", 1, 1,
326 )]));
327 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
328 "x", 1, 1,
329 )]));
330 let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
331 "x", 1, 1,
332 )]));
333 let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
334 "x", 1, 1,
335 )]));
336
337 let moa = MixtureOfAgentsAgent::builder()
338 .proposer(make_agent(p1, "a"))
339 .proposer(make_agent(p2, "b"))
340 .proposer(make_agent(p3, "c"))
341 .synthesizer(make_agent(synth, "synth"))
342 .layers(2)
343 .build()
344 .unwrap();
345
346 let debug = format!("{moa:?}");
347 assert!(debug.contains("MixtureOfAgentsAgent"));
348 assert!(debug.contains("proposer_count: 3"));
349 assert!(debug.contains("layers: 2"));
350 }
351
352 #[test]
357 fn builder_proposers_bulk_method() {
358 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
359 "x", 1, 1,
360 )]));
361 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
362 "x", 1, 1,
363 )]));
364 let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
365 "x", 1, 1,
366 )]));
367 let agents = vec![make_agent(p1, "a"), make_agent(p2, "b")];
368 let result = MixtureOfAgentsAgent::builder()
369 .proposers(agents)
370 .synthesizer(make_agent(synth, "synth"))
371 .build();
372 assert!(result.is_ok());
373 }
374
375 #[tokio::test]
376 async fn single_layer_execution() {
377 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
379 "proposal from alpha",
380 100,
381 50,
382 )]));
383 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
384 "proposal from beta",
385 120,
386 60,
387 )]));
388 let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
389 "synthesized result",
390 200,
391 100,
392 )]));
393
394 let moa = MixtureOfAgentsAgent::builder()
395 .proposer(make_agent(p1, "alpha"))
396 .proposer(make_agent(p2, "beta"))
397 .synthesizer(make_agent(synth, "synth"))
398 .build()
399 .unwrap();
400
401 let output = moa.execute("analyze this").await.unwrap();
402 assert_eq!(output.result, "synthesized result");
403 }
404
405 #[tokio::test]
406 async fn token_usage_accumulated() {
407 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
408 "p1-out", 100, 50,
409 )]));
410 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
411 "p2-out", 120, 60,
412 )]));
413 let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
414 "final", 200, 100,
415 )]));
416
417 let moa = MixtureOfAgentsAgent::builder()
418 .proposer(make_agent(p1, "a"))
419 .proposer(make_agent(p2, "b"))
420 .synthesizer(make_agent(synth, "synth"))
421 .build()
422 .unwrap();
423
424 let output = moa.execute("task").await.unwrap();
425 assert_eq!(output.tokens_used.input_tokens, 420);
427 assert_eq!(output.tokens_used.output_tokens, 210);
429 }
430
431 #[tokio::test]
432 async fn multi_layer_execution() {
433 let p1 = Arc::new(MockProvider::new(vec![
437 MockProvider::text_response("p1-layer1", 10, 5),
438 MockProvider::text_response("p1-layer2", 10, 5),
439 ]));
440 let p2 = Arc::new(MockProvider::new(vec![
441 MockProvider::text_response("p2-layer1", 10, 5),
442 MockProvider::text_response("p2-layer2", 10, 5),
443 ]));
444 let synth = Arc::new(MockProvider::new(vec![
445 MockProvider::text_response("synth-layer1", 20, 10),
446 MockProvider::text_response("synth-layer2-final", 20, 10),
447 ]));
448
449 let moa = MixtureOfAgentsAgent::builder()
450 .proposer(make_agent(p1, "a"))
451 .proposer(make_agent(p2, "b"))
452 .synthesizer(make_agent(synth, "synth"))
453 .layers(2)
454 .build()
455 .unwrap();
456
457 let output = moa.execute("task").await.unwrap();
458 assert_eq!(output.result, "synth-layer2-final");
459 assert_eq!(output.tokens_used.input_tokens, 80);
461 assert_eq!(output.tokens_used.output_tokens, 40);
463 }
464
465 #[tokio::test]
466 async fn proposer_error_carries_partial_usage() {
467 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
468 "ok", 100, 50,
469 )]));
470 let p2 = Arc::new(MockProvider::new(vec![]));
472 let synth = Arc::new(MockProvider::new(vec![MockProvider::text_response(
473 "final", 10, 5,
474 )]));
475
476 let moa = MixtureOfAgentsAgent::builder()
477 .proposer(make_agent(p1, "good"))
478 .proposer(make_agent(p2, "bad"))
479 .synthesizer(make_agent(synth, "synth"))
480 .build()
481 .unwrap();
482
483 let err = moa.execute("task").await.unwrap_err();
484 let partial = err.partial_usage();
485 assert!(
488 partial.input_tokens == 0 || partial.input_tokens >= 100,
489 "partial usage should be zero or include completed proposer"
490 );
491 }
492
493 #[tokio::test]
494 async fn synthesizer_error_carries_partial_usage_from_proposers() {
495 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
496 "ok1", 100, 50,
497 )]));
498 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
499 "ok2", 120, 60,
500 )]));
501 let synth = Arc::new(MockProvider::new(vec![]));
503
504 let moa = MixtureOfAgentsAgent::builder()
505 .proposer(make_agent(p1, "a"))
506 .proposer(make_agent(p2, "b"))
507 .synthesizer(make_agent(synth, "synth"))
508 .build()
509 .unwrap();
510
511 let err = moa.execute("task").await.unwrap_err();
512 let partial = err.partial_usage();
513 assert!(partial.input_tokens >= 220);
515 }
516
517 #[tokio::test]
518 async fn synthesizer_receives_sorted_proposal_document() {
519 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
520 "output-c", 10, 5,
521 )]));
522 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
523 "output-a", 10, 5,
524 )]));
525 let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
526 "output-b", 10, 5,
527 )]));
528 let synth_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
529 "final-synthesis",
530 10,
531 5,
532 )]));
533
534 let moa = MixtureOfAgentsAgent::builder()
535 .proposer(make_agent(Arc::clone(&p1), "charlie"))
536 .proposer(make_agent(Arc::clone(&p2), "alpha"))
537 .proposer(make_agent(Arc::clone(&p3), "beta"))
538 .synthesizer(make_agent(Arc::clone(&synth_p), "synth"))
539 .build()
540 .unwrap();
541
542 let output = moa.execute("task").await.unwrap();
543 assert_eq!(output.result, "final-synthesis");
544
545 let synth_requests = synth_p.captured_requests.lock().unwrap();
547 assert_eq!(synth_requests.len(), 1);
548 let synth_input = &synth_requests[0].messages[0];
549 let input_text = match &synth_input.content[0] {
550 crate::llm::types::ContentBlock::Text { text } => text.as_str(),
551 _ => panic!("expected text content"),
552 };
553 let alpha_pos = input_text
555 .find("## alpha")
556 .expect("should contain ## alpha");
557 let beta_pos = input_text.find("## beta").expect("should contain ## beta");
558 let charlie_pos = input_text
559 .find("## charlie")
560 .expect("should contain ## charlie");
561 assert!(alpha_pos < beta_pos, "alpha should come before beta");
562 assert!(beta_pos < charlie_pos, "beta should come before charlie");
563 assert!(input_text.contains("output-a"));
565 assert!(input_text.contains("output-b"));
566 assert!(input_text.contains("output-c"));
567 }
568}