use std::sync::Arc;
use async_trait::async_trait;
use crate::agent::Agent;
use crate::context::InvocationContext;
use crate::error::AgentError;
pub struct ParallelAgent {
name: String,
sub_agents: Vec<Arc<dyn Agent>>,
}
impl ParallelAgent {
pub fn new(name: impl Into<String>, sub_agents: Vec<Arc<dyn Agent>>) -> Self {
Self {
name: name.into(),
sub_agents,
}
}
}
#[async_trait]
impl Agent for ParallelAgent {
fn name(&self) -> &str {
&self.name
}
async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
let mut handles = Vec::new();
for sub in &self.sub_agents {
let sub = sub.clone();
let agent_session = ctx.agent_session.clone();
let event_tx = ctx.event_tx.clone();
let middleware = ctx.middleware.clone();
handles.push(tokio::spawn(async move {
let mut branch_ctx = InvocationContext {
agent_session,
event_tx,
middleware,
run_config: crate::run_config::RunConfig::default(),
session_id: None,
artifact_service: None,
memory_service: None,
session_service: None,
};
sub.run_live(&mut branch_ctx).await
}));
}
for handle in handles {
handle
.await
.map_err(|e| AgentError::Other(format!("Join error: {}", e)))??;
}
Ok(())
}
fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
self.sub_agents.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent_session::{AgentSession, NoOpSessionWriter};
use crate::context::InvocationContext;
use crate::error::AgentError;
use std::sync::Arc;
use tokio::sync::broadcast;
fn test_ctx() -> InvocationContext {
let (event_tx, _) = broadcast::channel(16);
let writer: Arc<dyn rs_genai::session::SessionWriter> = Arc::new(NoOpSessionWriter);
let agent_session = AgentSession::from_writer(writer, event_tx);
InvocationContext::new(agent_session)
}
struct StateSetAgent {
agent_name: String,
key: String,
value: String,
}
#[async_trait]
impl Agent for StateSetAgent {
fn name(&self) -> &str {
&self.agent_name
}
async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
ctx.state().set(&self.key, &self.value);
Ok(())
}
}
struct FailAgent {
agent_name: String,
}
#[async_trait]
impl Agent for FailAgent {
fn name(&self) -> &str {
&self.agent_name
}
async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
Err(AgentError::Other("parallel fail".to_string()))
}
}
#[tokio::test]
async fn parallel_runs_all() {
let agents: Vec<Arc<dyn Agent>> = vec![
Arc::new(StateSetAgent {
agent_name: "a".into(),
key: "key_a".into(),
value: "val_a".into(),
}),
Arc::new(StateSetAgent {
agent_name: "b".into(),
key: "key_b".into(),
value: "val_b".into(),
}),
Arc::new(StateSetAgent {
agent_name: "c".into(),
key: "key_c".into(),
value: "val_c".into(),
}),
];
let par = ParallelAgent::new("par", agents);
let mut ctx = test_ctx();
par.run_live(&mut ctx).await.unwrap();
assert_eq!(
ctx.state().get::<String>("key_a"),
Some("val_a".to_string())
);
assert_eq!(
ctx.state().get::<String>("key_b"),
Some("val_b".to_string())
);
assert_eq!(
ctx.state().get::<String>("key_c"),
Some("val_c".to_string())
);
}
#[tokio::test]
async fn parallel_fails_if_any_fails() {
let agents: Vec<Arc<dyn Agent>> = vec![
Arc::new(StateSetAgent {
agent_name: "a".into(),
key: "key_a".into(),
value: "val_a".into(),
}),
Arc::new(FailAgent {
agent_name: "b".into(),
}),
Arc::new(StateSetAgent {
agent_name: "c".into(),
key: "key_c".into(),
value: "val_c".into(),
}),
];
let par = ParallelAgent::new("par", agents);
let mut ctx = test_ctx();
let result = par.run_live(&mut ctx).await;
assert!(result.is_err());
}
#[tokio::test]
async fn parallel_empty_succeeds() {
let par = ParallelAgent::new("empty", vec![]);
let mut ctx = test_ctx();
par.run_live(&mut ctx).await.unwrap();
}
#[test]
fn parallel_sub_agents_returns_children() {
let agents: Vec<Arc<dyn Agent>> = vec![Arc::new(StateSetAgent {
agent_name: "child".into(),
key: "k".into(),
value: "v".into(),
})];
let par = ParallelAgent::new("par", agents);
assert_eq!(par.sub_agents().len(), 1);
assert_eq!(par.sub_agents()[0].name(), "child");
}
}