use std::collections::HashSet;
use quote::quote;
use syn::visit_mut::VisitMut;
pub fn instrument_source(source: &str, targets: &HashSet<String>) -> Result<String, syn::Error> {
let mut file: syn::File = syn::parse_str(source)?;
let mut instrumenter = Instrumenter {
targets: targets.clone(),
current_impl: None,
current_trait: None,
};
instrumenter.visit_file_mut(&mut file);
Ok(prettyplease::unparse(&file))
}
const PARALLEL_ITER_METHODS: &[&str] = &[
"par_iter",
"par_iter_mut",
"into_par_iter",
"par_bridge",
"par_chunks",
"par_chunks_mut",
"par_windows",
];
const SPAWN_FUNCTIONS: &[&str] = &["spawn", "scope", "scope_fifo", "join"];
struct Instrumenter {
targets: HashSet<String>,
current_impl: Option<String>,
current_trait: Option<String>,
}
impl Instrumenter {
fn inject_guard(&self, block: &mut syn::Block, name: &str) {
if !self.targets.contains(name) {
return;
}
let guard_stmt: syn::Stmt = syn::parse_quote! {
let _piano_guard = piano_runtime::enter(#name);
};
block.stmts.insert(0, guard_stmt);
if block_contains_concurrency(block) {
let fork_stmt: syn::Stmt = syn::parse_quote! {
let _piano_ctx = piano_runtime::fork();
};
block.stmts.insert(1, fork_stmt);
for stmt in block.stmts.iter_mut().skip(2) {
match stmt {
syn::Stmt::Expr(expr, _) => {
inject_adopt_in_concurrency_closures(expr, false);
}
syn::Stmt::Local(local) => {
if let Some(init) = &mut local.init {
inject_adopt_in_concurrency_closures(&mut init.expr, false);
}
}
_ => {}
}
}
}
}
}
fn block_contains_concurrency(block: &syn::Block) -> bool {
block.stmts.iter().any(stmt_contains_concurrency)
}
fn stmt_contains_concurrency(stmt: &syn::Stmt) -> bool {
match stmt {
syn::Stmt::Expr(e, _) => contains_concurrency_call(e),
syn::Stmt::Local(local) => local
.init
.as_ref()
.is_some_and(|init| contains_concurrency_call(&init.expr)),
_ => false,
}
}
fn contains_concurrency_call(expr: &syn::Expr) -> bool {
match expr {
syn::Expr::MethodCall(mc) => {
let method = mc.method.to_string();
if PARALLEL_ITER_METHODS.contains(&method.as_str())
|| SPAWN_FUNCTIONS.contains(&method.as_str())
{
return true;
}
if contains_concurrency_call(&mc.receiver) {
return true;
}
mc.args.iter().any(contains_concurrency_call)
}
syn::Expr::Call(call) => {
if let syn::Expr::Path(path) = &*call.func {
let last_seg = path.path.segments.last().map(|s| s.ident.to_string());
if let Some(name) = last_seg
&& SPAWN_FUNCTIONS.contains(&name.as_str())
{
return true;
}
}
call.args.iter().any(contains_concurrency_call)
}
syn::Expr::Block(b) => b.block.stmts.iter().any(stmt_contains_concurrency),
syn::Expr::Closure(c) => contains_concurrency_call(&c.body),
_ => false,
}
}
fn inject_adopt_in_concurrency_closures(expr: &mut syn::Expr, in_parallel_chain: bool) {
match expr {
syn::Expr::MethodCall(mc) => {
let method = mc.method.to_string();
let is_par = PARALLEL_ITER_METHODS.contains(&method.as_str());
let is_spawn = SPAWN_FUNCTIONS.contains(&method.as_str());
inject_adopt_in_concurrency_closures(&mut mc.receiver, in_parallel_chain);
let chain_active =
in_parallel_chain || is_par || receiver_has_parallel_method(&mc.receiver);
if (chain_active && !is_par) || is_spawn {
for arg in &mut mc.args {
if let syn::Expr::Closure(closure) = arg {
inject_adopt_at_closure_start(closure);
} else {
inject_adopt_in_concurrency_closures(arg, false);
}
}
} else {
for arg in &mut mc.args {
inject_adopt_in_concurrency_closures(arg, chain_active);
}
}
}
syn::Expr::Call(call) => {
let is_spawn = if let syn::Expr::Path(path) = &*call.func {
path.path
.segments
.last()
.map(|s| SPAWN_FUNCTIONS.contains(&s.ident.to_string().as_str()))
.unwrap_or(false)
} else {
false
};
if is_spawn {
for arg in &mut call.args {
if let syn::Expr::Closure(closure) = arg {
inject_adopt_at_closure_start(closure);
} else {
inject_adopt_in_concurrency_closures(arg, false);
}
}
} else {
for arg in &mut call.args {
inject_adopt_in_concurrency_closures(arg, false);
}
}
}
syn::Expr::Block(b) => {
for stmt in &mut b.block.stmts {
if let syn::Stmt::Expr(e, _) = stmt {
inject_adopt_in_concurrency_closures(e, false);
}
}
}
_ => {}
}
}
fn receiver_has_parallel_method(expr: &syn::Expr) -> bool {
match expr {
syn::Expr::MethodCall(mc) => {
let method = mc.method.to_string();
PARALLEL_ITER_METHODS.contains(&method.as_str())
|| receiver_has_parallel_method(&mc.receiver)
}
_ => false,
}
}
fn inject_adopt_at_closure_start(closure: &mut syn::ExprClosure) {
let adopt_stmt: syn::Stmt = syn::parse_quote! {
let _piano_adopt = _piano_ctx.as_ref().map(|c| piano_runtime::adopt(c));
};
match &mut *closure.body {
syn::Expr::Block(block) => {
block.block.stmts.insert(0, adopt_stmt);
}
other => {
let existing = other.clone();
*other = syn::parse_quote! {
{
#adopt_stmt
#existing
}
};
}
}
}
impl VisitMut for Instrumenter {
fn visit_item_fn_mut(&mut self, node: &mut syn::ItemFn) {
let name = node.sig.ident.to_string();
self.inject_guard(&mut node.block, &name);
syn::visit_mut::visit_item_fn_mut(self, node);
}
fn visit_item_impl_mut(&mut self, node: &mut syn::ItemImpl) {
let type_name = type_ident(&node.self_ty);
let prev = self.current_impl.take();
self.current_impl = Some(type_name);
syn::visit_mut::visit_item_impl_mut(self, node);
self.current_impl = prev;
}
fn visit_impl_item_fn_mut(&mut self, node: &mut syn::ImplItemFn) {
let method = node.sig.ident.to_string();
let qualified = match &self.current_impl {
Some(ty) => format!("{ty}::{method}"),
None => method,
};
self.inject_guard(&mut node.block, &qualified);
syn::visit_mut::visit_impl_item_fn_mut(self, node);
}
fn visit_item_trait_mut(&mut self, node: &mut syn::ItemTrait) {
let trait_name = node.ident.to_string();
let prev = self.current_trait.take();
self.current_trait = Some(trait_name);
syn::visit_mut::visit_item_trait_mut(self, node);
self.current_trait = prev;
}
fn visit_trait_item_fn_mut(&mut self, node: &mut syn::TraitItemFn) {
if let Some(block) = &mut node.default {
let method = node.sig.ident.to_string();
let qualified = match &self.current_trait {
Some(trait_name) => format!("{trait_name}::{method}"),
None => method,
};
self.inject_guard(block, &qualified);
}
syn::visit_mut::visit_trait_item_fn_mut(self, node);
}
}
pub fn inject_registrations(source: &str, names: &[String]) -> Result<String, syn::Error> {
let mut file: syn::File = syn::parse_str(source)?;
let mut injector = RegistrationInjector {
names: names.to_vec(),
};
injector.visit_file_mut(&mut file);
Ok(prettyplease::unparse(&file))
}
struct RegistrationInjector {
names: Vec<String>,
}
impl VisitMut for RegistrationInjector {
fn visit_item_fn_mut(&mut self, node: &mut syn::ItemFn) {
if node.sig.ident == "main" {
for name in self.names.iter().rev() {
let stmt: syn::Stmt = syn::parse_quote! {
piano_runtime::register(#name);
};
node.block.stmts.insert(0, stmt);
}
}
syn::visit_mut::visit_item_fn_mut(self, node);
}
}
pub fn inject_global_allocator(
source: &str,
existing_allocator_type: Option<&str>,
) -> Result<String, syn::Error> {
let mut file: syn::File = syn::parse_str(source)?;
match existing_allocator_type {
None => {
let item: syn::Item = syn::parse_quote! {
#[global_allocator]
static _PIANO_ALLOC: piano_runtime::PianoAllocator<std::alloc::System>
= piano_runtime::PianoAllocator::new(std::alloc::System);
};
file.items.insert(0, item);
}
Some(_) => {
for item in &mut file.items {
if let syn::Item::Static(static_item) = item {
let has_global_alloc = static_item
.attrs
.iter()
.any(|a| a.path().is_ident("global_allocator"));
if has_global_alloc {
let orig_ty = &static_item.ty;
let orig_expr = &static_item.expr;
*static_item.ty = syn::parse_quote! {
piano_runtime::PianoAllocator<#orig_ty>
};
*static_item.expr = syn::parse_quote! {
piano_runtime::PianoAllocator::new(#orig_expr)
};
break;
}
}
}
}
}
Ok(prettyplease::unparse(&file))
}
pub fn inject_shutdown(source: &str) -> Result<String, syn::Error> {
let mut file: syn::File = syn::parse_str(source)?;
let mut injector = ShutdownInjector;
injector.visit_file_mut(&mut file);
Ok(prettyplease::unparse(&file))
}
struct ShutdownInjector;
impl VisitMut for ShutdownInjector {
fn visit_item_fn_mut(&mut self, node: &mut syn::ItemFn) {
if node.sig.ident == "main" {
let has_return_type = !matches!(&node.sig.output, syn::ReturnType::Default);
let existing_stmts = std::mem::take(&mut node.block.stmts);
if has_return_type {
let inner_block: syn::Expr = syn::parse_quote! {
{
#(#existing_stmts)*
}
};
let combined: syn::Stmt = syn::parse_quote! {
let __piano_result = #inner_block;
};
let shutdown_stmt: syn::Stmt = syn::parse_quote! {
piano_runtime::shutdown();
};
let return_expr: syn::Stmt =
syn::Stmt::Expr(syn::parse_quote! { __piano_result }, None);
node.block.stmts = vec![combined, shutdown_stmt, return_expr];
} else {
let inner_block: syn::Stmt = syn::parse_quote! {
{
#(#existing_stmts)*
}
};
let shutdown_stmt: syn::Stmt = syn::parse_quote! {
piano_runtime::shutdown();
};
node.block.stmts = vec![inner_block, shutdown_stmt];
}
}
syn::visit_mut::visit_item_fn_mut(self, node);
}
}
fn type_ident(ty: &syn::Type) -> String {
match ty {
syn::Type::Path(tp) => tp
.path
.segments
.last()
.map(|seg| seg.ident.to_string())
.unwrap_or_else(|| quote!(#ty).to_string()),
_ => quote!(#ty).to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn instruments_top_level_function() {
let source = r#"
fn walk() {
do_stuff();
}
fn other() {
do_other();
}
"#;
let targets: HashSet<String> = ["walk".to_string()].into();
let result = instrument_source(source, &targets).unwrap();
assert!(
result.contains("piano_runtime::enter(\"walk\")"),
"walk should be instrumented"
);
assert!(
!result.contains("piano_runtime::enter(\"other\")"),
"other should not be instrumented",
);
}
#[test]
fn instruments_impl_method() {
let source = r#"
struct Walker;
impl Walker {
fn walk(&self) {
self.step();
}
}
"#;
let targets: HashSet<String> = ["Walker::walk".to_string()].into();
let result = instrument_source(source, &targets).unwrap();
assert!(
result.contains("piano_runtime::enter(\"Walker::walk\")"),
"Walker::walk should be instrumented. Got:\n{result}",
);
}
#[test]
fn preserves_function_signature_and_body() {
let source = r#"
fn compute(x: i32, y: i32) -> i32 {
x + y
}
"#;
let targets: HashSet<String> = ["compute".to_string()].into();
let result = instrument_source(source, &targets).unwrap();
assert!(
result.contains("fn compute(x: i32, y: i32) -> i32"),
"signature preserved"
);
assert!(result.contains("x + y"), "body preserved");
assert!(
result.contains("piano_runtime::enter(\"compute\")"),
"guard injected"
);
}
#[test]
fn multiple_functions_instrumented() {
let source = r#"
fn a() {}
fn b() {}
fn c() {}
"#;
let targets: HashSet<String> = ["a".to_string(), "c".to_string()].into();
let result = instrument_source(source, &targets).unwrap();
assert!(
result.contains("piano_runtime::enter(\"a\")"),
"a should be instrumented"
);
assert!(
!result.contains("piano_runtime::enter(\"b\")"),
"b should NOT be instrumented",
);
assert!(
result.contains("piano_runtime::enter(\"c\")"),
"c should be instrumented"
);
}
#[test]
fn injects_register_calls_in_main() {
let source = r#"
fn main() {
do_stuff();
}
"#;
let names = vec!["walk".to_string(), "parse".to_string()];
let result = inject_registrations(source, &names).unwrap();
assert!(
result.contains("piano_runtime::register(\"walk\")"),
"Got:\n{result}"
);
assert!(
result.contains("piano_runtime::register(\"parse\")"),
"Got:\n{result}"
);
}
#[test]
fn injects_global_allocator() {
let source = r#"
fn main() {
println!("hello");
}
"#;
let result = inject_global_allocator(source, None).unwrap();
assert!(
result.contains("#[global_allocator]"),
"should inject global_allocator attribute. Got:\n{result}"
);
assert!(
result.contains("PianoAllocator"),
"should use PianoAllocator. Got:\n{result}"
);
assert!(
result.contains("std::alloc::System"),
"should wrap System allocator. Got:\n{result}"
);
}
#[test]
fn wraps_existing_global_allocator() {
let source = r#"
use std::alloc::System;
#[global_allocator]
static ALLOC: System = System;
fn main() {}
"#;
let result = inject_global_allocator(source, Some("System")).unwrap();
assert!(
result.contains("PianoAllocator"),
"should wrap existing allocator. Got:\n{result}"
);
}
#[test]
fn does_not_inject_init() {
let source = r#"
fn main() {
println!("hello");
}
"#;
let targets: HashSet<String> = HashSet::new();
let result = instrument_source(source, &targets).unwrap();
assert!(
!result.contains("piano_runtime::init()"),
"should NOT inject init (init is a no-op)"
);
}
#[test]
fn injects_shutdown_at_end_of_main() {
let source = r#"
fn main() {
do_stuff();
}
"#;
let result = inject_shutdown(source).unwrap();
assert!(
result.contains("piano_runtime::shutdown()"),
"should inject shutdown. Got:\n{result}"
);
let shutdown_pos = result.find("piano_runtime::shutdown()").unwrap();
let do_stuff_pos = result.find("do_stuff()").unwrap();
assert!(
shutdown_pos > do_stuff_pos,
"shutdown should come after existing code"
);
}
#[test]
fn injects_shutdown_preserves_main_return_type() {
let source = r#"
use std::process::ExitCode;
fn main() -> ExitCode {
do_stuff();
ExitCode::SUCCESS
}
"#;
let result = inject_shutdown(source).unwrap();
assert!(
result.contains("piano_runtime::shutdown()"),
"should inject shutdown. Got:\n{result}"
);
let parsed: syn::File = syn::parse_str(&result)
.unwrap_or_else(|e| panic!("rewritten code should parse: {e}\n\n{result}"));
let main_fn = parsed
.items
.iter()
.find_map(|item| {
if let syn::Item::Fn(f) = item {
if f.sig.ident == "main" {
return Some(f);
}
}
None
})
.expect("should have main fn");
let last = main_fn
.block
.stmts
.last()
.expect("main should have statements");
assert!(
matches!(last, syn::Stmt::Expr(_, None)),
"last statement should be a tail expression (no semicolon) for the return value. Got:\n{result}"
);
}
#[test]
fn injects_fork_and_adopt_for_par_iter() {
let source = r#"
fn process_all(items: &[Item]) -> Vec<Result> {
items.par_iter()
.map(|item| transform(item))
.collect()
}
"#;
let targets: HashSet<String> = ["process_all".to_string()].into();
let result = instrument_source(source, &targets).unwrap();
assert!(
result.contains("piano_runtime::enter(\"process_all\")"),
"should have guard. Got:\n{result}"
);
assert!(
result.contains("piano_runtime::fork()"),
"should inject fork. Got:\n{result}"
);
assert!(
result.contains("piano_runtime::adopt"),
"should inject adopt in closure. Got:\n{result}"
);
}
#[test]
fn injects_fork_and_adopt_for_thread_spawn() {
let source = r#"
fn do_work() {
std::thread::spawn(|| {
heavy_computation();
});
}
"#;
let targets: HashSet<String> = ["do_work".to_string()].into();
let result = instrument_source(source, &targets).unwrap();
assert!(
result.contains("piano_runtime::fork()"),
"should inject fork. Got:\n{result}"
);
assert!(
result.contains("piano_runtime::adopt"),
"should inject adopt in spawn closure. Got:\n{result}"
);
}
#[test]
fn injects_adopt_in_rayon_scope_spawn() {
let source = r#"
fn parallel_work() {
rayon::scope(|s| {
s.spawn(|_| { work_a(); });
s.spawn(|_| { work_b(); });
});
}
"#;
let targets: HashSet<String> = ["parallel_work".to_string()].into();
let result = instrument_source(source, &targets).unwrap();
assert!(
result.contains("piano_runtime::fork()"),
"should inject fork. Got:\n{result}"
);
assert!(
result.contains("piano_runtime::adopt"),
"should inject adopt. Got:\n{result}"
);
}
#[test]
fn no_fork_inject_in_non_target_function() {
let source = r#"
fn not_targeted() {
items.par_iter().map(|x| x).collect()
}
"#;
let targets: HashSet<String> = HashSet::new();
let result = instrument_source(source, &targets).unwrap();
assert!(
!result.contains("piano_runtime::fork()"),
"should NOT inject fork in non-target function. Got:\n{result}"
);
}
}