1pub mod transformer;
2
3#[cfg(test)]
4mod loop_test;
5
6use anyhow::Result;
7use std::fs;
8use std::path::Path;
9use syn::{ExprAsync, File, ImplItemFn, Item, ItemFn, parse_file, parse_quote, visit::Visit};
10use transformer::AsyncInstrumenter;
11
12struct AsyncDetector {
14 has_async: bool,
15}
16
17impl AsyncDetector {
18 fn new() -> Self {
19 Self { has_async: false }
20 }
21
22 fn check(file: &File) -> bool {
23 let mut detector = Self::new();
24 detector.visit_file(file);
25 detector.has_async
26 }
27}
28
29impl<'ast> Visit<'ast> for AsyncDetector {
30 fn visit_expr_async(&mut self, _: &'ast ExprAsync) {
31 self.has_async = true;
32 }
33
34 fn visit_impl_item_fn(&mut self, func: &'ast ImplItemFn) {
35 if func.sig.asyncness.is_some() {
36 self.has_async = true;
37 return;
38 }
39 syn::visit::visit_impl_item_fn(self, func);
40 }
41
42 fn visit_item_fn(&mut self, func: &'ast ItemFn) {
43 if func.sig.asyncness.is_some() {
44 self.has_async = true;
45 return;
46 }
47 syn::visit::visit_item_fn(self, func);
48 }
49}
50
51#[derive(Default)]
52struct InstrumentationDetector {
53 already_instrumented: bool,
54}
55
56impl InstrumentationDetector {
57 fn check(file: &File) -> bool {
58 let mut detector = Self::default();
59 detector.visit_file(file);
60 detector.already_instrumented
61 }
62}
63
64impl<'ast> Visit<'ast> for InstrumentationDetector {
65 fn visit_item_mod(&mut self, module: &'ast syn::ItemMod) {
66 if module.ident == "__async_profile_guard__" {
67 self.already_instrumented = true;
68 return;
69 }
70 syn::visit::visit_item_mod(self, module);
71 }
72
73 fn visit_path(&mut self, path: &'ast syn::Path) {
74 if self.already_instrumented {
75 return;
76 }
77
78 let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
79
80 if segments.windows(3).any(|window| {
81 window[0] == "__async_profile_guard__" && window[1] == "Guard" && window[2] == "new"
82 }) {
83 self.already_instrumented = true;
84 return;
85 }
86
87 syn::visit::visit_path(self, path);
88 }
89}
90
91fn generate_guard_module(threshold_ms: u64) -> Item {
93 parse_quote! {
94 #[doc(hidden)]
95 #[allow(dead_code)]
96 mod __async_profile_guard__ {
97 use std::time::{Duration, Instant};
98
99 const THRESHOLD_MS: u64 = #threshold_ms;
100
101 pub struct Guard {
102 name: &'static str,
103 file: &'static str,
104 from_line: u32,
105 current_start: Option<Instant>,
106 consecutive_hits: u32,
107 }
108
109 impl Guard {
110 pub fn new(name: &'static str, file: &'static str, line: u32) -> Self {
111 Guard {
112 name,
113 file,
114 from_line: line,
115 current_start: Some(Instant::now()),
116 consecutive_hits: 0,
117 }
118 }
119
120 pub fn checkpoint(&mut self, new_line: u32) {
121 if let Some(start) = self.current_start.take() {
122 let elapsed = start.elapsed();
123 if elapsed > Duration::from_millis(THRESHOLD_MS) {
124 self.consecutive_hits = self.consecutive_hits.saturating_add(1);
125 let span = format!("{file}:{from}-{to}", file = self.file, from = self.from_line, to = new_line);
126 let wraparound = new_line < self.from_line;
127 if wraparound {
128 tracing::warn!(
129 elapsed_ms = elapsed.as_millis(),
130 name = %self.name,
131 span = %span,
132 hits = self.consecutive_hits,
133 wraparound = wraparound,
134 "long poll (iteration tail wraparound)"
135 );
136 } else {
137 tracing::warn!(
138 elapsed_ms = elapsed.as_millis(),
139 name = %self.name,
140 span = %span,
141 hits = self.consecutive_hits,
142 wraparound = wraparound,
143 "long poll (iteration tail)"
144 );
145 }
146 } else {
147 self.consecutive_hits = 0;
148 }
149 }
150 self.from_line = new_line;
151 self.current_start = Some(Instant::now());
152 }
153
154 pub fn end_section(&mut self, to_line: u32) {
155 if let Some(start) = self.current_start.take() {
156 let elapsed = start.elapsed();
157 if elapsed > Duration::from_millis(THRESHOLD_MS) {
158 self.consecutive_hits = self.consecutive_hits.saturating_add(1);
159 let span = format!("{file}:{from}-{to}", file = self.file, from = self.from_line, to = to_line);
160 let wraparound = to_line < self.from_line;
161 if wraparound {
162 tracing::warn!(
163 elapsed_ms = elapsed.as_millis(),
164 name = %self.name,
165 span = %span,
166 hits = self.consecutive_hits,
167 wraparound = wraparound,
168 "long poll (loop wraparound)"
169 );
170 } else {
171 tracing::warn!(
172 elapsed_ms = elapsed.as_millis(),
173 name = %self.name,
174 span = %span,
175 hits = self.consecutive_hits,
176 wraparound = wraparound,
177 "long poll"
178 );
179 }
180 } else {
181 self.consecutive_hits = 0;
182 }
183 }
184 }
185
186 pub fn start_section(&mut self, new_line: u32) {
187 self.from_line = new_line;
188 self.current_start = Some(Instant::now());
189 }
190 }
191
192 impl Drop for Guard {
193 fn drop(&mut self) {
194 if let Some(start) = self.current_start {
196 let elapsed = start.elapsed();
197 if elapsed > Duration::from_millis(THRESHOLD_MS) {
198 self.consecutive_hits = self.consecutive_hits.saturating_add(1);
199 let span =
200 format!("{file}:{line}-{line}", file = self.file, line = self.from_line);
201 tracing::warn!(
202 elapsed_ms = elapsed.as_millis(),
203 name = %self.name,
204 span = %span,
205 hits = self.consecutive_hits,
206 wraparound = false,
207 "long poll"
208 );
209 }
210 }
211 }
212 }
213 }
214 }
215}
216
217pub fn inject_guard_module(source: &str, threshold_ms: u64) -> Result<String> {
220 let mut syntax_tree = parse_file(source)?;
221
222 if InstrumentationDetector::check(&syntax_tree) {
223 return Ok(source.to_owned());
224 }
225
226 let guard_module = generate_guard_module(threshold_ms);
228 syntax_tree.items.insert(0, guard_module);
229
230 let mut instrumenter = AsyncInstrumenter::new(threshold_ms);
232 instrumenter.instrument_file(&mut syntax_tree);
233
234 let formatted = prettyplease::unparse(&syntax_tree);
235 Ok(formatted)
236}
237
238pub fn instrument_async_only(source: &str) -> Result<Option<String>> {
242 let mut syntax_tree = parse_file(source)?;
243
244 if InstrumentationDetector::check(&syntax_tree) {
245 return Ok(None);
246 }
247
248 if !AsyncDetector::check(&syntax_tree) {
250 return Ok(None);
251 }
252
253 let mut instrumenter = AsyncInstrumenter::new(10);
255 instrumenter.instrument_file(&mut syntax_tree);
256
257 let formatted = prettyplease::unparse(&syntax_tree);
258 Ok(Some(formatted))
259}
260
261pub fn instrument_code(source: &str) -> Result<String> {
264 instrument_code_with_threshold(source, 10)
265}
266
267pub fn instrument_code_with_threshold(source: &str, threshold_ms: u64) -> Result<String> {
270 inject_guard_module(source, threshold_ms)
271}
272
273pub fn instrument_file(path: &Path) -> Result<String> {
275 let content = fs::read_to_string(path)?;
276 instrument_code(&content)
277}
278
279pub fn instrument_file_with_threshold(path: &Path, threshold_ms: u64) -> Result<String> {
281 let content = fs::read_to_string(path)?;
282 instrument_code_with_threshold(&content, threshold_ms)
283}
284
285pub fn instrument_file_in_place(path: &Path) -> Result<()> {
287 let instrumented = instrument_file(path)?;
288
289 let backup_path = path.with_extension("rs.bak");
291 fs::copy(path, &backup_path)?;
292
293 fs::write(path, instrumented)?;
295
296 Ok(())
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_inject_guard_module() {
305 let source = r#"
306async fn fetch_data() {
307 let response = client.get().await;
308 let parsed = parse(response);
309 store(parsed).await;
310}
311"#;
312
313 let result = inject_guard_module(source, 10).unwrap();
314 assert!(result.contains("mod __async_profile_guard__"));
315 assert!(result.contains("__guard"));
316 assert!(result.contains("__guard.end_section("));
317 assert!(result.contains("__guard.start_section("));
318 assert!(!result.contains("line!()"));
319 }
320
321 #[test]
322 fn test_instrument_async_only() {
323 let source = r#"
324async fn fetch_data() {
325 let response = client.get().await;
326 let parsed = parse(response);
327 store(parsed).await;
328}
329"#;
330
331 let result = instrument_async_only(source).unwrap();
332 assert!(result.is_some(), "Should instrument async functions");
333 let instrumented = result.unwrap();
334 assert!(!instrumented.contains("mod __async_profile_guard__"));
336 assert!(instrumented.contains("crate::__async_profile_guard__::Guard::new"));
338 assert!(instrumented.contains("__guard.end_section("));
339 assert!(instrumented.contains("__guard.start_section("));
340 assert!(!instrumented.contains("line!()"));
341 }
342
343 #[test]
344 fn test_instrument_async_only_idempotent() {
345 let source = r#"
346async fn fetch_data() {
347 let response = client.get().await;
348 store(response).await;
349}
350"#;
351
352 let first = instrument_async_only(source).unwrap().unwrap();
353 assert!(instrument_async_only(&first).unwrap().is_none());
354 }
355
356 #[test]
357 fn test_inject_guard_module_idempotent() {
358 let source = r#"
359async fn action() {
360 do_it().await;
361}
362"#;
363
364 let first = inject_guard_module(source, 10).unwrap();
365 let second = inject_guard_module(&first, 10).unwrap();
366 assert_eq!(first, second);
367 }
368
369 #[test]
370 fn test_skip_non_async_file() {
371 let source = r#"
372fn regular_function() {
373 println!("No async here");
374}
375
376struct MyStruct {
377 field: String,
378}
379
380impl MyStruct {
381 fn new() -> Self {
382 Self { field: String::new() }
383 }
384}
385"#;
386
387 let result = instrument_async_only(source).unwrap();
388 assert!(result.is_none(), "Should skip files without async code");
389 }
390
391 #[test]
392 fn test_detect_async_in_impl() {
393 let source = r#"
394struct Service;
395
396impl Service {
397 async fn handle_request(&self) {
398 tokio::time::sleep(Duration::from_millis(100)).await;
399 }
400}
401"#;
402
403 let result = instrument_async_only(source).unwrap();
404 assert!(
405 result.is_some(),
406 "Should detect async methods in impl blocks"
407 );
408 }
409
410 #[test]
411 fn test_detect_async_block() {
412 let source = r#"
413fn spawn_task() {
414 tokio::spawn(async {
415 println!("In async block");
416 });
417}
418"#;
419
420 let result = instrument_async_only(source).unwrap();
421 assert!(result.is_some(), "Should detect async blocks");
422 }
423
424 #[test]
425 fn test_instrument_code() {
426 let source = r#"
427async fn fetch_data() {
428 let response = client.get().await;
429 let parsed = parse(response);
430 store(parsed).await;
431}
432"#;
433
434 let result = instrument_code(source).unwrap();
435 assert!(result.contains("__guard"));
436 assert!(result.contains("__guard.end_section("));
437 assert!(result.contains("__guard.start_section("));
438 assert!(!result.contains("line!()"));
439 }
440
441 #[test]
442 fn test_no_instrument_sync() {
443 let source = r#"
444fn sync_function() {
445 let x = 42;
446 println!("{}", x);
447}
448"#;
449
450 let result = instrument_code(source).unwrap();
451 assert!(!result.contains("__guard"));
452 }
453}