burn_central_workspace/tools/
function_discovery.rs1use crate::execution::cancellable::{CancellableProcess, CancellableResult, CancellationToken};
6use quote::ToTokens;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::io::{BufRead, BufReader};
10use std::path::{Path, PathBuf};
11use std::process::Stdio;
12use std::sync::Arc;
13
14const MAGIC: &str = "BCFN1|";
15const END: &str = "|END";
16const SEP: char = '|';
17
18#[derive(Clone, Debug, Serialize, Deserialize)]
19pub struct FunctionMetadata {
20 pub mod_path: String,
21 pub fn_name: String,
22 pub builder_fn_name: String,
23 pub routine_name: String,
24 pub proc_type: String,
25 pub token_stream: Vec<u8>,
26}
27
28impl FunctionMetadata {
29 pub fn get_function_code(&self) -> String {
30 if self.token_stream.is_empty() {
31 format!(
33 "fn {}() {{\n // Function implementation not available\n}}",
34 self.fn_name
35 )
36 } else {
37 if let Ok(source_code) = std::str::from_utf8(&self.token_stream) {
39 if !source_code.trim_start().starts_with('{') {
41 return source_code.to_string();
42 }
43 }
44
45 match syn_serde::json::from_slice::<syn::ItemFn>(&self.token_stream) {
47 Ok(itemfn) => match syn::parse2(itemfn.into_token_stream()) {
48 Ok(syn_tree) => prettyplease::unparse(&syn_tree),
49 Err(_) => format!(
50 "fn {}() {{\n // Failed to parse token stream\n}}",
51 self.fn_name
52 ),
53 },
54 Err(_) => format!(
55 "fn {}() {{\n // Failed to deserialize token stream\n}}",
56 self.fn_name
57 ),
58 }
59 }
60 }
61}
62
63#[derive(Debug, thiserror::Error)]
64pub enum DiscoveryError {
65 #[error("Failed to spawn cargo rustc process: {0}")]
66 SpawnFailed(String),
67 #[error("Cargo rustc failed for package '{package}' (status: {status})")]
68 CargoError {
69 package: String,
70 status: i32,
71 diagnostics: String,
72 },
73 #[error("Function discovery was cancelled")]
74 Cancelled,
75}
76
77#[derive(Debug, Clone, PartialEq, Eq, Hash)]
78pub struct PkgId {
79 pub name: String,
80 pub version: Option<String>,
81}
82
83#[derive(Debug, Clone)]
84pub struct FunctionDiscovery {
85 project_root: PathBuf,
86}
87
88pub struct DiscoveryConfig {
89 pub packages: Vec<PkgId>,
90 pub target_dir: Option<PathBuf>,
91}
92
93#[derive(Debug)]
94pub struct DiscoveryResult {
95 pub functions: HashMap<PkgId, Vec<FunctionMetadata>>,
96}
97
98pub struct DiscoveryEvent {
99 pub package: PkgId,
100 pub message: Option<String>,
101}
102
103type DiscoveryEventReporter = dyn crate::event::Reporter<DiscoveryEvent>;
104
105impl FunctionDiscovery {
106 pub fn new(project_root: impl Into<PathBuf>) -> Self {
107 Self {
108 project_root: project_root.into(),
109 }
110 }
111
112 pub fn discover_functions(
114 &self,
115 discovery_config: &DiscoveryConfig,
116 cancellation_token: &CancellationToken,
117 event_reporter: Option<Arc<DiscoveryEventReporter>>,
118 ) -> Result<DiscoveryResult, DiscoveryError> {
119 let mut package_functions = HashMap::new();
120 for package in &discovery_config.packages {
121 let expanded = self.expand_with_cargo(
122 package,
123 discovery_config.target_dir.as_deref(),
124 cancellation_token,
125 event_reporter.clone(),
126 )?;
127
128 let functions = parse_expanded_output(&expanded);
129 package_functions
130 .entry(package.clone())
131 .or_insert_with(Vec::new)
132 .extend(functions);
133
134 if let Some(reporter) = event_reporter.as_ref() {
135 reporter.report_event(DiscoveryEvent {
136 package: package.clone(),
137 message: Some(format!(
138 "Discovered {} functions",
139 package_functions.get(package).map_or(0, |fns| fns.len()),
140 )),
141 });
142 }
143 }
144
145 let result = DiscoveryResult {
146 functions: package_functions,
147 };
148 Ok(result)
149 }
150
151 fn expand_with_cargo(
152 &self,
153 package: &PkgId,
154 target_dir: Option<&Path>,
155 cancellation_token: &CancellationToken,
156 event_reporter: Option<Arc<DiscoveryEventReporter>>,
157 ) -> Result<String, DiscoveryError> {
158 let mut cmd = super::cargo::command();
159 cmd.current_dir(&self.project_root)
160 .arg("rustc")
161 .arg("--lib")
162 .arg("--profile=check")
163 .arg("--message-format=json")
164 .arg("--quiet");
165
166 let spec = if let Some(ref version) = package.version {
167 format!("{}@{}", package.name, version)
168 } else {
169 package.name.to_string()
170 };
171 cmd.arg("-p").arg(spec);
172
173 if let Some(target_dir) = target_dir {
174 cmd.arg("--target-dir").arg(target_dir);
175 }
176
177 cmd.arg("--");
178 cmd.arg("-Zunpretty=expanded");
179 cmd.env("RUSTC_BOOTSTRAP", "1");
180 cmd.env("RUST_LOG", "error");
181
182 let mut child = cmd
183 .stdout(Stdio::piped())
184 .stderr(Stdio::piped())
185 .stdin(Stdio::null())
186 .spawn()
187 .map_err(|e| DiscoveryError::SpawnFailed(e.to_string()))?;
188
189 let (output_tx, output_rx) = std::sync::mpsc::channel();
190 let (errors_tx, errors_rx) = std::sync::mpsc::channel();
191 if let Some(stdout) = child.stdout.take() {
193 let reader = BufReader::new(stdout);
194 let package = package.clone();
195 let event_reporter = event_reporter.clone();
196 let errors_tx = errors_tx.clone();
197 std::thread::spawn(move || {
198 let stream = cargo_metadata::Message::parse_stream(reader);
199 for message in stream.flatten() {
200 match message {
201 cargo_metadata::Message::CompilerMessage(msg) => {
202 let rendered = msg.message.rendered.unwrap_or_default();
203 if matches!(
204 msg.message.level,
205 cargo_metadata::diagnostic::DiagnosticLevel::Error
206 ) {
207 let _ = errors_tx.send(rendered);
208 }
209 }
210 cargo_metadata::Message::CompilerArtifact(_artifact) => {
211 if let Some(ref reporter) = event_reporter {
212 reporter.report_event(DiscoveryEvent {
213 package: package.clone(),
214 message: Some(format!(
215 "Compiled artifact: {}",
216 _artifact.target.name
217 )),
218 });
219 }
220 }
221 cargo_metadata::Message::TextLine(line) => {
222 let _ = output_tx.send(line.clone());
223 }
224 _ => {}
225 }
226 }
227 });
228 }
229
230 if let Some(stderr) = child.stderr.take() {
231 let reader = BufReader::new(stderr);
232 let errors_tx = errors_tx.clone();
233 std::thread::spawn(move || {
234 for line in reader.lines().map_while(Result::ok) {
235 let _ = errors_tx.send(line);
236 }
237 });
238 }
239
240 let cancellable = CancellableProcess::new(child, cancellation_token.clone());
241 let result = cancellable.wait();
242
243 match result {
244 CancellableResult::Completed(status) => {
245 if !status.success() {
246 return Err(DiscoveryError::CargoError {
247 package: package.name.clone(),
248 status: status.code().unwrap_or(-1),
249 diagnostics: errors_rx.try_iter().collect::<Vec<_>>().join("\n"),
250 });
251 }
252 let expanded = output_rx.try_iter().collect::<Vec<String>>().join("\n");
253 Ok(expanded)
254 }
255 CancellableResult::Cancelled => Err(DiscoveryError::Cancelled),
256 }
257 }
258}
259
260fn parse_expanded_output(expanded: &str) -> Vec<FunctionMetadata> {
261 let bytes = expanded.as_bytes();
262 let mut i = 0usize;
263 let mut out = Vec::new();
264
265 while let Some(m) = find(bytes, MAGIC.as_bytes(), i) {
266 let start_payload = m + MAGIC.len();
267 if let Some(end) = find(bytes, END.as_bytes(), start_payload) {
268 if let Ok(slice) = std::str::from_utf8(&bytes[m..end + END.len()]) {
269 if let Some(meta) = parse_bcfn_marker(slice) {
270 out.push(meta);
271 }
272 }
273 i = end + END.len();
274 } else {
275 break;
277 }
278 }
279
280 for meta in &mut out {
281 let result = extract_ast_token_stream(expanded, &meta.fn_name);
282 if let Some(token_stream) = result {
283 meta.token_stream = token_stream;
284 }
285 }
286
287 out
288}
289
290fn parse_bcfn_marker(marker: &str) -> Option<FunctionMetadata> {
292 if !marker.starts_with(MAGIC) || !marker.ends_with(END) {
293 return None;
294 }
295 let body = &marker[MAGIC.len()..marker.len() - END.len()];
296 let mut it = body.split(SEP);
297
298 let mod_path = it.next()?.to_string();
299 let fn_name = it.next()?.to_string();
300 let builder_fn_name = it.next()?.to_string();
301 let routine_name = it.next()?.to_string();
302 let proc_type = it.next()?.to_string();
303
304 if it.next().is_some() {
306 return None;
307 }
308
309 Some(FunctionMetadata {
310 mod_path,
311 fn_name,
312 builder_fn_name,
313 routine_name,
314 proc_type,
315 token_stream: Vec::new(),
316 })
317}
318
319fn find(hay: &[u8], needle: &[u8], mut from: usize) -> Option<usize> {
321 while from + needle.len() <= hay.len() {
322 if &hay[from..from + needle.len()] == needle {
323 return Some(from);
324 }
325 from += 1;
326 }
327 None
328}
329
330fn unescape_byte_string(escaped: &str) -> Vec<u8> {
333 let mut result = Vec::new();
334 let mut chars = escaped.chars();
335
336 while let Some(ch) = chars.next() {
337 if ch == '\\' {
338 if let Some(next) = chars.next() {
340 match next {
341 '"' => result.push(b'"'),
342 '\\' => result.push(b'\\'),
343 'n' => result.push(b'\n'),
344 'r' => result.push(b'\r'),
345 't' => result.push(b'\t'),
346 _ => {
348 result.push(b'\\');
349 result.extend(next.to_string().as_bytes());
350 }
351 }
352 } else {
353 result.push(b'\\');
355 }
356 } else {
357 result.extend(ch.to_string().as_bytes());
359 }
360 }
361
362 result
363}
364
365fn extract_ast_token_stream(expanded: &str, fn_name: &str) -> Option<Vec<u8>> {
368 let ast_const_name = format!("_BURN_FUNCTION_AST_{}", fn_name.to_uppercase());
370
371 let const_pattern = format!("const {}: &[u8]", ast_const_name);
373 let const_pos = expanded.find(&const_pattern)?;
374
375 let search_start = const_pos + const_pattern.len();
377 let b_quote_pattern = "b\"";
378 let b_quote_pos = expanded[search_start..].find(b_quote_pattern)?;
379 let content_start = search_start + b_quote_pos + b_quote_pattern.len();
380
381 let chars: Vec<char> = expanded[content_start..].chars().collect();
383 let mut pos = 0;
384
385 while pos < chars.len() {
386 if chars[pos] == '\\' && pos + 1 < chars.len() {
387 pos += 2;
389 } else if chars[pos] == '"' {
390 if pos + 1 < chars.len() && chars[pos + 1] == ';' {
393 let escaped_content: String = chars[..pos].iter().collect();
394 return Some(unescape_byte_string(&escaped_content));
395 } else {
396 pos += 1;
397 }
398 } else {
399 pos += 1;
400 }
401 }
402
403 None
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn parses_markers() {
412 let expanded = r#"
413 /* noise */ const X:&str="hello";
414 const BURN_CENTRAL_FUNCTION_TRAIN:&str="BCFN1|my::module|train_fn|__train_fn_builder|train|training|END";
415 const BURN_CENTRAL_FUNCTION_EVAL:&str=
416 "BCFN1|my::module|eval_fn|__eval_fn_builder|evaluate|training|END";
417 "#;
418
419 let v = parse_expanded_output(expanded);
420 assert_eq!(v.len(), 2);
421 assert_eq!(v[0].mod_path, "my::module");
422 assert_eq!(v[0].fn_name, "train_fn");
423 assert_eq!(v[1].fn_name, "eval_fn");
424 assert_eq!(v[1].routine_name, "evaluate");
425 }
426
427 #[test]
428 fn rejects_bad_marker() {
429 let bad = "BCFN1|a|b|c|d|END";
430 assert!(parse_bcfn_marker(bad).is_none());
431 }
432
433 #[test]
434 fn accepts_complex_mod_path() {
435 let ok = "BCFN1|a::b::c::d|f|__builder|r|training|END";
436 let m = parse_bcfn_marker(ok).unwrap();
437 assert_eq!(m.mod_path, "a::b::c::d");
438 }
439
440 #[test]
441 fn unescapes_byte_string() {
442 let escaped = r#"hello \"world\" with \\backslash\\ and \n newline"#;
443 let result = unescape_byte_string(escaped);
444 let expected = b"hello \"world\" with \\backslash\\ and \n newline";
445 assert_eq!(result, expected);
446 }
447
448 #[test]
449 fn extracts_ast_token_stream() {
450 let expanded = r#"
451 const _: () = {
452 const BURN_CENTRAL_FUNCTION_TEST: &str = "BCFN1|my::module|test|__test_builder|test|training|END";
453 const _BURN_FUNCTION_AST_TEST: &[u8] = b"{\"vis\":\"pub\",\"ident\":\"test\"}";
454 };
455 "#;
456
457 let token_stream = extract_ast_token_stream(expanded, "test").unwrap();
458 let expected = b"{\"vis\":\"pub\",\"ident\":\"test\"}";
459 assert_eq!(token_stream, expected);
460 }
461
462 #[test]
463 fn parses_markers_with_ast() {
464 let expanded = r#"
465 const _: () = {
466 const BURN_CENTRAL_FUNCTION_TRAIN:&str="BCFN1|my::module|train_fn|__train_fn_builder|train|training|END";
467 const _BURN_FUNCTION_AST_TRAIN_FN: &[u8] = b"{\"vis\":\"pub\",\"ident\":\"train_fn\"}";
468 };
469 const _: () = {
470 const BURN_CENTRAL_FUNCTION_EVAL:&str="BCFN1|my::module|eval_fn|__eval_fn_builder|evaluate|training|END";
471 const _BURN_FUNCTION_AST_EVAL_FN: &[u8] = b"{\"vis\":\"pub\",\"ident\":\"eval_fn\"}";
472 };
473 "#;
474
475 let v = parse_expanded_output(expanded);
476 assert_eq!(v.len(), 2);
477
478 assert_eq!(v[0].mod_path, "my::module");
480 assert_eq!(v[0].fn_name, "train_fn");
481
482 assert!(!v[0].token_stream.is_empty());
484 assert!(!v[1].token_stream.is_empty());
485
486 let expected_train = b"{\"vis\":\"pub\",\"ident\":\"train_fn\"}";
488 let expected_eval = b"{\"vis\":\"pub\",\"ident\":\"eval_fn\"}";
489 assert_eq!(v[0].token_stream, expected_train);
490 assert_eq!(v[1].token_stream, expected_eval);
491 }
492
493 #[test]
494 fn handles_missing_ast_gracefully() {
495 let expanded = r#"
496 const BURN_CENTRAL_FUNCTION_TRAIN:&str="BCFN1|my::module|train_fn|__train_fn_builder|train|training|END";
497 "#;
498
499 let v = parse_expanded_output(expanded);
500 assert_eq!(v.len(), 1);
501 assert_eq!(v[0].fn_name, "train_fn");
502 assert!(v[0].token_stream.is_empty());
504 }
505
506 #[test]
507 fn extracts_ast_with_newlines() {
508 let expanded = r#"
510 #[allow(dead_code)]
511 const BURN_CENTRAL_FUNCTION_TRAINING: &str =
512 "BCFN1|mnist_heat::training|training|__training_builder|mnist|training|END";
513 #[allow(dead_code)]
514 const _BURN_FUNCTION_AST_TRAINING: &[u8] =
515 b"{\"vis\":\"pub\",\"ident\":\"training\"}";
516 "#;
517
518 let token_stream = extract_ast_token_stream(expanded, "training").unwrap();
519 let expected = b"{\"vis\":\"pub\",\"ident\":\"training\"}";
520 assert_eq!(token_stream, expected);
521 }
522
523 #[test]
524 fn extracts_real_world_ast() {
525 let expanded = r#"
527 #[allow(dead_code)]
528 const BURN_CENTRAL_FUNCTION_TRAINING: &str =
529 "BCFN1|mnist_heat::training|training|__training_builder|mnist|training|END";
530 #[allow(dead_code)]
531 const _BURN_FUNCTION_AST_TRAINING: &[u8] =
532 b"{\"vis\":\"pub\",\"ident\":\"training\",\"generics\":{\"params\":[{\"type\":{\"ident\":\"B\",\"colon_token\":true,\"bounds\":[{\"trait\":{\"path\":{\"segments\":[{\"ident\":\"AutodiffBackend\"}]}}}]}}]},\"inputs\":[{\"typed\":{\"pat\":{\"ident\":{\"ident\":\"client\"}},\"ty\":{\"reference\":{\"elem\":{\"path\":{\"segments\":[{\"ident\":\"ExperimentRun\"}]}}}}}},{\"typed\":{\"pat\":{\"ident\":{\"ident\":\"config\"}},\"ty\":{\"path\":{\"segments\":[{\"ident\":\"Args\",\"arguments\":{\"angle_bracketed\":{\"args\":[{\"type\":{\"path\":{\"segments\":[{\"ident\":\"ExperimentConfig\"}]}}}]}}}]}}}}],\"output\":{\"path\":{\"segments\":[{\"ident\":\"Result\"}]}}}";
533 "#;
534
535 let token_stream = extract_ast_token_stream(expanded, "training").unwrap();
536
537 let json_str = std::str::from_utf8(&token_stream).unwrap();
539 assert!(json_str.starts_with("{\"vis\":\"pub\",\"ident\":\"training\""));
540 assert!(json_str.contains("\"ident\":\"AutodiffBackend\""));
541 assert!(json_str.contains("\"ident\":\"client\""));
542 assert!(json_str.contains("\"ident\":\"config\""));
543
544 let _: serde_json::Value =
546 serde_json::from_slice(&token_stream).expect("Token stream should be valid JSON");
547 }
548
549 #[test]
550 fn get_function_code_returns_source_with_comments() {
551 let meta = FunctionMetadata {
552 mod_path: "my::module".to_string(),
553 fn_name: "test".to_string(),
554 builder_fn_name: "__test_builder".to_string(),
555 routine_name: "test".to_string(),
556 proc_type: "training".to_string(),
557 token_stream: "pub fn test() {\n // Important comment\n let value = 42;\n}"
558 .as_bytes()
559 .to_vec(),
560 };
561
562 let code = meta.get_function_code();
563 assert!(code.contains("// Important comment"));
564 assert!(code.contains("let value = 42;"));
565 }
566}