use super::config::*;
use super::core::{agent_loop, agent_loop_continue};
use super::helpers::derive_config_segment;
use crate::types::*;
use chrono::Utc;
use futures::future::join_all;
use std::sync::Arc;
use tokio::sync::mpsc;
async fn run_parallel_branches(
prompts: Vec<AgentMessage>,
contexts: Vec<AgentContext>,
configs: Vec<AgentLoopConfig>,
tx: &mpsc::UnboundedSender<AgentEvent>,
cancel: &tokio_util::sync::CancellationToken,
) -> Vec<ParallelLoopOutcome> {
let branch_futures: Vec<_> = contexts
.into_iter()
.zip(configs)
.enumerate()
.map(|(i, (mut ctx, config))| {
let loop_id = ctx.loop_id.clone().unwrap_or_else(|| {
tracing::warn!(
"run_parallel_branches: branch context {} missing loop_id; this should \
have been set by the dispatcher. Falling back to empty string.",
i
);
String::new()
});
let prompts = prompts.clone();
let main_tx = tx.clone();
let cancel = cancel.clone();
async move {
let (branch_tx, branch_rx) = mpsc::unbounded_channel::<AgentEvent>();
let (usage_tx, usage_rx) = tokio::sync::oneshot::channel::<Usage>();
let original_context_len = ctx.messages.len();
tokio::spawn(async move {
let mut branch_rx = branch_rx;
let mut last_usage = Usage::default();
while let Some(event) = branch_rx.recv().await {
if let AgentEvent::AgentEnd { ref usage, .. } = event {
last_usage = usage.clone();
}
main_tx.send(event).ok();
}
usage_tx.send(last_usage).ok();
});
let new_messages = if prompts.is_empty() {
agent_loop_continue(&mut ctx, &config, branch_tx, cancel).await
} else {
agent_loop(prompts, &mut ctx, &config, branch_tx, cancel).await
};
let usage = usage_rx.await.unwrap_or_default();
ParallelLoopOutcome {
config_index: i,
loop_id,
context: ctx,
new_messages,
usage,
original_context_len,
}
}
})
.collect();
join_all(branch_futures).await
}
pub async fn agent_loop_parallel(
prompts: Vec<AgentMessage>,
mut base_context: AgentContext,
configs: Vec<AgentLoopConfig>,
strategy: Arc<dyn EvaluationStrategy>,
tx: mpsc::UnboundedSender<AgentEvent>,
cancel: tokio_util::sync::CancellationToken,
) -> ParallelLoopResult {
assert!(
!configs.is_empty(),
"agent_loop_parallel requires at least one config"
);
if prompts.is_empty() {
assert!(
!base_context.messages.is_empty(),
"agent_loop_parallel with empty prompts requires non-empty base_context.messages \
(agent_loop_continue mode)"
);
assert!(
base_context.messages.last().map(|m| m.role()) != Some("assistant"),
"agent_loop_parallel with empty prompts requires context NOT ending on an \
assistant message (agent_loop_continue mode)"
);
}
base_context
.agent_id
.get_or_insert_with(|| uuid::Uuid::new_v4().to_string());
let session_id = base_context
.session_id
.get_or_insert_with(|| uuid::Uuid::new_v4().to_string())
.clone();
let loop_ids: Vec<String> = configs
.iter()
.enumerate()
.map(|(i, cfg)| format!("{}.{}.{}", session_id, derive_config_segment(cfg), i + 1))
.collect();
tx.send(AgentEvent::ParallelLoopStart {
session_id: session_id.clone(),
loop_ids: loop_ids.clone(),
timestamp: Utc::now(),
})
.ok();
let branch_contexts: Vec<AgentContext> = loop_ids
.iter()
.map(|lid| {
let mut ctx = base_context.clone();
ctx.loop_id = Some(lid.clone());
ctx
})
.collect();
let outcomes =
run_parallel_branches(prompts.clone(), branch_contexts, configs, &tx, &cancel).await;
let (decision, eval_usage) = strategy.evaluate(&prompts, &outcomes, &tx, cancel).await;
let selected_index = match decision {
EvaluationDecision::Select(i) => i.min(outcomes.len() - 1),
};
tx.send(AgentEvent::ParallelLoopEnd {
session_id,
selected_loop_id: outcomes[selected_index].loop_id.clone(),
selected_config_index: selected_index,
evaluation_usage: eval_usage.clone(),
timestamp: Utc::now(),
})
.ok();
let total_usage = outcomes
.iter()
.fold(Usage::default(), |acc, o| acc.combine(&o.usage))
.combine(&eval_usage);
let mut all_outcomes = outcomes;
let selected = all_outcomes.remove(selected_index);
ParallelLoopResult {
selected_context: selected.context,
selected_messages: selected.new_messages,
selected_index,
all_outcomes,
total_usage,
}
}