use crate::analyzers::purity_detector::PurityAnalysis;
use crate::data_flow::{DataFlowGraph, MutationInfo};
use crate::priority::call_graph::FunctionId;
mod ast_helpers {
use syn::{ImplItemFn, ItemFn};
pub enum FoundFunction<'a> {
TopLevel(&'a ItemFn),
ImplMethod(&'a ImplItemFn),
}
impl<'a> FoundFunction<'a> {
pub fn block(&self) -> &syn::Block {
match self {
FoundFunction::TopLevel(f) => &f.block,
FoundFunction::ImplMethod(m) => &m.block,
}
}
pub fn inputs(&self) -> impl Iterator<Item = &syn::FnArg> {
match self {
FoundFunction::TopLevel(f) => f.sig.inputs.iter(),
FoundFunction::ImplMethod(m) => m.sig.inputs.iter(),
}
}
}
pub fn find_function_in_ast<'a>(
ast: &'a syn::File,
metric_name: &str,
line: usize,
) -> Option<FoundFunction<'a>> {
for item in &ast.items {
match item {
syn::Item::Fn(item_fn) => {
let span_line = item_fn.sig.ident.span().start().line;
let fn_name = item_fn.sig.ident.to_string();
if span_line == line && (metric_name == fn_name) {
return Some(FoundFunction::TopLevel(item_fn));
}
}
syn::Item::Impl(item_impl) => {
for impl_item in &item_impl.items {
if let syn::ImplItem::Fn(method) = impl_item {
let span_line = method.sig.ident.span().start().line;
let method_name = method.sig.ident.to_string();
let matches_name = metric_name == method_name
|| metric_name.ends_with(&format!("::{}", method_name));
if span_line == line && matches_name {
return Some(FoundFunction::ImplMethod(method));
}
}
}
}
_ => {}
}
}
None
}
}
pub use ast_helpers::{find_function_in_ast, FoundFunction};
pub fn populate_from_purity_analysis(
data_flow: &mut DataFlowGraph,
func_id: &FunctionId,
purity: &PurityAnalysis,
) {
if let Some(cfg_analysis) = &purity.data_flow_info {
data_flow.set_cfg_analysis(func_id.clone(), cfg_analysis.clone());
use crate::data_flow::CfgAnalysisWithContext;
let context = CfgAnalysisWithContext::new(purity.var_names.clone(), cfg_analysis.clone());
data_flow.set_cfg_analysis_with_context(func_id.clone(), context);
}
let detected_mutations: Vec<String> = purity
.live_mutations
.iter()
.map(|m| m.target.clone())
.collect();
let mutation_info = MutationInfo {
has_mutations: !detected_mutations.is_empty() || purity.total_mutations > 0,
detected_mutations,
};
data_flow.set_mutation_info(func_id.clone(), mutation_info);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analysis::data_flow::{DataFlowAnalysis, ReachingDefinitions};
use crate::core::PurityLevel;
use std::path::PathBuf;
fn create_test_function_id(name: &str) -> FunctionId {
FunctionId::new(PathBuf::from("test.rs"), name.to_string(), 1)
}
fn create_test_purity_analysis() -> PurityAnalysis {
use crate::analyzers::purity_detector::LocalMutation;
let data_flow = DataFlowAnalysis {
reaching_defs: ReachingDefinitions::default(),
};
PurityAnalysis {
is_pure: false,
purity_level: PurityLevel::LocallyPure,
reasons: vec![],
confidence: 0.9,
data_flow_info: Some(data_flow),
live_mutations: vec![LocalMutation {
target: "x".to_string(),
}],
total_mutations: 2,
var_names: vec!["x".to_string(), "y".to_string()],
}
}
#[test]
fn test_populate_from_purity_analysis() {
let mut data_flow = DataFlowGraph::new();
let func_id = create_test_function_id("test_func");
let purity = create_test_purity_analysis();
populate_from_purity_analysis(&mut data_flow, &func_id, &purity);
assert!(data_flow.get_cfg_analysis(&func_id).is_some());
let mutation_info = data_flow.get_mutation_info(&func_id).unwrap();
assert!(mutation_info.has_mutations);
assert_eq!(mutation_info.detected_mutations.len(), 1);
}
#[test]
fn test_find_function_in_ast_top_level() {
let code = r#"
fn top_level_func(x: i32) -> i32 {
x + 1
}
"#;
let ast = syn::parse_file(code).unwrap();
let found = find_function_in_ast(&ast, "top_level_func", 2);
assert!(found.is_some(), "Should find top-level function");
let not_found = find_function_in_ast(&ast, "top_level_func", 5);
assert!(not_found.is_none(), "Should not find at wrong line");
}
#[test]
fn test_find_function_in_ast_impl_method_simple_name() {
let code = r#"
struct Foo;
impl Foo {
fn method(&self) -> i32 {
42
}
}
"#;
let ast = syn::parse_file(code).unwrap();
let found = find_function_in_ast(&ast, "method", 5);
assert!(found.is_some(), "Should find impl method by simple name");
match found.unwrap() {
FoundFunction::ImplMethod(_) => { }
FoundFunction::TopLevel(_) => panic!("Should be ImplMethod, not TopLevel"),
}
}
#[test]
fn test_find_function_in_ast_impl_method_qualified_name() {
let code = r#"
struct Bar;
impl Bar {
fn do_something(&self, x: i32) -> i32 {
x * 2
}
}
"#;
let ast = syn::parse_file(code).unwrap();
let found = find_function_in_ast(&ast, "Bar::do_something", 5);
assert!(found.is_some(), "Should find impl method by qualified name");
match found.unwrap() {
FoundFunction::ImplMethod(_) => { }
FoundFunction::TopLevel(_) => panic!("Should be ImplMethod, not TopLevel"),
}
}
#[test]
fn test_found_function_block() {
let code = r#"
fn top_level() {
let x = 1;
}
struct Foo;
impl Foo {
fn method(&self) {
let y = 2;
}
}
"#;
let ast = syn::parse_file(code).unwrap();
let top_level = find_function_in_ast(&ast, "top_level", 2).unwrap();
let block = top_level.block();
assert!(
!block.stmts.is_empty(),
"Top-level block should have statements"
);
let method = find_function_in_ast(&ast, "method", 9).unwrap();
let block = method.block();
assert!(
!block.stmts.is_empty(),
"Impl method block should have statements"
);
}
#[test]
fn test_found_function_inputs() {
let code = r#"
fn top_level(a: i32, b: String) {}
struct Foo;
impl Foo {
fn method(&self, x: i32, y: bool) {}
}
"#;
let ast = syn::parse_file(code).unwrap();
let top_level = find_function_in_ast(&ast, "top_level", 2).unwrap();
let inputs: Vec<_> = top_level.inputs().collect();
assert_eq!(inputs.len(), 2, "Top-level should have 2 inputs");
let method = find_function_in_ast(&ast, "method", 7).unwrap();
let inputs: Vec<_> = method.inputs().collect();
assert_eq!(
inputs.len(),
3,
"Impl method should have 3 inputs (self + 2 params)"
);
}
}