use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct SwarmTask {
pub id: String,
pub prompt: String,
pub role: String,
pub agent_name: Option<String>,
pub dependencies: Vec<String>,
pub target_files: Vec<String>,
}
#[derive(Debug)]
pub struct VerificationResult {
pub passed: bool,
pub output: String,
pub command: String,
}
pub(super) struct TaskScheduler {
pub(super) ready_queue: Vec<String>,
pub(super) in_flight_ids: HashSet<String>,
pub(super) in_flight_files: HashSet<String>,
reverse_deps: HashMap<String, Vec<String>>,
file_to_tasks: HashMap<String, Vec<String>>,
}
impl TaskScheduler {
pub(super) fn new(
tasks: &[SwarmTask],
completed: &HashSet<String>,
all_ids: &HashSet<String>,
task_map: &HashMap<String, SwarmTask>,
) -> Self {
let mut reverse_deps: HashMap<String, Vec<String>> = HashMap::new();
let mut file_to_tasks: HashMap<String, Vec<String>> = HashMap::new();
for task in tasks {
for dep in &task.dependencies {
reverse_deps
.entry(dep.clone())
.or_default()
.push(task.id.clone());
}
for f in &task.target_files {
file_to_tasks
.entry(f.clone())
.or_default()
.push(task.id.clone());
}
}
let mut sched = Self {
ready_queue: Vec::new(),
in_flight_ids: HashSet::new(),
in_flight_files: HashSet::new(),
reverse_deps,
file_to_tasks,
};
sched.ready_queue = all_ids
.iter()
.filter(|id| sched.is_task_ready(id, completed, task_map, all_ids))
.cloned()
.collect();
sched
}
pub(super) fn is_task_ready(
&self,
id: &str,
completed: &HashSet<String>,
task_map: &HashMap<String, SwarmTask>,
all_ids: &HashSet<String>,
) -> bool {
!completed.contains(id)
&& !self.in_flight_ids.contains(id)
&& task_map[id]
.dependencies
.iter()
.filter(|d| all_ids.contains(*d))
.all(|d| completed.contains(d))
&& (task_map[id].target_files.is_empty()
|| !task_map[id]
.target_files
.iter()
.any(|f| self.in_flight_files.contains(f)))
}
pub(super) fn drain_ready(
&mut self,
max: usize,
task_map: &HashMap<String, SwarmTask>,
completed: &HashSet<String>,
) -> Vec<String> {
self.ready_queue
.retain(|id| !completed.contains(id) && !self.in_flight_ids.contains(id));
let candidates = std::mem::take(&mut self.ready_queue);
let mut dispatch: Vec<String> = Vec::new();
let mut batch_files: HashSet<String> = HashSet::new();
for id in candidates {
if dispatch.len() >= max {
self.ready_queue.push(id);
continue;
}
let task = &task_map[&id];
if !task.target_files.is_empty()
&& task.target_files.iter().any(|f| batch_files.contains(f))
{
self.ready_queue.push(id);
continue;
}
batch_files.extend(task.target_files.iter().cloned());
dispatch.push(id);
}
dispatch
}
pub(super) fn mark_dispatched(&mut self, id: &str, files: &[String]) {
self.in_flight_ids.insert(id.to_string());
self.in_flight_files.extend(files.iter().cloned());
}
pub(super) fn remove_in_flight(&mut self, id: &str) {
self.in_flight_ids.remove(id);
}
pub(super) fn on_completed(
&mut self,
finished_id: &str,
completed: &HashSet<String>,
task_map: &HashMap<String, SwarmTask>,
all_ids: &HashSet<String>,
) {
let freed_files: Vec<String> = task_map
.get(finished_id)
.map(|t| t.target_files.clone())
.unwrap_or_default();
for f in &freed_files {
self.in_flight_files.remove(f);
}
let mut candidates: HashSet<String> = HashSet::new();
if let Some(dependents) = self.reverse_deps.get(finished_id) {
candidates.extend(dependents.iter().cloned());
}
for f in &freed_files {
if let Some(file_tasks) = self.file_to_tasks.get(f) {
candidates.extend(file_tasks.iter().cloned());
}
}
for cand in candidates {
if self.is_task_ready(&cand, completed, task_map, all_ids) {
self.ready_queue.push(cand);
}
}
}
pub(super) fn add_task(
&mut self,
task: &SwarmTask,
completed: &HashSet<String>,
task_map: &HashMap<String, SwarmTask>,
all_ids: &HashSet<String>,
) {
for dep in &task.dependencies {
self.reverse_deps
.entry(dep.clone())
.or_default()
.push(task.id.clone());
}
for f in &task.target_files {
self.file_to_tasks
.entry(f.clone())
.or_default()
.push(task.id.clone());
}
if self.is_task_ready(&task.id, completed, task_map, all_ids) {
self.ready_queue.push(task.id.clone());
}
}
}
pub(super) fn enforce_file_disjoint(tasks: &mut [SwarmTask]) {
let mut file_owner: HashMap<String, usize> = HashMap::new();
let mut deps_to_add: Vec<(usize, String)> = Vec::new();
for (i, task) in tasks.iter().enumerate() {
for file in &task.target_files {
if let Some(&owner_idx) = file_owner.get(file) {
let owner_id = tasks[owner_idx].id.clone();
if !task.dependencies.contains(&owner_id) {
tracing::info!(
"File overlap on `{file}`: forcing {} to depend on {}",
task.id,
owner_id
);
deps_to_add.push((i, owner_id));
}
} else {
file_owner.insert(file.clone(), i);
}
}
}
for (idx, dep_id) in deps_to_add {
if !tasks[idx].dependencies.contains(&dep_id) {
tasks[idx].dependencies.push(dep_id);
}
}
}
pub(super) fn is_trivially_sequential(msg: &str) -> bool {
let lower = msg.to_lowercase();
let words = msg.split_whitespace().count();
if words < 5 {
return true;
}
let sequential_signals = [
"먼저",
"그다음",
"하나씩",
"순서대로",
"차례로",
"step 1",
"step 2",
"first then",
"sequentially",
"one by one",
"in order",
"第一",
"第二",
"まず",
"次に",
];
if sequential_signals.iter().any(|k| lower.contains(k)) {
return true;
}
let file_refs = msg.matches('/').count()
+ msg.matches(".rs").count()
+ msg.matches(".ts").count()
+ msg.matches(".tsx").count()
+ msg.matches(".py").count()
+ msg.matches(".go").count()
+ msg.matches(".js").count()
+ msg.matches(".java").count()
+ msg.matches(".kt").count();
if words < 20 && file_refs == 1 {
return true;
}
false
}
pub(super) fn parse_task_json(text: &str) -> crate::common::Result<Vec<SwarmTask>> {
let json_str = text
.find('[')
.and_then(|start| text.rfind(']').map(|end| &text[start..=end]))
.unwrap_or("[]");
let parsed: Vec<serde_json::Value> = serde_json::from_str(json_str).map_err(|e| {
crate::common::AgentError::InvalidArgument(format!("Failed to parse task JSON: {e}"))
})?;
let tasks: Vec<SwarmTask> = parsed
.into_iter()
.map(|v| SwarmTask {
id: v["id"].as_str().unwrap_or("t0").to_string(),
prompt: v["prompt"].as_str().unwrap_or("").to_string(),
role: v["role"].as_str().unwrap_or("worker").to_string(),
agent_name: v["agent"]
.as_str()
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty()),
dependencies: v["dependencies"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default(),
target_files: v["target_files"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default(),
})
.filter(|t| !t.prompt.is_empty())
.collect();
Ok(tasks)
}
pub(super) fn truncate(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
let boundary = s
.char_indices()
.map(|(i, _)| i)
.take_while(|&i| i <= max_len)
.last()
.unwrap_or(0);
format!("{}...", &s[..boundary])
}
}