tsift-graph 0.1.74

Language-aware graph extraction — call sites, routes, community detection, and path finding for tsift
Documentation
use anyhow::Result;
use std::collections::HashMap;
use tree_sitter::{Parser, Query, QueryCursor, StreamingIterator};

use crate::lang::Lang;

#[derive(Debug, Clone, Default)]
pub struct ComplexityMetrics {
    pub branches: i64,
    pub loops: i64,
    pub returns: i64,
    pub max_nesting: i64,
    pub unsafe_blocks: i64,
}

impl ComplexityMetrics {
    pub fn total_complexity(&self) -> i64 {
        self.branches + self.loops + self.returns
    }

    pub fn from_raw_fields(
        branches: i64,
        loops: i64,
        returns: i64,
        max_nesting: i64,
        unsafe_blocks: i64,
    ) -> Self {
        Self {
            branches,
            loops,
            returns,
            max_nesting,
            unsafe_blocks,
        }
    }
}

pub trait LanguageExtractor: Send + Sync {
    fn lang(&self) -> Lang;
    fn extract_complexity(&self, source: &[u8]) -> Result<ComplexityMetrics>;
}

struct BuiltinExtractor {
    lang: Lang,
}

impl BuiltinExtractor {
    fn complexity_query(&self) -> Option<&'static str> {
        match self.lang {
            #[cfg(feature = "lang-rust")]
            Lang::Rust => Some(
                r#"
                (if_expression) @branch
                (match_expression) @branch
                (for_expression) @loop
                (while_expression) @loop
                (loop_expression) @loop
                (return_expression) @return
                (unsafe_block) @unsafe
            "#,
            ),
            #[cfg(feature = "lang-python")]
            Lang::Python => Some(
                r#"
                (if_statement) @branch
                (elif_clause) @branch
                (for_statement) @loop
                (while_statement) @loop
                (return_statement) @return
            "#,
            ),
            #[cfg(feature = "lang-typescript")]
            Lang::TypeScript | Lang::Tsx => Some(
                r#"
                (if_statement) @branch
                (switch_statement) @branch
                (ternary_expression) @branch
                (for_statement) @loop
                (for_in_statement) @loop
                (while_statement) @loop
                (do_statement) @loop
                (return_statement) @return
            "#,
            ),
            #[cfg(feature = "lang-javascript")]
            Lang::JavaScript | Lang::Jsx => Some(
                r#"
                (if_statement) @branch
                (switch_statement) @branch
                (ternary_expression) @branch
                (for_statement) @loop
                (for_in_statement) @loop
                (while_statement) @loop
                (do_statement) @loop
                (return_statement) @return
            "#,
            ),
            #[cfg(feature = "lang-kotlin")]
            Lang::Kotlin => Some(
                r#"
                (if_expression) @branch
                (when_expression) @branch
                (for_statement) @loop
                (while_statement) @loop
                (do_while_statement) @loop
                (return_expression) @return
            "#,
            ),
            _ => None,
        }
    }

    fn compute_max_nesting(&self, source: &[u8]) -> i64 {
        let ts_lang = self.lang.tree_sitter_language();
        let mut parser = Parser::new();
        if parser.set_language(&ts_lang).is_err() {
            return 0;
        }
        let tree = match parser.parse(source, None) {
            Some(t) => t,
            None => return 0,
        };
        let mut max_depth: i64 = 0;
        fn walk(node: tree_sitter::Node, depth: i64, max_depth: &mut i64) {
            let kind = node.kind();
            let is_scope = matches!(
                kind,
                "function_item"
                    | "function_definition"
                    | "function_declaration"
                    | "class_definition"
                    | "class_declaration"
                    | "impl_item"
                    | "if_expression"
                    | "if_statement"
                    | "for_expression"
                    | "for_statement"
                    | "while_expression"
                    | "while_statement"
                    | "loop_expression"
                    | "match_expression"
                    | "switch_statement"
                    | "when_expression"
                    | "block"
                    | "expression_list"
            );
            let child_depth = if is_scope { depth + 1 } else { depth };
            if child_depth > *max_depth {
                *max_depth = child_depth;
            }
            let mut cursor = node.walk();
            for child in node.children(&mut cursor) {
                walk(child, child_depth, max_depth);
            }
        }
        walk(tree.root_node(), 0, &mut max_depth);
        max_depth.max(0)
    }
}

impl LanguageExtractor for BuiltinExtractor {
    fn lang(&self) -> Lang {
        self.lang
    }

    fn extract_complexity(&self, source: &[u8]) -> Result<ComplexityMetrics> {
        let query_str = match self.complexity_query() {
            Some(q) => q,
            None => return Ok(ComplexityMetrics::default()),
        };
        let ts_lang = self.lang.tree_sitter_language();
        let mut parser = Parser::new();
        parser.set_language(&ts_lang)?;
        let tree = parser
            .parse(source, None)
            .ok_or_else(|| anyhow::anyhow!("parse failed"))?;
        let query = Query::new(&ts_lang, query_str)?;
        let mut cursor = QueryCursor::new();
        let mut metrics = ComplexityMetrics::default();

        let capture_names: Vec<String> = query
            .capture_names()
            .iter()
            .map(|s| s.to_string())
            .collect();

        let mut matches = cursor.matches(&query, tree.root_node(), source);
        while let Some(m) = matches.next() {
            for capture in m.captures {
                let name = &capture_names[capture.index as usize];
                match name.as_str() {
                    "branch" => metrics.branches += 1,
                    "loop" => metrics.loops += 1,
                    "return" => metrics.returns += 1,
                    "unsafe" => metrics.unsafe_blocks += 1,
                    _ => {}
                }
            }
        }

        metrics.max_nesting = self.compute_max_nesting(source);
        Ok(metrics)
    }
}

pub struct LanguageRegistry {
    extractors: HashMap<String, Box<dyn LanguageExtractor>>,
}

impl LanguageRegistry {
    pub fn new() -> Self {
        let mut registry = Self {
            extractors: HashMap::new(),
        };
        registry.register_builtins();
        registry
    }

    fn register_builtins(&mut self) {
        for lang in Lang::all() {
            let ext = lang.name().to_string();
            let extractor = BuiltinExtractor { lang };
            self.extractors.insert(ext, Box::new(extractor));
        }
    }

    pub fn register(&mut self, name: String, extractor: Box<dyn LanguageExtractor>) {
        self.extractors.insert(name, extractor);
    }

    pub fn get(&self, lang_name: &str) -> Option<&dyn LanguageExtractor> {
        self.extractors.get(lang_name).map(|e| e.as_ref())
    }

    pub fn extractor_for_extension(&self, ext: &str) -> Option<&dyn LanguageExtractor> {
        let lang = Lang::from_extension(ext)?;
        self.get(lang.name())
    }

    pub fn complexity_for_source(&self, lang: Lang, source: &[u8]) -> Result<ComplexityMetrics> {
        let extractor = self.get(lang.name()).ok_or_else(|| {
            anyhow::anyhow!("no extractor registered for language: {}", lang.name())
        })?;
        extractor.extract_complexity(source)
    }

    pub fn registered_languages(&self) -> Vec<&str> {
        let mut names: Vec<&str> = self.extractors.keys().map(|s| s.as_str()).collect();
        names.sort();
        names
    }
}

impl Default for LanguageRegistry {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn registry_has_all_builtin_languages() {
        let registry = LanguageRegistry::new();
        let languages = registry.registered_languages();
        for lang in Lang::all() {
            assert!(
                languages.contains(&lang.name()),
                "missing builtin language: {}",
                lang.name()
            );
        }
    }

    #[cfg(feature = "lang-rust")]
    #[test]
    fn rust_complexity_counting() {
        let registry = LanguageRegistry::new();
        let source = br#"fn example(x: i32) -> i32 {
    if x > 0 {
        return x;
    }
    for i in 0..x {
        if i % 2 == 0 {
            continue;
        }
    }
    0
}
"#;
        let metrics = registry.complexity_for_source(Lang::Rust, source).unwrap();
        assert!(
            metrics.branches >= 2,
            "expected >=2 branches, got {}",
            metrics.branches
        );
        assert!(
            metrics.loops >= 1,
            "expected >=1 loop, got {}",
            metrics.loops
        );
        assert!(
            metrics.returns >= 1,
            "expected >=1 return, got {}",
            metrics.returns
        );
    }

    #[cfg(feature = "lang-python")]
    #[test]
    fn python_complexity_counting() {
        let registry = LanguageRegistry::new();
        let source = br#"def example(x):
    if x > 0:
        return x
    for i in range(x):
        if i % 2 == 0:
            continue
    return 0
"#;
        let metrics = registry
            .complexity_for_source(Lang::Python, source)
            .unwrap();
        assert!(
            metrics.branches >= 2,
            "expected >=2 branches, got {}",
            metrics.branches
        );
        assert!(
            metrics.loops >= 1,
            "expected >=1 loop, got {}",
            metrics.loops
        );
        assert!(
            metrics.returns >= 2,
            "expected >=2 returns, got {}",
            metrics.returns
        );
    }

    #[cfg(feature = "lang-typescript")]
    #[test]
    fn typescript_complexity_counting() {
        let registry = LanguageRegistry::new();
        let source = br#"function example(x: number): number {
    if (x > 0) {
        return x;
    }
    for (let i = 0; i < x; i++) {
        if (i % 2 === 0) continue;
    }
    return 0;
}
"#;
        let metrics = registry
            .complexity_for_source(Lang::TypeScript, source)
            .unwrap();
        assert!(
            metrics.branches >= 2,
            "expected >=2 branches, got {}",
            metrics.branches
        );
        assert!(
            metrics.loops >= 1,
            "expected >=1 loop, got {}",
            metrics.loops
        );
        assert!(
            metrics.returns >= 2,
            "expected >=2 returns, got {}",
            metrics.returns
        );
    }

    #[test]
    fn total_complexity_sums_metrics() {
        let metrics = ComplexityMetrics::from_raw_fields(3, 2, 1, 4, 0);
        assert_eq!(metrics.total_complexity(), 6);
    }

    #[test]
    fn extractor_for_extension_works() {
        let registry = LanguageRegistry::new();
        assert!(registry.extractor_for_extension("rs").is_some());
        assert!(registry.extractor_for_extension("py").is_some());
        assert!(registry.extractor_for_extension("xyz").is_none());
    }
}