use crate::context::{detect_function_role, FileType, FrameworkPattern, FunctionContext};
use syn::{
visit::Visit, Attribute, Block, Expr, ExprCall, ExprMethodCall, ImplItem, ItemFn, ItemImpl,
Path,
};
pub struct ContextDetector {
module_path: Vec<String>,
pub contexts: Vec<(String, FunctionContext)>,
pub line_contexts: Vec<(usize, usize, FunctionContext)>,
file_type: FileType,
}
impl ContextDetector {
pub fn new(file_type: FileType) -> Self {
Self {
module_path: Vec::new(),
contexts: Vec::new(),
line_contexts: Vec::new(),
file_type,
}
}
pub fn analyze_function(&mut self, func: &ItemFn) -> FunctionContext {
let func_name = func.sig.ident.to_string();
let is_test = has_test_attribute(&func.attrs);
let role = detect_function_role(&func_name, is_test);
let is_async = func.sig.asyncness.is_some();
let framework_pattern = detect_framework_pattern(&func_name, &func.attrs, &func.block);
let context = FunctionContext::new()
.with_role(role)
.with_file_type(self.file_type)
.with_async(is_async)
.with_function_name(func_name.clone())
.with_module_path(self.module_path.clone());
let context = if let Some(pattern) = framework_pattern {
context.with_framework_pattern(pattern)
} else {
context
};
self.contexts.push((func_name.clone(), context.clone()));
let start_span = func.sig.ident.span();
let end_span = func.block.brace_token.span.join();
let start_line = start_span.start().line;
let end_line = end_span.end().line;
self.line_contexts
.push((start_line, end_line, context.clone()));
context
}
pub fn get_context(&self, func_name: &str) -> Option<&FunctionContext> {
self.contexts
.iter()
.find(|(name, _)| name == func_name)
.map(|(_, context)| context)
}
pub fn get_context_for_line(&self, line: usize) -> Option<&FunctionContext> {
self.line_contexts
.iter()
.find(|(start, end, _)| line >= *start && line <= *end)
.map(|(_, _, context)| context)
}
pub fn detect_config_loader_from_body(block: &Block) -> bool {
let mut detector = ConfigLoaderDetector::default();
detector.visit_block(block);
detector.is_config_loader()
}
}
#[derive(Default)]
struct ConfigLoaderDetector {
has_file_read: bool,
has_env_read: bool,
has_toml_parse: bool,
has_json_parse: bool,
has_config_type: bool,
}
impl ConfigLoaderDetector {
fn is_config_loader(&self) -> bool {
(self.has_file_read || self.has_env_read)
&& (self.has_toml_parse || self.has_json_parse || self.has_config_type)
}
}
impl<'ast> Visit<'ast> for ConfigLoaderDetector {
fn visit_expr_call(&mut self, node: &'ast ExprCall) {
if let Expr::Path(path) = &*node.func {
let path_str = path_to_string(&path.path);
if path_str.contains("read_to_string")
|| path_str.contains("File::open")
|| path_str.contains("fs::read")
{
self.has_file_read = true;
}
if path_str.contains("env::var") || path_str.contains("std::env") {
self.has_env_read = true;
}
if path_str.contains("toml::from_str") || path_str.contains("toml::parse") {
self.has_toml_parse = true;
}
if path_str.contains("serde_json::from_str") || path_str.contains("json::parse") {
self.has_json_parse = true;
}
}
syn::visit::visit_expr_call(self, node);
}
fn visit_expr_method_call(&mut self, node: &'ast ExprMethodCall) {
let method_name = node.method.to_string();
if method_name == "parse" || method_name == "from_str" || method_name == "deserialize" {
if let Expr::Path(path) = &*node.receiver {
let path_str = path_to_string(&path.path);
if path_str.contains("config") || path_str.contains("Config") {
self.has_config_type = true;
}
}
}
syn::visit::visit_expr_method_call(self, node);
}
}
impl<'ast> Visit<'ast> for ContextDetector {
fn visit_item_fn(&mut self, node: &'ast ItemFn) {
self.analyze_function(node);
syn::visit::visit_item_fn(self, node);
}
fn visit_item_impl(&mut self, node: &'ast ItemImpl) {
for item in &node.items {
if let ImplItem::Fn(method) = item {
let func_name = method.sig.ident.to_string();
let is_test = has_test_attribute(&method.attrs);
let role = detect_function_role(&func_name, is_test);
let is_async = method.sig.asyncness.is_some();
let context = FunctionContext::new()
.with_role(role)
.with_file_type(self.file_type)
.with_async(is_async)
.with_function_name(func_name.clone())
.with_module_path(self.module_path.clone());
self.contexts.push((func_name.clone(), context.clone()));
let start_span = method.sig.ident.span();
let end_span = method.block.brace_token.span.join();
let start_line = start_span.start().line;
let end_line = end_span.end().line;
self.line_contexts.push((start_line, end_line, context));
}
}
syn::visit::visit_item_impl(self, node);
}
}
fn has_test_attribute(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| {
attr.path().segments.iter().any(|segment| {
let ident = segment.ident.to_string();
ident == "test" || ident == "tokio_test" || ident == "async_std_test"
})
})
}
fn detect_framework_pattern(
name: &str,
attrs: &[Attribute],
block: &Block,
) -> Option<FrameworkPattern> {
if name == "main" {
if block_contains_async_runtime(block) {
return Some(FrameworkPattern::AsyncRuntime);
}
return Some(FrameworkPattern::RustMain);
}
for attr in attrs {
let attr_str = attr
.path()
.segments
.iter()
.map(|s| s.ident.to_string())
.collect::<Vec<_>>()
.join("::");
if attr_str.contains("get")
|| attr_str.contains("post")
|| attr_str.contains("put")
|| attr_str.contains("delete")
|| attr_str.contains("patch")
|| attr_str.contains("route")
|| attr_str.contains("handler")
|| attr_str.contains("endpoint")
|| attr_str.contains("api")
|| attr_str.contains("web")
|| attr_str.contains("actix_web")
|| attr_str.contains("rocket")
|| attr_str.contains("warp")
|| attr_str.contains("axum")
|| attr_str.contains("tide")
{
return Some(FrameworkPattern::WebHandler);
}
}
if name.contains("command") || name.contains("cmd") || name.contains("cli") {
return Some(FrameworkPattern::CliHandler);
}
if ContextDetector::detect_config_loader_from_body(block) {
return Some(FrameworkPattern::ConfigInit);
}
if has_test_attribute(attrs) {
return Some(FrameworkPattern::TestFramework);
}
None
}
fn block_contains_async_runtime(block: &Block) -> bool {
let mut detector = AsyncRuntimeDetector::default();
detector.visit_block(block);
detector.has_runtime
}
#[derive(Default)]
struct AsyncRuntimeDetector {
has_runtime: bool,
}
impl<'ast> Visit<'ast> for AsyncRuntimeDetector {
fn visit_expr_method_call(&mut self, node: &'ast ExprMethodCall) {
let method = node.method.to_string();
if method == "block_on" || method == "spawn" || method == "spawn_blocking" {
self.has_runtime = true;
}
syn::visit::visit_expr_method_call(self, node);
}
fn visit_expr_call(&mut self, node: &'ast ExprCall) {
if let Expr::Path(path) = &*node.func {
let path_str = path_to_string(&path.path);
if path_str.contains("tokio::runtime")
|| path_str.contains("async_std::task")
|| path_str.contains("Runtime::new")
{
self.has_runtime = true;
}
}
syn::visit::visit_expr_call(self, node);
}
}
fn path_to_string(path: &Path) -> String {
path.segments
.iter()
.map(|segment| segment.ident.to_string())
.collect::<Vec<_>>()
.join("::")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::FunctionRole;
use syn;
#[test]
fn test_context_detection() {
let code = r#"
#[test]
fn test_something() {
assert_eq!(1, 1);
}
fn main() {
println!("Hello");
}
async fn handle_request() {
// handler code
}
fn load_config() -> Config {
let content = fs::read_to_string("config.toml")?;
toml::from_str(&content)?
}
"#;
let file = syn::parse_file(code).unwrap();
let mut detector = ContextDetector::new(FileType::Production);
for item in file.items {
if let syn::Item::Fn(func) = item {
detector.visit_item_fn(&func);
}
}
assert_eq!(detector.contexts.len(), 4);
let test_ctx = detector.get_context("test_something").unwrap();
assert_eq!(test_ctx.role, FunctionRole::TestFunction);
assert!(test_ctx.is_test());
let main_ctx = detector.get_context("main").unwrap();
assert_eq!(main_ctx.role, FunctionRole::Main);
assert!(main_ctx.is_entry_point());
let handler_ctx = detector.get_context("handle_request").unwrap();
assert!(handler_ctx.is_async);
assert_eq!(handler_ctx.role, FunctionRole::Handler);
let config_ctx = detector.get_context("load_config").unwrap();
assert_eq!(config_ctx.role, FunctionRole::ConfigLoader);
assert!(config_ctx.allows_blocking_io());
}
}