Skip to main content

bn/commands/run/
wave.rs

1use std::collections::HashSet;
2use std::path::Path;
3use std::process::Command;
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6
7use anyhow::Result;
8
9use crate::bean::Status;
10use crate::index::Index;
11use crate::stream::{self, StreamEvent};
12use crate::util::natural_cmp;
13
14use super::plan::SizedBean;
15use super::ready_queue::run_single_direct;
16use super::{AgentResult, BeanAction, SpawnMode};
17
18/// A wave of beans that can be dispatched in parallel.
19pub struct Wave {
20    pub beans: Vec<SizedBean>,
21}
22
23/// Compute waves of beans grouped by dependency order.
24/// Wave 0: no deps. Wave 1: deps all in wave 0. Etc.
25pub(super) fn compute_waves(beans: &[SizedBean], index: &Index) -> Vec<Wave> {
26    let mut waves = Vec::new();
27    let bean_ids: HashSet<String> = beans.iter().map(|b| b.id.clone()).collect();
28
29    // Already-closed beans count as completed
30    let mut completed: HashSet<String> = index
31        .beans
32        .iter()
33        .filter(|e| e.status == Status::Closed)
34        .map(|e| e.id.clone())
35        .collect();
36
37    let mut remaining: Vec<SizedBean> = beans.to_vec();
38
39    while !remaining.is_empty() {
40        let (ready, blocked): (Vec<SizedBean>, Vec<SizedBean>) =
41            remaining.into_iter().partition(|b| {
42                // All explicit deps must be completed or not in our dispatch set
43                let explicit_ok = b
44                    .dependencies
45                    .iter()
46                    .all(|d| completed.contains(d) || !bean_ids.contains(d));
47
48                // All requires must be satisfied (producer completed or not in set)
49                let requires_ok = b.requires.iter().all(|req| {
50                    // Find the sibling producer for this artifact
51                    if let Some(producer) = beans.iter().find(|other| {
52                        other.id != b.id && other.parent == b.parent && other.produces.contains(req)
53                    }) {
54                        completed.contains(&producer.id)
55                    } else {
56                        true // No producer in set, assume satisfied
57                    }
58                });
59
60                explicit_ok && requires_ok
61            });
62
63        if ready.is_empty() {
64            // Remaining beans have unresolvable deps (cycle or missing)
65            // Add them all as a final wave to avoid losing them
66            eprintln!(
67                "Warning: {} bean(s) have unresolvable dependencies, adding to final wave",
68                blocked.len()
69            );
70            waves.push(Wave { beans: blocked });
71            break;
72        }
73
74        for b in &ready {
75            completed.insert(b.id.clone());
76        }
77
78        waves.push(Wave { beans: ready });
79        remaining = blocked;
80    }
81
82    // Sort beans within each wave by priority then ID
83    for wave in &mut waves {
84        wave.beans.sort_by(|a, b| {
85            a.priority
86                .cmp(&b.priority)
87                .then_with(|| natural_cmp(&a.id, &b.id))
88        });
89    }
90
91    waves
92}
93
94// ---------------------------------------------------------------------------
95// Wave execution
96// ---------------------------------------------------------------------------
97
98/// Spawn agents for a wave of beans, respecting max parallelism.
99pub(super) fn run_wave(
100    beans_dir: &Path,
101    beans: &[SizedBean],
102    spawn_mode: &SpawnMode,
103    cfg: &super::RunConfig,
104    wave_number: usize,
105) -> Result<Vec<AgentResult>> {
106    match spawn_mode {
107        SpawnMode::Template {
108            run_template,
109            plan_template,
110        } => run_wave_template(
111            beans,
112            run_template,
113            plan_template.as_deref(),
114            cfg.max_jobs,
115            cfg.timeout_minutes,
116        ),
117        SpawnMode::Direct => run_wave_direct(
118            beans_dir,
119            beans,
120            cfg.max_jobs,
121            cfg.timeout_minutes,
122            cfg.idle_timeout_minutes,
123            cfg.json_stream,
124            wave_number,
125            cfg.file_locking,
126        ),
127    }
128}
129
130/// Template mode: spawn agents via `sh -c <template>` (backward compat).
131fn run_wave_template(
132    beans: &[SizedBean],
133    run_template: &str,
134    _plan_template: Option<&str>,
135    max_jobs: usize,
136    _timeout_minutes: u32,
137) -> Result<Vec<AgentResult>> {
138    let mut results = Vec::new();
139    let mut children: Vec<(SizedBean, std::process::Child, Instant)> = Vec::new();
140
141    let mut pending: Vec<&SizedBean> = beans.iter().collect();
142
143    while !pending.is_empty() || !children.is_empty() {
144        // Spawn up to max_jobs
145        while children.len() < max_jobs && !pending.is_empty() {
146            let sb = pending.remove(0);
147            let template = match sb.action {
148                BeanAction::Implement => run_template,
149            };
150
151            let cmd = template.replace("{id}", &sb.id);
152            match Command::new("sh").args(["-c", &cmd]).spawn() {
153                Ok(child) => {
154                    children.push((sb.clone(), child, Instant::now()));
155                }
156                Err(e) => {
157                    eprintln!("  Failed to spawn agent for {}: {}", sb.id, e);
158                    results.push(AgentResult {
159                        id: sb.id.clone(),
160                        title: sb.title.clone(),
161                        action: sb.action,
162                        success: false,
163                        duration: Duration::ZERO,
164                        total_tokens: None,
165                        total_cost: None,
166                        error: Some(format!("Failed to spawn: {}", e)),
167                        tool_count: 0,
168                        turns: 0,
169                        failure_summary: None,
170                    });
171                }
172            }
173        }
174
175        if children.is_empty() {
176            break;
177        }
178
179        // Poll for completions
180        let mut still_running = Vec::new();
181        for (sb, mut child, started) in children {
182            match child.try_wait() {
183                Ok(Some(status)) => {
184                    let err = if status.success() {
185                        None
186                    } else {
187                        Some(format!("Exit code {}", status.code().unwrap_or(-1)))
188                    };
189                    results.push(AgentResult {
190                        id: sb.id.clone(),
191                        title: sb.title.clone(),
192                        action: sb.action,
193                        success: status.success(),
194                        duration: started.elapsed(),
195                        total_tokens: None,
196                        total_cost: None,
197                        error: err,
198                        tool_count: 0,
199                        turns: 0,
200                        failure_summary: None,
201                    });
202                }
203                Ok(None) => {
204                    still_running.push((sb, child, started));
205                }
206                Err(e) => {
207                    eprintln!("  Error checking agent for {}: {}", sb.id, e);
208                    results.push(AgentResult {
209                        id: sb.id.clone(),
210                        title: sb.title.clone(),
211                        action: sb.action,
212                        success: false,
213                        duration: started.elapsed(),
214                        total_tokens: None,
215                        total_cost: None,
216                        error: Some(format!("Error checking process: {}", e)),
217                        tool_count: 0,
218                        turns: 0,
219                        failure_summary: None,
220                    });
221                }
222            }
223        }
224        children = still_running;
225
226        if !children.is_empty() {
227            std::thread::sleep(Duration::from_millis(500));
228        }
229    }
230
231    Ok(results)
232}
233
234/// Direct mode: spawn pi directly with JSON output and monitoring.
235#[allow(clippy::too_many_arguments)]
236fn run_wave_direct(
237    beans_dir: &Path,
238    beans: &[SizedBean],
239    max_jobs: usize,
240    timeout_minutes: u32,
241    idle_timeout_minutes: u32,
242    json_stream: bool,
243    wave_number: usize,
244    file_locking: bool,
245) -> Result<Vec<AgentResult>> {
246    let results = Arc::new(Mutex::new(Vec::new()));
247    let mut pending: Vec<SizedBean> = beans.to_vec();
248    let mut handles: Vec<std::thread::JoinHandle<()>> = Vec::new();
249
250    while !pending.is_empty() || !handles.is_empty() {
251        // Spawn up to max_jobs threads
252        while handles.len() < max_jobs && !pending.is_empty() {
253            let sb = pending.remove(0);
254            let beans_dir = beans_dir.to_path_buf();
255            let results = Arc::clone(&results);
256            let timeout_min = timeout_minutes;
257            let idle_min = idle_timeout_minutes;
258
259            if json_stream {
260                stream::emit(&StreamEvent::BeanStart {
261                    id: sb.id.clone(),
262                    title: sb.title.clone(),
263                    round: wave_number,
264                    file_overlaps: None,
265                    attempt: None,
266                    priority: None,
267                });
268            }
269
270            let handle = std::thread::spawn(move || {
271                let result = run_single_direct(
272                    &beans_dir,
273                    &sb,
274                    timeout_min,
275                    idle_min,
276                    json_stream,
277                    file_locking,
278                );
279                results.lock().unwrap().push(result);
280            });
281            handles.push(handle);
282        }
283
284        // Wait for at least one thread to finish
285        let prev_count = handles.len();
286        let mut still_running = Vec::new();
287        for handle in handles.drain(..) {
288            if handle.is_finished() {
289                let _ = handle.join();
290            } else {
291                still_running.push(handle);
292            }
293        }
294
295        // If nothing finished, wait briefly before polling again
296        if still_running.len() == prev_count && !still_running.is_empty() {
297            std::thread::sleep(Duration::from_millis(200));
298        }
299
300        handles = still_running;
301    }
302
303    // Wait for any remaining threads
304    for handle in handles {
305        let _ = handle.join();
306    }
307
308    Ok(Arc::try_unwrap(results).unwrap().into_inner().unwrap())
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::commands::run::BeanAction;
315    use crate::index::Index;
316
317    #[test]
318    fn compute_waves_no_deps() {
319        let index = Index { beans: vec![] };
320        let beans = vec![
321            SizedBean {
322                id: "1".to_string(),
323                title: "A".to_string(),
324                action: BeanAction::Implement,
325                priority: 2,
326                dependencies: vec![],
327                parent: None,
328                produces: vec![],
329                requires: vec![],
330                paths: vec![],
331            },
332            SizedBean {
333                id: "2".to_string(),
334                title: "B".to_string(),
335                action: BeanAction::Implement,
336                priority: 2,
337                dependencies: vec![],
338                parent: None,
339                produces: vec![],
340                requires: vec![],
341                paths: vec![],
342            },
343        ];
344        let waves = compute_waves(&beans, &index);
345        assert_eq!(waves.len(), 1);
346        assert_eq!(waves[0].beans.len(), 2);
347    }
348
349    #[test]
350    fn compute_waves_linear_chain() {
351        let index = Index { beans: vec![] };
352        let beans = vec![
353            SizedBean {
354                id: "1".to_string(),
355                title: "A".to_string(),
356                action: BeanAction::Implement,
357                priority: 2,
358                dependencies: vec![],
359                parent: None,
360                produces: vec![],
361                requires: vec![],
362                paths: vec![],
363            },
364            SizedBean {
365                id: "2".to_string(),
366                title: "B".to_string(),
367                action: BeanAction::Implement,
368                priority: 2,
369                dependencies: vec!["1".to_string()],
370                parent: None,
371                produces: vec![],
372                requires: vec![],
373                paths: vec![],
374            },
375            SizedBean {
376                id: "3".to_string(),
377                title: "C".to_string(),
378                action: BeanAction::Implement,
379                priority: 2,
380                dependencies: vec!["2".to_string()],
381                parent: None,
382                produces: vec![],
383                requires: vec![],
384                paths: vec![],
385            },
386        ];
387        let waves = compute_waves(&beans, &index);
388        assert_eq!(waves.len(), 3);
389        assert_eq!(waves[0].beans[0].id, "1");
390        assert_eq!(waves[1].beans[0].id, "2");
391        assert_eq!(waves[2].beans[0].id, "3");
392    }
393
394    #[test]
395    fn compute_waves_diamond() {
396        let index = Index { beans: vec![] };
397        // 1 → (2, 3) → 4
398        let beans = vec![
399            SizedBean {
400                id: "1".to_string(),
401                title: "Root".to_string(),
402                action: BeanAction::Implement,
403                priority: 2,
404                dependencies: vec![],
405                parent: None,
406                produces: vec![],
407                requires: vec![],
408                paths: vec![],
409            },
410            SizedBean {
411                id: "2".to_string(),
412                title: "Left".to_string(),
413                action: BeanAction::Implement,
414                priority: 2,
415                dependencies: vec!["1".to_string()],
416                parent: None,
417                produces: vec![],
418                requires: vec![],
419                paths: vec![],
420            },
421            SizedBean {
422                id: "3".to_string(),
423                title: "Right".to_string(),
424                action: BeanAction::Implement,
425                priority: 2,
426                dependencies: vec!["1".to_string()],
427                parent: None,
428                produces: vec![],
429                requires: vec![],
430                paths: vec![],
431            },
432            SizedBean {
433                id: "4".to_string(),
434                title: "Join".to_string(),
435                action: BeanAction::Implement,
436                priority: 2,
437                dependencies: vec!["2".to_string(), "3".to_string()],
438                parent: None,
439                produces: vec![],
440                requires: vec![],
441                paths: vec![],
442            },
443        ];
444        let waves = compute_waves(&beans, &index);
445        assert_eq!(waves.len(), 3);
446        assert_eq!(waves[0].beans.len(), 1); // 1
447        assert_eq!(waves[1].beans.len(), 2); // 2, 3
448        assert_eq!(waves[2].beans.len(), 1); // 4
449    }
450
451    #[test]
452    fn template_wave_execution_with_echo() {
453        let beans = vec![SizedBean {
454            id: "1".to_string(),
455            title: "Test".to_string(),
456            action: BeanAction::Implement,
457            priority: 2,
458            dependencies: vec![],
459            parent: None,
460            produces: vec![],
461            requires: vec![],
462            paths: vec![],
463        }];
464
465        let results = run_wave_template(&beans, "echo {id}", None, 4, 30).unwrap();
466        assert_eq!(results.len(), 1);
467        assert!(results[0].success);
468        assert_eq!(results[0].id, "1");
469    }
470
471    #[test]
472    fn template_wave_runs_implement_action() {
473        let beans = vec![SizedBean {
474            id: "1".to_string(),
475            title: "Test".to_string(),
476            action: BeanAction::Implement,
477            priority: 2,
478            dependencies: vec![],
479            parent: None,
480            produces: vec![],
481            requires: vec![],
482            paths: vec![],
483        }];
484
485        let results = run_wave_template(&beans, "echo {id}", None, 4, 30).unwrap();
486        assert_eq!(results.len(), 1);
487        assert!(results[0].success);
488        assert_eq!(results[0].id, "1");
489    }
490
491    #[test]
492    fn template_wave_failed_command() {
493        let beans = vec![SizedBean {
494            id: "1".to_string(),
495            title: "Fail".to_string(),
496            action: BeanAction::Implement,
497            priority: 2,
498            dependencies: vec![],
499            parent: None,
500            produces: vec![],
501            requires: vec![],
502            paths: vec![],
503        }];
504
505        let results = run_wave_template(&beans, "false", None, 4, 30).unwrap();
506        assert_eq!(results.len(), 1);
507        assert!(!results[0].success);
508        assert!(results[0].error.is_some());
509    }
510}