Skip to main content

burn_central_cli/tools/
discovery.rs

1use crate::context::CliContext;
2use anyhow::Context;
3use burn_central_workspace::{
4    ProjectContext,
5    tools::{function_discovery::DiscoveryEvent, functions_registry::FunctionRegistry},
6};
7use cliclack::{MultiProgress, ProgressBar};
8use colored::Colorize;
9use std::sync::{Arc, Mutex};
10use std::time::Instant;
11
12/// Reporter for function discovery progress
13struct DiscoveryReporter {
14    multi_progress: MultiProgress,
15    main_progress: ProgressBar,
16    current_package: Mutex<Option<String>>,
17    current_message: Mutex<String>,
18    package_start_time: Mutex<Option<Instant>>,
19}
20
21impl DiscoveryReporter {
22    fn new(terminal: &crate::tools::terminal::Terminal) -> Self {
23        let multi_progress = terminal.multiprogress("Discovering project functions");
24        let main_progress = multi_progress.add(terminal.spinner());
25
26        Self {
27            multi_progress,
28            main_progress,
29            current_package: Mutex::new(None),
30            current_message: Mutex::new("Starting...".to_string()),
31            package_start_time: Mutex::new(None),
32        }
33    }
34
35    fn add_to_history(&self, message: String) {
36        self.multi_progress
37            .println(format!("  {}", message.dimmed()));
38    }
39
40    fn update_display(&self) {
41        let current_package = self.current_package.lock().unwrap();
42        let current_message = self.current_message.lock().unwrap();
43        let package_start_time = self.package_start_time.lock().unwrap();
44
45        if let (Some(package), Some(start_time)) = (current_package.as_ref(), *package_start_time) {
46            let elapsed_time = crate::tools::time::format_elapsed_time(start_time.elapsed());
47
48            self.main_progress.set_message(format!(
49                "{} {} [{}]",
50                package.green().bold(),
51                current_message.trim(),
52                elapsed_time.dimmed()
53            ));
54        }
55    }
56
57    fn start(&self) {
58        self.main_progress.start("Initializing...");
59    }
60
61    fn stop(&self, count: usize) {
62        self.flush_active_package();
63        self.main_progress.stop(format!(
64            "Discovered {} function{}.",
65            count,
66            if count == 1 { "" } else { "s" }
67        ));
68        self.multi_progress.stop();
69    }
70
71    fn error(&self, message: String) {
72        let current_package = self.current_package.lock().unwrap();
73        let current_message = self.current_message.lock().unwrap();
74
75        if let Some(package) = current_package.as_ref() {
76            self.main_progress.set_message(format!(
77                "{} {} [{}]",
78                package.red().bold(),
79                current_message.trim(),
80                "x".red()
81            ));
82        }
83        self.multi_progress.error(message);
84    }
85
86    fn flush_active_package(&self) {
87        let current_package = self.current_package.lock().unwrap();
88        let current_message = self.current_message.lock().unwrap();
89
90        if let Some(package) = current_package.as_ref() {
91            if !current_message.trim().is_empty() && current_message.trim() != "Starting..." {
92                let history_msg = format!("{} - {}", package, current_message.trim());
93                self.add_to_history(history_msg);
94            }
95        }
96    }
97
98    fn report_event(&self, event: DiscoveryEvent) {
99        let message = event.message.unwrap_or_else(|| "Analyzing...".to_string());
100        let package_name = event.package.name.clone();
101
102        let mut current_package = self.current_package.lock().unwrap();
103        let mut package_start_time = self.package_start_time.lock().unwrap();
104
105        let is_new_package = current_package.as_ref() != Some(&package_name);
106
107        if is_new_package {
108            drop(current_package);
109            drop(package_start_time);
110            self.flush_active_package();
111
112            current_package = self.current_package.lock().unwrap();
113            package_start_time = self.package_start_time.lock().unwrap();
114
115            *current_package = Some(package_name.clone());
116            *package_start_time = Some(Instant::now());
117        }
118
119        // Update current message
120        *self.current_message.lock().unwrap() = message;
121
122        drop(current_package);
123        drop(package_start_time);
124
125        self.update_display();
126    }
127}
128
129pub fn preload_functions(
130    context: &CliContext,
131    project: &ProjectContext,
132) -> anyhow::Result<FunctionRegistry> {
133    let reporter = Arc::new(DiscoveryReporter::new(context.terminal()));
134    reporter.start();
135
136    let reporter_clone = Arc::clone(&reporter);
137    let functions =
138        project
139            .load_functions(Some(Arc::new(move |event: DiscoveryEvent| {
140                reporter_clone.report_event(event);
141            })))
142            .inspect_err(|e| {
143                reporter.error("Failed to discover project functions.".to_string());
144                match e {
145                burn_central_workspace::tools::function_discovery::DiscoveryError::CargoError {
146                    package: _,
147                    status: _,
148                    diagnostics,
149                } => {
150                    context.terminal().print_err(&format!("Error: {}", e));
151
152                    let header = "=== RUSTC DIAGNOSTICS ===";
153                    let footer = "=".repeat(header.len());
154                    let message =
155                        format!("{}\n\n{}\n{}", header.yellow(), diagnostics, footer.yellow());
156                    context.terminal().print_err(&message);
157                }
158                _ => {
159                    context.terminal().print_err(&format!("Error: {}", e));
160                }
161            }
162            })
163            .context("Function discovery failed")?;
164
165    reporter.stop(functions.num_functions());
166    Ok(functions)
167}