debtmap 0.19.1

Code complexity and technical debt analyzer
Documentation
use crate::analyzers::go::parser::node_text;
use tree_sitter::Node;

#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct GoAdvancedSignals {
    pub patterns: Vec<String>,
    pub error_swallowing_count: u32,
    pub error_swallowing_patterns: Vec<String>,
}

pub fn detect_advanced_signals(
    body: Node,
    source: &str,
    function_name: &str,
    is_test: bool,
    package_variables: &[String],
) -> GoAdvancedSignals {
    let mut signals = collect_signals(body, source, function_name, is_test, package_variables);
    add_repetitive_error_handling(body, source, &mut signals);
    normalize_signals(signals)
}

fn collect_signals(
    node: Node,
    source: &str,
    function_name: &str,
    is_test: bool,
    package_variables: &[String],
) -> GoAdvancedSignals {
    if is_nested_callable(node) {
        return GoAdvancedSignals::default();
    }

    children(node)
        .into_iter()
        .map(|child| collect_signals(child, source, function_name, is_test, package_variables))
        .fold(
            node_signals(node, source, function_name, is_test, package_variables),
            merge_signals,
        )
}

fn node_signals(
    node: Node,
    source: &str,
    function_name: &str,
    is_test: bool,
    package_variables: &[String],
) -> GoAdvancedSignals {
    match node.kind() {
        "go_statement" => pattern("goroutine-without-synchronization"),
        "defer_statement" if has_loop_ancestor(node) => pattern("defer-in-loop"),
        "send_statement" => pattern("channel-operation"),
        "unary_expression" if node_text(&node, source).trim_start().starts_with("<-") => {
            pattern("channel-operation")
        }
        "assignment_statement" => {
            assignment_signals(node, source, function_name, package_variables)
        }
        "call_expression" => call_signals(node, source, function_name, is_test),
        _ => GoAdvancedSignals::default(),
    }
}

fn assignment_signals(
    node: Node,
    source: &str,
    function_name: &str,
    package_variables: &[String],
) -> GoAdvancedSignals {
    let text = node_text(&node, source);
    let mut signals = GoAdvancedSignals::default();

    if swallows_error(text) {
        signals.patterns.push("swallowed-error".to_string());
        signals.error_swallowing_count = 1;
        signals
            .error_swallowing_patterns
            .push("blank-identifier-error".to_string());
    }

    if mutates_pointer_receiver(text, function_name) {
        signals
            .patterns
            .push("pointer-receiver-mutation".to_string());
    }

    if mutates_indexed_value(text) {
        signals.patterns.push("collection-mutation".to_string());
    }

    if mutates_package_variable(text, package_variables) {
        signals.patterns.push("package-global-mutation".to_string());
    }

    signals
}

fn call_signals(node: Node, source: &str, function_name: &str, is_test: bool) -> GoAdvancedSignals {
    let Some(function) = node.child_by_field_name("function") else {
        return GoAdvancedSignals::default();
    };

    match node_text(&function, source) {
        "panic" if !is_test && function_name != "main" => pattern("panic-in-production"),
        "recover" => pattern("recover-without-handling"),
        _ => GoAdvancedSignals::default(),
    }
}

fn add_repetitive_error_handling(body: Node, source: &str, signals: &mut GoAdvancedSignals) {
    if count_error_return_branches(body, source) >= 3 {
        signals
            .patterns
            .push("repetitive-error-handling".to_string());
    }
}

fn count_error_return_branches(node: Node, source: &str) -> u32 {
    if is_nested_callable(node) {
        return 0;
    }

    let current = u32::from(is_error_return_branch(node, source));
    current
        + children(node)
            .into_iter()
            .map(|child| count_error_return_branches(child, source))
            .sum::<u32>()
}

fn is_error_return_branch(node: Node, source: &str) -> bool {
    if node.kind() != "if_statement" {
        return false;
    }

    let text = node_text(&node, source);
    text.contains("err != nil") && text.contains("return")
}

fn swallows_error(text: &str) -> bool {
    let left = text.split_once('=').map(|(left, _)| left).unwrap_or(text);
    left.split(',')
        .map(str::trim)
        .any(|part| part == "_" || part.ends_with(" _"))
}

fn mutates_pointer_receiver(text: &str, function_name: &str) -> bool {
    function_name.contains('.')
        && text
            .split_once('=')
            .is_some_and(|(left, _)| left.contains('.'))
}

fn mutates_indexed_value(text: &str) -> bool {
    text.split_once('=')
        .is_some_and(|(left, _)| left.contains('[') && left.contains(']'))
}

fn mutates_package_variable(text: &str, package_variables: &[String]) -> bool {
    let Some((left, _)) = text.split_once('=') else {
        return false;
    };

    left.split(',')
        .map(clean_assignment_target)
        .any(|target| package_variables.iter().any(|name| name == target))
}

fn clean_assignment_target(target: &str) -> &str {
    target
        .trim()
        .trim_end_matches(':')
        .trim_end_matches('+')
        .trim_end_matches('-')
}

fn has_loop_ancestor(node: Node) -> bool {
    node.parent().is_some_and(|parent| {
        is_loop_node(parent) || (!is_nested_callable(parent) && has_loop_ancestor(parent))
    })
}

fn is_loop_node(node: Node) -> bool {
    node.kind() == "for_statement"
}

fn is_nested_callable(node: Node) -> bool {
    matches!(
        node.kind(),
        "func_literal" | "function_declaration" | "method_declaration"
    )
}

fn pattern(name: &str) -> GoAdvancedSignals {
    GoAdvancedSignals {
        patterns: vec![name.to_string()],
        ..Default::default()
    }
}

fn merge_signals(mut left: GoAdvancedSignals, right: GoAdvancedSignals) -> GoAdvancedSignals {
    left.patterns.extend(right.patterns);
    left.error_swallowing_count += right.error_swallowing_count;
    left.error_swallowing_patterns
        .extend(right.error_swallowing_patterns);
    left
}

fn normalize_signals(mut signals: GoAdvancedSignals) -> GoAdvancedSignals {
    signals.patterns.sort();
    signals.patterns.dedup();
    signals.error_swallowing_patterns.sort();
    signals.error_swallowing_patterns.dedup();
    signals
}

fn children(node: Node) -> Vec<Node> {
    let mut cursor = node.walk();
    node.children(&mut cursor).collect()
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::analyzers::go::parser::parse_source;
    use std::path::PathBuf;

    fn signals(source: &str, name: &str) -> GoAdvancedSignals {
        let ast = parse_source(source, &PathBuf::from("service.go")).unwrap();
        let root = ast.tree.root_node();
        let mut cursor = root.walk();
        let body = root
            .children(&mut cursor)
            .find(|node| matches!(node.kind(), "function_declaration" | "method_declaration"))
            .and_then(|node| node.child_by_field_name("body"))
            .unwrap();
        detect_advanced_signals(body, source, name, false, &[])
    }

    #[test]
    fn detects_repetitive_error_handling() {
        let source = r#"package service

func Load() error {
    if err != nil { return err }
    if err != nil { return err }
    if err != nil { return err }
    return nil
}
"#;
        let signals = signals(source, "Load");

        assert!(
            signals
                .patterns
                .contains(&"repetitive-error-handling".to_string())
        );
    }

    #[test]
    fn detects_swallowed_error() {
        let source = r#"package service

func Load() {
    value, _ := parse()
    _ = value
}
"#;
        let signals = signals(source, "Load");

        assert!(signals.patterns.contains(&"swallowed-error".to_string()));
        assert_eq!(signals.error_swallowing_count, 1);
    }

    #[test]
    fn detects_panic_recover_and_concurrency_risks() {
        let source = r#"package service

func Run(ch chan int) {
    go worker()
    ch <- 1
    defer recover()
    panic("failed")
}
"#;
        let signals = signals(source, "Run");

        assert!(
            signals
                .patterns
                .contains(&"goroutine-without-synchronization".to_string())
        );
        assert!(signals.patterns.contains(&"channel-operation".to_string()));
        assert!(
            signals
                .patterns
                .contains(&"recover-without-handling".to_string())
        );
        assert!(
            signals
                .patterns
                .contains(&"panic-in-production".to_string())
        );
    }

    #[test]
    fn detects_defer_in_loop_and_mutation() {
        let source = r#"package service

func (s *State) Run(items map[string]int) {
    for key := range items {
        defer close()
        s.count = 1
        items[key] = 2
    }
}
"#;
        let signals = signals(source, "State.Run");

        assert!(signals.patterns.contains(&"defer-in-loop".to_string()));
        assert!(
            signals
                .patterns
                .contains(&"pointer-receiver-mutation".to_string())
        );
        assert!(
            signals
                .patterns
                .contains(&"collection-mutation".to_string())
        );
    }

    #[test]
    fn detects_package_global_mutation() {
        let source = r#"package service

var shared int

func Update() {
    shared = 1
}
"#;
        let ast = parse_source(source, &PathBuf::from("service.go")).unwrap();
        let root = ast.tree.root_node();
        let mut cursor = root.walk();
        let body = root
            .children(&mut cursor)
            .find(|node| node.kind() == "function_declaration")
            .and_then(|node| node.child_by_field_name("body"))
            .unwrap();
        let signals = detect_advanced_signals(body, source, "Update", false, &["shared".into()]);

        assert!(
            signals
                .patterns
                .contains(&"package-global-mutation".to_string())
        );
    }
}