use std::fmt::Write as _;
use zeph_llm::provider::{LlmProvider, Message, Role};
use super::error::OrchestrationError;
use super::graph::{TaskGraph, TaskStatus};
use crate::config::OrchestrationConfig;
use crate::sanitizer::{
ContentIsolationConfig, ContentSanitizer, ContentSource, ContentSourceKind,
};
#[allow(async_fn_in_trait)]
pub trait Aggregator: Send + Sync {
async fn aggregate(&self, graph: &TaskGraph) -> Result<String, OrchestrationError>;
}
pub struct LlmAggregator<P: LlmProvider> {
provider: P,
aggregation_char_budget: usize,
sanitizer: ContentSanitizer,
}
impl<P: LlmProvider> LlmAggregator<P> {
#[must_use]
pub fn new(provider: P, config: &OrchestrationConfig) -> Self {
let aggregation_char_budget = config.aggregator_max_tokens as usize * 4;
Self {
provider,
aggregation_char_budget,
sanitizer: ContentSanitizer::new(&ContentIsolationConfig::default()),
}
}
}
impl<P: LlmProvider + Send + Sync> Aggregator for LlmAggregator<P> {
async fn aggregate(&self, graph: &TaskGraph) -> Result<String, OrchestrationError> {
let completed: Vec<_> = graph
.tasks
.iter()
.filter(|t| t.status == TaskStatus::Completed && t.result.is_some())
.collect();
let skipped: Vec<_> = graph
.tasks
.iter()
.filter(|t| t.status == TaskStatus::Skipped)
.collect();
if completed.is_empty() && skipped.is_empty() {
return Err(OrchestrationError::AggregationFailed(
"no completed or skipped tasks to aggregate".into(),
));
}
let num_completed = completed.len().max(1);
let per_task = self.aggregation_char_budget / num_completed;
let mut task_sections = String::new();
for task in &completed {
let raw_output = task.result.as_ref().map_or("", |r| r.output.as_str());
let truncated = truncate_chars(raw_output, per_task);
let sanitized = self
.sanitizer
.sanitize(truncated, ContentSource::new(ContentSourceKind::ToolResult));
let _ = write!(
task_sections,
"### Task: {}\n{}\n\n",
task.title, sanitized.body
);
}
if !skipped.is_empty() {
task_sections.push_str("### Skipped tasks (no output available):\n");
for task in &skipped {
let _ = writeln!(task_sections, "- {} ({})", task.title, task.description);
}
task_sections.push('\n');
}
let system = "You are a result synthesizer. Given the outputs from a set of completed \
sub-tasks, produce a single coherent summary that directly addresses \
the original goal. Be concise. If tasks were skipped, acknowledge them briefly.";
let user = format!(
"Goal: {goal}\n\n\
Task results:\n\n\
{task_sections}\
Synthesize the above into a single coherent response for the user.",
goal = graph.goal,
);
let messages = vec![
Message::from_legacy(Role::System, system),
Message::from_legacy(Role::User, user),
];
match self.provider.chat(&messages).await {
Ok(synthesis) => Ok(synthesis),
Err(e) => {
tracing::error!(
graph_id = %graph.id,
error = %e,
"aggregation LLM call failed; falling back to raw concatenation"
);
Ok(build_fallback(graph, &self.sanitizer, per_task))
}
}
}
}
fn truncate_chars(s: &str, max_chars: usize) -> &str {
if max_chars == 0 {
return "";
}
match s.char_indices().nth(max_chars) {
Some((byte_idx, _)) => &s[..byte_idx],
None => s,
}
}
fn build_fallback(
graph: &TaskGraph,
sanitizer: &ContentSanitizer,
per_task_chars: usize,
) -> String {
let mut out = String::new();
let _ = write!(out, "Goal: {}\n\n", graph.goal);
for task in &graph.tasks {
if task.status == TaskStatus::Completed {
if let Some(ref result) = task.result {
let truncated = truncate_chars(&result.output, per_task_chars);
let cleaned = sanitizer
.sanitize(truncated, ContentSource::new(ContentSourceKind::ToolResult));
let _ = write!(out, "## {}\n{}\n\n", task.title, cleaned.body);
}
} else if task.status == TaskStatus::Skipped {
let _ = write!(
out,
"## {} (skipped — {})\n\n",
task.title, task.description
);
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::orchestration::graph::{GraphStatus, TaskGraph, TaskNode, TaskResult, TaskStatus};
fn make_graph_with_tasks(statuses: &[(TaskStatus, Option<&str>)]) -> TaskGraph {
let mut graph = TaskGraph::new("test goal");
for (i, (status, output)) in statuses.iter().enumerate() {
let mut node = TaskNode::new(u32::try_from(i).unwrap(), format!("task-{i}"), "desc");
node.status = *status;
node.result = output.map(|o| TaskResult {
output: o.to_string(),
artifacts: vec![],
duration_ms: 100,
agent_id: None,
agent_def: None,
});
graph.tasks.push(node);
}
graph.status = GraphStatus::Completed;
graph
}
#[test]
fn test_truncate_empty() {
assert_eq!(truncate_chars("", 10), "");
}
#[test]
fn test_truncate_zero_budget() {
assert_eq!(truncate_chars("hello", 0), "");
}
#[test]
fn test_truncate_within_budget() {
assert_eq!(truncate_chars("hello", 10), "hello");
}
#[test]
fn test_truncate_at_limit() {
assert_eq!(truncate_chars("hello world", 5), "hello");
}
#[test]
fn test_truncate_multibyte() {
let s = "привет мир";
let truncated = truncate_chars(s, 6);
assert_eq!(truncated, "привет");
}
fn make_sanitizer() -> ContentSanitizer {
ContentSanitizer::new(&ContentIsolationConfig::default())
}
#[test]
fn test_build_fallback_includes_completed() {
let graph = make_graph_with_tasks(&[(TaskStatus::Completed, Some("output-a"))]);
let out = build_fallback(&graph, &make_sanitizer(), 4096);
assert!(
out.contains("output-a"),
"fallback should include task output"
);
assert!(out.contains("task-0"));
}
#[test]
fn test_build_fallback_includes_skipped_marker() {
let graph = make_graph_with_tasks(&[(TaskStatus::Skipped, None)]);
let out = build_fallback(&graph, &make_sanitizer(), 4096);
assert!(
out.contains("skipped"),
"fallback should mark skipped tasks"
);
}
#[test]
fn test_build_fallback_goal_included() {
let graph = make_graph_with_tasks(&[(TaskStatus::Completed, Some("x"))]);
let out = build_fallback(&graph, &make_sanitizer(), 4096);
assert!(out.contains("test goal"));
}
mod mock_tests {
use super::*;
use zeph_llm::mock::MockProvider;
fn make_config() -> OrchestrationConfig {
OrchestrationConfig {
aggregator_max_tokens: 1024,
..OrchestrationConfig::default()
}
}
#[tokio::test]
async fn test_aggregate_calls_llm_and_returns_synthesis() {
let provider = MockProvider::with_responses(vec!["synthesized result".to_string()]);
let agg = LlmAggregator::new(provider, &make_config());
let graph = make_graph_with_tasks(&[(TaskStatus::Completed, Some("task output"))]);
let result = agg.aggregate(&graph).await.unwrap();
assert_eq!(result, "synthesized result");
}
#[tokio::test]
async fn test_aggregate_fallback_on_llm_failure() {
let provider = MockProvider::failing();
let agg = LlmAggregator::new(provider, &make_config());
let graph = make_graph_with_tasks(&[(TaskStatus::Completed, Some("raw output"))]);
let result = agg.aggregate(&graph).await.unwrap();
assert!(
result.contains("raw output"),
"fallback should have raw output"
);
}
#[tokio::test]
async fn test_aggregate_error_when_no_tasks() {
let provider = MockProvider::default();
let agg = LlmAggregator::new(provider, &make_config());
let graph = TaskGraph::new("empty goal");
let err = agg.aggregate(&graph).await.unwrap_err();
assert!(matches!(err, OrchestrationError::AggregationFailed(_)));
}
#[tokio::test]
async fn test_aggregate_includes_skipped_in_prompt() {
let provider = MockProvider::failing();
let agg = LlmAggregator::new(provider, &make_config());
let mut graph = make_graph_with_tasks(&[
(TaskStatus::Completed, Some("ok")),
(TaskStatus::Skipped, None),
]);
graph.tasks[1].description = "unique-skipped-description".to_string();
let result = agg.aggregate(&graph).await.unwrap();
assert!(
result.contains("task-1") || result.contains("skipped"),
"fallback must include skipped task info; got: {result}"
);
assert!(
result.contains("unique-skipped-description"),
"fallback must include skipped task description; got: {result}"
);
}
#[tokio::test]
async fn test_aggregate_per_task_budget_truncation() {
let config = OrchestrationConfig {
aggregator_max_tokens: 1, ..OrchestrationConfig::default()
};
let provider = MockProvider::failing();
let agg = LlmAggregator::new(provider, &config);
let long_output = "a".repeat(1000);
let graph = make_graph_with_tasks(&[(TaskStatus::Completed, Some(&long_output))]);
let result = agg.aggregate(&graph).await.unwrap();
assert!(
!result.contains(&"a".repeat(5)),
"with 4-char budget, no sequence of >=5 'a' chars should appear; \
result len={}, result={result:?}",
result.len()
);
}
}
}