baobao_codegen/pipeline/
runner.rs

1//! Pipeline orchestrator.
2
3use baobao_manifest::Manifest;
4use eyre::Result;
5
6use super::{
7    CompilationContext, Phase, Plugin,
8    phase::PhaseInfo,
9    phases::{AnalyzePhase, LowerPhase, ValidatePhase},
10};
11
12/// The compilation pipeline orchestrator.
13///
14/// The pipeline manages the execution of compilation phases and plugin hooks.
15/// It runs built-in phases (validate, lower, analyze) followed by any user
16/// phases, calling plugin hooks before and after each phase.
17///
18/// # Example
19///
20/// ```ignore
21/// let pipeline = Pipeline::new()
22///     .plugin(MyPlugin::new())
23///     .phase(MyCustomPhase);
24///
25/// let ctx = pipeline.run(manifest)?;
26/// ```
27pub struct Pipeline {
28    builtin_phases: Vec<Box<dyn Phase>>,
29    user_phases: Vec<Box<dyn Phase>>,
30    plugins: Vec<Box<dyn Plugin>>,
31}
32
33impl Pipeline {
34    /// Create a new pipeline with default built-in phases.
35    pub fn new() -> Self {
36        Self {
37            builtin_phases: vec![
38                Box::new(ValidatePhase::new()),
39                Box::new(LowerPhase),
40                Box::new(AnalyzePhase),
41            ],
42            user_phases: Vec::new(),
43            plugins: Vec::new(),
44        }
45    }
46
47    /// Add a phase to run after the built-in phases.
48    pub fn phase(mut self, phase: impl Phase + 'static) -> Self {
49        self.user_phases.push(Box::new(phase));
50        self
51    }
52
53    /// Add a plugin to receive phase lifecycle hooks.
54    pub fn plugin(mut self, plugin: impl Plugin + 'static) -> Self {
55        self.plugins.push(Box::new(plugin));
56        self
57    }
58
59    /// Iterate over all phases (builtin + user).
60    fn all_phases(&self) -> impl Iterator<Item = &Box<dyn Phase>> {
61        self.builtin_phases.iter().chain(self.user_phases.iter())
62    }
63
64    /// Get the names of all phases that will be executed.
65    pub fn phase_names(&self) -> Vec<&'static str> {
66        self.all_phases().map(|p| p.name()).collect()
67    }
68
69    /// Get information about all phases that will be executed.
70    pub fn phase_info(&self) -> Vec<PhaseInfo> {
71        self.all_phases().map(|p| p.info()).collect()
72    }
73
74    /// Run the pipeline on a manifest.
75    ///
76    /// Executes all phases in order:
77    /// 1. ValidatePhase - validates manifest, collects diagnostics
78    /// 2. LowerPhase - transforms manifest to IR
79    /// 3. AnalyzePhase - computes shared data
80    /// 4. User phases (if any)
81    ///
82    /// Plugin hooks are called before and after each phase.
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if any phase fails fatally.
87    pub fn run(&self, manifest: Manifest) -> Result<CompilationContext> {
88        let mut ctx = CompilationContext::new(manifest);
89
90        for phase in self.all_phases() {
91            self.run_phase(phase.as_ref(), &mut ctx)?;
92        }
93
94        Ok(ctx)
95    }
96
97    /// Run a single phase with plugin hooks.
98    fn run_phase(&self, phase: &dyn Phase, ctx: &mut CompilationContext) -> Result<()> {
99        let phase_name = phase.name();
100
101        // Call before hooks
102        for plugin in &self.plugins {
103            plugin.on_before_phase(phase_name, ctx)?;
104        }
105
106        // Run the phase
107        phase.run(ctx)?;
108
109        // Call after hooks
110        for plugin in &self.plugins {
111            plugin.on_after_phase(phase_name, ctx)?;
112        }
113
114        Ok(())
115    }
116}
117
118impl Default for Pipeline {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use std::sync::{
127        Arc,
128        atomic::{AtomicUsize, Ordering},
129    };
130
131    use super::*;
132
133    struct CountingPlugin {
134        before_count: Arc<AtomicUsize>,
135        after_count: Arc<AtomicUsize>,
136    }
137
138    impl CountingPlugin {
139        fn new() -> (Self, Arc<AtomicUsize>, Arc<AtomicUsize>) {
140            let before = Arc::new(AtomicUsize::new(0));
141            let after = Arc::new(AtomicUsize::new(0));
142            (
143                Self {
144                    before_count: before.clone(),
145                    after_count: after.clone(),
146                },
147                before,
148                after,
149            )
150        }
151    }
152
153    impl Plugin for CountingPlugin {
154        fn name(&self) -> &'static str {
155            "counting"
156        }
157
158        fn on_before_phase(&self, _phase: &str, _ctx: &mut CompilationContext) -> Result<()> {
159            self.before_count.fetch_add(1, Ordering::SeqCst);
160            Ok(())
161        }
162
163        fn on_after_phase(&self, _phase: &str, _ctx: &mut CompilationContext) -> Result<()> {
164            self.after_count.fetch_add(1, Ordering::SeqCst);
165            Ok(())
166        }
167    }
168
169    fn parse_manifest(content: &str) -> Manifest {
170        toml::from_str(content).expect("Failed to parse test manifest")
171    }
172
173    fn make_test_manifest() -> Manifest {
174        parse_manifest(
175            r#"
176            [cli]
177            name = "test"
178            language = "rust"
179        "#,
180        )
181    }
182
183    #[test]
184    fn test_pipeline_runs_phases() {
185        let manifest = make_test_manifest();
186        let pipeline = Pipeline::new();
187
188        let ctx = pipeline.run(manifest).expect("pipeline should succeed");
189
190        // After running, IR and computed data should be populated
191        assert!(ctx.ir.is_some());
192        assert!(ctx.computed.is_some());
193    }
194
195    #[test]
196    fn test_pipeline_plugin_hooks() {
197        let manifest = make_test_manifest();
198        let (plugin, before_count, after_count) = CountingPlugin::new();
199
200        let pipeline = Pipeline::new().plugin(plugin);
201        let _ = pipeline.run(manifest).expect("pipeline should succeed");
202
203        // 3 built-in phases = 3 before + 3 after hooks
204        assert_eq!(before_count.load(Ordering::SeqCst), 3);
205        assert_eq!(after_count.load(Ordering::SeqCst), 3);
206    }
207}