use anyhow::Result;
use oxi_agent::Agent;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum GroupStrategy {
Pipeline,
Parallel {
max_concurrency: usize,
},
Orchestrated {
leader: usize,
},
}
impl Default for GroupStrategy {
fn default() -> Self {
GroupStrategy::Parallel { max_concurrency: 4 }
}
}
#[derive(Debug, Clone)]
pub struct AgentGroupOutput {
pub name: String,
pub content: String,
pub success: bool,
pub error: Option<String>,
}
#[derive(Debug)]
pub struct GroupResult {
pub results: Vec<AgentGroupOutput>,
pub total_duration_ms: u64,
}
impl GroupResult {
pub fn all_succeeded(&self) -> bool {
self.results.iter().all(|r| r.success)
}
pub fn combined_content(&self) -> String {
self.results
.iter()
.map(|r| r.content.as_str())
.collect::<Vec<_>>()
.join("\n\n")
}
}
pub struct AgentGroup {
agents: Vec<Arc<Agent>>,
strategy: GroupStrategy,
}
impl AgentGroup {
pub fn new(strategy: GroupStrategy) -> Self {
Self {
agents: Vec::new(),
strategy,
}
}
pub fn agent(mut self, agent: Arc<Agent>) -> Self {
self.agents.push(agent);
self
}
pub fn len(&self) -> usize {
self.agents.len()
}
pub fn is_empty(&self) -> bool {
self.agents.is_empty()
}
pub async fn run(&self, prompt: String) -> Result<GroupResult> {
if self.agents.is_empty() {
return Ok(GroupResult {
results: Vec::new(),
total_duration_ms: 0,
});
}
let start = std::time::Instant::now();
let results = match &self.strategy {
GroupStrategy::Pipeline => self.run_pipeline(prompt).await?,
GroupStrategy::Parallel { max_concurrency } => {
self.run_parallel(prompt, *max_concurrency).await?
}
GroupStrategy::Orchestrated { leader } => {
self.run_orchestrated(prompt, *leader).await?
}
};
Ok(GroupResult {
total_duration_ms: start.elapsed().as_millis() as u64,
results,
})
}
async fn run_pipeline(&self, prompt: String) -> Result<Vec<AgentGroupOutput>> {
let mut results = Vec::with_capacity(self.agents.len());
let mut current_input = prompt;
for agent in &self.agents {
match agent.run(current_input.clone()).await {
Ok((response, _events)) => {
results.push(AgentGroupOutput {
name: agent.model_id(),
content: response.content.clone(),
success: true,
error: None,
});
current_input = response.content;
}
Err(e) => {
results.push(AgentGroupOutput {
name: agent.model_id(),
content: String::new(),
success: false,
error: Some(e.to_string()),
});
break;
}
}
}
Ok(results)
}
async fn run_parallel(
&self,
prompt: String,
max_concurrency: usize,
) -> Result<Vec<AgentGroupOutput>> {
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_concurrency));
let mut handles = Vec::with_capacity(self.agents.len());
for agent in self.agents.iter() {
let agent = Arc::clone(agent);
let prompt = prompt.clone();
let sem = Arc::clone(&semaphore);
handles.push(tokio::task::spawn_blocking(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create runtime");
rt.block_on(async move {
let _permit = sem.acquire().await.expect("semaphore closed");
match agent.run(prompt).await {
Ok((response, _events)) => AgentGroupOutput {
name: agent.model_id(),
content: response.content,
success: true,
error: None,
},
Err(e) => AgentGroupOutput {
name: agent.model_id(),
content: String::new(),
success: false,
error: Some(e.to_string()),
},
}
})
}));
}
let mut results = Vec::with_capacity(handles.len());
for handle in handles {
match handle.await {
Ok(output) => results.push(output),
Err(e) => results.push(AgentGroupOutput {
name: String::new(),
content: String::new(),
success: false,
error: Some(format!("Join error: {}", e)),
}),
}
}
Ok(results)
}
async fn run_orchestrated(
&self,
prompt: String,
leader_idx: usize,
) -> Result<Vec<AgentGroupOutput>> {
if leader_idx >= self.agents.len() {
anyhow::bail!(
"Leader index {} out of range ({} agents)",
leader_idx,
self.agents.len()
);
}
let leader = &self.agents[leader_idx];
let (response, _events) = leader.run(prompt).await?;
Ok(vec![AgentGroupOutput {
name: leader.model_id(),
content: response.content,
success: true,
error: None,
}])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_group_strategy_default() {
let strategy = GroupStrategy::default();
matches!(strategy, GroupStrategy::Parallel { max_concurrency: 4 });
}
#[test]
fn test_empty_group() {
let group = AgentGroup::new(GroupStrategy::Pipeline);
assert!(group.is_empty());
assert_eq!(group.len(), 0);
}
#[test]
fn test_group_result_all_succeeded() {
let result = GroupResult {
results: vec![
AgentGroupOutput {
name: "a".into(),
content: "ok".into(),
success: true,
error: None,
},
AgentGroupOutput {
name: "b".into(),
content: "ok".into(),
success: true,
error: None,
},
],
total_duration_ms: 100,
};
assert!(result.all_succeeded());
}
#[test]
fn test_group_result_combined_content() {
let result = GroupResult {
results: vec![
AgentGroupOutput {
name: "a".into(),
content: "first".into(),
success: true,
error: None,
},
AgentGroupOutput {
name: "b".into(),
content: "second".into(),
success: true,
error: None,
},
],
total_duration_ms: 100,
};
assert_eq!(result.combined_content(), "first\n\nsecond");
}
}