use std::fmt;
use traitclaw_core::agent::Agent;
use traitclaw_core::types::message::{Message, MessageRole};
pub trait TerminationCondition: Send + Sync {
fn should_terminate(&self, round: usize, messages: &[Message]) -> bool;
}
#[derive(Debug, Clone)]
pub struct MaxRoundsTermination {
max_rounds: usize,
}
impl MaxRoundsTermination {
#[must_use]
pub fn new(max_rounds: usize) -> Self {
Self { max_rounds }
}
}
impl TerminationCondition for MaxRoundsTermination {
fn should_terminate(&self, round: usize, _messages: &[Message]) -> bool {
round >= self.max_rounds
}
}
#[derive(Debug, Clone)]
pub struct GroupChatResult {
pub transcript: Vec<Message>,
pub final_message: String,
}
pub struct RoundRobinGroupChat {
agents: Vec<Agent>,
termination: Box<dyn TerminationCondition>,
}
impl fmt::Debug for RoundRobinGroupChat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RoundRobinGroupChat")
.field("agents", &self.agents.len())
.finish()
}
}
impl RoundRobinGroupChat {
#[must_use]
pub fn new(agents: Vec<Agent>) -> Self {
let max_rounds = agents.len().saturating_mul(3).max(1);
Self {
termination: Box::new(MaxRoundsTermination::new(max_rounds)),
agents,
}
}
#[must_use]
pub fn with_max_rounds(mut self, n: usize) -> Self {
self.termination = Box::new(MaxRoundsTermination::new(n));
self
}
#[must_use]
pub fn with_termination(mut self, t: impl TerminationCondition + 'static) -> Self {
self.termination = Box::new(t);
self
}
#[must_use]
pub fn len(&self) -> usize {
self.agents.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.agents.is_empty()
}
pub async fn run(&self, task: &str) -> traitclaw_core::Result<GroupChatResult> {
if self.agents.is_empty() {
return Err(traitclaw_core::Error::Runtime(
"RoundRobinGroupChat::run() called with no agents".into(),
));
}
let mut transcript = vec![Message {
role: MessageRole::User,
content: task.to_string(),
tool_call_id: None,
}];
let n_agents = self.agents.len();
let mut round = 0;
loop {
if self.termination.should_terminate(round, &transcript) {
break;
}
let agent_idx = round % n_agents;
let agent = &self.agents[agent_idx];
let context = Self::format_transcript(&transcript);
let output = agent.run(&context).await?;
let response_text = output.text().to_string();
transcript.push(Message {
role: MessageRole::Assistant,
content: response_text,
tool_call_id: None,
});
round += 1;
}
let final_message = transcript
.last()
.map(|m| m.content.clone())
.unwrap_or_default();
Ok(GroupChatResult {
transcript,
final_message,
})
}
fn format_transcript(messages: &[Message]) -> String {
messages
.iter()
.map(|m| {
let role = match m.role {
MessageRole::User => "User",
MessageRole::Assistant => "Assistant",
MessageRole::System => "System",
MessageRole::Tool => "Tool",
_ => "Unknown",
};
format!("[{}]: {}", role, m.content)
})
.collect::<Vec<_>>()
.join("\n\n")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests_common::EchoProvider;
#[test]
fn test_max_rounds_at_boundary() {
let term = MaxRoundsTermination::new(3);
assert!(!term.should_terminate(0, &[]));
assert!(!term.should_terminate(1, &[]));
assert!(!term.should_terminate(2, &[]));
assert!(term.should_terminate(3, &[]));
assert!(term.should_terminate(4, &[]));
}
#[test]
fn test_max_rounds_zero() {
let term = MaxRoundsTermination::new(0);
assert!(term.should_terminate(0, &[]));
}
#[test]
fn test_group_chat_new_default_rounds() {
let agents = vec![
Agent::with_system(EchoProvider::new("A"), "Agent A"),
Agent::with_system(EchoProvider::new("B"), "Agent B"),
];
let chat = RoundRobinGroupChat::new(agents);
assert_eq!(chat.len(), 2);
}
#[test]
fn test_group_chat_with_max_rounds() {
let agents = vec![Agent::with_system(EchoProvider::new("A"), "Agent A")];
let chat = RoundRobinGroupChat::new(agents).with_max_rounds(10);
assert_eq!(chat.len(), 1);
}
#[tokio::test]
async fn test_group_chat_run_basic() {
let agents = vec![
Agent::with_system(EchoProvider::new("R"), "Researcher"),
Agent::with_system(EchoProvider::new("W"), "Writer"),
];
let chat = RoundRobinGroupChat::new(agents).with_max_rounds(2);
let result = chat.run("Discuss Rust").await.unwrap();
assert_eq!(result.transcript.len(), 3);
assert!(!result.final_message.is_empty());
}
#[tokio::test]
async fn test_group_chat_round_robin_order() {
let agents = vec![
Agent::with_system(EchoProvider::new("FIRST"), "First"),
Agent::with_system(EchoProvider::new("SECOND"), "Second"),
];
let chat = RoundRobinGroupChat::new(agents).with_max_rounds(4);
let result = chat.run("Test").await.unwrap();
assert_eq!(result.transcript.len(), 5);
assert!(result.transcript[1].content.contains("[FIRST]"));
assert!(result.transcript[2].content.contains("[SECOND]"));
assert!(result.transcript[3].content.contains("[FIRST]"));
assert!(result.transcript[4].content.contains("[SECOND]"));
}
#[tokio::test]
async fn test_group_chat_empty_agents_returns_error() {
let chat = RoundRobinGroupChat::new(vec![]);
let result = chat.run("Test").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_group_chat_custom_termination() {
struct ContainsKeyword;
impl TerminationCondition for ContainsKeyword {
fn should_terminate(&self, _round: usize, messages: &[Message]) -> bool {
messages.iter().any(|m| m.content.contains("DONE"))
}
}
let agents = vec![Agent::with_system(EchoProvider::new("DONE"), "Agent")];
let chat = RoundRobinGroupChat::new(agents).with_termination(ContainsKeyword);
let result = chat.run("Test").await.unwrap();
assert_eq!(result.transcript.len(), 2);
}
#[test]
fn test_group_chat_debug() {
let chat = RoundRobinGroupChat::new(vec![Agent::with_system(EchoProvider::new("A"), "A")]);
let debug = format!("{chat:?}");
assert!(debug.contains("RoundRobinGroupChat"));
}
}