Skip to main content

database_bootstrap/
runner.rs

1//! Bootstrap runner that executes steps in dependency order.
2
3use crate::step::{BootstrapStep, StepOutput};
4use crate::{BootstrapError, Requirement, VerifyResult};
5use std::collections::{HashMap, HashSet};
6use crate::traits::{DatabaseConnection, TransactionalConnection};
7
8/// Options for bootstrap execution.
9#[derive(Debug, Clone, Default)]
10pub struct RunOptions {
11    /// If true, verify all steps and report status without making changes.
12    pub dry_run: bool,
13    /// If true, continue running even if a required step fails verification.
14    pub continue_on_error: bool,
15}
16
17/// Summary of bootstrap execution.
18#[derive(Debug)]
19pub struct BootstrapSummary {
20    /// Results for each step.
21    pub results: Vec<(&'static str, VerifyResult)>,
22    /// Outputs (secrets, etc.) from steps.
23    pub outputs: Vec<StepOutput>,
24}
25
26impl BootstrapSummary {
27    /// Returns true if any step had a partial or missing result for required steps.
28    pub fn has_issues(&self) -> bool {
29        self.results
30            .iter()
31            .any(|(_, result)| *result != VerifyResult::Present)
32    }
33}
34
35/// Executes bootstrap steps in dependency order.
36///
37/// The runner uses topological sort to determine execution order,
38/// ensuring that dependencies are satisfied before each step runs.
39/// Each step runs in its own transaction with the bootstrap RLS bypass.
40///
41/// The [`Default`] implementation pre-registers all builtin bootstrap steps:
42/// - `system_tenant`: Creates system and default tenants
43/// - `super_admin`: Creates the super-admin user
44/// - `global_roles`: Creates global default roles
45/// - `agent_service_token`: Creates the agent service token
46///
47/// Use [`BootstrapRunner::new`] for an empty runner if you need custom step configuration.
48pub struct BootstrapRunner<T> {
49    steps: HashMap<&'static str, Box<dyn BootstrapStep<T>>>,
50    execution_order: Vec<&'static str>,
51}
52
53impl<T> BootstrapRunner<T>
54where
55    T: DatabaseConnection,
56    T::Error: Into<BootstrapError<T::Error>>,
57    BootstrapError<T::Error>: From<T::Error>
58{
59    /// Creates an empty runner.
60    pub fn new() -> Self {
61        Self {
62            steps: HashMap::new(),
63            execution_order: Vec::new(),
64        }
65    }
66
67    /// Adds a bootstrap step to the runner.
68    ///
69    /// Steps can be added in any order. The runner will execute them
70    /// in dependency order based on their declared dependencies.
71    pub fn add_step<S: BootstrapStep<T> + 'static>(&mut self, step: S) {
72        let name = step.name();
73
74        if self.steps.contains_key(name) {
75            tracing::warn!("duplicate bootstrap step: {}", name);
76            return;
77        }
78
79        self.steps.insert(name, Box::new(step));
80        self.execution_order.push(name);
81    }
82
83    /// Ensures all dependencies are registered, creating missing ones from factories.
84    fn ensure_dependencies(&mut self) -> Result<(), BootstrapError<T::Error>> {
85        let mut to_add: Vec<Box<dyn BootstrapStep<T>>> = Vec::new();
86
87        // Collect all steps and their dependencies
88        let existing: HashSet<&'static str> = self.steps.keys().copied().collect();
89
90        for step in self.steps.values() {
91            for dep_factory in step.dependencies() {
92                let dep_name = dep_factory.step_name();
93                if !existing.contains(dep_name) {
94                    tracing::info!(
95                        step = dep_name,
96                        required_by = step.name(),
97                        "auto-creating missing dependency step"
98                    );
99                    to_add.push(dep_factory.create());
100                }
101            }
102        }
103
104        // Add the new steps
105        for step in to_add {
106            let name = step.name();
107            if !self.steps.contains_key(name) {
108                self.steps.insert(name, step);
109                self.execution_order.push(name);
110            }
111        }
112
113        Ok(())
114    }
115
116    /// Builds and returns the execution order based on dependencies.
117    fn build_execution_order(&self) -> Result<Vec<&'static str>, BootstrapError<T::Error>> {
118        let steps: Vec<&'static str> = self.execution_order.to_vec();
119        let n = steps.len();
120        let mut in_degree: HashMap<&'static str, usize> = HashMap::new();
121        let mut dependents: HashMap<&'static str, Vec<&'static str>> = HashMap::new();
122
123        for name in &steps {
124            in_degree.insert(*name, 0);
125            dependents.insert(*name, Vec::new());
126        }
127
128        for name in &steps {
129            let step = self.steps.get(name).expect("step must exist");
130            for dep_factory in step.dependencies() {
131                let dep_name = dep_factory.step_name();
132                if !in_degree.contains_key(dep_name) {
133                    return Err(BootstrapError::DependencyNotSatisfied {
134                        step: name,
135                        dependency: dep_name,
136                    });
137                }
138                in_degree.entry(*name).and_modify(|d| *d += 1);
139                dependents.entry(dep_name).or_default().push(*name);
140            }
141        }
142
143        let mut queue: Vec<&'static str> = in_degree
144            .iter()
145            .filter(|(_, d)| **d == 0)
146            .map(|(&name, _)| name)
147            .collect();
148
149        let mut result = Vec::new();
150        while let Some(name) = queue.pop() {
151            result.push(name);
152            if let Some(deps) = dependents.get(name) {
153                for dep in deps {
154                    if let Some(d) = in_degree.get_mut(dep) {
155                        *d -= 1;
156                        if *d == 0 {
157                            queue.push(*dep);
158                        }
159                    }
160                }
161            }
162        }
163
164        if result.len() != n {
165            for name in &steps {
166                if !result.contains(name) {
167                    return Err(BootstrapError::CircularDependency { step: name });
168                }
169            }
170        }
171
172        Ok(result)
173    }
174
175    /// Runs all bootstrap steps in dependency order.
176    pub async fn run(
177        &mut self,
178        db: &T,
179        options: &RunOptions,
180    ) -> Result<BootstrapSummary, BootstrapError<T::Error>> {
181        // Ensure all dependencies are registered
182        self.ensure_dependencies()?;
183
184        let order = self.build_execution_order()?;
185        let mut results = Vec::new();
186        let mut completed: HashSet<&'static str> = HashSet::new();
187        let mut outputs: Vec<StepOutput> = Vec::new();
188
189        tracing::info!("╔═══════════════════════════════════════════════════════════╗");
190        tracing::info!("║                    BOOTSTRAP STARTING                     ║");
191        tracing::info!("╚═══════════════════════════════════════════════════════════╝");
192
193        for name in order {
194            let step = self.steps.get(name).expect("step must exist");
195
196            // Check dependencies
197            for dep_factory in step.dependencies() {
198                let dep_name = dep_factory.step_name();
199                if !completed.contains(dep_name) {
200                    return Err(BootstrapError::DependencyNotSatisfied {
201                        step: name,
202                        dependency: dep_name,
203                    });
204                }
205            }
206
207            tracing::info!("┌───────────────────────────────────────────────────────────┐");
208            tracing::info!("│ Step: {:<51} │", name);
209            tracing::info!("├───────────────────────────────────────────────────────────┤");
210
211            let txn = db.begin().await?;
212
213            let verify_details = step.verify(&txn).await?;
214            let status = match verify_details.result {
215                VerifyResult::Present => "PRESENT ✓",
216                VerifyResult::Missing => "MISSING",
217                VerifyResult::Partial => "PARTIAL ⚠",
218            };
219            tracing::info!("│ Status: {:<49} │", status);
220
221            match verify_details.result {
222                VerifyResult::Present => {
223                    tracing::info!("│ Action: {:<49} │", "Skipping (already exists)");
224                    results.push((name, verify_details.result));
225                    completed.insert(name);
226                    txn.commit().await?;
227                }
228                VerifyResult::Missing => {
229                    if step.skip_if_missing() {
230                        tracing::warn!("│ Action: {:<49} │", "Skipping (optional, not configured)");
231                        results.push((name, VerifyResult::Missing));
232                        completed.insert(name);
233                        txn.commit().await?;
234                    } else if options.dry_run {
235                        tracing::info!("│ Action: {:<49} │", "Would run (dry run)");
236                        results.push((name, VerifyResult::Missing));
237                        txn.commit().await?;
238                    } else {
239                        tracing::info!("│ Action: {:<49} │", "Running...");
240                        let run_result = step.run(&txn, &verify_details).await;
241                        
242                        if let Err(e) = run_result {
243                            // Rollback on error before returning
244                            let _ = txn.rollback().await;
245                            return Err(e);
246                        }
247                        
248                        if let Some(output) = run_result.unwrap() {
249                            outputs.push(output);
250                        }
251                        results.push((name, VerifyResult::Missing));
252
253                        let after_result = step.verify(&txn).await?;
254                        match after_result.result {
255                            VerifyResult::Present => {
256                                tracing::info!("│ Result: {:<49} │", "Success ✓");
257                                completed.insert(name);
258                                txn.commit().await?;
259                            }
260                            VerifyResult::Missing | VerifyResult::Partial => {
261                                if step.requirement() == Requirement::Required {
262                                    tracing::info!("│ Result: {:<49} │", "FAILED ✗");
263                                    // Rollback before returning error
264                                    let _ = txn.rollback().await;
265                                    return Err(BootstrapError::VerificationFailed {
266                                        step: name,
267                                        expected: "Present".to_string(),
268                                        actual: format!("{:?}", after_result.result),
269                                    });
270                                }
271                                tracing::info!("│ Result: {:<49} │", "Incomplete (optional)");
272                                completed.insert(name);
273                                txn.commit().await?;
274                            }
275                        }
276                    }
277                }
278                VerifyResult::Partial => {
279                    tracing::info!("│ Action: {:<49} │", "Fixing partial state...");
280                    if options.dry_run {
281                        tracing::info!("│ Result: {:<49} │", "Would fix (dry run)");
282                        results.push((name, verify_details.result));
283                        txn.commit().await?;
284                    } else {
285                        let run_result = step.run(&txn, &verify_details).await;
286                        
287                        if let Err(e) = run_result {
288                            // Rollback on error before returning
289                            let _ = txn.rollback().await;
290                            return Err(e);
291                        }
292                        
293                        if let Some(output) = run_result.unwrap() {
294                            outputs.push(output);
295                        }
296                        results.push((name, verify_details.result));
297
298                        let after_result = step.verify(&txn).await?;
299                        match after_result.result {
300                            VerifyResult::Present => {
301                                tracing::info!("│ Result: {:<49} │", "Fixed ✓");
302                                completed.insert(name);
303                                txn.commit().await?;
304                            }
305                            VerifyResult::Missing | VerifyResult::Partial => {
306                                if step.requirement() == Requirement::Required {
307                                    tracing::info!("│ Result: {:<49} │", "FAILED ✗");
308                                    // Rollback before returning error
309                                    let _ = txn.rollback().await;
310                                    return Err(BootstrapError::VerificationFailed {
311                                        step: name,
312                                        expected: "Present".to_string(),
313                                        actual: format!("{:?}", after_result.result),
314                                    });
315                                }
316                                tracing::info!("│ Result: {:<49} │", "Incomplete (optional)");
317                                completed.insert(name);
318                                txn.commit().await?;
319                            }
320                        }
321                    }
322                }
323            }
324            tracing::info!("└───────────────────────────────────────────────────────────┘");
325        }
326
327        // Print summary
328        if !outputs.is_empty() {
329            self.print_summary(&outputs);
330        }
331
332        Ok(BootstrapSummary { results, outputs })
333    }
334
335    fn print_summary(&self, outputs: &[StepOutput]) {
336        const MIN_WIDTH: usize = 40;
337        const HEADER: &str = "BOOTSTRAP COMPLETE";
338        const SENSITIVE_MSG: &str = "⚠ SENSITIVE - Save this information now!";
339
340        // Calculate required width based on content (using char count for unicode support)
341        let mut max_content_len = HEADER.chars().count();
342        let mut has_sensitive = false;
343
344        for output in outputs {
345            max_content_len = max_content_len.max(output.title.chars().count() + 2);
346            for (key, value) in &output.fields {
347                max_content_len =
348                    max_content_len.max(key.chars().count() + value.chars().count() + 7);
349            }
350            if output.is_sensitive {
351                max_content_len = max_content_len.max(SENSITIVE_MSG.chars().count() + 4);
352                has_sensitive = true;
353            }
354        }
355
356        let width = MIN_WIDTH.max(max_content_len + 4);
357
358        // Helper to print a content line (width = total chars inside the box borders)
359        let print_line = |content: &str| {
360            let padding = width.saturating_sub(content.chars().count() + 1);
361            tracing::info!("║ {}{}║", content, " ".repeat(padding));
362        };
363
364        // Print header box
365        tracing::info!("");
366        tracing::info!("╔{}╗", "═".repeat(width));
367        let header_pad = width.saturating_sub(HEADER.chars().count()) / 2;
368        tracing::info!(
369            "║{}{}{}║",
370            " ".repeat(header_pad),
371            HEADER,
372            " ".repeat(width - HEADER.chars().count() - header_pad)
373        );
374        tracing::info!("╠{}╣", "═".repeat(width));
375
376        // Print each output section
377        for output in outputs {
378            tracing::info!("║{}║", " ".repeat(width));
379            print_line(&format!("  {}", output.title));
380
381            for (key, value) in &output.fields {
382                print_line(&format!("    {}: {}", key, value));
383            }
384
385            if output.is_sensitive {
386                print_line(&format!("  {}  ", SENSITIVE_MSG));
387            }
388        }
389
390        tracing::info!("╚{}╝", "═".repeat(width));
391
392        if has_sensitive {
393            tracing::warn!(
394                "⚠️  Sensitive data was generated. Save the output above before it scrolls away!"
395            );
396        }
397    }
398}