pub mod transformer;
#[cfg(test)]
mod loop_test;
use anyhow::Result;
use std::fs;
use std::path::Path;
use syn::{ExprAsync, File, ImplItemFn, Item, ItemFn, parse_file, parse_quote, visit::Visit};
use transformer::AsyncInstrumenter;
struct AsyncDetector {
has_async: bool,
}
impl AsyncDetector {
fn new() -> Self {
Self { has_async: false }
}
fn check(file: &File) -> bool {
let mut detector = Self::new();
detector.visit_file(file);
detector.has_async
}
}
impl<'ast> Visit<'ast> for AsyncDetector {
fn visit_expr_async(&mut self, _: &'ast ExprAsync) {
self.has_async = true;
}
fn visit_impl_item_fn(&mut self, func: &'ast ImplItemFn) {
if func.sig.asyncness.is_some() {
self.has_async = true;
return;
}
syn::visit::visit_impl_item_fn(self, func);
}
fn visit_item_fn(&mut self, func: &'ast ItemFn) {
if func.sig.asyncness.is_some() {
self.has_async = true;
return;
}
syn::visit::visit_item_fn(self, func);
}
}
#[derive(Default)]
struct InstrumentationDetector {
already_instrumented: bool,
}
impl InstrumentationDetector {
fn check(file: &File) -> bool {
let mut detector = Self::default();
detector.visit_file(file);
detector.already_instrumented
}
}
impl<'ast> Visit<'ast> for InstrumentationDetector {
fn visit_item_mod(&mut self, module: &'ast syn::ItemMod) {
if module.ident == "__async_profile_guard__" {
self.already_instrumented = true;
return;
}
syn::visit::visit_item_mod(self, module);
}
fn visit_path(&mut self, path: &'ast syn::Path) {
if self.already_instrumented {
return;
}
let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
if segments.windows(3).any(|window| {
window[0] == "__async_profile_guard__" && window[1] == "Guard" && window[2] == "new"
}) {
self.already_instrumented = true;
return;
}
syn::visit::visit_path(self, path);
}
}
fn generate_guard_module(threshold_ms: u64) -> Item {
parse_quote! {
#[doc(hidden)]
#[allow(dead_code)]
mod __async_profile_guard__ {
use std::time::{Duration, Instant};
const THRESHOLD_MS: u64 = #threshold_ms;
pub struct Guard {
name: &'static str,
file: &'static str,
from_line: u32,
current_start: Option<Instant>,
consecutive_hits: u32,
}
impl Guard {
pub fn new(name: &'static str, file: &'static str, line: u32) -> Self {
Guard {
name,
file,
from_line: line,
current_start: Some(Instant::now()),
consecutive_hits: 0,
}
}
pub fn checkpoint(&mut self, new_line: u32) {
if let Some(start) = self.current_start.take() {
let elapsed = start.elapsed();
if elapsed > Duration::from_millis(THRESHOLD_MS) {
self.consecutive_hits = self.consecutive_hits.saturating_add(1);
let span = format!("{file}:{from}-{to}", file = self.file, from = self.from_line, to = new_line);
let wraparound = new_line < self.from_line;
if wraparound {
tracing::warn!(
elapsed_ms = elapsed.as_millis(),
name = %self.name,
span = %span,
hits = self.consecutive_hits,
wraparound = wraparound,
"long poll (iteration tail wraparound)"
);
} else {
tracing::warn!(
elapsed_ms = elapsed.as_millis(),
name = %self.name,
span = %span,
hits = self.consecutive_hits,
wraparound = wraparound,
"long poll (iteration tail)"
);
}
} else {
self.consecutive_hits = 0;
}
}
self.from_line = new_line;
self.current_start = Some(Instant::now());
}
pub fn end_section(&mut self, to_line: u32) {
if let Some(start) = self.current_start.take() {
let elapsed = start.elapsed();
if elapsed > Duration::from_millis(THRESHOLD_MS) {
self.consecutive_hits = self.consecutive_hits.saturating_add(1);
let span = format!("{file}:{from}-{to}", file = self.file, from = self.from_line, to = to_line);
let wraparound = to_line < self.from_line;
if wraparound {
tracing::warn!(
elapsed_ms = elapsed.as_millis(),
name = %self.name,
span = %span,
hits = self.consecutive_hits,
wraparound = wraparound,
"long poll (loop wraparound)"
);
} else {
tracing::warn!(
elapsed_ms = elapsed.as_millis(),
name = %self.name,
span = %span,
hits = self.consecutive_hits,
wraparound = wraparound,
"long poll"
);
}
} else {
self.consecutive_hits = 0;
}
}
}
pub fn start_section(&mut self, new_line: u32) {
self.from_line = new_line;
self.current_start = Some(Instant::now());
}
}
impl Drop for Guard {
fn drop(&mut self) {
if let Some(start) = self.current_start {
let elapsed = start.elapsed();
if elapsed > Duration::from_millis(THRESHOLD_MS) {
self.consecutive_hits = self.consecutive_hits.saturating_add(1);
let span =
format!("{file}:{line}-{line}", file = self.file, line = self.from_line);
tracing::warn!(
elapsed_ms = elapsed.as_millis(),
name = %self.name,
span = %span,
hits = self.consecutive_hits,
wraparound = false,
"long poll"
);
}
}
}
}
}
}
}
pub fn inject_guard_module(source: &str, threshold_ms: u64) -> Result<String> {
let mut syntax_tree = parse_file(source)?;
if InstrumentationDetector::check(&syntax_tree) {
return Ok(source.to_owned());
}
let guard_module = generate_guard_module(threshold_ms);
syntax_tree.items.insert(0, guard_module);
let mut instrumenter = AsyncInstrumenter::new(threshold_ms);
instrumenter.instrument_file(&mut syntax_tree);
let formatted = prettyplease::unparse(&syntax_tree);
Ok(formatted)
}
pub fn instrument_async_only(source: &str) -> Result<Option<String>> {
let mut syntax_tree = parse_file(source)?;
if InstrumentationDetector::check(&syntax_tree) {
return Ok(None);
}
if !AsyncDetector::check(&syntax_tree) {
return Ok(None);
}
let mut instrumenter = AsyncInstrumenter::new(10);
instrumenter.instrument_file(&mut syntax_tree);
let formatted = prettyplease::unparse(&syntax_tree);
Ok(Some(formatted))
}
pub fn instrument_code(source: &str) -> Result<String> {
instrument_code_with_threshold(source, 10)
}
pub fn instrument_code_with_threshold(source: &str, threshold_ms: u64) -> Result<String> {
inject_guard_module(source, threshold_ms)
}
pub fn instrument_file(path: &Path) -> Result<String> {
let content = fs::read_to_string(path)?;
instrument_code(&content)
}
pub fn instrument_file_with_threshold(path: &Path, threshold_ms: u64) -> Result<String> {
let content = fs::read_to_string(path)?;
instrument_code_with_threshold(&content, threshold_ms)
}
pub fn instrument_file_in_place(path: &Path) -> Result<()> {
let instrumented = instrument_file(path)?;
let backup_path = path.with_extension("rs.bak");
fs::copy(path, &backup_path)?;
fs::write(path, instrumented)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inject_guard_module() {
let source = r#"
async fn fetch_data() {
let response = client.get().await;
let parsed = parse(response);
store(parsed).await;
}
"#;
let result = inject_guard_module(source, 10).unwrap();
assert!(result.contains("mod __async_profile_guard__"));
assert!(result.contains("__guard"));
assert!(result.contains("__guard.end_section("));
assert!(result.contains("__guard.start_section("));
assert!(!result.contains("line!()"));
}
#[test]
fn test_instrument_async_only() {
let source = r#"
async fn fetch_data() {
let response = client.get().await;
let parsed = parse(response);
store(parsed).await;
}
"#;
let result = instrument_async_only(source).unwrap();
assert!(result.is_some(), "Should instrument async functions");
let instrumented = result.unwrap();
assert!(!instrumented.contains("mod __async_profile_guard__"));
assert!(instrumented.contains("crate::__async_profile_guard__::Guard::new"));
assert!(instrumented.contains("__guard.end_section("));
assert!(instrumented.contains("__guard.start_section("));
assert!(!instrumented.contains("line!()"));
}
#[test]
fn test_instrument_async_only_idempotent() {
let source = r#"
async fn fetch_data() {
let response = client.get().await;
store(response).await;
}
"#;
let first = instrument_async_only(source).unwrap().unwrap();
assert!(instrument_async_only(&first).unwrap().is_none());
}
#[test]
fn test_inject_guard_module_idempotent() {
let source = r#"
async fn action() {
do_it().await;
}
"#;
let first = inject_guard_module(source, 10).unwrap();
let second = inject_guard_module(&first, 10).unwrap();
assert_eq!(first, second);
}
#[test]
fn test_skip_non_async_file() {
let source = r#"
fn regular_function() {
println!("No async here");
}
struct MyStruct {
field: String,
}
impl MyStruct {
fn new() -> Self {
Self { field: String::new() }
}
}
"#;
let result = instrument_async_only(source).unwrap();
assert!(result.is_none(), "Should skip files without async code");
}
#[test]
fn test_detect_async_in_impl() {
let source = r#"
struct Service;
impl Service {
async fn handle_request(&self) {
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
"#;
let result = instrument_async_only(source).unwrap();
assert!(
result.is_some(),
"Should detect async methods in impl blocks"
);
}
#[test]
fn test_detect_async_block() {
let source = r#"
fn spawn_task() {
tokio::spawn(async {
println!("In async block");
});
}
"#;
let result = instrument_async_only(source).unwrap();
assert!(result.is_some(), "Should detect async blocks");
}
#[test]
fn test_instrument_code() {
let source = r#"
async fn fetch_data() {
let response = client.get().await;
let parsed = parse(response);
store(parsed).await;
}
"#;
let result = instrument_code(source).unwrap();
assert!(result.contains("__guard"));
assert!(result.contains("__guard.end_section("));
assert!(result.contains("__guard.start_section("));
assert!(!result.contains("line!()"));
}
#[test]
fn test_no_instrument_sync() {
let source = r#"
fn sync_function() {
let x = 42;
println!("{}", x);
}
"#;
let result = instrument_code(source).unwrap();
assert!(!result.contains("__guard"));
}
}