use crate::error::MultiError;
use crate::mailbox::Mailbox;
use crate::runner::AgentRunner;
use crate::shared::SharedInfra;
use crate::types::{AgentOutput, AgentSpec};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Instant;
use tracing::instrument;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FleetResult {
pub outputs: Vec<AgentOutput>,
pub duration_ms: f64,
pub succeeded: usize,
pub failed: usize,
}
pub struct Fleet {
pub agents: Vec<AgentSpec>,
pub agent_timeout_secs: Option<u64>,
}
impl Fleet {
pub fn new(agents: Vec<AgentSpec>) -> Self {
Self {
agents,
agent_timeout_secs: None,
}
}
pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
self.agent_timeout_secs = Some(timeout_secs);
self
}
#[instrument(name = "multi.fleet", skip_all)]
pub async fn run(
&self,
runner: &Arc<dyn AgentRunner>,
infra: &SharedInfra,
) -> Result<FleetResult, MultiError> {
let start = Instant::now();
let mailbox = Arc::new(Mailbox::default());
let mut handles = Vec::new();
for spec in &self.agents {
let runner = Arc::clone(runner);
let rt = infra.make_runtime();
let spec = spec.clone();
let mailbox = Arc::clone(&mailbox);
let timeout = self.agent_timeout_secs;
let _rx = mailbox.register(&spec.name).await;
let handle = tokio::spawn(async move {
let task = spec
.metadata
.get("task")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let result = if let Some(secs) = timeout {
match tokio::time::timeout(
tokio::time::Duration::from_secs(secs),
runner.run(&spec, &task, &rt, &mailbox),
)
.await
{
Ok(result) => result,
Err(_) => Ok(AgentOutput {
name: spec.name.clone(),
answer: format!("Timed out after {}s", secs),
turns: 0,
tool_calls: 0,
duration_ms: (secs * 1000) as f64,
error: Some("timeout".to_string()),
outcome: None,
tokens: None,
}),
}
} else {
runner.run(&spec, &task, &rt, &mailbox).await
};
mailbox.unregister(&spec.name).await;
result
});
handles.push(handle);
}
let mut outputs = Vec::new();
let mut succeeded = 0;
let mut failed = 0;
for handle in handles {
match handle.await {
Ok(Ok(output)) => {
if output.succeeded() {
succeeded += 1;
} else {
failed += 1;
}
outputs.push(output);
}
Ok(Err(e)) => {
failed += 1;
outputs.push(AgentOutput {
name: "unknown".to_string(),
answer: String::new(),
turns: 0,
tool_calls: 0,
duration_ms: 0.0,
error: Some(e.to_string()),
outcome: None,
tokens: None,
});
}
Err(e) => {
failed += 1;
outputs.push(AgentOutput {
name: "unknown".to_string(),
answer: String::new(),
turns: 0,
tool_calls: 0,
duration_ms: 0.0,
error: Some(format!("Task join error: {}", e)),
outcome: None,
tokens: None,
});
}
}
}
Ok(FleetResult {
outputs,
duration_ms: start.elapsed().as_millis() as f64,
succeeded,
failed,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::AgentSpec;
#[test]
fn test_fleet_construction() {
let fleet = Fleet::new(vec![
AgentSpec::new("a", "Agent A system prompt"),
AgentSpec::new("b", "Agent B system prompt"),
])
.with_timeout(60);
assert_eq!(fleet.agents.len(), 2);
assert_eq!(fleet.agent_timeout_secs, Some(60));
}
}