Skip to main content

burn_central_workspace/tools/
function_discovery.rs

1//! Build-time function discovery using cargo rustc macro expansion
2//!
3//! Uses `cargo rustc -- -Zunpretty=expanded` to extract `BCFN1|mod_path|fn|builder|routine|proc_type|END` markers from the expanded source code.
4
5use crate::execution::cancellable::{CancellableProcess, CancellableResult, CancellationToken};
6use quote::ToTokens;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::io::{BufRead, BufReader};
10use std::path::{Path, PathBuf};
11use std::process::Stdio;
12use std::sync::Arc;
13
14const MAGIC: &str = "BCFN1|";
15const END: &str = "|END";
16const SEP: char = '|';
17
18#[derive(Clone, Debug, Serialize, Deserialize)]
19pub struct FunctionMetadata {
20    pub mod_path: String,
21    pub fn_name: String,
22    pub builder_fn_name: String,
23    pub routine_name: String,
24    pub proc_type: String,
25    pub token_stream: Vec<u8>,
26}
27
28impl FunctionMetadata {
29    pub fn get_function_code(&self) -> String {
30        if self.token_stream.is_empty() {
31            // If no token stream is available, create a placeholder function
32            format!(
33                "fn {}() {{\n    // Function implementation not available\n}}",
34                self.fn_name
35            )
36        } else {
37            // Try to decode as UTF-8 string first (new format with original source)
38            if let Ok(source_code) = std::str::from_utf8(&self.token_stream) {
39                // Check if it looks like Rust source code (not JSON)
40                if !source_code.trim_start().starts_with('{') {
41                    return source_code.to_string();
42                }
43            }
44
45            // Fall back to JSON AST deserialization (old format)
46            match syn_serde::json::from_slice::<syn::ItemFn>(&self.token_stream) {
47                Ok(itemfn) => match syn::parse2(itemfn.into_token_stream()) {
48                    Ok(syn_tree) => prettyplease::unparse(&syn_tree),
49                    Err(_) => format!(
50                        "fn {}() {{\n    // Failed to parse token stream\n}}",
51                        self.fn_name
52                    ),
53                },
54                Err(_) => format!(
55                    "fn {}() {{\n    // Failed to deserialize token stream\n}}",
56                    self.fn_name
57                ),
58            }
59        }
60    }
61}
62
63#[derive(Debug, thiserror::Error)]
64pub enum DiscoveryError {
65    #[error("Failed to spawn cargo rustc process: {0}")]
66    SpawnFailed(String),
67    #[error("Cargo rustc failed for package '{package}' (status: {status})")]
68    CargoError {
69        package: String,
70        status: i32,
71        diagnostics: String,
72    },
73    #[error("Function discovery was cancelled")]
74    Cancelled,
75}
76
77#[derive(Debug, Clone, PartialEq, Eq, Hash)]
78pub struct PkgId {
79    pub name: String,
80    pub version: Option<String>,
81}
82
83#[derive(Debug, Clone)]
84pub struct FunctionDiscovery {
85    project_root: PathBuf,
86}
87
88pub struct DiscoveryConfig {
89    pub packages: Vec<PkgId>,
90    pub target_dir: Option<PathBuf>,
91}
92
93#[derive(Debug)]
94pub struct DiscoveryResult {
95    pub functions: HashMap<PkgId, Vec<FunctionMetadata>>,
96}
97
98pub struct DiscoveryEvent {
99    pub package: PkgId,
100    pub message: Option<String>,
101}
102
103type DiscoveryEventReporter = dyn crate::event::Reporter<DiscoveryEvent>;
104
105impl FunctionDiscovery {
106    pub fn new(project_root: impl Into<PathBuf>) -> Self {
107        Self {
108            project_root: project_root.into(),
109        }
110    }
111
112    /// Expand and extract with cancellation support
113    pub fn discover_functions(
114        &self,
115        discovery_config: &DiscoveryConfig,
116        cancellation_token: &CancellationToken,
117        event_reporter: Option<Arc<DiscoveryEventReporter>>,
118    ) -> Result<DiscoveryResult, DiscoveryError> {
119        let mut package_functions = HashMap::new();
120        for package in &discovery_config.packages {
121            let expanded = self.expand_with_cargo(
122                package,
123                discovery_config.target_dir.as_deref(),
124                cancellation_token,
125                event_reporter.clone(),
126            )?;
127
128            let functions = parse_expanded_output(&expanded);
129            package_functions
130                .entry(package.clone())
131                .or_insert_with(Vec::new)
132                .extend(functions);
133
134            if let Some(reporter) = event_reporter.as_ref() {
135                reporter.report_event(DiscoveryEvent {
136                    package: package.clone(),
137                    message: Some(format!(
138                        "Discovered {} functions",
139                        package_functions.get(package).map_or(0, |fns| fns.len()),
140                    )),
141                });
142            }
143        }
144
145        let result = DiscoveryResult {
146            functions: package_functions,
147        };
148        Ok(result)
149    }
150
151    fn expand_with_cargo(
152        &self,
153        package: &PkgId,
154        target_dir: Option<&Path>,
155        cancellation_token: &CancellationToken,
156        event_reporter: Option<Arc<DiscoveryEventReporter>>,
157    ) -> Result<String, DiscoveryError> {
158        let mut cmd = super::cargo::command();
159        cmd.current_dir(&self.project_root)
160            .arg("rustc")
161            .arg("--lib")
162            .arg("--profile=check")
163            .arg("--message-format=json")
164            .arg("--quiet");
165
166        let spec = if let Some(ref version) = package.version {
167            format!("{}@{}", package.name, version)
168        } else {
169            package.name.to_string()
170        };
171        cmd.arg("-p").arg(spec);
172
173        if let Some(target_dir) = target_dir {
174            cmd.arg("--target-dir").arg(target_dir);
175        }
176
177        cmd.arg("--");
178        cmd.arg("-Zunpretty=expanded");
179        cmd.env("RUSTC_BOOTSTRAP", "1");
180        cmd.env("RUST_LOG", "error");
181
182        let mut child = cmd
183            .stdout(Stdio::piped())
184            .stderr(Stdio::piped())
185            .stdin(Stdio::null())
186            .spawn()
187            .map_err(|e| DiscoveryError::SpawnFailed(e.to_string()))?;
188
189        let (output_tx, output_rx) = std::sync::mpsc::channel();
190        let (errors_tx, errors_rx) = std::sync::mpsc::channel();
191        // Capture and report stdout (cargo messages)
192        if let Some(stdout) = child.stdout.take() {
193            let reader = BufReader::new(stdout);
194            let package = package.clone();
195            let event_reporter = event_reporter.clone();
196            let errors_tx = errors_tx.clone();
197            std::thread::spawn(move || {
198                let stream = cargo_metadata::Message::parse_stream(reader);
199                for message in stream.flatten() {
200                    match message {
201                        cargo_metadata::Message::CompilerMessage(msg) => {
202                            let rendered = msg.message.rendered.unwrap_or_default();
203                            if matches!(
204                                msg.message.level,
205                                cargo_metadata::diagnostic::DiagnosticLevel::Error
206                            ) {
207                                let _ = errors_tx.send(rendered);
208                            }
209                        }
210                        cargo_metadata::Message::CompilerArtifact(_artifact) => {
211                            if let Some(ref reporter) = event_reporter {
212                                reporter.report_event(DiscoveryEvent {
213                                    package: package.clone(),
214                                    message: Some(format!(
215                                        "Compiled artifact: {}",
216                                        _artifact.target.name
217                                    )),
218                                });
219                            }
220                        }
221                        cargo_metadata::Message::TextLine(line) => {
222                            let _ = output_tx.send(line.clone());
223                        }
224                        _ => {}
225                    }
226                }
227            });
228        }
229
230        if let Some(stderr) = child.stderr.take() {
231            let reader = BufReader::new(stderr);
232            let errors_tx = errors_tx.clone();
233            std::thread::spawn(move || {
234                for line in reader.lines().map_while(Result::ok) {
235                    let _ = errors_tx.send(line);
236                }
237            });
238        }
239
240        let cancellable = CancellableProcess::new(child, cancellation_token.clone());
241        let result = cancellable.wait();
242
243        match result {
244            CancellableResult::Completed(status) => {
245                if !status.success() {
246                    return Err(DiscoveryError::CargoError {
247                        package: package.name.clone(),
248                        status: status.code().unwrap_or(-1),
249                        diagnostics: errors_rx.try_iter().collect::<Vec<_>>().join("\n"),
250                    });
251                }
252                let expanded = output_rx.try_iter().collect::<Vec<String>>().join("\n");
253                Ok(expanded)
254            }
255            CancellableResult::Cancelled => Err(DiscoveryError::Cancelled),
256        }
257    }
258}
259
260fn parse_expanded_output(expanded: &str) -> Vec<FunctionMetadata> {
261    let bytes = expanded.as_bytes();
262    let mut i = 0usize;
263    let mut out = Vec::new();
264
265    while let Some(m) = find(bytes, MAGIC.as_bytes(), i) {
266        let start_payload = m + MAGIC.len();
267        if let Some(end) = find(bytes, END.as_bytes(), start_payload) {
268            if let Ok(slice) = std::str::from_utf8(&bytes[m..end + END.len()]) {
269                if let Some(meta) = parse_bcfn_marker(slice) {
270                    out.push(meta);
271                }
272            }
273            i = end + END.len();
274        } else {
275            // no closing sentinel; stop scanning
276            break;
277        }
278    }
279
280    for meta in &mut out {
281        let result = extract_ast_token_stream(expanded, &meta.fn_name);
282        if let Some(token_stream) = result {
283            meta.token_stream = token_stream;
284        }
285    }
286
287    out
288}
289
290/// Expected `BCFN1|mod_path|fn_name|builder|routine|proc_type|END`.
291fn parse_bcfn_marker(marker: &str) -> Option<FunctionMetadata> {
292    if !marker.starts_with(MAGIC) || !marker.ends_with(END) {
293        return None;
294    }
295    let body = &marker[MAGIC.len()..marker.len() - END.len()];
296    let mut it = body.split(SEP);
297
298    let mod_path = it.next()?.to_string();
299    let fn_name = it.next()?.to_string();
300    let builder_fn_name = it.next()?.to_string();
301    let routine_name = it.next()?.to_string();
302    let proc_type = it.next()?.to_string();
303
304    // There must be exactly 5 parts.
305    if it.next().is_some() {
306        return None;
307    }
308
309    Some(FunctionMetadata {
310        mod_path,
311        fn_name,
312        builder_fn_name,
313        routine_name,
314        proc_type,
315        token_stream: Vec::new(),
316    })
317}
318
319/// Naive byte-substring search (no regex).
320fn find(hay: &[u8], needle: &[u8], mut from: usize) -> Option<usize> {
321    while from + needle.len() <= hay.len() {
322        if &hay[from..from + needle.len()] == needle {
323            return Some(from);
324        }
325        from += 1;
326    }
327    None
328}
329
330/// Unescape a Rust byte string literal (without the surrounding b"...")
331/// Handles common escape sequences: \", \\, \n, \r, \t
332fn unescape_byte_string(escaped: &str) -> Vec<u8> {
333    let mut result = Vec::new();
334    let mut chars = escaped.chars();
335
336    while let Some(ch) = chars.next() {
337        if ch == '\\' {
338            // Handle escape sequences
339            if let Some(next) = chars.next() {
340                match next {
341                    '"' => result.push(b'"'),
342                    '\\' => result.push(b'\\'),
343                    'n' => result.push(b'\n'),
344                    'r' => result.push(b'\r'),
345                    't' => result.push(b'\t'),
346                    // For any other escape, just include it as-is
347                    _ => {
348                        result.push(b'\\');
349                        result.extend(next.to_string().as_bytes());
350                    }
351                }
352            } else {
353                // Trailing backslash
354                result.push(b'\\');
355            }
356        } else {
357            // Regular character - convert to bytes
358            result.extend(ch.to_string().as_bytes());
359        }
360    }
361
362    result
363}
364
365/// Extract the JSON AST from a _BURN_FUNCTION_AST_* constant
366/// Pattern: const _BURN_FUNCTION_AST_NAME: &[u8] = b"{...json...}";
367fn extract_ast_token_stream(expanded: &str, fn_name: &str) -> Option<Vec<u8>> {
368    // Derive the AST constant name from the function name
369    let ast_const_name = format!("_BURN_FUNCTION_AST_{}", fn_name.to_uppercase());
370
371    // Search for the constant declaration
372    let const_pattern = format!("const {}: &[u8]", ast_const_name);
373    let const_pos = expanded.find(&const_pattern)?;
374
375    // Find the `b"` after the constant declaration (allowing for whitespace/newlines between = and b")
376    let search_start = const_pos + const_pattern.len();
377    let b_quote_pattern = "b\"";
378    let b_quote_pos = expanded[search_start..].find(b_quote_pattern)?;
379    let content_start = search_start + b_quote_pos + b_quote_pattern.len();
380
381    // Find the closing `";`
382    let chars: Vec<char> = expanded[content_start..].chars().collect();
383    let mut pos = 0;
384
385    while pos < chars.len() {
386        if chars[pos] == '\\' && pos + 1 < chars.len() {
387            // Skip the escaped character
388            pos += 2;
389        } else if chars[pos] == '"' {
390            // Found the closing quote
391            // Check if it's followed by `;`
392            if pos + 1 < chars.len() && chars[pos + 1] == ';' {
393                let escaped_content: String = chars[..pos].iter().collect();
394                return Some(unescape_byte_string(&escaped_content));
395            } else {
396                pos += 1;
397            }
398        } else {
399            pos += 1;
400        }
401    }
402
403    None
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn parses_markers() {
412        let expanded = r#"
413            /* noise */ const X:&str="hello";
414            const BURN_CENTRAL_FUNCTION_TRAIN:&str="BCFN1|my::module|train_fn|__train_fn_builder|train|training|END";
415            const BURN_CENTRAL_FUNCTION_EVAL:&str=
416                "BCFN1|my::module|eval_fn|__eval_fn_builder|evaluate|training|END";
417        "#;
418
419        let v = parse_expanded_output(expanded);
420        assert_eq!(v.len(), 2);
421        assert_eq!(v[0].mod_path, "my::module");
422        assert_eq!(v[0].fn_name, "train_fn");
423        assert_eq!(v[1].fn_name, "eval_fn");
424        assert_eq!(v[1].routine_name, "evaluate");
425    }
426
427    #[test]
428    fn rejects_bad_marker() {
429        let bad = "BCFN1|a|b|c|d|END";
430        assert!(parse_bcfn_marker(bad).is_none());
431    }
432
433    #[test]
434    fn accepts_complex_mod_path() {
435        let ok = "BCFN1|a::b::c::d|f|__builder|r|training|END";
436        let m = parse_bcfn_marker(ok).unwrap();
437        assert_eq!(m.mod_path, "a::b::c::d");
438    }
439
440    #[test]
441    fn unescapes_byte_string() {
442        let escaped = r#"hello \"world\" with \\backslash\\ and \n newline"#;
443        let result = unescape_byte_string(escaped);
444        let expected = b"hello \"world\" with \\backslash\\ and \n newline";
445        assert_eq!(result, expected);
446    }
447
448    #[test]
449    fn extracts_ast_token_stream() {
450        let expanded = r#"
451            const _: () = {
452                const BURN_CENTRAL_FUNCTION_TEST: &str = "BCFN1|my::module|test|__test_builder|test|training|END";
453                const _BURN_FUNCTION_AST_TEST: &[u8] = b"{\"vis\":\"pub\",\"ident\":\"test\"}";
454            };
455        "#;
456
457        let token_stream = extract_ast_token_stream(expanded, "test").unwrap();
458        let expected = b"{\"vis\":\"pub\",\"ident\":\"test\"}";
459        assert_eq!(token_stream, expected);
460    }
461
462    #[test]
463    fn parses_markers_with_ast() {
464        let expanded = r#"
465            const _: () = {
466                const BURN_CENTRAL_FUNCTION_TRAIN:&str="BCFN1|my::module|train_fn|__train_fn_builder|train|training|END";
467                const _BURN_FUNCTION_AST_TRAIN_FN: &[u8] = b"{\"vis\":\"pub\",\"ident\":\"train_fn\"}";
468            };
469            const _: () = {
470                const BURN_CENTRAL_FUNCTION_EVAL:&str="BCFN1|my::module|eval_fn|__eval_fn_builder|evaluate|training|END";
471                const _BURN_FUNCTION_AST_EVAL_FN: &[u8] = b"{\"vis\":\"pub\",\"ident\":\"eval_fn\"}";
472            };
473        "#;
474
475        let v = parse_expanded_output(expanded);
476        assert_eq!(v.len(), 2);
477
478        // Verify metadata
479        assert_eq!(v[0].mod_path, "my::module");
480        assert_eq!(v[0].fn_name, "train_fn");
481
482        // Verify token streams are populated
483        assert!(!v[0].token_stream.is_empty());
484        assert!(!v[1].token_stream.is_empty());
485
486        // Verify token stream content
487        let expected_train = b"{\"vis\":\"pub\",\"ident\":\"train_fn\"}";
488        let expected_eval = b"{\"vis\":\"pub\",\"ident\":\"eval_fn\"}";
489        assert_eq!(v[0].token_stream, expected_train);
490        assert_eq!(v[1].token_stream, expected_eval);
491    }
492
493    #[test]
494    fn handles_missing_ast_gracefully() {
495        let expanded = r#"
496            const BURN_CENTRAL_FUNCTION_TRAIN:&str="BCFN1|my::module|train_fn|__train_fn_builder|train|training|END";
497        "#;
498
499        let v = parse_expanded_output(expanded);
500        assert_eq!(v.len(), 1);
501        assert_eq!(v[0].fn_name, "train_fn");
502        // Token stream should be empty when AST constant is missing
503        assert!(v[0].token_stream.is_empty());
504    }
505
506    #[test]
507    fn extracts_ast_with_newlines() {
508        // This is the actual format from the macro expansion
509        let expanded = r#"
510            #[allow(dead_code)]
511            const BURN_CENTRAL_FUNCTION_TRAINING: &str =
512                "BCFN1|mnist_heat::training|training|__training_builder|mnist|training|END";
513            #[allow(dead_code)]
514            const _BURN_FUNCTION_AST_TRAINING: &[u8] =
515                b"{\"vis\":\"pub\",\"ident\":\"training\"}";
516        "#;
517
518        let token_stream = extract_ast_token_stream(expanded, "training").unwrap();
519        let expected = b"{\"vis\":\"pub\",\"ident\":\"training\"}";
520        assert_eq!(token_stream, expected);
521    }
522
523    #[test]
524    fn extracts_real_world_ast() {
525        // This is the actual full format from the mnist project
526        let expanded = r#"
527            #[allow(dead_code)]
528            const BURN_CENTRAL_FUNCTION_TRAINING: &str =
529                "BCFN1|mnist_heat::training|training|__training_builder|mnist|training|END";
530            #[allow(dead_code)]
531            const _BURN_FUNCTION_AST_TRAINING: &[u8] =
532                b"{\"vis\":\"pub\",\"ident\":\"training\",\"generics\":{\"params\":[{\"type\":{\"ident\":\"B\",\"colon_token\":true,\"bounds\":[{\"trait\":{\"path\":{\"segments\":[{\"ident\":\"AutodiffBackend\"}]}}}]}}]},\"inputs\":[{\"typed\":{\"pat\":{\"ident\":{\"ident\":\"client\"}},\"ty\":{\"reference\":{\"elem\":{\"path\":{\"segments\":[{\"ident\":\"ExperimentRun\"}]}}}}}},{\"typed\":{\"pat\":{\"ident\":{\"ident\":\"config\"}},\"ty\":{\"path\":{\"segments\":[{\"ident\":\"Args\",\"arguments\":{\"angle_bracketed\":{\"args\":[{\"type\":{\"path\":{\"segments\":[{\"ident\":\"ExperimentConfig\"}]}}}]}}}]}}}}],\"output\":{\"path\":{\"segments\":[{\"ident\":\"Result\"}]}}}";
533        "#;
534
535        let token_stream = extract_ast_token_stream(expanded, "training").unwrap();
536
537        // Verify it starts with the expected JSON structure
538        let json_str = std::str::from_utf8(&token_stream).unwrap();
539        assert!(json_str.starts_with("{\"vis\":\"pub\",\"ident\":\"training\""));
540        assert!(json_str.contains("\"ident\":\"AutodiffBackend\""));
541        assert!(json_str.contains("\"ident\":\"client\""));
542        assert!(json_str.contains("\"ident\":\"config\""));
543
544        // Verify it's valid JSON by attempting to parse it
545        let _: serde_json::Value =
546            serde_json::from_slice(&token_stream).expect("Token stream should be valid JSON");
547    }
548
549    #[test]
550    fn get_function_code_returns_source_with_comments() {
551        let meta = FunctionMetadata {
552            mod_path: "my::module".to_string(),
553            fn_name: "test".to_string(),
554            builder_fn_name: "__test_builder".to_string(),
555            routine_name: "test".to_string(),
556            proc_type: "training".to_string(),
557            token_stream: "pub fn test() {\n    // Important comment\n    let value = 42;\n}"
558                .as_bytes()
559                .to_vec(),
560        };
561
562        let code = meta.get_function_code();
563        assert!(code.contains("// Important comment"));
564        assert!(code.contains("let value = 42;"));
565    }
566}