burn_central_cli/tools/
discovery.rs1use crate::context::CliContext;
2use anyhow::Context;
3use burn_central_workspace::{
4 ProjectContext,
5 tools::{function_discovery::DiscoveryEvent, functions_registry::FunctionRegistry},
6};
7use cliclack::{MultiProgress, ProgressBar};
8use colored::Colorize;
9use std::sync::{Arc, Mutex};
10use std::time::Instant;
11
12struct DiscoveryReporter {
14 multi_progress: MultiProgress,
15 main_progress: ProgressBar,
16 current_package: Mutex<Option<String>>,
17 current_message: Mutex<String>,
18 package_start_time: Mutex<Option<Instant>>,
19}
20
21impl DiscoveryReporter {
22 fn new(terminal: &crate::tools::terminal::Terminal) -> Self {
23 let multi_progress = terminal.multiprogress("Discovering project functions");
24 let main_progress = multi_progress.add(terminal.spinner());
25
26 Self {
27 multi_progress,
28 main_progress,
29 current_package: Mutex::new(None),
30 current_message: Mutex::new("Starting...".to_string()),
31 package_start_time: Mutex::new(None),
32 }
33 }
34
35 fn add_to_history(&self, message: String) {
36 self.multi_progress
37 .println(format!(" {}", message.dimmed()));
38 }
39
40 fn update_display(&self) {
41 let current_package = self.current_package.lock().unwrap();
42 let current_message = self.current_message.lock().unwrap();
43 let package_start_time = self.package_start_time.lock().unwrap();
44
45 if let (Some(package), Some(start_time)) = (current_package.as_ref(), *package_start_time) {
46 let elapsed_time = crate::tools::time::format_elapsed_time(start_time.elapsed());
47
48 self.main_progress.set_message(format!(
49 "{} {} [{}]",
50 package.green().bold(),
51 current_message.trim(),
52 elapsed_time.dimmed()
53 ));
54 }
55 }
56
57 fn start(&self) {
58 self.main_progress.start("Initializing...");
59 }
60
61 fn stop(&self, count: usize) {
62 self.flush_active_package();
63 self.main_progress.stop(format!(
64 "Discovered {} function{}.",
65 count,
66 if count == 1 { "" } else { "s" }
67 ));
68 self.multi_progress.stop();
69 }
70
71 fn error(&self, message: String) {
72 let current_package = self.current_package.lock().unwrap();
73 let current_message = self.current_message.lock().unwrap();
74
75 if let Some(package) = current_package.as_ref() {
76 self.main_progress.set_message(format!(
77 "{} {} [{}]",
78 package.red().bold(),
79 current_message.trim(),
80 "x".red()
81 ));
82 }
83 self.multi_progress.error(message);
84 }
85
86 fn flush_active_package(&self) {
87 let current_package = self.current_package.lock().unwrap();
88 let current_message = self.current_message.lock().unwrap();
89
90 if let Some(package) = current_package.as_ref() {
91 if !current_message.trim().is_empty() && current_message.trim() != "Starting..." {
92 let history_msg = format!("{} - {}", package, current_message.trim());
93 self.add_to_history(history_msg);
94 }
95 }
96 }
97
98 fn report_event(&self, event: DiscoveryEvent) {
99 let message = event.message.unwrap_or_else(|| "Analyzing...".to_string());
100 let package_name = event.package.name.clone();
101
102 let mut current_package = self.current_package.lock().unwrap();
103 let mut package_start_time = self.package_start_time.lock().unwrap();
104
105 let is_new_package = current_package.as_ref() != Some(&package_name);
106
107 if is_new_package {
108 drop(current_package);
109 drop(package_start_time);
110 self.flush_active_package();
111
112 current_package = self.current_package.lock().unwrap();
113 package_start_time = self.package_start_time.lock().unwrap();
114
115 *current_package = Some(package_name.clone());
116 *package_start_time = Some(Instant::now());
117 }
118
119 *self.current_message.lock().unwrap() = message;
121
122 drop(current_package);
123 drop(package_start_time);
124
125 self.update_display();
126 }
127}
128
129pub fn preload_functions(
130 context: &CliContext,
131 project: &ProjectContext,
132) -> anyhow::Result<FunctionRegistry> {
133 let reporter = Arc::new(DiscoveryReporter::new(context.terminal()));
134 reporter.start();
135
136 let reporter_clone = Arc::clone(&reporter);
137 let functions =
138 project
139 .load_functions(Some(Arc::new(move |event: DiscoveryEvent| {
140 reporter_clone.report_event(event);
141 })))
142 .inspect_err(|e| {
143 reporter.error("Failed to discover project functions.".to_string());
144 match e {
145 burn_central_workspace::tools::function_discovery::DiscoveryError::CargoError {
146 package: _,
147 status: _,
148 diagnostics,
149 } => {
150 context.terminal().print_err(&format!("Error: {}", e));
151
152 let header = "=== RUSTC DIAGNOSTICS ===";
153 let footer = "=".repeat(header.len());
154 let message =
155 format!("{}\n\n{}\n{}", header.yellow(), diagnostics, footer.yellow());
156 context.terminal().print_err(&message);
157 }
158 _ => {
159 context.terminal().print_err(&format!("Error: {}", e));
160 }
161 }
162 })
163 .context("Function discovery failed")?;
164
165 reporter.stop(functions.num_functions());
166 Ok(functions)
167}