no_block_pls/
lib.rs

1pub mod transformer;
2
3#[cfg(test)]
4mod loop_test;
5
6use anyhow::Result;
7use std::fs;
8use std::path::Path;
9use syn::{ExprAsync, File, ImplItemFn, Item, ItemFn, parse_file, parse_quote, visit::Visit};
10use transformer::AsyncInstrumenter;
11
12/// Visitor to check if a file contains any async code
13struct AsyncDetector {
14    has_async: bool,
15}
16
17impl AsyncDetector {
18    fn new() -> Self {
19        Self { has_async: false }
20    }
21
22    fn check(file: &File) -> bool {
23        let mut detector = Self::new();
24        detector.visit_file(file);
25        detector.has_async
26    }
27}
28
29impl<'ast> Visit<'ast> for AsyncDetector {
30    fn visit_expr_async(&mut self, _: &'ast ExprAsync) {
31        self.has_async = true;
32    }
33
34    fn visit_impl_item_fn(&mut self, func: &'ast ImplItemFn) {
35        if func.sig.asyncness.is_some() {
36            self.has_async = true;
37            return;
38        }
39        syn::visit::visit_impl_item_fn(self, func);
40    }
41
42    fn visit_item_fn(&mut self, func: &'ast ItemFn) {
43        if func.sig.asyncness.is_some() {
44            self.has_async = true;
45            return;
46        }
47        syn::visit::visit_item_fn(self, func);
48    }
49}
50
51#[derive(Default)]
52struct InstrumentationDetector {
53    already_instrumented: bool,
54}
55
56impl InstrumentationDetector {
57    fn check(file: &File) -> bool {
58        let mut detector = Self::default();
59        detector.visit_file(file);
60        detector.already_instrumented
61    }
62}
63
64impl<'ast> Visit<'ast> for InstrumentationDetector {
65    fn visit_item_mod(&mut self, module: &'ast syn::ItemMod) {
66        if module.ident == "__async_profile_guard__" {
67            self.already_instrumented = true;
68            return;
69        }
70        syn::visit::visit_item_mod(self, module);
71    }
72
73    fn visit_path(&mut self, path: &'ast syn::Path) {
74        if self.already_instrumented {
75            return;
76        }
77
78        let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
79
80        if segments.windows(3).any(|window| {
81            window[0] == "__async_profile_guard__" && window[1] == "Guard" && window[2] == "new"
82        }) {
83            self.already_instrumented = true;
84            return;
85        }
86
87        syn::visit::visit_path(self, path);
88    }
89}
90
91/// Generate the guard module code with the specified threshold
92fn generate_guard_module(threshold_ms: u64) -> Item {
93    parse_quote! {
94        #[doc(hidden)]
95        #[allow(dead_code)]
96        mod __async_profile_guard__ {
97            use std::time::{Duration, Instant};
98
99            const THRESHOLD_MS: u64 = #threshold_ms;
100
101            pub struct Guard {
102                name: &'static str,
103                file: &'static str,
104                from_line: u32,
105                current_start: Option<Instant>,
106                consecutive_hits: u32,
107            }
108
109            impl Guard {
110                pub fn new(name: &'static str, file: &'static str, line: u32) -> Self {
111                    Guard {
112                        name,
113                        file,
114                        from_line: line,
115                        current_start: Some(Instant::now()),
116                        consecutive_hits: 0,
117                    }
118                }
119
120                pub fn checkpoint(&mut self, new_line: u32) {
121                    if let Some(start) = self.current_start.take() {
122                        let elapsed = start.elapsed();
123                        if elapsed > Duration::from_millis(THRESHOLD_MS) {
124                            self.consecutive_hits = self.consecutive_hits.saturating_add(1);
125                            let span = format!("{file}:{from}-{to}", file = self.file, from = self.from_line, to = new_line);
126                            let wraparound = new_line < self.from_line;
127                            if wraparound {
128                                tracing::warn!(
129                                    elapsed_ms = elapsed.as_millis(),
130                                    name = %self.name,
131                                    span = %span,
132                                    hits = self.consecutive_hits,
133                                    wraparound = wraparound,
134                                    "long poll (iteration tail wraparound)"
135                                );
136                            } else {
137                                tracing::warn!(
138                                    elapsed_ms = elapsed.as_millis(),
139                                    name = %self.name,
140                                    span = %span,
141                                    hits = self.consecutive_hits,
142                                    wraparound = wraparound,
143                                    "long poll (iteration tail)"
144                                );
145                            }
146                        } else {
147                            self.consecutive_hits = 0;
148                        }
149                    }
150                    self.from_line = new_line;
151                    self.current_start = Some(Instant::now());
152                }
153
154                pub fn end_section(&mut self, to_line: u32) {
155                    if let Some(start) = self.current_start.take() {
156                        let elapsed = start.elapsed();
157                        if elapsed > Duration::from_millis(THRESHOLD_MS) {
158                            self.consecutive_hits = self.consecutive_hits.saturating_add(1);
159                            let span = format!("{file}:{from}-{to}", file = self.file, from = self.from_line, to = to_line);
160                            let wraparound = to_line < self.from_line;
161                            if wraparound {
162                                tracing::warn!(
163                                    elapsed_ms = elapsed.as_millis(),
164                                    name = %self.name,
165                                    span = %span,
166                                    hits = self.consecutive_hits,
167                                    wraparound = wraparound,
168                                    "long poll (loop wraparound)"
169                                );
170                            } else {
171                                tracing::warn!(
172                                    elapsed_ms = elapsed.as_millis(),
173                                    name = %self.name,
174                                    span = %span,
175                                    hits = self.consecutive_hits,
176                                    wraparound = wraparound,
177                                    "long poll"
178                                );
179                            }
180                        } else {
181                            self.consecutive_hits = 0;
182                        }
183                    }
184                }
185
186                pub fn start_section(&mut self, new_line: u32) {
187                    self.from_line = new_line;
188                    self.current_start = Some(Instant::now());
189                }
190            }
191
192            impl Drop for Guard {
193                fn drop(&mut self) {
194                    // Check final section if still timing
195                    if let Some(start) = self.current_start {
196                        let elapsed = start.elapsed();
197                        if elapsed > Duration::from_millis(THRESHOLD_MS) {
198                            self.consecutive_hits = self.consecutive_hits.saturating_add(1);
199                            let span =
200                                format!("{file}:{line}-{line}", file = self.file, line = self.from_line);
201                            tracing::warn!(
202                                elapsed_ms = elapsed.as_millis(),
203                                name = %self.name,
204                                span = %span,
205                                hits = self.consecutive_hits,
206                                wraparound = false,
207                                "long poll"
208                            );
209                        }
210                    }
211                }
212            }
213        }
214    }
215}
216
217/// Inject the guard module into a root file (lib.rs or main.rs)
218/// This should only be called once per crate
219pub fn inject_guard_module(source: &str, threshold_ms: u64) -> Result<String> {
220    let mut syntax_tree = parse_file(source)?;
221
222    if InstrumentationDetector::check(&syntax_tree) {
223        return Ok(source.to_owned());
224    }
225
226    // Insert guard module at the beginning
227    let guard_module = generate_guard_module(threshold_ms);
228    syntax_tree.items.insert(0, guard_module);
229
230    // Also instrument any async functions in this file
231    let mut instrumenter = AsyncInstrumenter::new(threshold_ms);
232    instrumenter.instrument_file(&mut syntax_tree);
233
234    let formatted = prettyplease::unparse(&syntax_tree);
235    Ok(formatted)
236}
237
238/// Instrument async functions without injecting the guard module
239/// Use this for all non-root files (assumes guard module exists in crate root)
240/// Returns None if the file has no async code
241pub fn instrument_async_only(source: &str) -> Result<Option<String>> {
242    let mut syntax_tree = parse_file(source)?;
243
244    if InstrumentationDetector::check(&syntax_tree) {
245        return Ok(None);
246    }
247
248    // Check if file has any async code
249    if !AsyncDetector::check(&syntax_tree) {
250        return Ok(None);
251    }
252
253    // Only instrument async functions, don't inject guard module
254    let mut instrumenter = AsyncInstrumenter::new(10);
255    instrumenter.instrument_file(&mut syntax_tree);
256
257    let formatted = prettyplease::unparse(&syntax_tree);
258    Ok(Some(formatted))
259}
260
261/// Process a single Rust source file and return the instrumented code with default threshold
262/// This injects the guard module AND instruments async functions (backward compatibility)
263pub fn instrument_code(source: &str) -> Result<String> {
264    instrument_code_with_threshold(source, 10)
265}
266
267/// Process a single Rust source file and return the instrumented code with specified threshold
268/// This injects the guard module AND instruments async functions (backward compatibility)
269pub fn instrument_code_with_threshold(source: &str, threshold_ms: u64) -> Result<String> {
270    inject_guard_module(source, threshold_ms)
271}
272
273/// Process a file at the given path and return the instrumented code with default threshold
274pub fn instrument_file(path: &Path) -> Result<String> {
275    let content = fs::read_to_string(path)?;
276    instrument_code(&content)
277}
278
279/// Process a file at the given path and return the instrumented code with specified threshold
280pub fn instrument_file_with_threshold(path: &Path, threshold_ms: u64) -> Result<String> {
281    let content = fs::read_to_string(path)?;
282    instrument_code_with_threshold(&content, threshold_ms)
283}
284
285/// Process a file and write the instrumented version back (with backup)
286pub fn instrument_file_in_place(path: &Path) -> Result<()> {
287    let instrumented = instrument_file(path)?;
288
289    // Create backup
290    let backup_path = path.with_extension("rs.bak");
291    fs::copy(path, &backup_path)?;
292
293    // Write instrumented version
294    fs::write(path, instrumented)?;
295
296    Ok(())
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_inject_guard_module() {
305        let source = r#"
306async fn fetch_data() {
307    let response = client.get().await;
308    let parsed = parse(response);
309    store(parsed).await;
310}
311"#;
312
313        let result = inject_guard_module(source, 10).unwrap();
314        assert!(result.contains("mod __async_profile_guard__"));
315        assert!(result.contains("__guard"));
316        assert!(result.contains("__guard.end_section("));
317        assert!(result.contains("__guard.start_section("));
318        assert!(!result.contains("line!()"));
319    }
320
321    #[test]
322    fn test_instrument_async_only() {
323        let source = r#"
324async fn fetch_data() {
325    let response = client.get().await;
326    let parsed = parse(response);
327    store(parsed).await;
328}
329"#;
330
331        let result = instrument_async_only(source).unwrap();
332        assert!(result.is_some(), "Should instrument async functions");
333        let instrumented = result.unwrap();
334        // Should NOT contain the module definition
335        assert!(!instrumented.contains("mod __async_profile_guard__"));
336        // But should contain references to it
337        assert!(instrumented.contains("crate::__async_profile_guard__::Guard::new"));
338        assert!(instrumented.contains("__guard.end_section("));
339        assert!(instrumented.contains("__guard.start_section("));
340        assert!(!instrumented.contains("line!()"));
341    }
342
343    #[test]
344    fn test_instrument_async_only_idempotent() {
345        let source = r#"
346async fn fetch_data() {
347    let response = client.get().await;
348    store(response).await;
349}
350"#;
351
352        let first = instrument_async_only(source).unwrap().unwrap();
353        assert!(instrument_async_only(&first).unwrap().is_none());
354    }
355
356    #[test]
357    fn test_inject_guard_module_idempotent() {
358        let source = r#"
359async fn action() {
360    do_it().await;
361}
362"#;
363
364        let first = inject_guard_module(source, 10).unwrap();
365        let second = inject_guard_module(&first, 10).unwrap();
366        assert_eq!(first, second);
367    }
368
369    #[test]
370    fn test_skip_non_async_file() {
371        let source = r#"
372fn regular_function() {
373    println!("No async here");
374}
375
376struct MyStruct {
377    field: String,
378}
379
380impl MyStruct {
381    fn new() -> Self {
382        Self { field: String::new() }
383    }
384}
385"#;
386
387        let result = instrument_async_only(source).unwrap();
388        assert!(result.is_none(), "Should skip files without async code");
389    }
390
391    #[test]
392    fn test_detect_async_in_impl() {
393        let source = r#"
394struct Service;
395
396impl Service {
397    async fn handle_request(&self) {
398        tokio::time::sleep(Duration::from_millis(100)).await;
399    }
400}
401"#;
402
403        let result = instrument_async_only(source).unwrap();
404        assert!(
405            result.is_some(),
406            "Should detect async methods in impl blocks"
407        );
408    }
409
410    #[test]
411    fn test_detect_async_block() {
412        let source = r#"
413fn spawn_task() {
414    tokio::spawn(async {
415        println!("In async block");
416    });
417}
418"#;
419
420        let result = instrument_async_only(source).unwrap();
421        assert!(result.is_some(), "Should detect async blocks");
422    }
423
424    #[test]
425    fn test_instrument_code() {
426        let source = r#"
427async fn fetch_data() {
428    let response = client.get().await;
429    let parsed = parse(response);
430    store(parsed).await;
431}
432"#;
433
434        let result = instrument_code(source).unwrap();
435        assert!(result.contains("__guard"));
436        assert!(result.contains("__guard.end_section("));
437        assert!(result.contains("__guard.start_section("));
438        assert!(!result.contains("line!()"));
439    }
440
441    #[test]
442    fn test_no_instrument_sync() {
443        let source = r#"
444fn sync_function() {
445    let x = 42;
446    println!("{}", x);
447}
448"#;
449
450        let result = instrument_code(source).unwrap();
451        assert!(!result.contains("__guard"));
452    }
453}