use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::collections::HashMap;
use swarm_engine_core::actions::ActionDef;
use swarm_engine_core::agent::{BatchDecisionRequest, DecisionResponse, WorkerDecisionRequest};
use swarm_engine_core::exploration::{DependencyGraph, SelectResult};
use swarm_engine_core::types::{LoraConfig, WorkerId};
use crate::decider::{LlmDecider, LlmError};
pub type BatchProcessResult = Vec<(WorkerId, Result<DecisionResponse, BatchProcessError>)>;
#[derive(Debug, Clone, thiserror::Error)]
pub enum BatchProcessError {
#[error("Batch process error (transient): {0}")]
Transient(String),
#[error("Batch process error: {0}")]
Permanent(String),
}
impl BatchProcessError {
pub fn transient(message: impl Into<String>) -> Self {
Self::Transient(message.into())
}
pub fn permanent(message: impl Into<String>) -> Self {
Self::Permanent(message.into())
}
pub fn is_transient(&self) -> bool {
matches!(self, Self::Transient(_))
}
pub fn message(&self) -> &str {
match self {
Self::Transient(msg) => msg,
Self::Permanent(msg) => msg,
}
}
}
impl From<LlmError> for BatchProcessError {
fn from(e: LlmError) -> Self {
if e.is_transient() {
Self::Transient(e.message().to_string())
} else {
Self::Permanent(e.message().to_string())
}
}
}
impl From<swarm_engine_core::error::SwarmError> for BatchProcessError {
fn from(err: swarm_engine_core::error::SwarmError) -> Self {
if err.is_transient() {
Self::Transient(err.message())
} else {
Self::Permanent(err.message())
}
}
}
impl From<BatchProcessError> for swarm_engine_core::error::SwarmError {
fn from(err: BatchProcessError) -> Self {
match err {
BatchProcessError::Transient(message) => {
swarm_engine_core::error::SwarmError::LlmTransient { message }
}
BatchProcessError::Permanent(message) => {
swarm_engine_core::error::SwarmError::LlmPermanent { message }
}
}
}
}
pub trait BatchProcessor: Send + Sync {
fn process(
&self,
request: BatchDecisionRequest,
) -> Pin<Box<dyn Future<Output = BatchProcessResult> + Send + '_>>;
fn plan_dependencies(
&self,
_task: &str,
_actions: &[ActionDef],
_hint: Option<&SelectResult>,
) -> Pin<Box<dyn Future<Output = Option<DependencyGraph>> + Send + '_>> {
Box::pin(async { None })
}
fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>>;
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct LlmBatchProcessorConfig {
pub parallel: bool,
pub max_concurrency: usize,
pub max_retries: Option<usize>,
}
impl Default for LlmBatchProcessorConfig {
fn default() -> Self {
Self {
parallel: true,
max_concurrency: 4,
max_retries: Some(5),
}
}
}
pub struct LlmBatchProcessor<D: LlmDecider> {
decider: Arc<D>,
config: LlmBatchProcessorConfig,
}
impl<D: LlmDecider> LlmBatchProcessor<D> {
pub fn new(decider: D) -> Self {
Self {
decider: Arc::new(decider),
config: LlmBatchProcessorConfig::default(),
}
}
pub fn from_arc(decider: Arc<D>) -> Self {
Self {
decider,
config: LlmBatchProcessorConfig::default(),
}
}
pub fn with_config(mut self, config: LlmBatchProcessorConfig) -> Self {
self.config = config;
self
}
}
impl<D: LlmDecider + 'static> BatchProcessor for LlmBatchProcessor<D> {
fn process(
&self,
request: BatchDecisionRequest,
) -> Pin<Box<dyn Future<Output = BatchProcessResult> + Send + '_>> {
Box::pin(async move {
if request.requests.is_empty() {
return vec![];
}
let requests: Vec<(WorkerId, WorkerDecisionRequest)> = request
.requests
.into_iter()
.map(|r| (r.worker_id, r))
.collect();
if self.config.parallel {
self.process_parallel(requests).await
} else {
self.process_sequential(requests).await
}
})
}
fn plan_dependencies(
&self,
task: &str,
actions: &[ActionDef],
hint: Option<&SelectResult>,
) -> Pin<Box<dyn Future<Output = Option<DependencyGraph>> + Send + '_>> {
let task = task.to_string();
let actions: Vec<ActionDef> = actions.to_vec();
let decider = Arc::clone(&self.decider);
let (lora, vote_count) = match hint {
Some(SelectResult::UseLlm {
lora,
vote_count,
match_rate,
..
}) => {
tracing::debug!(
match_rate = match_rate,
vote_count = vote_count,
has_lora = lora.is_some(),
"Using SelectResult hint for plan_dependencies"
);
(lora.clone(), *vote_count)
}
_ => {
tracing::debug!("No SelectResult hint, using defaults (lora=None, vote_count=3)");
(None, 3)
}
};
Box::pin(async move {
use std::time::Instant;
use swarm_engine_core::actions::ActionCategory;
use swarm_engine_core::exploration::DependencyGraphBuilder;
let start_time = Instant::now();
let action_names: Vec<String> = actions.iter().map(|a| a.name.clone()).collect();
let discover: Vec<&ActionDef> = actions
.iter()
.filter(|a| a.category == ActionCategory::NodeExpand)
.collect();
let not_discover: Vec<&ActionDef> = actions
.iter()
.filter(|a| a.category == ActionCategory::NodeStateChange)
.collect();
tracing::debug!(
discover = ?discover.iter().map(|a| &a.name).collect::<Vec<_>>(),
not_discover = ?not_discover.iter().map(|a| &a.name).collect::<Vec<_>>(),
"Separated actions by category"
);
let discover_sort_start = Instant::now();
let sorted_discover = if discover.len() <= 1 {
discover.iter().map(|a| a.name.clone()).collect()
} else {
binary_sort_actions(&task, &discover, decider.as_ref(), lora.as_ref(), vote_count).await
};
let discover_sort_ms = discover_sort_start.elapsed().as_millis();
tracing::debug!(
sorted = ?sorted_discover,
elapsed_ms = discover_sort_ms,
vote_count = vote_count,
has_lora = lora.is_some(),
"Sorted Discover actions via binary comparison"
);
let not_discover_sort_start = Instant::now();
let sorted_not_discover = if not_discover.len() <= 1 {
not_discover.iter().map(|a| a.name.clone()).collect()
} else {
binary_sort_actions(&task, ¬_discover, decider.as_ref(), lora.as_ref(), vote_count).await
};
let not_discover_sort_ms = not_discover_sort_start.elapsed().as_millis();
tracing::debug!(
sorted = ?sorted_not_discover,
elapsed_ms = not_discover_sort_ms,
"Sorted NotDiscover actions via binary comparison"
);
let mut builder = DependencyGraphBuilder::new()
.task(&task)
.available_actions(action_names.clone());
if !sorted_discover.is_empty() {
builder = builder.start_node(&sorted_discover[0]);
} else if !sorted_not_discover.is_empty() {
builder = builder.start_node(&sorted_not_discover[0]);
}
if let Some(last) = sorted_not_discover.last() {
builder = builder.terminal_node(last);
} else if !sorted_discover.is_empty() {
builder = builder.terminal_node(sorted_discover.last().unwrap());
}
for window in sorted_discover.windows(2) {
builder = builder.edge(&window[0], &window[1], 0.9);
}
if !sorted_discover.is_empty() && !sorted_not_discover.is_empty() {
builder = builder.edge(
sorted_discover.last().unwrap(),
&sorted_not_discover[0],
0.9,
);
}
for window in sorted_not_discover.windows(2) {
builder = builder.edge(&window[0], &window[1], 0.9);
}
let mut graph = builder.build();
let total_ms = start_time.elapsed().as_millis();
graph.set_action_order(sorted_discover.clone(), sorted_not_discover.clone());
{
use swarm_engine_core::events::{LearningEvent, LearningEventChannel};
use swarm_engine_core::learn::DependencyGraphRecord;
let prompt = format!(
"Task: {}\n\nAvailable Actions:\n{}",
task,
action_names
.iter()
.map(|n| format!("- {}", n))
.collect::<Vec<_>>()
.join("\n")
);
let response = format!(
"discover_order: {:?}\nnot_discover_order: {:?}",
sorted_discover, sorted_not_discover
);
let event = LearningEvent::dependency_graph_inference(decider.model_name())
.prompt(&prompt)
.response(&response)
.available_actions(action_names)
.discover_order(sorted_discover.clone())
.not_discover_order(sorted_not_discover.clone())
.endpoint(decider.endpoint())
.latency_ms(total_ms as u64)
.success()
.build();
LearningEventChannel::global().emit(event.clone());
let record = DependencyGraphRecord::from(&event);
graph.set_learn_record(record);
}
tracing::info!(
discover_order = ?sorted_discover,
not_discover_order = ?sorted_not_discover,
edges = graph.edges().len(),
discover_sort_ms = discover_sort_ms,
not_discover_sort_ms = not_discover_sort_ms,
total_ms = total_ms,
"DependencyGraph generated via LLM binary sort"
);
Some(graph)
})
}
fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
let decider = Arc::clone(&self.decider);
Box::pin(async move { decider.is_healthy().await })
}
fn name(&self) -> &str {
self.decider.model_name()
}
}
impl<D: LlmDecider + 'static> LlmBatchProcessor<D> {
async fn process_parallel(
&self,
requests: Vec<(WorkerId, WorkerDecisionRequest)>,
) -> BatchProcessResult {
let grouped = group_by_lora(requests);
let group_count = grouped.len();
if group_count > 1 {
tracing::debug!(
groups = group_count,
"Processing requests in {} LoRA groups",
group_count
);
}
let mut all_results = Vec::new();
for (lora_config, group_requests) in grouped {
if group_count > 1 {
tracing::trace!(
lora = ?lora_config,
count = group_requests.len(),
"Processing LoRA group"
);
}
let results = self.process_group(group_requests).await;
all_results.extend(results);
}
all_results
}
async fn process_group(
&self,
requests: Vec<(WorkerId, WorkerDecisionRequest)>,
) -> BatchProcessResult {
use futures::future::join_all;
use tokio::sync::Semaphore;
let max_concurrency = self
.decider
.max_concurrency()
.await
.unwrap_or(self.config.max_concurrency);
let semaphore = Arc::new(Semaphore::new(max_concurrency));
let futures: Vec<_> = requests
.into_iter()
.map(|(worker_id, req)| {
let decider = Arc::clone(&self.decider);
let sem = Arc::clone(&semaphore);
async move {
let _permit = sem.acquire().await.expect("Semaphore closed");
let result = decider.decide(req).await;
(worker_id, result)
}
})
.collect();
let results = join_all(futures).await;
results
.into_iter()
.map(|(worker_id, result)| {
let mapped = result.map_err(BatchProcessError::from);
(worker_id, mapped)
})
.collect()
}
async fn process_sequential(
&self,
requests: Vec<(WorkerId, WorkerDecisionRequest)>,
) -> BatchProcessResult {
let mut results = Vec::with_capacity(requests.len());
for (worker_id, req) in requests {
let result = self.decider.decide(req).await;
let mapped = result.map_err(BatchProcessError::from);
results.push((worker_id, mapped));
}
results
}
}
fn group_by_lora(
requests: Vec<(WorkerId, WorkerDecisionRequest)>,
) -> HashMap<Option<LoraConfig>, Vec<(WorkerId, WorkerDecisionRequest)>> {
let mut groups: HashMap<Option<LoraConfig>, Vec<(WorkerId, WorkerDecisionRequest)>> =
HashMap::new();
for (worker_id, req) in requests {
let lora_key = req.lora.clone();
groups.entry(lora_key).or_default().push((worker_id, req));
}
groups
}
async fn binary_sort_actions<D: LlmDecider>(
task: &str,
actions: &[&ActionDef],
decider: &D,
lora: Option<&LoraConfig>,
vote_count: u8,
) -> Vec<String> {
use futures::future::join_all;
use std::collections::HashMap;
if actions.len() <= 1 {
return actions.iter().map(|a| a.name.clone()).collect();
}
let mut requests: Vec<(usize, usize, String, String, String)> = Vec::new();
let mut pair_index = 0;
for i in 0..actions.len() {
for j in (i + 1)..actions.len() {
let a = actions[i];
let b = actions[j];
let prompt = format!(
"Goal: {}\n- {}: {}\n- {}: {}\nWhich comes first: {} or {}?\nAnswer (one word):",
task, a.name, a.description, b.name, b.description, a.name, b.name
);
for vote_idx in 0..vote_count as usize {
requests.push((
pair_index,
vote_idx,
prompt.clone(),
a.name.clone(),
b.name.clone(),
));
}
pair_index += 1;
}
}
let total_requests = requests.len();
tracing::debug!(
pairs = pair_index,
total_requests = total_requests,
"Binary sort: sending batch requests"
);
let futures: Vec<_> = requests
.into_iter()
.map(|(pair_idx, vote_idx, prompt, a_name, b_name)| {
let decider_ref = decider;
async move {
let result = decider_ref.call_raw(&prompt, lora).await;
(pair_idx, vote_idx, result, a_name, b_name)
}
})
.collect();
let results = join_all(futures).await;
let mut pair_votes: HashMap<usize, (usize, usize, String, String)> = HashMap::new();
for (pair_idx, _vote_idx, result, a_name, b_name) in results {
let entry = pair_votes
.entry(pair_idx)
.or_insert((0, 0, a_name.clone(), b_name.clone()));
if let Ok(response) = result {
let response_upper = response.to_uppercase();
let a_upper = a_name.to_uppercase();
let b_upper = b_name.to_uppercase();
if response_upper.contains(&a_upper) {
entry.0 += 1;
} else if response_upper.contains(&b_upper) {
entry.1 += 1;
}
}
}
let mut wins: HashMap<String, usize> = HashMap::new();
for a in actions {
wins.insert(a.name.clone(), 0);
}
for (_pair_idx, (a_count, b_count, a_name, b_name)) in pair_votes {
if a_count >= b_count {
*wins.get_mut(&b_name).unwrap() += 1;
} else {
*wins.get_mut(&a_name).unwrap() += 1;
}
}
let mut sorted: Vec<_> = wins.into_iter().collect();
sorted.sort_by_key(|(_, count)| *count);
tracing::debug!(
sorted = ?sorted.iter().map(|(n, c)| format!("{}:{}", n, c)).collect::<Vec<_>>(),
"Binary sort completed"
);
sorted.into_iter().map(|(name, _)| name).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_process_error_transient() {
let err = BatchProcessError::transient("connection timeout");
assert!(err.is_transient());
assert_eq!(err.message(), "connection timeout");
}
#[test]
fn test_batch_process_error_permanent() {
let err = BatchProcessError::permanent("invalid model");
assert!(!err.is_transient());
assert_eq!(err.message(), "invalid model");
}
#[test]
fn test_batch_process_error_from_llm_error() {
let llm_err = LlmError::transient("timeout");
let batch_err: BatchProcessError = llm_err.into();
assert!(batch_err.is_transient());
assert_eq!(batch_err.message(), "timeout");
}
#[test]
fn test_ollama_batch_processor_config_default() {
let config = LlmBatchProcessorConfig::default();
assert!(config.parallel);
assert_eq!(config.max_concurrency, 4);
}
use std::collections::HashMap;
fn binary_sort_sync(
actions: &[&str],
comparator: impl Fn(&str, &str) -> String,
) -> Vec<String> {
if actions.len() <= 1 {
return actions.iter().map(|s| s.to_string()).collect();
}
let mut wins: HashMap<String, usize> = HashMap::new();
for &a in actions {
wins.insert(a.to_string(), 0);
}
for i in 0..actions.len() {
for j in (i + 1)..actions.len() {
let a = actions[i];
let b = actions[j];
let winner = comparator(a, b);
if winner == a {
*wins.get_mut(b).unwrap() += 1;
} else {
*wins.get_mut(a).unwrap() += 1;
}
}
}
let mut sorted: Vec<_> = wins.into_iter().collect();
sorted.sort_by_key(|(_, count)| *count);
sorted.into_iter().map(|(name, _)| name).collect()
}
#[test]
fn test_binary_sort_two_actions() {
let result = binary_sort_sync(
&["Fetch", "Summarize"],
|a, _b| a.to_string(), );
assert_eq!(result, vec!["Fetch", "Summarize"]);
let result = binary_sort_sync(
&["Fetch", "Summarize"],
|_a, b| b.to_string(), );
assert_eq!(result, vec!["Summarize", "Fetch"]);
}
#[test]
fn test_binary_sort_three_actions() {
let result = binary_sort_sync(&["Test", "Deploy", "Build"], |a, b| {
let order = ["Build", "Test", "Deploy"];
let a_idx = order.iter().position(|&x| x == a).unwrap();
let b_idx = order.iter().position(|&x| x == b).unwrap();
if a_idx < b_idx {
a.to_string()
} else {
b.to_string()
}
});
assert_eq!(result, vec!["Build", "Test", "Deploy"]);
}
#[test]
fn test_binary_sort_wins_calculation() {
let mut wins: HashMap<String, usize> = HashMap::new();
wins.insert("A".to_string(), 0);
wins.insert("B".to_string(), 0);
wins.insert("C".to_string(), 0);
*wins.get_mut("B").unwrap() += 1;
*wins.get_mut("C").unwrap() += 1;
*wins.get_mut("C").unwrap() += 1;
assert_eq!(wins["A"], 0);
assert_eq!(wins["B"], 1);
assert_eq!(wins["C"], 2);
let mut sorted: Vec<_> = wins.into_iter().collect();
sorted.sort_by_key(|(_, count)| *count);
let result: Vec<_> = sorted.into_iter().map(|(name, _)| name).collect();
assert_eq!(result, vec!["A", "B", "C"]);
}
fn extract_winner(response: &str, a: &str, b: &str) -> Option<String> {
let response_upper = response.to_uppercase();
let a_upper = a.to_uppercase();
let b_upper = b.to_uppercase();
if response_upper.contains(&a_upper) {
Some(a.to_string())
} else if response_upper.contains(&b_upper) {
Some(b.to_string())
} else {
None
}
}
#[test]
fn test_extract_winner() {
assert_eq!(
extract_winner("Fetch", "Fetch", "Summarize"),
Some("Fetch".to_string())
);
assert_eq!(
extract_winner("Summarize", "Fetch", "Summarize"),
Some("Summarize".to_string())
);
assert_eq!(
extract_winner(" Fetch", "Fetch", "Summarize"),
Some("Fetch".to_string())
);
assert_eq!(
extract_winner("fetch", "Fetch", "Summarize"),
Some("Fetch".to_string())
);
assert_eq!(
extract_winner("FETCH", "Fetch", "Summarize"),
Some("Fetch".to_string())
);
assert_eq!(
extract_winner("The answer is Fetch.", "Fetch", "Summarize"),
Some("Fetch".to_string())
);
assert_eq!(extract_winner("Unknown", "Fetch", "Summarize"), None);
assert_eq!(
extract_winner("Fetch then Summarize", "Fetch", "Summarize"),
Some("Fetch".to_string())
);
}
#[test]
fn test_vote_majority() {
fn vote_majority(responses: &[&str], a: &str, b: &str) -> String {
let mut a_count = 0;
let mut b_count = 0;
for response in responses {
if let Some(winner) = extract_winner(response, a, b) {
if winner == a {
a_count += 1;
} else {
b_count += 1;
}
}
}
if a_count >= b_count {
a.to_string()
} else {
b.to_string()
}
}
assert_eq!(
vote_majority(&["Fetch", "Fetch", "Fetch"], "Fetch", "Summarize"),
"Fetch"
);
assert_eq!(
vote_majority(&["Fetch", "Summarize", "Fetch"], "Fetch", "Summarize"),
"Fetch"
);
assert_eq!(
vote_majority(&["Summarize", "Summarize", "Fetch"], "Fetch", "Summarize"),
"Summarize"
);
assert_eq!(
vote_majority(&["Fetch", "Summarize", "Unknown"], "Fetch", "Summarize"),
"Fetch"
);
}
use swarm_engine_core::context::{ContextTarget, GlobalContext, ResolvedContext};
fn create_test_request(
worker_id: usize,
lora: Option<LoraConfig>,
) -> (WorkerId, WorkerDecisionRequest) {
let global = GlobalContext {
tick: 0,
max_ticks: 100,
progress: 0.0,
success_rate: 0.0,
task_description: Some("test".to_string()),
hint: None,
};
let context = ResolvedContext::new(global, ContextTarget::Worker(WorkerId(worker_id)));
(
WorkerId(worker_id),
WorkerDecisionRequest {
worker_id: WorkerId(worker_id),
query: format!("query_{}", worker_id),
context,
lora,
},
)
}
#[test]
fn test_group_by_lora_single_group_no_lora() {
let requests = vec![
create_test_request(0, None),
create_test_request(1, None),
create_test_request(2, None),
];
let groups = group_by_lora(requests);
assert_eq!(groups.len(), 1);
assert!(groups.contains_key(&None));
assert_eq!(groups[&None].len(), 3);
}
#[test]
fn test_group_by_lora_single_group_with_lora() {
let lora = LoraConfig::with_id(0);
let requests = vec![
create_test_request(0, Some(lora.clone())),
create_test_request(1, Some(lora.clone())),
];
let groups = group_by_lora(requests);
assert_eq!(groups.len(), 1);
assert!(groups.contains_key(&Some(lora)));
}
#[test]
fn test_group_by_lora_multiple_groups() {
let lora_a = LoraConfig::with_id(0);
let lora_b = LoraConfig::with_id(1);
let requests = vec![
create_test_request(0, Some(lora_a.clone())),
create_test_request(1, Some(lora_b.clone())),
create_test_request(2, Some(lora_a.clone())),
create_test_request(3, None),
create_test_request(4, Some(lora_b.clone())),
];
let groups = group_by_lora(requests);
assert_eq!(groups.len(), 3);
assert_eq!(groups[&Some(lora_a)].len(), 2);
assert_eq!(groups[&Some(lora_b)].len(), 2);
assert_eq!(groups[&None].len(), 1);
}
#[test]
fn test_group_by_lora_preserves_order_within_group() {
let lora = LoraConfig::with_id(0);
let requests = vec![
create_test_request(5, Some(lora.clone())),
create_test_request(3, Some(lora.clone())),
create_test_request(7, Some(lora.clone())),
];
let groups = group_by_lora(requests);
let group = &groups[&Some(lora)];
assert_eq!(group[0].0, WorkerId(5));
assert_eq!(group[1].0, WorkerId(3));
assert_eq!(group[2].0, WorkerId(7));
}
#[test]
fn test_group_by_lora_different_scales() {
let lora_full = LoraConfig::new(0, 1.0);
let lora_half = LoraConfig::new(0, 0.5);
let requests = vec![
create_test_request(0, Some(lora_full.clone())),
create_test_request(1, Some(lora_half.clone())),
create_test_request(2, Some(lora_full.clone())),
];
let groups = group_by_lora(requests);
assert_eq!(groups.len(), 2);
assert_eq!(groups[&Some(lora_full)].len(), 2);
assert_eq!(groups[&Some(lora_half)].len(), 1);
}
#[test]
fn test_group_by_lora_empty() {
let requests: Vec<(WorkerId, WorkerDecisionRequest)> = vec![];
let groups = group_by_lora(requests);
assert!(groups.is_empty());
}
}