use crate::orchestration::{
agent::{Agent, AgentInput, AgentOutput},
context::ExecutionTrace,
errors::Result,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrchestratorInput {
pub content: String,
#[serde(default)]
pub context: serde_json::Value,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
impl OrchestratorInput {
pub fn new(content: impl Into<String>) -> Self {
Self {
content: content.into(),
context: serde_json::json!({}),
metadata: HashMap::new(),
}
}
pub fn with_context(mut self, context: serde_json::Value) -> Self {
self.context = context;
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrchestratorOutput {
pub result: String,
pub agent_outputs: Vec<AgentOutput>,
pub execution_trace: ExecutionTrace,
pub success: bool,
pub error: Option<String>,
}
impl OrchestratorOutput {
pub fn success(
result: impl Into<String>,
agent_outputs: Vec<AgentOutput>,
execution_trace: ExecutionTrace,
) -> Self {
Self {
result: result.into(),
agent_outputs,
execution_trace,
success: true,
error: None,
}
}
pub fn failure(error: impl Into<String>, execution_trace: ExecutionTrace) -> Self {
Self {
result: String::new(),
agent_outputs: Vec::new(),
execution_trace,
success: false,
error: Some(error.into()),
}
}
pub fn is_successful(&self) -> bool {
self.success
}
}
#[async_trait::async_trait]
pub trait Orchestrator: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
async fn orchestrate(
&self,
agents: Vec<Box<dyn Agent>>,
input: OrchestratorInput,
) -> Result<OrchestratorOutput>;
}
pub struct BaseOrchestrator {
name: String,
description: String,
}
impl BaseOrchestrator {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn description(&self) -> &str {
&self.description
}
pub async fn execute_agent_with_retry(
&self,
agent: &dyn Agent,
input: AgentInput,
max_retries: usize,
) -> AgentOutput {
let mut last_error = None;
for attempt in 0..=max_retries {
match agent.execute(input.clone()).await {
Ok(output) => return output,
Err(e) => {
last_error = Some(e.to_string());
if attempt < max_retries {
tokio::time::sleep(std::time::Duration::from_millis(
100 * 2_u64.pow(attempt as u32),
))
.await;
}
},
}
}
AgentOutput::new(format!(
"Agent {} failed after {} retries: {}",
agent.name(),
max_retries,
last_error.unwrap_or_else(|| "Unknown error".to_string())
))
.with_confidence(0.0)
}
pub fn input_to_agent_input(&self, input: &OrchestratorInput) -> AgentInput {
AgentInput::new(&input.content)
.with_context(input.context.clone())
.with_metadata("orchestrator", self.name())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::orchestration::agent::SimpleAgent;
#[tokio::test]
async fn test_orchestrator_input() {
let input = OrchestratorInput::new("Test content")
.with_context(serde_json::json!({"key": "value"}))
.with_metadata("meta1", "value1");
assert_eq!(input.content, "Test content");
assert_eq!(input.context["key"], "value");
assert_eq!(input.metadata["meta1"], "value1");
}
#[tokio::test]
async fn test_orchestrator_output() {
use crate::orchestration::context::ExecutionTrace;
let trace = ExecutionTrace::new();
let outputs = vec![AgentOutput::new("result1")];
let success = OrchestratorOutput::success("Final result", outputs, trace.clone());
assert!(success.is_successful());
assert_eq!(success.result, "Final result");
assert!(success.error.is_none());
let failure = OrchestratorOutput::failure("Something went wrong", trace);
assert!(!failure.is_successful());
assert_eq!(failure.error, Some("Something went wrong".to_string()));
}
#[tokio::test]
async fn test_base_orchestrator() {
let orchestrator = BaseOrchestrator::new("TestOrchestrator", "A test orchestrator");
assert_eq!(orchestrator.name(), "TestOrchestrator");
assert_eq!(orchestrator.description(), "A test orchestrator");
}
#[tokio::test]
async fn test_execute_agent_with_retry_success() {
let orchestrator = BaseOrchestrator::new("Test", "Test");
let agent = SimpleAgent::new("TestAgent", "Test", |input| {
Ok(AgentOutput::new(format!("Processed: {}", input.content)))
});
let input = AgentInput::new("Hello");
let output = orchestrator
.execute_agent_with_retry(&agent, input, 3)
.await;
assert!(output.is_successful());
assert_eq!(output.content, "Processed: Hello");
}
#[tokio::test]
async fn test_execute_agent_with_retry_failure() {
let orchestrator = BaseOrchestrator::new("Test", "Test");
let agent = SimpleAgent::new("FailingAgent", "Always fails", |_input| {
Err(anyhow::anyhow!("Always fails").into())
});
let input = AgentInput::new("Hello");
let output = orchestrator
.execute_agent_with_retry(&agent, input, 2)
.await;
assert!(!output.is_successful());
assert!(output.content.contains("failed after"));
assert_eq!(output.confidence, 0.0);
}
}