baobao_codegen/pipeline/
runner.rs1use baobao_manifest::Manifest;
4use eyre::Result;
5
6use super::{
7 CompilationContext, Phase, Plugin,
8 phase::PhaseInfo,
9 phases::{AnalyzePhase, LowerPhase, ValidatePhase},
10};
11
12pub struct Pipeline {
28 builtin_phases: Vec<Box<dyn Phase>>,
29 user_phases: Vec<Box<dyn Phase>>,
30 plugins: Vec<Box<dyn Plugin>>,
31}
32
33impl Pipeline {
34 pub fn new() -> Self {
36 Self {
37 builtin_phases: vec![
38 Box::new(ValidatePhase::new()),
39 Box::new(LowerPhase),
40 Box::new(AnalyzePhase),
41 ],
42 user_phases: Vec::new(),
43 plugins: Vec::new(),
44 }
45 }
46
47 pub fn phase(mut self, phase: impl Phase + 'static) -> Self {
49 self.user_phases.push(Box::new(phase));
50 self
51 }
52
53 pub fn plugin(mut self, plugin: impl Plugin + 'static) -> Self {
55 self.plugins.push(Box::new(plugin));
56 self
57 }
58
59 fn all_phases(&self) -> impl Iterator<Item = &Box<dyn Phase>> {
61 self.builtin_phases.iter().chain(self.user_phases.iter())
62 }
63
64 pub fn phase_names(&self) -> Vec<&'static str> {
66 self.all_phases().map(|p| p.name()).collect()
67 }
68
69 pub fn phase_info(&self) -> Vec<PhaseInfo> {
71 self.all_phases().map(|p| p.info()).collect()
72 }
73
74 pub fn run(&self, manifest: Manifest) -> Result<CompilationContext> {
88 let mut ctx = CompilationContext::new(manifest);
89
90 for phase in self.all_phases() {
91 self.run_phase(phase.as_ref(), &mut ctx)?;
92 }
93
94 Ok(ctx)
95 }
96
97 fn run_phase(&self, phase: &dyn Phase, ctx: &mut CompilationContext) -> Result<()> {
99 let phase_name = phase.name();
100
101 for plugin in &self.plugins {
103 plugin.on_before_phase(phase_name, ctx)?;
104 }
105
106 phase.run(ctx)?;
108
109 for plugin in &self.plugins {
111 plugin.on_after_phase(phase_name, ctx)?;
112 }
113
114 Ok(())
115 }
116}
117
118impl Default for Pipeline {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use std::sync::{
127 Arc,
128 atomic::{AtomicUsize, Ordering},
129 };
130
131 use super::*;
132
133 struct CountingPlugin {
134 before_count: Arc<AtomicUsize>,
135 after_count: Arc<AtomicUsize>,
136 }
137
138 impl CountingPlugin {
139 fn new() -> (Self, Arc<AtomicUsize>, Arc<AtomicUsize>) {
140 let before = Arc::new(AtomicUsize::new(0));
141 let after = Arc::new(AtomicUsize::new(0));
142 (
143 Self {
144 before_count: before.clone(),
145 after_count: after.clone(),
146 },
147 before,
148 after,
149 )
150 }
151 }
152
153 impl Plugin for CountingPlugin {
154 fn name(&self) -> &'static str {
155 "counting"
156 }
157
158 fn on_before_phase(&self, _phase: &str, _ctx: &mut CompilationContext) -> Result<()> {
159 self.before_count.fetch_add(1, Ordering::SeqCst);
160 Ok(())
161 }
162
163 fn on_after_phase(&self, _phase: &str, _ctx: &mut CompilationContext) -> Result<()> {
164 self.after_count.fetch_add(1, Ordering::SeqCst);
165 Ok(())
166 }
167 }
168
169 fn parse_manifest(content: &str) -> Manifest {
170 toml::from_str(content).expect("Failed to parse test manifest")
171 }
172
173 fn make_test_manifest() -> Manifest {
174 parse_manifest(
175 r#"
176 [cli]
177 name = "test"
178 language = "rust"
179 "#,
180 )
181 }
182
183 #[test]
184 fn test_pipeline_runs_phases() {
185 let manifest = make_test_manifest();
186 let pipeline = Pipeline::new();
187
188 let ctx = pipeline.run(manifest).expect("pipeline should succeed");
189
190 assert!(ctx.ir.is_some());
192 assert!(ctx.computed.is_some());
193 }
194
195 #[test]
196 fn test_pipeline_plugin_hooks() {
197 let manifest = make_test_manifest();
198 let (plugin, before_count, after_count) = CountingPlugin::new();
199
200 let pipeline = Pipeline::new().plugin(plugin);
201 let _ = pipeline.run(manifest).expect("pipeline should succeed");
202
203 assert_eq!(before_count.load(Ordering::SeqCst), 3);
205 assert_eq!(after_count.load(Ordering::SeqCst), 3);
206 }
207}