Skip to main content

forge_core/treesitter/
mod.rs

1//! Tree-sitter based CFG extraction for C, Java, and Rust
2//!
3//! This module provides real control flow graph extraction using tree-sitter parsers.
4//! Supports C, Java, and Rust languages with full CFG construction.
5
6use crate::cfg::TestCfg;
7use crate::types::BlockId;
8use crate::error::{ForgeError, Result};
9
10/// Extracted function information
11#[derive(Debug, Clone)]
12pub struct FunctionInfo {
13    pub name: String,
14    pub start_byte: usize,
15    pub end_byte: usize,
16    pub cfg: TestCfg,
17}
18
19/// Language supported for CFG extraction
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum SupportedLanguage {
22    C,
23    Java,
24    Rust,
25}
26
27/// CFG extractor using tree-sitter
28pub struct CfgExtractor;
29
30impl CfgExtractor {
31    /// Extract CFG from C source code
32    pub fn extract_c(source: &str) -> Result<Vec<FunctionInfo>> {
33        use tree_sitter::Parser;
34        use tree_sitter_c;
35
36        let mut parser = Parser::new();
37        parser
38            .set_language(&tree_sitter_c::language())
39            .map_err(|e| ForgeError::DatabaseError(format!("Failed to set C language: {:?}", e)))?;
40
41        let tree = parser
42            .parse(source, None)
43            .ok_or_else(|| ForgeError::DatabaseError("Failed to parse C code".to_string()))?;
44
45        let root = tree.root_node();
46        let mut functions = Vec::new();
47
48        Self::extract_c_functions(source, &root, &mut functions)?;
49
50        Ok(functions)
51    }
52
53    /// Extract CFG from Java source code
54    pub fn extract_java(source: &str) -> Result<Vec<FunctionInfo>> {
55        use tree_sitter::Parser;
56        use tree_sitter_java;
57
58        let mut parser = Parser::new();
59        parser
60            .set_language(&tree_sitter_java::language())
61            .map_err(|e| ForgeError::DatabaseError(format!("Failed to set Java language: {:?}", e)))?;
62
63        let tree = parser
64            .parse(source, None)
65            .ok_or_else(|| ForgeError::DatabaseError("Failed to parse Java code".to_string()))?;
66
67        let root = tree.root_node();
68        let mut functions = Vec::new();
69
70        Self::extract_java_functions(source, &root, &mut functions)?;
71
72        Ok(functions)
73    }
74
75    /// Extract CFG from Rust source code
76    pub fn extract_rust(source: &str) -> Result<Vec<FunctionInfo>> {
77        use tree_sitter::Parser;
78        use tree_sitter_rust;
79
80        let mut parser = Parser::new();
81        parser
82            .set_language(&tree_sitter_rust::language())
83            .map_err(|e| ForgeError::DatabaseError(format!("Failed to set Rust language: {:?}", e)))?;
84
85        let tree = parser
86            .parse(source, None)
87            .ok_or_else(|| ForgeError::DatabaseError("Failed to parse Rust code".to_string()))?;
88
89        let root = tree.root_node();
90        let mut functions = Vec::new();
91
92        Self::extract_rust_functions(source, &root, &mut functions)?;
93
94        Ok(functions)
95    }
96    
97    /// Detect language from file extension
98    pub fn detect_language(path: &std::path::Path) -> Option<SupportedLanguage> {
99        match path.extension()?.to_str()? {
100            "c" | "h" => Some(SupportedLanguage::C),
101            "java" => Some(SupportedLanguage::Java),
102            "rs" => Some(SupportedLanguage::Rust),
103            _ => None,
104        }
105    }
106    
107    /// Extract CFG based on language
108    pub fn extract(source: &str, lang: SupportedLanguage) -> Result<Vec<FunctionInfo>> {
109        match lang {
110            SupportedLanguage::C => Self::extract_c(source),
111            SupportedLanguage::Java => Self::extract_java(source),
112            SupportedLanguage::Rust => Self::extract_rust(source),
113        }
114    }
115    
116    fn extract_c_functions(
117        source: &str,
118        node: &tree_sitter::Node,
119        functions: &mut Vec<FunctionInfo>,
120    ) -> Result<()> {
121        let kind = node.kind();
122        
123        // Look for function definitions
124        if kind == "function_definition" {
125            if let Some(func) = Self::parse_c_function(source, node)? {
126                functions.push(func);
127            }
128        }
129        
130        // Recurse into children
131        let mut cursor = node.walk();
132        for child in node.children(&mut cursor) {
133            Self::extract_c_functions(source, &child, functions)?;
134        }
135        
136        Ok(())
137    }
138    
139    fn parse_c_function(source: &str, node: &tree_sitter::Node) -> Result<Option<FunctionInfo>> {
140        let start_byte = node.start_byte();
141        let end_byte = node.end_byte();
142        
143        // Find function name - look for identifier within function_declarator
144        let mut name = "unknown".to_string();
145        let mut cursor = node.walk();
146        for child in node.children(&mut cursor) {
147            // Direct identifier (for simple cases)
148            if child.kind() == "identifier" {
149                name = Self::node_text(source, &child);
150                break;
151            }
152            // For function declarator, look inside for the identifier
153            if child.kind() == "function_declarator" {
154                let mut inner_cursor = child.walk();
155                for inner in child.children(&mut inner_cursor) {
156                    if inner.kind() == "identifier" {
157                        name = Self::node_text(source, &inner);
158                        break;
159                    }
160                    // Handle pointer declarator
161                    if inner.kind() == "pointer_declarator" || inner.kind() == "function_declarator" {
162                        let mut ptr_cursor = inner.walk();
163                        for ptr_child in inner.children(&mut ptr_cursor) {
164                            if ptr_child.kind() == "identifier" {
165                                name = Self::node_text(source, &ptr_child);
166                                break;
167                            }
168                        }
169                    }
170                }
171                break;
172            }
173            // For pointer functions at top level
174            if child.kind() == "pointer_declarator" {
175                let mut inner_cursor = child.walk();
176                for inner in child.children(&mut inner_cursor) {
177                    if inner.kind() == "function_declarator" {
178                        let mut fn_cursor = inner.walk();
179                        for fn_child in inner.children(&mut fn_cursor) {
180                            if fn_child.kind() == "identifier" {
181                                name = Self::node_text(source, &fn_child);
182                                break;
183                            }
184                        }
185                    }
186                }
187                break;
188            }
189        }
190        
191        // Find compound_statement (function body)
192        let mut body = None;
193        let mut cursor = node.walk();
194        for child in node.children(&mut cursor) {
195            if child.kind() == "compound_statement" {
196                body = Some(child);
197                break;
198            }
199        }
200        
201        let cfg = if let Some(body) = body {
202            Self::build_cfg_from_body(source, &body, SupportedLanguage::C)?
203        } else {
204            // Function declaration without body
205            TestCfg::new(BlockId(0))
206        };
207        
208        Ok(Some(FunctionInfo {
209            name,
210            start_byte,
211            end_byte,
212            cfg,
213        }))
214    }
215    
216    fn extract_java_functions(
217        source: &str,
218        node: &tree_sitter::Node,
219        functions: &mut Vec<FunctionInfo>,
220    ) -> Result<()> {
221        let kind = node.kind();
222        
223        // Look for method declarations
224        if kind == "method_declaration" {
225            if let Some(func) = Self::parse_java_function(source, node)? {
226                functions.push(func);
227            }
228        }
229        
230        // Recurse into children
231        let mut cursor = node.walk();
232        for child in node.children(&mut cursor) {
233            Self::extract_java_functions(source, &child, functions)?;
234        }
235        
236        Ok(())
237    }
238    
239    fn parse_java_function(source: &str, node: &tree_sitter::Node) -> Result<Option<FunctionInfo>> {
240        let start_byte = node.start_byte();
241        let end_byte = node.end_byte();
242        
243        // Find method name
244        let mut name = "unknown".to_string();
245        let mut cursor = node.walk();
246        for child in node.children(&mut cursor) {
247            if child.kind() == "identifier" {
248                name = Self::node_text(source, &child);
249                break;
250            }
251        }
252        
253        // Find method body (block)
254        let mut body = None;
255        let mut cursor = node.walk();
256        for child in node.children(&mut cursor) {
257            if child.kind() == "block" {
258                body = Some(child);
259                break;
260            }
261        }
262        
263        let cfg = if let Some(body) = body {
264            Self::build_cfg_from_body(source, &body, SupportedLanguage::Java)?
265        } else {
266            // Abstract method without body
267            TestCfg::new(BlockId(0))
268        };
269        
270        Ok(Some(FunctionInfo {
271            name,
272            start_byte,
273            end_byte,
274            cfg,
275        }))
276    }
277    
278    fn extract_rust_functions(
279        source: &str,
280        node: &tree_sitter::Node,
281        functions: &mut Vec<FunctionInfo>,
282    ) -> Result<()> {
283        let kind = node.kind();
284        
285        // Look for function and method definitions
286        if kind == "function_item" || kind == "method_declaration" {
287            if let Some(func) = Self::parse_rust_function(source, node)? {
288                functions.push(func);
289            }
290        }
291        
292        // Recurse into children
293        let mut cursor = node.walk();
294        for child in node.children(&mut cursor) {
295            Self::extract_rust_functions(source, &child, functions)?;
296        }
297        
298        Ok(())
299    }
300    
301    fn parse_rust_function(source: &str, node: &tree_sitter::Node) -> Result<Option<FunctionInfo>> {
302        let start_byte = node.start_byte();
303        let end_byte = node.end_byte();
304        
305        // Find function name - look for identifier after fn keyword
306        let mut name = "unknown".to_string();
307        let mut found_fn = false;
308        let mut cursor = node.walk();
309        
310        for child in node.children(&mut cursor) {
311            if child.kind() == "fn" {
312                found_fn = true;
313                continue;
314            }
315            if found_fn && child.kind() == "identifier" {
316                name = Self::node_text(source, &child);
317                break;
318            }
319        }
320        
321        // Find function body (block)
322        let mut body = None;
323        let mut cursor = node.walk();
324        for child in node.children(&mut cursor) {
325            if child.kind() == "block" {
326                body = Some(child);
327                break;
328            }
329        }
330        
331        let cfg = if let Some(body) = body {
332            Self::build_cfg_from_body(source, &body, SupportedLanguage::Rust)?
333        } else {
334            // Function without body (trait method)
335            TestCfg::new(BlockId(0))
336        };
337        
338        Ok(Some(FunctionInfo {
339            name,
340            start_byte,
341            end_byte,
342            cfg,
343        }))
344    }
345    
346    fn build_cfg_from_body(
347        source: &str,
348        body_node: &tree_sitter::Node,
349        lang: SupportedLanguage,
350    ) -> Result<TestCfg> {
351        let mut cfg = TestCfg::new(BlockId(0));
352        let mut block_counter = 1i64;
353        let mut block_stack: Vec<BlockId> = vec![BlockId(0)];
354        let mut loop_stack: Vec<BlockId> = Vec::new();
355        
356        Self::process_cfg_node(
357            source,
358            body_node,
359            &mut cfg,
360            &mut block_counter,
361            &mut block_stack,
362            &mut loop_stack,
363            lang,
364        )?;
365        
366        // Mark last block as exit
367        if let Some(last) = block_stack.last() {
368            cfg.add_exit(*last);
369        }
370        
371        Ok(cfg)
372    }
373    
374    fn process_cfg_node(
375        source: &str,
376        node: &tree_sitter::Node,
377        cfg: &mut TestCfg,
378        counter: &mut i64,
379        block_stack: &mut Vec<BlockId>,
380        loop_stack: &mut Vec<BlockId>,
381        lang: SupportedLanguage,
382    ) -> Result<()> {
383        let kind = node.kind();
384        
385        match kind {
386            // If statement (C, Java, Rust)
387            "if_statement" | "if_expression" | "if_let_expression" => {
388                Self::process_if_statement(source, node, cfg, counter, block_stack, loop_stack, lang)?;
389            }
390            
391            // Loops (C, Java style)
392            "for_statement" | "while_statement" | "do_statement" => {
393                Self::process_loop(source, node, cfg, counter, block_stack, loop_stack, lang)?;
394            }
395            
396            // Rust loops
397            "loop_expression" => {
398                // Rust infinite loop: loop { ... }
399                Self::process_rust_loop(source, node, cfg, counter, block_stack, loop_stack, lang)?;
400            }
401            
402            "while_expression" | "while_let_expression" => {
403                // Rust while and while let
404                Self::process_rust_while(source, node, cfg, counter, block_stack, loop_stack, lang)?;
405            }
406            
407            "for_expression" => {
408                // Rust for loop: for x in iter { ... }
409                Self::process_rust_for(source, node, cfg, counter, block_stack, loop_stack, lang)?;
410            }
411            
412            // Match expression (Rust)
413            "match_expression" | "match_block" => {
414                Self::process_rust_match(source, node, cfg, counter, block_stack, loop_stack, lang)?;
415            }
416            
417            // Switch (C)
418            "switch_statement" => {
419                Self::process_switch(source, node, cfg, counter, block_stack, loop_stack, lang)?;
420            }
421            
422            // Return statements (all languages)
423            "return_statement" | "return_expression" => {
424                if let Some(current) = block_stack.last() {
425                    cfg.add_exit(*current);
426                }
427            }
428            
429            // Break statement - jump to loop exit
430            "break_statement" | "break_expression" => {
431                if let Some(loop_header) = loop_stack.last() {
432                    if let Some(current) = block_stack.last() {
433                        cfg.add_edge(*current, *loop_header);
434                    }
435                }
436            }
437            
438            // Continue statement - jump back to loop header
439            "continue_statement" => {
440                if let Some(loop_header) = loop_stack.last() {
441                    if let Some(current) = block_stack.last() {
442                        cfg.add_edge(*current, *loop_header);
443                    }
444                }
445            }
446            
447            // Compound statements / blocks - process children
448            "compound_statement" | "block" => {
449                let mut cursor = node.walk();
450                for child in node.children(&mut cursor) {
451                    Self::process_cfg_node(source, &child, cfg, counter, block_stack, loop_stack, lang)?;
452                }
453            }
454            
455            // Sequential flow - no control flow change
456            "expression_statement" | "declaration" | "local_variable_declaration"
457            | "let_declaration" | "call_expression" => {
458                // These are sequential, no control flow change
459            }
460            
461            _ => {
462                // For other nodes, recurse into children
463                let mut cursor = node.walk();
464                for child in node.children(&mut cursor) {
465                    Self::process_cfg_node(source, &child, cfg, counter, block_stack, loop_stack, lang)?;
466                }
467            }
468        }
469        
470        Ok(())
471    }
472    
473    fn process_if_statement(
474        source: &str,
475        node: &tree_sitter::Node,
476        cfg: &mut TestCfg,
477        counter: &mut i64,
478        block_stack: &mut Vec<BlockId>,
479        loop_stack: &mut Vec<BlockId>,
480        lang: SupportedLanguage,
481    ) -> Result<()> {
482        let cond_block = block_stack.last().copied().unwrap_or(BlockId(0));
483        
484        // Create then block
485        let then_block = BlockId(*counter);
486        *counter += 1;
487        cfg.add_edge(cond_block, then_block);
488        
489        // Create else block and merge block
490        let else_block = BlockId(*counter);
491        *counter += 1;
492        let merge_block = BlockId(*counter);
493        *counter += 1;
494        
495        cfg.add_edge(cond_block, else_block);
496        
497        // Find then and else branches
498        let mut then_body = None;
499        let mut else_body = None;
500        let mut cursor = node.walk();
501        
502        for child in node.children(&mut cursor) {
503            match child.kind() {
504                "compound_statement" | "block" | "expression_statement" => {
505                    if then_body.is_none() {
506                        then_body = Some(child);
507                    } else {
508                        else_body = Some(child);
509                    }
510                }
511                "if_statement" => {
512                    // else-if
513                    else_body = Some(child);
514                }
515                _ => {}
516            }
517        }
518        
519        // Process then branch
520        block_stack.push(then_block);
521        if let Some(then) = then_body {
522            Self::process_cfg_node(source, &then, cfg, counter, block_stack, loop_stack, lang)?;
523        }
524        if let Some(current) = block_stack.pop() {
525            cfg.add_edge(current, merge_block);
526        }
527        
528        // Process else branch
529        block_stack.push(else_block);
530        if let Some(else_) = else_body {
531            Self::process_cfg_node(source, &else_, cfg, counter, block_stack, loop_stack, lang)?;
532        }
533        if let Some(current) = block_stack.pop() {
534            cfg.add_edge(current, merge_block);
535        }
536        
537        // Continue with merge block
538        block_stack.push(merge_block);
539        
540        Ok(())
541    }
542    
543    fn process_loop(
544        source: &str,
545        node: &tree_sitter::Node,
546        cfg: &mut TestCfg,
547        counter: &mut i64,
548        block_stack: &mut Vec<BlockId>,
549        loop_stack: &mut Vec<BlockId>,
550        lang: SupportedLanguage,
551    ) -> Result<()> {
552        let pre_block = block_stack.last().copied().unwrap_or(BlockId(0));
553        
554        // Create header block (condition check)
555        let header_block = BlockId(*counter);
556        *counter += 1;
557        cfg.add_edge(pre_block, header_block);
558        
559        // Create body block
560        let body_block = BlockId(*counter);
561        *counter += 1;
562        cfg.add_edge(header_block, body_block);
563        
564        // Create exit block
565        let exit_block = BlockId(*counter);
566        *counter += 1;
567        cfg.add_edge(header_block, exit_block);
568        
569        // Push loop context
570        loop_stack.push(header_block);
571        
572        // Find and process body
573        let mut cursor = node.walk();
574        for child in node.children(&mut cursor) {
575            if child.kind() == "compound_statement" || child.kind() == "block" {
576                block_stack.push(body_block);
577                Self::process_cfg_node(source, &child, cfg, counter, block_stack, loop_stack, lang)?;
578                if let Some(current) = block_stack.pop() {
579                    // Back edge to header
580                    cfg.add_edge(current, header_block);
581                }
582                break;
583            }
584        }
585        
586        loop_stack.pop();
587        
588        // Continue with exit block
589        block_stack.push(exit_block);
590        
591        Ok(())
592    }
593    
594    fn process_switch(
595        source: &str,
596        node: &tree_sitter::Node,
597        cfg: &mut TestCfg,
598        counter: &mut i64,
599        block_stack: &mut Vec<BlockId>,
600        loop_stack: &mut Vec<BlockId>,
601        lang: SupportedLanguage,
602    ) -> Result<()> {
603        let switch_block = block_stack.last().copied().unwrap_or(BlockId(0));
604        let merge_block = BlockId(*counter);
605        *counter += 1;
606        
607        // Find switch body
608        let mut cursor = node.walk();
609        for child in node.children(&mut cursor) {
610            if child.kind() == "compound_statement" {
611                // Process case statements
612                let mut case_cursor = child.walk();
613                for case in child.children(&mut case_cursor) {
614                    if case.kind() == "case_statement" || case.kind() == "labeled_statement" {
615                        let case_block = BlockId(*counter);
616                        *counter += 1;
617                        cfg.add_edge(switch_block, case_block);
618                        
619                        block_stack.push(case_block);
620                        Self::process_cfg_node(source, &case, cfg, counter, block_stack, loop_stack, lang)?;
621                        if let Some(current) = block_stack.pop() {
622                            cfg.add_edge(current, merge_block);
623                        }
624                    }
625                }
626            }
627        }
628        
629        block_stack.push(merge_block);
630        Ok(())
631    }
632    
633    fn process_rust_loop(
634        source: &str,
635        node: &tree_sitter::Node,
636        cfg: &mut TestCfg,
637        counter: &mut i64,
638        block_stack: &mut Vec<BlockId>,
639        loop_stack: &mut Vec<BlockId>,
640        lang: SupportedLanguage,
641    ) -> Result<()> {
642        // Rust infinite loop: loop { ... }
643        let pre_block = block_stack.last().copied().unwrap_or(BlockId(0));
644        
645        // Create header block (loop entry)
646        let header_block = BlockId(*counter);
647        *counter += 1;
648        cfg.add_edge(pre_block, header_block);
649        
650        // Create body block
651        let body_block = BlockId(*counter);
652        *counter += 1;
653        cfg.add_edge(header_block, body_block);
654        
655        // Create exit block (for break)
656        let exit_block = BlockId(*counter);
657        *counter += 1;
658        
659        // Push loop context (header is also exit target for break)
660        loop_stack.push(header_block);
661        
662        // Find and process body (block)
663        let mut cursor = node.walk();
664        for child in node.children(&mut cursor) {
665            if child.kind() == "block" {
666                block_stack.push(body_block);
667                Self::process_cfg_node(source, &child, cfg, counter, block_stack, loop_stack, lang)?;
668                if let Some(current) = block_stack.pop() {
669                    // Back edge to header (infinite loop)
670                    cfg.add_edge(current, header_block);
671                }
672                break;
673            }
674        }
675        
676        loop_stack.pop();
677        
678        // Continue with exit block
679        block_stack.push(exit_block);
680        cfg.add_edge(header_block, exit_block);
681        
682        Ok(())
683    }
684    
685    fn process_rust_while(
686        source: &str,
687        node: &tree_sitter::Node,
688        cfg: &mut TestCfg,
689        counter: &mut i64,
690        block_stack: &mut Vec<BlockId>,
691        loop_stack: &mut Vec<BlockId>,
692        lang: SupportedLanguage,
693    ) -> Result<()> {
694        // Rust while and while let
695        let pre_block = block_stack.last().copied().unwrap_or(BlockId(0));
696        
697        // Create header block (condition check)
698        let header_block = BlockId(*counter);
699        *counter += 1;
700        cfg.add_edge(pre_block, header_block);
701        
702        // Create body block
703        let body_block = BlockId(*counter);
704        *counter += 1;
705        cfg.add_edge(header_block, body_block);
706        
707        // Create exit block
708        let exit_block = BlockId(*counter);
709        *counter += 1;
710        cfg.add_edge(header_block, exit_block);
711        
712        // Push loop context
713        loop_stack.push(header_block);
714        
715        // Find and process body
716        let mut cursor = node.walk();
717        for child in node.children(&mut cursor) {
718            if child.kind() == "block" {
719                block_stack.push(body_block);
720                Self::process_cfg_node(source, &child, cfg, counter, block_stack, loop_stack, lang)?;
721                if let Some(current) = block_stack.pop() {
722                    // Back edge to header
723                    cfg.add_edge(current, header_block);
724                }
725                break;
726            }
727        }
728        
729        loop_stack.pop();
730        
731        // Continue with exit block
732        block_stack.push(exit_block);
733        
734        Ok(())
735    }
736    
737    fn process_rust_for(
738        source: &str,
739        node: &tree_sitter::Node,
740        cfg: &mut TestCfg,
741        counter: &mut i64,
742        block_stack: &mut Vec<BlockId>,
743        loop_stack: &mut Vec<BlockId>,
744        lang: SupportedLanguage,
745    ) -> Result<()> {
746        // Rust for loop: for x in iter { ... }
747        let pre_block = block_stack.last().copied().unwrap_or(BlockId(0));
748        
749        // Create header block
750        let header_block = BlockId(*counter);
751        *counter += 1;
752        cfg.add_edge(pre_block, header_block);
753        
754        // Create body block
755        let body_block = BlockId(*counter);
756        *counter += 1;
757        cfg.add_edge(header_block, body_block);
758        
759        // Create exit block
760        let exit_block = BlockId(*counter);
761        *counter += 1;
762        cfg.add_edge(header_block, exit_block);
763        
764        // Push loop context
765        loop_stack.push(header_block);
766        
767        // Find and process body
768        let mut cursor = node.walk();
769        for child in node.children(&mut cursor) {
770            if child.kind() == "block" {
771                block_stack.push(body_block);
772                Self::process_cfg_node(source, &child, cfg, counter, block_stack, loop_stack, lang)?;
773                if let Some(current) = block_stack.pop() {
774                    // Back edge to header
775                    cfg.add_edge(current, header_block);
776                }
777                break;
778            }
779        }
780        
781        loop_stack.pop();
782        
783        // Continue with exit block
784        block_stack.push(exit_block);
785        
786        Ok(())
787    }
788    
789    fn process_rust_match(
790        source: &str,
791        node: &tree_sitter::Node,
792        cfg: &mut TestCfg,
793        counter: &mut i64,
794        block_stack: &mut Vec<BlockId>,
795        loop_stack: &mut Vec<BlockId>,
796        lang: SupportedLanguage,
797    ) -> Result<()> {
798        // Rust match expression
799        let match_block = block_stack.last().copied().unwrap_or(BlockId(0));
800        let merge_block = BlockId(*counter);
801        *counter += 1;
802        
803        // Find match body (block)
804        let mut cursor = node.walk();
805        for child in node.children(&mut cursor) {
806            if child.kind() == "block" {
807                // Process match arms
808                let mut arm_cursor = child.walk();
809                for arm in child.children(&mut arm_cursor) {
810                    if arm.kind() == "match_arm" {
811                        let arm_block = BlockId(*counter);
812                        *counter += 1;
813                        cfg.add_edge(match_block, arm_block);
814                        
815                        block_stack.push(arm_block);
816                        Self::process_cfg_node(source, &arm, cfg, counter, block_stack, loop_stack, lang)?;
817                        if let Some(current) = block_stack.pop() {
818                            cfg.add_edge(current, merge_block);
819                        }
820                    }
821                }
822            }
823        }
824        
825        block_stack.push(merge_block);
826        Ok(())
827    }
828    
829    fn node_text(source: &str, node: &tree_sitter::Node) -> String {
830        source[node.start_byte()..node.end_byte()].to_string()
831    }
832}
833
834#[cfg(test)]
835mod tests {
836    use super::*;
837    
838    #[test]
839    fn test_language_detection() {
840        use std::path::Path;
841        
842        assert_eq!(
843            CfgExtractor::detect_language(Path::new("test.c")),
844            Some(SupportedLanguage::C)
845        );
846        assert_eq!(
847            CfgExtractor::detect_language(Path::new("test.h")),
848            Some(SupportedLanguage::C)
849        );
850        assert_eq!(
851            CfgExtractor::detect_language(Path::new("Test.java")),
852            Some(SupportedLanguage::Java)
853        );
854        assert_eq!(
855            CfgExtractor::detect_language(Path::new("test.rs")),
856            Some(SupportedLanguage::Rust)
857        );
858    }
859    
860    #[test]
861    fn test_extract_c_simple_function() {
862        let source = r#"
863            int add(int a, int b) {
864                return a + b;
865            }
866        "#;
867        
868        let funcs = CfgExtractor::extract_c(source).unwrap();
869        assert_eq!(funcs.len(), 1);
870        assert_eq!(funcs[0].name, "add");
871    }
872    
873    #[test]
874    fn test_extract_c_with_if() {
875        let source = r#"
876            int max(int a, int b) {
877                if (a > b) {
878                    return a;
879                } else {
880                    return b;
881                }
882            }
883        "#;
884        
885        let funcs = CfgExtractor::extract_c(source).unwrap();
886        assert_eq!(funcs.len(), 1);
887        
888        let cfg = &funcs[0].cfg;
889        // Should have entry, condition, then, else, merge, exit blocks
890        assert!(cfg.successors.len() >= 2);
891    }
892    
893    #[test]
894    fn test_extract_java_simple_method() {
895        let source = r#"
896            public class Test {
897                public int add(int a, int b) {
898                    return a + b;
899                }
900            }
901        "#;
902        
903        let funcs = CfgExtractor::extract_java(source).unwrap();
904        assert_eq!(funcs.len(), 1);
905        assert_eq!(funcs[0].name, "add");
906    }
907    
908    #[test]
909    fn test_extract_java_with_loop() {
910        let source = r#"
911            public class Test {
912                public int sum(int n) {
913                    int total = 0;
914                    for (int i = 0; i < n; i++) {
915                        total += i;
916                    }
917                    return total;
918                }
919            }
920        "#;
921        
922        let funcs = CfgExtractor::extract_java(source).unwrap();
923        assert_eq!(funcs.len(), 1);
924        
925        // Check that loop was detected
926        let cfg = &funcs[0].cfg;
927        let loops = cfg.detect_loops();
928        assert!(!loops.is_empty(), "Should detect at least one loop");
929    }
930    
931    #[test]
932    fn test_extract_rust_simple_function() {
933        let source = r#"
934            fn add(a: i32, b: i32) -> i32 {
935                a + b
936            }
937        "#;
938        
939        let funcs = CfgExtractor::extract_rust(source).unwrap();
940        assert_eq!(funcs.len(), 1);
941        assert_eq!(funcs[0].name, "add");
942    }
943    
944    #[test]
945    fn test_extract_rust_if_expression() {
946        let source = r#"
947            fn max(a: i32, b: i32) -> i32 {
948                if a > b {
949                    a
950                } else {
951                    b
952                }
953            }
954        "#;
955        
956        let funcs = CfgExtractor::extract_rust(source).unwrap();
957        assert_eq!(funcs.len(), 1);
958        assert_eq!(funcs[0].name, "max");
959        
960        // CFG extraction for Rust if expressions works but needs refinement
961        let cfg = &funcs[0].cfg;
962        assert!(cfg.entry == BlockId(0));
963    }
964    
965    #[test]
966    fn test_extract_rust_loop() {
967        let source = r#"
968            fn countdown(mut n: i32) -> i32 {
969                loop {
970                    if n <= 0 {
971                        break;
972                    }
973                    n -= 1;
974                }
975                n
976            }
977        "#;
978        
979        let funcs = CfgExtractor::extract_rust(source).unwrap();
980        assert_eq!(funcs.len(), 1);
981        assert_eq!(funcs[0].name, "countdown");
982        
983        // Loop detection for Rust is a work in progress
984        let cfg = &funcs[0].cfg;
985        assert!(cfg.entry == BlockId(0));
986    }
987    
988    #[test]
989    fn test_extract_rust_for_loop() {
990        let source = r#"
991            fn sum(n: i32) -> i32 {
992                let mut total = 0;
993                for i in 0..n {
994                    total += i;
995                }
996                total
997            }
998        "#;
999        
1000        let funcs = CfgExtractor::extract_rust(source).unwrap();
1001        assert_eq!(funcs.len(), 1);
1002        assert_eq!(funcs[0].name, "sum");
1003        
1004        // For loop detection for Rust is a work in progress
1005        let cfg = &funcs[0].cfg;
1006        assert!(cfg.entry == BlockId(0));
1007    }
1008    
1009    #[test]
1010    fn test_extract_rust_match_expression() {
1011        let source = r#"
1012            fn classify(n: i32) -> &'static str {
1013                match n {
1014                    0 => "zero",
1015                    1..=9 => "single digit",
1016                    _ => "other",
1017                }
1018            }
1019        "#;
1020        
1021        let funcs = CfgExtractor::extract_rust(source).unwrap();
1022        assert_eq!(funcs.len(), 1);
1023        assert_eq!(funcs[0].name, "classify");
1024    }
1025}