Skip to main content

dk_runner/
scheduler.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::Arc;
4
5use tokio::sync::mpsc;
6use tracing::info;
7use uuid::Uuid;
8
9use dk_engine::repo::Engine;
10
11use crate::changeset::scope_command_to_changeset;
12use crate::executor::{Executor, StepOutput, StepStatus};
13use crate::findings::{Finding, Suggestion};
14use crate::steps::{agent_review, command, human_approve, semantic};
15use crate::workflow::types::{Stage, Step, StepType, Workflow};
16
17/// Result of running a single step, with metadata for streaming.
18#[derive(Debug, Clone)]
19pub struct StepResult {
20    pub stage_name: String,
21    pub step_name: String,
22    pub status: StepStatus,
23    pub output: String,
24    pub required: bool,
25    pub findings: Vec<Finding>,
26    pub suggestions: Vec<Suggestion>,
27}
28
29/// Run an entire workflow: stages sequentially, steps within parallel stages concurrently.
30/// Sends `StepResult`s to `tx` as each step completes. Returns `true` if all required steps passed.
31///
32/// `engine` and `repo_id` are optional — when provided, the semantic step uses the full
33/// Engine-backed analysis. Pass `None` for both in tests or contexts without an Engine.
34pub async fn run_workflow(
35    workflow: &Workflow,
36    executor: &dyn Executor,
37    work_dir: &Path,
38    changeset_files: &[String],
39    env: &HashMap<String, String>,
40    tx: &mpsc::Sender<StepResult>,
41    engine: Option<&Arc<Engine>>,
42    repo_id: Option<Uuid>,
43    changeset_id: Option<Uuid>,
44) -> bool {
45    let mut all_passed = true;
46
47    for stage in &workflow.stages {
48        info!(stage = %stage.name, parallel = stage.parallel, "running stage");
49
50        let results = if stage.parallel {
51            run_stage_parallel(stage, executor, work_dir, changeset_files, env, engine, repo_id, changeset_id)
52                .await
53        } else {
54            run_stage_sequential(stage, executor, work_dir, changeset_files, env, engine, repo_id, changeset_id)
55                .await
56        };
57
58        for result in results {
59            if result.status != StepStatus::Pass && result.required {
60                all_passed = false;
61            }
62            let _ = tx.send(result).await;
63        }
64    }
65
66    all_passed
67}
68
69async fn run_stage_parallel(
70    stage: &Stage,
71    executor: &dyn Executor,
72    work_dir: &Path,
73    changeset_files: &[String],
74    env: &HashMap<String, String>,
75    engine: Option<&Arc<Engine>>,
76    repo_id: Option<Uuid>,
77    changeset_id: Option<Uuid>,
78) -> Vec<StepResult> {
79    let mut futures = Vec::new();
80    for step in &stage.steps {
81        futures.push(run_single_step(
82            &stage.name,
83            step,
84            executor,
85            work_dir,
86            changeset_files,
87            env,
88            engine,
89            repo_id,
90            changeset_id,
91        ));
92    }
93    futures::future::join_all(futures).await
94}
95
96async fn run_stage_sequential(
97    stage: &Stage,
98    executor: &dyn Executor,
99    work_dir: &Path,
100    changeset_files: &[String],
101    env: &HashMap<String, String>,
102    engine: Option<&Arc<Engine>>,
103    repo_id: Option<Uuid>,
104    changeset_id: Option<Uuid>,
105) -> Vec<StepResult> {
106    let mut results = Vec::new();
107    for step in &stage.steps {
108        let result = run_single_step(
109            &stage.name,
110            step,
111            executor,
112            work_dir,
113            changeset_files,
114            env,
115            engine,
116            repo_id,
117            changeset_id,
118        )
119        .await;
120        let failed_required = step.required && result.status != StepStatus::Pass;
121        results.push(result);
122        // Abort early if a required step failed — no point running subsequent
123        // steps (e.g., cargo test after cargo check fails with compile errors)
124        if failed_required {
125            tracing::warn!(
126                stage = %stage.name,
127                step = %step.name,
128                "required step failed — aborting remaining steps in sequential stage"
129            );
130            break;
131        }
132    }
133    results
134}
135
136async fn run_single_step(
137    stage_name: &str,
138    step: &Step,
139    executor: &dyn Executor,
140    work_dir: &Path,
141    changeset_files: &[String],
142    env: &HashMap<String, String>,
143    engine: Option<&Arc<Engine>>,
144    repo_id: Option<Uuid>,
145    changeset_id: Option<Uuid>,
146) -> StepResult {
147    info!(step = %step.name, "running step");
148
149    match &step.step_type {
150        StepType::Command { run } => {
151            let cmd = if step.changeset_aware {
152                let local_files: Vec<String> = if let Some(sub) = &step.work_dir {
153                    let prefix = format!("{}/", sub.display());
154                    changeset_files
155                        .iter()
156                        .filter_map(|f| f.strip_prefix(&prefix).map(|s| s.to_string()))
157                        .collect()
158                } else {
159                    changeset_files.to_vec()
160                };
161                scope_command_to_changeset(run, &local_files)
162                    .unwrap_or_else(|| run.clone())
163            } else {
164                run.clone()
165            };
166            let step_work_dir = match &step.work_dir {
167                Some(sub) => work_dir.join(sub),
168                None => work_dir.to_path_buf(),
169            };
170            let output =
171                match command::run_command_step(executor, &cmd, &step_work_dir, step.timeout, env).await {
172                    Ok(out) => out,
173                    Err(e) => StepOutput {
174                        status: StepStatus::Fail,
175                        stdout: String::new(),
176                        stderr: e.to_string(),
177                        duration: std::time::Duration::ZERO,
178                    },
179                };
180
181            let combined_output = if output.stderr.is_empty() {
182                output.stdout
183            } else {
184                format!("{}{}", output.stdout, output.stderr)
185            };
186
187            StepResult {
188                stage_name: stage_name.to_string(),
189                step_name: step.name.clone(),
190                status: output.status,
191                output: combined_output,
192                required: step.required,
193                findings: Vec::new(),
194                suggestions: Vec::new(),
195            }
196        }
197        StepType::Semantic { checks } => {
198            if let (Some(eng), Some(rid)) = (engine, repo_id) {
199                // Full Engine-backed semantic analysis
200                let (output, findings, suggestions) = semantic::run_semantic_step(
201                    eng,
202                    rid,
203                    changeset_files,
204                    work_dir,
205                    checks,
206                )
207                .await;
208
209                let combined_output = if output.stderr.is_empty() {
210                    output.stdout
211                } else {
212                    format!("{}{}", output.stdout, output.stderr)
213                };
214
215                StepResult {
216                    stage_name: stage_name.to_string(),
217                    step_name: step.name.clone(),
218                    status: output.status,
219                    output: combined_output,
220                    required: step.required,
221                    findings,
222                    suggestions,
223                }
224            } else {
225                // Fallback to simple shim (no Engine available)
226                let output = semantic::run_semantic_step_simple(checks).await;
227
228                let combined_output = if output.stderr.is_empty() {
229                    output.stdout
230                } else {
231                    format!("{}{}", output.stdout, output.stderr)
232                };
233
234                StepResult {
235                    stage_name: stage_name.to_string(),
236                    step_name: step.name.clone(),
237                    status: output.status,
238                    output: combined_output,
239                    required: step.required,
240                    findings: Vec::new(),
241                    suggestions: Vec::new(),
242                }
243            }
244        }
245        StepType::AgentReview { prompt } => {
246            let provider = agent_review::claude::ClaudeReviewProvider::from_env();
247            if let Some(provider) = provider {
248                let mut diff = String::new();
249                let mut files = Vec::new();
250                for path in changeset_files {
251                    let full_path = work_dir.join(path);
252                    if let Ok(content) = tokio::fs::read_to_string(&full_path).await {
253                        diff.push_str(&format!("--- {path}\n+++ {path}\n{content}\n"));
254                        files.push(agent_review::provider::FileContext {
255                            path: path.clone(),
256                            content,
257                        });
258                    }
259                }
260                let (output, findings, suggestions) =
261                    agent_review::run_agent_review_step_with_provider(
262                        &provider, &diff, files, prompt,
263                    )
264                    .await;
265                return StepResult {
266                    stage_name: stage_name.to_string(),
267                    step_name: step.name.clone(),
268                    status: output.status,
269                    output: if output.stderr.is_empty() {
270                        output.stdout
271                    } else {
272                        format!("{}{}", output.stdout, output.stderr)
273                    },
274                    required: step.required,
275                    findings,
276                    suggestions,
277                };
278            }
279            // No provider: use legacy stub
280            let output = agent_review::run_agent_review_step(prompt).await;
281            StepResult {
282                stage_name: stage_name.to_string(),
283                step_name: step.name.clone(),
284                status: output.status,
285                output: if output.stderr.is_empty() {
286                    output.stdout
287                } else {
288                    format!("{}{}", output.stdout, output.stderr)
289                },
290                required: step.required,
291                findings: Vec::new(),
292                suggestions: Vec::new(),
293            }
294        }
295        StepType::HumanApprove => {
296            if let (Some(eng), Some(cid)) = (engine, changeset_id) {
297                let (output, findings) = human_approve::run_human_approve_step_with_engine(
298                    eng, cid, Some(step.timeout),
299                ).await;
300                return StepResult {
301                    stage_name: stage_name.to_string(),
302                    step_name: step.name.clone(),
303                    status: output.status,
304                    output: if output.stderr.is_empty() { output.stdout } else { format!("{}{}", output.stdout, output.stderr) },
305                    required: step.required,
306                    findings,
307                    suggestions: Vec::new(),
308                };
309            }
310            let output = human_approve::run_human_approve_step().await;
311            StepResult {
312                stage_name: stage_name.to_string(),
313                step_name: step.name.clone(),
314                status: output.status,
315                output: if output.stderr.is_empty() { output.stdout } else { format!("{}{}", output.stdout, output.stderr) },
316                required: step.required,
317                findings: Vec::new(),
318                suggestions: Vec::new(),
319            }
320        }
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use crate::executor::process::ProcessExecutor;
328    use crate::workflow::types::*;
329    use std::time::Duration;
330
331    #[tokio::test]
332    async fn test_run_workflow_passes() {
333        let wf = Workflow {
334            name: "test".into(),
335            timeout: Duration::from_secs(30),
336            stages: vec![Stage {
337                name: "checks".into(),
338                parallel: false,
339                steps: vec![Step {
340                    name: "echo-test".into(),
341                    step_type: StepType::Command {
342                        run: "echo hello".into(),
343                    },
344                    timeout: Duration::from_secs(5),
345                    required: true,
346                    changeset_aware: false,
347                    work_dir: None,
348                }],
349            }],
350            allowed_commands: vec![],
351        };
352
353        let exec = ProcessExecutor::new();
354        let (tx, mut rx) = mpsc::channel(32);
355        let dir = std::env::temp_dir();
356
357        let passed =
358            run_workflow(&wf, &exec, &dir, &[], &HashMap::new(), &tx, None, None, None).await;
359        drop(tx);
360        assert!(passed);
361        let result = rx.recv().await.unwrap();
362        assert_eq!(result.status, StepStatus::Pass);
363    }
364
365    #[tokio::test]
366    async fn test_failing_required_step() {
367        let wf = Workflow {
368            name: "test".into(),
369            timeout: Duration::from_secs(30),
370            stages: vec![Stage {
371                name: "checks".into(),
372                parallel: false,
373                steps: vec![Step {
374                    name: "disallowed".into(),
375                    step_type: StepType::Command {
376                        run: "false_cmd_not_in_allowlist".into(),
377                    },
378                    timeout: Duration::from_secs(5),
379                    required: true,
380                    changeset_aware: false,
381                    work_dir: None,
382                }],
383            }],
384            allowed_commands: vec![],
385        };
386
387        let exec = ProcessExecutor::new();
388        let (tx, _rx) = mpsc::channel(32);
389        let dir = std::env::temp_dir();
390
391        let passed =
392            run_workflow(&wf, &exec, &dir, &[], &HashMap::new(), &tx, None, None, None).await;
393        drop(tx);
394        assert!(!passed);
395    }
396
397    #[tokio::test]
398    async fn test_parallel_stage() {
399        let wf = Workflow {
400            name: "test".into(),
401            timeout: Duration::from_secs(30),
402            stages: vec![Stage {
403                name: "parallel-checks".into(),
404                parallel: true,
405                steps: vec![
406                    Step {
407                        name: "echo-a".into(),
408                        step_type: StepType::Command {
409                            run: "echo a".into(),
410                        },
411                        timeout: Duration::from_secs(5),
412                        required: true,
413                        changeset_aware: false,
414                        work_dir: None,
415                    },
416                    Step {
417                        name: "echo-b".into(),
418                        step_type: StepType::Command {
419                            run: "echo b".into(),
420                        },
421                        timeout: Duration::from_secs(5),
422                        required: true,
423                        changeset_aware: false,
424                        work_dir: None,
425                    },
426                ],
427            }],
428            allowed_commands: vec![],
429        };
430
431        let exec = ProcessExecutor::new();
432        let (tx, mut rx) = mpsc::channel(32);
433        let dir = std::env::temp_dir();
434
435        let passed =
436            run_workflow(&wf, &exec, &dir, &[], &HashMap::new(), &tx, None, None, None).await;
437        drop(tx);
438        assert!(passed);
439
440        let mut results = Vec::new();
441        while let Some(r) = rx.recv().await {
442            results.push(r);
443        }
444        assert_eq!(results.len(), 2);
445    }
446}