Skip to main content

bn/commands/run/
wave.rs

1use std::collections::HashSet;
2use std::path::Path;
3use std::process::Command;
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6
7use anyhow::Result;
8
9use crate::bean::Status;
10use crate::index::Index;
11use crate::stream::{self, StreamEvent};
12use crate::util::natural_cmp;
13
14use super::plan::SizedBean;
15use super::ready_queue::run_single_direct;
16use super::{AgentResult, BeanAction, SpawnMode};
17
18/// A wave of beans that can be dispatched in parallel.
19pub struct Wave {
20    pub beans: Vec<SizedBean>,
21}
22
23/// Compute waves of beans grouped by dependency order.
24/// Wave 0: no deps. Wave 1: deps all in wave 0. Etc.
25pub(super) fn compute_waves(beans: &[SizedBean], index: &Index) -> Vec<Wave> {
26    let mut waves = Vec::new();
27    let bean_ids: HashSet<String> = beans.iter().map(|b| b.id.clone()).collect();
28
29    // Already-closed beans count as completed
30    let mut completed: HashSet<String> = index
31        .beans
32        .iter()
33        .filter(|e| e.status == Status::Closed)
34        .map(|e| e.id.clone())
35        .collect();
36
37    let mut remaining: Vec<SizedBean> = beans.to_vec();
38
39    while !remaining.is_empty() {
40        let (ready, blocked): (Vec<SizedBean>, Vec<SizedBean>) =
41            remaining.into_iter().partition(|b| {
42                // All explicit deps must be completed or not in our dispatch set
43                let explicit_ok = b
44                    .dependencies
45                    .iter()
46                    .all(|d| completed.contains(d) || !bean_ids.contains(d));
47
48                // All requires must be satisfied (producer completed or not in set)
49                let requires_ok = b.requires.iter().all(|req| {
50                    // Find the sibling producer for this artifact
51                    if let Some(producer) = beans.iter().find(|other| {
52                        other.id != b.id && other.parent == b.parent && other.produces.contains(req)
53                    }) {
54                        completed.contains(&producer.id)
55                    } else {
56                        true // No producer in set, assume satisfied
57                    }
58                });
59
60                explicit_ok && requires_ok
61            });
62
63        if ready.is_empty() {
64            // Remaining beans have unresolvable deps (cycle or missing)
65            // Add them all as a final wave to avoid losing them
66            eprintln!(
67                "Warning: {} bean(s) have unresolvable dependencies, adding to final wave",
68                blocked.len()
69            );
70            waves.push(Wave { beans: blocked });
71            break;
72        }
73
74        for b in &ready {
75            completed.insert(b.id.clone());
76        }
77
78        waves.push(Wave { beans: ready });
79        remaining = blocked;
80    }
81
82    // Sort beans within each wave by priority then ID
83    for wave in &mut waves {
84        wave.beans.sort_by(|a, b| {
85            a.priority
86                .cmp(&b.priority)
87                .then_with(|| natural_cmp(&a.id, &b.id))
88        });
89    }
90
91    waves
92}
93
94// ---------------------------------------------------------------------------
95// Wave execution
96// ---------------------------------------------------------------------------
97
98/// Spawn agents for a wave of beans, respecting max parallelism.
99pub(super) fn run_wave(
100    beans_dir: &Path,
101    beans: &[SizedBean],
102    spawn_mode: &SpawnMode,
103    cfg: &super::RunConfig,
104    wave_number: usize,
105) -> Result<Vec<AgentResult>> {
106    match spawn_mode {
107        SpawnMode::Template {
108            run_template,
109            plan_template,
110        } => run_wave_template(
111            beans,
112            run_template,
113            plan_template.as_deref(),
114            cfg.max_jobs,
115            cfg.timeout_minutes,
116        ),
117        SpawnMode::Direct => run_wave_direct(
118            beans_dir,
119            beans,
120            cfg.max_jobs,
121            cfg.timeout_minutes,
122            cfg.idle_timeout_minutes,
123            cfg.json_stream,
124            wave_number,
125            cfg.file_locking,
126        ),
127    }
128}
129
130/// Template mode: spawn agents via `sh -c <template>` (backward compat).
131fn run_wave_template(
132    beans: &[SizedBean],
133    run_template: &str,
134    plan_template: Option<&str>,
135    max_jobs: usize,
136    _timeout_minutes: u32,
137) -> Result<Vec<AgentResult>> {
138    let mut results = Vec::new();
139    let mut children: Vec<(SizedBean, std::process::Child, Instant)> = Vec::new();
140
141    let mut pending: Vec<&SizedBean> = beans.iter().collect();
142
143    while !pending.is_empty() || !children.is_empty() {
144        // Spawn up to max_jobs
145        while children.len() < max_jobs && !pending.is_empty() {
146            let sb = pending.remove(0);
147            let template = match sb.action {
148                BeanAction::Implement => run_template,
149                BeanAction::Plan => {
150                    if let Some(pt) = plan_template {
151                        pt
152                    } else {
153                        // No plan template — skip with error
154                        results.push(AgentResult {
155                            id: sb.id.clone(),
156                            title: sb.title.clone(),
157                            action: sb.action,
158                            success: false,
159                            duration: Duration::ZERO,
160                            total_tokens: None,
161                            total_cost: None,
162                            error: Some("No plan template configured".to_string()),
163                        });
164                        continue;
165                    }
166                }
167            };
168
169            let cmd = template.replace("{id}", &sb.id);
170            match Command::new("sh").args(["-c", &cmd]).spawn() {
171                Ok(child) => {
172                    children.push((sb.clone(), child, Instant::now()));
173                }
174                Err(e) => {
175                    eprintln!("  Failed to spawn agent for {}: {}", sb.id, e);
176                    results.push(AgentResult {
177                        id: sb.id.clone(),
178                        title: sb.title.clone(),
179                        action: sb.action,
180                        success: false,
181                        duration: Duration::ZERO,
182                        total_tokens: None,
183                        total_cost: None,
184                        error: Some(format!("Failed to spawn: {}", e)),
185                    });
186                }
187            }
188        }
189
190        if children.is_empty() {
191            break;
192        }
193
194        // Poll for completions
195        let mut still_running = Vec::new();
196        for (sb, mut child, started) in children {
197            match child.try_wait() {
198                Ok(Some(status)) => {
199                    results.push(AgentResult {
200                        id: sb.id.clone(),
201                        title: sb.title.clone(),
202                        action: sb.action,
203                        success: status.success(),
204                        duration: started.elapsed(),
205                        total_tokens: None,
206                        total_cost: None,
207                        error: if status.success() {
208                            None
209                        } else {
210                            Some(format!("Exit code {}", status.code().unwrap_or(-1)))
211                        },
212                    });
213                }
214                Ok(None) => {
215                    still_running.push((sb, child, started));
216                }
217                Err(e) => {
218                    eprintln!("  Error checking agent for {}: {}", sb.id, e);
219                    results.push(AgentResult {
220                        id: sb.id.clone(),
221                        title: sb.title.clone(),
222                        action: sb.action,
223                        success: false,
224                        duration: started.elapsed(),
225                        total_tokens: None,
226                        total_cost: None,
227                        error: Some(format!("Error checking process: {}", e)),
228                    });
229                }
230            }
231        }
232        children = still_running;
233
234        if !children.is_empty() {
235            std::thread::sleep(Duration::from_millis(500));
236        }
237    }
238
239    Ok(results)
240}
241
242/// Direct mode: spawn pi directly with JSON output and monitoring.
243fn run_wave_direct(
244    beans_dir: &Path,
245    beans: &[SizedBean],
246    max_jobs: usize,
247    timeout_minutes: u32,
248    idle_timeout_minutes: u32,
249    json_stream: bool,
250    wave_number: usize,
251    file_locking: bool,
252) -> Result<Vec<AgentResult>> {
253    let results = Arc::new(Mutex::new(Vec::new()));
254    let mut pending: Vec<SizedBean> = beans.to_vec();
255    let mut handles: Vec<std::thread::JoinHandle<()>> = Vec::new();
256
257    while !pending.is_empty() || !handles.is_empty() {
258        // Spawn up to max_jobs threads
259        while handles.len() < max_jobs && !pending.is_empty() {
260            let sb = pending.remove(0);
261            let beans_dir = beans_dir.to_path_buf();
262            let results = Arc::clone(&results);
263            let timeout_min = timeout_minutes;
264            let idle_min = idle_timeout_minutes;
265
266            if json_stream {
267                stream::emit(&StreamEvent::BeanStart {
268                    id: sb.id.clone(),
269                    title: sb.title.clone(),
270                    round: wave_number,
271                });
272            }
273
274            let handle = std::thread::spawn(move || {
275                let result = run_single_direct(
276                    &beans_dir,
277                    &sb,
278                    timeout_min,
279                    idle_min,
280                    json_stream,
281                    file_locking,
282                );
283                results.lock().unwrap().push(result);
284            });
285            handles.push(handle);
286        }
287
288        // Wait for at least one thread to finish
289        let prev_count = handles.len();
290        let mut still_running = Vec::new();
291        for handle in handles.drain(..) {
292            if handle.is_finished() {
293                let _ = handle.join();
294            } else {
295                still_running.push(handle);
296            }
297        }
298
299        // If nothing finished, wait briefly before polling again
300        if still_running.len() == prev_count && !still_running.is_empty() {
301            std::thread::sleep(Duration::from_millis(200));
302        }
303
304        handles = still_running;
305    }
306
307    // Wait for any remaining threads
308    for handle in handles {
309        let _ = handle.join();
310    }
311
312    Ok(Arc::try_unwrap(results).unwrap().into_inner().unwrap())
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use crate::commands::run::BeanAction;
319    use crate::index::Index;
320
321    #[test]
322    fn compute_waves_no_deps() {
323        let index = Index { beans: vec![] };
324        let beans = vec![
325            SizedBean {
326                id: "1".to_string(),
327                title: "A".to_string(),
328                tokens: 100,
329                action: BeanAction::Implement,
330                priority: 2,
331                dependencies: vec![],
332                parent: None,
333                produces: vec![],
334                requires: vec![],
335                paths: vec![],
336            },
337            SizedBean {
338                id: "2".to_string(),
339                title: "B".to_string(),
340                tokens: 100,
341                action: BeanAction::Implement,
342                priority: 2,
343                dependencies: vec![],
344                parent: None,
345                produces: vec![],
346                requires: vec![],
347                paths: vec![],
348            },
349        ];
350        let waves = compute_waves(&beans, &index);
351        assert_eq!(waves.len(), 1);
352        assert_eq!(waves[0].beans.len(), 2);
353    }
354
355    #[test]
356    fn compute_waves_linear_chain() {
357        let index = Index { beans: vec![] };
358        let beans = vec![
359            SizedBean {
360                id: "1".to_string(),
361                title: "A".to_string(),
362                tokens: 100,
363                action: BeanAction::Implement,
364                priority: 2,
365                dependencies: vec![],
366                parent: None,
367                produces: vec![],
368                requires: vec![],
369                paths: vec![],
370            },
371            SizedBean {
372                id: "2".to_string(),
373                title: "B".to_string(),
374                tokens: 100,
375                action: BeanAction::Implement,
376                priority: 2,
377                dependencies: vec!["1".to_string()],
378                parent: None,
379                produces: vec![],
380                requires: vec![],
381                paths: vec![],
382            },
383            SizedBean {
384                id: "3".to_string(),
385                title: "C".to_string(),
386                tokens: 100,
387                action: BeanAction::Implement,
388                priority: 2,
389                dependencies: vec!["2".to_string()],
390                parent: None,
391                produces: vec![],
392                requires: vec![],
393                paths: vec![],
394            },
395        ];
396        let waves = compute_waves(&beans, &index);
397        assert_eq!(waves.len(), 3);
398        assert_eq!(waves[0].beans[0].id, "1");
399        assert_eq!(waves[1].beans[0].id, "2");
400        assert_eq!(waves[2].beans[0].id, "3");
401    }
402
403    #[test]
404    fn compute_waves_diamond() {
405        let index = Index { beans: vec![] };
406        // 1 → (2, 3) → 4
407        let beans = vec![
408            SizedBean {
409                id: "1".to_string(),
410                title: "Root".to_string(),
411                tokens: 100,
412                action: BeanAction::Implement,
413                priority: 2,
414                dependencies: vec![],
415                parent: None,
416                produces: vec![],
417                requires: vec![],
418                paths: vec![],
419            },
420            SizedBean {
421                id: "2".to_string(),
422                title: "Left".to_string(),
423                tokens: 100,
424                action: BeanAction::Implement,
425                priority: 2,
426                dependencies: vec!["1".to_string()],
427                parent: None,
428                produces: vec![],
429                requires: vec![],
430                paths: vec![],
431            },
432            SizedBean {
433                id: "3".to_string(),
434                title: "Right".to_string(),
435                tokens: 100,
436                action: BeanAction::Implement,
437                priority: 2,
438                dependencies: vec!["1".to_string()],
439                parent: None,
440                produces: vec![],
441                requires: vec![],
442                paths: vec![],
443            },
444            SizedBean {
445                id: "4".to_string(),
446                title: "Join".to_string(),
447                tokens: 100,
448                action: BeanAction::Implement,
449                priority: 2,
450                dependencies: vec!["2".to_string(), "3".to_string()],
451                parent: None,
452                produces: vec![],
453                requires: vec![],
454                paths: vec![],
455            },
456        ];
457        let waves = compute_waves(&beans, &index);
458        assert_eq!(waves.len(), 3);
459        assert_eq!(waves[0].beans.len(), 1); // 1
460        assert_eq!(waves[1].beans.len(), 2); // 2, 3
461        assert_eq!(waves[2].beans.len(), 1); // 4
462    }
463
464    #[test]
465    fn template_wave_execution_with_echo() {
466        let beans = vec![SizedBean {
467            id: "1".to_string(),
468            title: "Test".to_string(),
469            tokens: 100,
470            action: BeanAction::Implement,
471            priority: 2,
472            dependencies: vec![],
473            parent: None,
474            produces: vec![],
475            requires: vec![],
476            paths: vec![],
477        }];
478
479        let results = run_wave_template(&beans, "echo {id}", None, 4, 30).unwrap();
480        assert_eq!(results.len(), 1);
481        assert!(results[0].success);
482        assert_eq!(results[0].id, "1");
483    }
484
485    #[test]
486    fn template_wave_plan_without_template_errors() {
487        let beans = vec![SizedBean {
488            id: "1".to_string(),
489            title: "Test".to_string(),
490            tokens: 100,
491            action: BeanAction::Plan,
492            priority: 2,
493            dependencies: vec![],
494            parent: None,
495            produces: vec![],
496            requires: vec![],
497            paths: vec![],
498        }];
499
500        let results = run_wave_template(&beans, "echo {id}", None, 4, 30).unwrap();
501        assert_eq!(results.len(), 1);
502        assert!(!results[0].success);
503        assert!(results[0]
504            .error
505            .as_ref()
506            .unwrap()
507            .contains("No plan template"));
508    }
509
510    #[test]
511    fn template_wave_failed_command() {
512        let beans = vec![SizedBean {
513            id: "1".to_string(),
514            title: "Fail".to_string(),
515            tokens: 100,
516            action: BeanAction::Implement,
517            priority: 2,
518            dependencies: vec![],
519            parent: None,
520            produces: vec![],
521            requires: vec![],
522            paths: vec![],
523        }];
524
525        let results = run_wave_template(&beans, "false", None, 4, 30).unwrap();
526        assert_eq!(results.len(), 1);
527        assert!(!results[0].success);
528        assert!(results[0].error.is_some());
529    }
530}