1use super::agent::LLMAgent;
28use super::types::{LLMError, LLMResult};
29use std::collections::HashMap;
30use std::sync::Arc;
31
32#[derive(Debug, Clone)]
34pub enum TeamPattern {
35 Chain,
37 Parallel,
39 Debate {
41 max_rounds: usize,
43 },
44 Supervised,
46 MapReduce,
48 Custom,
50}
51
52#[derive(Debug, Clone)]
54pub struct AgentRole {
55 pub id: String,
57 pub name: String,
59 pub description: String,
61 pub prompt_template: Option<String>,
63}
64
65impl AgentRole {
66 pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
67 Self {
68 id: id.into(),
69 name: name.into(),
70 description: String::new(),
71 prompt_template: None,
72 }
73 }
74
75 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
76 self.description = desc.into();
77 self
78 }
79
80 pub fn with_template(mut self, template: impl Into<String>) -> Self {
81 self.prompt_template = Some(template.into());
82 self
83 }
84}
85
86pub struct AgentMember {
88 pub role: AgentRole,
90 pub agent: Arc<LLMAgent>,
92}
93
94impl AgentMember {
95 pub fn new(id: impl Into<String>, agent: Arc<LLMAgent>) -> Self {
96 let id = id.into();
97 Self {
98 role: AgentRole::new(&id, &id),
99 agent,
100 }
101 }
102
103 pub fn with_role(mut self, role: AgentRole) -> Self {
104 self.role = role;
105 self
106 }
107
108 pub async fn execute(&self, input: &str, context: Option<&str>) -> LLMResult<String> {
110 let prompt = if let Some(ref template) = self.role.prompt_template {
111 let mut p = template.replace("{input}", input);
112 if let Some(ctx) = context {
113 p = p.replace("{context}", ctx);
114 }
115 p
116 } else if let Some(ctx) = context {
117 format!("Context:\n{}\n\nTask:\n{}", ctx, input)
118 } else {
119 input.to_string()
120 };
121
122 self.agent.ask(&prompt).await
123 }
124}
125
126pub struct AgentTeam {
128 pub id: String,
130 pub name: String,
132 members: Vec<AgentMember>,
134 member_map: HashMap<String, usize>,
136 pattern: TeamPattern,
138 supervisor_id: Option<String>,
140 aggregate_prompt: Option<String>,
142}
143
144impl AgentTeam {
145 pub fn new(id: impl Into<String>) -> AgentTeamBuilder {
147 AgentTeamBuilder::new(id)
148 }
149
150 async fn run_chain(&self, input: &str) -> LLMResult<String> {
152 let mut current_output = input.to_string();
153
154 for member in &self.members {
155 current_output = member.execute(¤t_output, None).await?;
156 }
157
158 Ok(current_output)
159 }
160
161 async fn run_parallel(&self, input: &str) -> LLMResult<String> {
163 let mut results = Vec::new();
164
165 for member in &self.members {
168 let result = member.execute(input, None).await?;
169 results.push((member.role.id.clone(), result));
170 }
171
172 let aggregated = results
174 .iter()
175 .map(|(id, result)| format!("=== {} ===\n{}", id, result))
176 .collect::<Vec<_>>()
177 .join("\n\n");
178
179 if let Some(ref aggregate_prompt) = self.aggregate_prompt
181 && let Some(first_member) = self.members.first()
182 {
183 let prompt = aggregate_prompt
184 .replace("{results}", &aggregated)
185 .replace("{input}", input);
186 return first_member.agent.ask(&prompt).await;
187 }
188
189 Ok(aggregated)
190 }
191
192 async fn run_debate(&self, input: &str, max_rounds: usize) -> LLMResult<String> {
194 if self.members.len() < 2 {
195 return Err(LLMError::Other(
196 "Debate requires at least 2 agents".to_string(),
197 ));
198 }
199
200 let mut context = format!("Initial topic: {}\n\n", input);
201 let mut last_response = String::new();
202
203 for round in 0..max_rounds {
204 for (i, member) in self.members.iter().enumerate() {
205 let prompt = format!(
206 "Round {}, Speaker {}: {}\n\n\
207 Previous discussion:\n{}\n\n\
208 Please provide your perspective. Be constructive and build on previous points.",
209 round + 1,
210 i + 1,
211 member.role.name,
212 context
213 );
214
215 let response = member.execute(&prompt, None).await?;
216 context.push_str(&format!(
217 "[{} - Round {}]:\n{}\n\n",
218 member.role.name,
219 round + 1,
220 response
221 ));
222 last_response = response;
223 }
224 }
225
226 if let Some(first_member) = self.members.first() {
228 let summary_prompt = format!(
229 "Based on the following debate, provide a concise summary of the key points \
230 and conclusions:\n\n{}",
231 context
232 );
233 first_member.agent.ask(&summary_prompt).await
234 } else {
235 Ok(last_response)
236 }
237 }
238
239 async fn run_supervised(&self, input: &str) -> LLMResult<String> {
241 let supervisor_id = self.supervisor_id.as_ref().ok_or_else(|| {
242 LLMError::Other("Supervisor not specified for Supervised pattern".to_string())
243 })?;
244
245 let supervisor_idx = self
246 .member_map
247 .get(supervisor_id)
248 .ok_or_else(|| LLMError::Other(format!("Supervisor '{}' not found", supervisor_id)))?;
249
250 let mut worker_results = Vec::new();
252 for (i, member) in self.members.iter().enumerate() {
253 if i != *supervisor_idx {
254 let result = member.execute(input, None).await?;
255 worker_results.push((member.role.id.clone(), member.role.name.clone(), result));
256 }
257 }
258
259 let results_text = worker_results
261 .iter()
262 .map(|(id, name, result)| format!("=== {} ({}) ===\n{}", name, id, result))
263 .collect::<Vec<_>>()
264 .join("\n\n");
265
266 let supervisor = &self.members[*supervisor_idx];
267 let eval_prompt = format!(
268 "You are the supervisor. Evaluate the following responses to the task: \"{}\"\n\n\
269 Responses:\n{}\n\n\
270 Please provide:\n\
271 1. An evaluation of each response\n\
272 2. The best response or a synthesized improved response\n\
273 3. Suggestions for improvement",
274 input, results_text
275 );
276
277 supervisor.agent.ask(&eval_prompt).await
278 }
279
280 async fn run_map_reduce(&self, input: &str) -> LLMResult<String> {
282 let mut mapped_results = Vec::new();
284 for member in &self.members {
285 let result = member.execute(input, None).await?;
286 mapped_results.push((member.role.id.clone(), result));
287 }
288
289 let reduce_input = mapped_results
291 .iter()
292 .map(|(id, result)| format!("[{}]: {}", id, result))
293 .collect::<Vec<_>>()
294 .join("\n\n");
295
296 let reduce_prompt = if let Some(ref aggregate_prompt) = self.aggregate_prompt {
297 aggregate_prompt
298 .replace("{results}", &reduce_input)
299 .replace("{input}", input)
300 } else {
301 format!(
302 "Synthesize the following results into a coherent response:\n\n{}\n\n\
303 Original task: {}",
304 reduce_input, input
305 )
306 };
307
308 if let Some(first_member) = self.members.first() {
310 first_member.agent.ask(&reduce_prompt).await
311 } else {
312 Ok(reduce_input)
313 }
314 }
315
316 pub async fn run(&self, input: impl Into<String>) -> LLMResult<String> {
318 let input = input.into();
319
320 match &self.pattern {
321 TeamPattern::Chain => self.run_chain(&input).await,
322 TeamPattern::Parallel => self.run_parallel(&input).await,
323 TeamPattern::Debate { max_rounds } => self.run_debate(&input, *max_rounds).await,
324 TeamPattern::Supervised => self.run_supervised(&input).await,
325 TeamPattern::MapReduce => self.run_map_reduce(&input).await,
326 TeamPattern::Custom => {
327 self.run_chain(&input).await
329 }
330 }
331 }
332
333 pub fn get_member(&self, id: &str) -> Option<&AgentMember> {
335 self.member_map.get(id).map(|idx| &self.members[*idx])
336 }
337
338 pub fn member_ids(&self) -> Vec<&str> {
340 self.members.iter().map(|m| m.role.id.as_str()).collect()
341 }
342}
343
344pub struct AgentTeamBuilder {
346 id: String,
347 name: String,
348 members: Vec<AgentMember>,
349 pattern: TeamPattern,
350 supervisor_id: Option<String>,
351 aggregate_prompt: Option<String>,
352}
353
354impl AgentTeamBuilder {
355 pub fn new(id: impl Into<String>) -> Self {
356 let id = id.into();
357 Self {
358 name: id.clone(),
359 id,
360 members: Vec::new(),
361 pattern: TeamPattern::Chain,
362 supervisor_id: None,
363 aggregate_prompt: None,
364 }
365 }
366
367 pub fn with_name(mut self, name: impl Into<String>) -> Self {
369 self.name = name.into();
370 self
371 }
372
373 pub fn add_member(mut self, id: impl Into<String>, agent: Arc<LLMAgent>) -> Self {
375 self.members.push(AgentMember::new(id, agent));
376 self
377 }
378
379 pub fn add_member_with_role(mut self, agent: Arc<LLMAgent>, role: AgentRole) -> Self {
381 let member = AgentMember::new(&role.id, agent).with_role(role);
382 self.members.push(member);
383 self
384 }
385
386 pub fn with_pattern(mut self, pattern: TeamPattern) -> Self {
388 self.pattern = pattern;
389 self
390 }
391
392 pub fn with_supervisor(mut self, supervisor_id: impl Into<String>) -> Self {
394 self.supervisor_id = Some(supervisor_id.into());
395 self.pattern = TeamPattern::Supervised;
396 self
397 }
398
399 pub fn with_aggregate_prompt(mut self, prompt: impl Into<String>) -> Self {
401 self.aggregate_prompt = Some(prompt.into());
402 self
403 }
404
405 pub fn build(self) -> AgentTeam {
407 let member_map: HashMap<String, usize> = self
408 .members
409 .iter()
410 .enumerate()
411 .map(|(i, m)| (m.role.id.clone(), i))
412 .collect();
413
414 AgentTeam {
415 id: self.id,
416 name: self.name,
417 members: self.members,
418 member_map,
419 pattern: self.pattern,
420 supervisor_id: self.supervisor_id,
421 aggregate_prompt: self.aggregate_prompt,
422 }
423 }
424}
425
426pub fn content_creation_team(
434 researcher: Arc<LLMAgent>,
435 writer: Arc<LLMAgent>,
436 editor: Arc<LLMAgent>,
437) -> AgentTeam {
438 AgentTeamBuilder::new("content-creation")
439 .with_name("Content Creation Team")
440 .add_member_with_role(
441 researcher,
442 AgentRole::new("researcher", "Researcher")
443 .with_description("Research and gather information on the topic")
444 .with_template(
445 "Research the following topic thoroughly and provide key findings:\n\n{input}",
446 ),
447 )
448 .add_member_with_role(
449 writer,
450 AgentRole::new("writer", "Writer")
451 .with_description("Write engaging content based on research")
452 .with_template(
453 "Based on the following research, write an engaging article:\n\n{input}",
454 ),
455 )
456 .add_member_with_role(
457 editor,
458 AgentRole::new("editor", "Editor")
459 .with_description("Edit and polish the content")
460 .with_template(
461 "Edit and improve the following article for clarity and engagement:\n\n{input}",
462 ),
463 )
464 .with_pattern(TeamPattern::Chain)
465 .build()
466}
467
468pub fn code_review_team(
472 security_reviewer: Arc<LLMAgent>,
473 performance_reviewer: Arc<LLMAgent>,
474 style_reviewer: Arc<LLMAgent>,
475 supervisor: Arc<LLMAgent>,
476) -> AgentTeam {
477 AgentTeamBuilder::new("code-review")
478 .with_name("Code Review Team")
479 .add_member_with_role(
480 security_reviewer,
481 AgentRole::new("security", "Security Reviewer")
482 .with_description("Review code for security vulnerabilities"),
483 )
484 .add_member_with_role(
485 performance_reviewer,
486 AgentRole::new("performance", "Performance Reviewer")
487 .with_description("Review code for performance issues"),
488 )
489 .add_member_with_role(
490 style_reviewer,
491 AgentRole::new("style", "Style Reviewer")
492 .with_description("Review code for style and best practices"),
493 )
494 .add_member_with_role(
495 supervisor,
496 AgentRole::new("supervisor", "Lead Reviewer")
497 .with_description("Synthesize reviews and provide final feedback"),
498 )
499 .with_supervisor("supervisor")
500 .build()
501}
502
503pub fn debate_team(agent1: Arc<LLMAgent>, agent2: Arc<LLMAgent>, max_rounds: usize) -> AgentTeam {
507 AgentTeamBuilder::new("debate")
508 .with_name("Debate Team")
509 .add_member_with_role(
510 agent1,
511 AgentRole::new("debater1", "Debater 1")
512 .with_description("Present and defend your position"),
513 )
514 .add_member_with_role(
515 agent2,
516 AgentRole::new("debater2", "Debater 2")
517 .with_description("Present an alternative perspective"),
518 )
519 .with_pattern(TeamPattern::Debate { max_rounds })
520 .build()
521}
522
523pub fn analysis_team(analysts: Vec<(impl Into<String>, Arc<LLMAgent>)>) -> AgentTeam {
527 let mut builder = AgentTeamBuilder::new("analysis")
528 .with_name("Analysis Team")
529 .with_pattern(TeamPattern::MapReduce)
530 .with_aggregate_prompt(
531 "Synthesize the following analyses into a comprehensive report:\n\n{results}\n\n\
532 Original question: {input}",
533 );
534
535 for (id, agent) in analysts {
536 builder = builder.add_member(id, agent);
537 }
538
539 builder.build()
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545
546 #[test]
547 fn test_team_builder() {
548 let builder = AgentTeamBuilder::new("test-team")
550 .with_name("Test Team")
551 .with_pattern(TeamPattern::Chain);
552
553 assert_eq!(builder.id, "test-team");
555 assert_eq!(builder.name, "Test Team");
556 }
557
558 #[test]
559 fn test_agent_role() {
560 let role = AgentRole::new("researcher", "Researcher")
561 .with_description("Research topics")
562 .with_template("{input}");
563
564 assert_eq!(role.id, "researcher");
565 assert_eq!(role.name, "Researcher");
566 assert_eq!(role.description, "Research topics");
567 assert!(role.prompt_template.is_some());
568 }
569}