1use crate::cfg::TestCfg;
7use crate::types::BlockId;
8use crate::error::{ForgeError, Result};
9
10#[derive(Debug, Clone)]
12pub struct FunctionInfo {
13 pub name: String,
14 pub start_byte: usize,
15 pub end_byte: usize,
16 pub cfg: TestCfg,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum SupportedLanguage {
22 C,
23 Java,
24 Rust,
25}
26
27pub struct CfgExtractor;
29
30impl CfgExtractor {
31 pub fn extract_c(source: &str) -> Result<Vec<FunctionInfo>> {
33 use tree_sitter::Parser;
34 use tree_sitter_c;
35
36 let mut parser = Parser::new();
37 parser
38 .set_language(&tree_sitter_c::language())
39 .map_err(|e| ForgeError::DatabaseError(format!("Failed to set C language: {:?}", e)))?;
40
41 let tree = parser
42 .parse(source, None)
43 .ok_or_else(|| ForgeError::DatabaseError("Failed to parse C code".to_string()))?;
44
45 let root = tree.root_node();
46 let mut functions = Vec::new();
47
48 Self::extract_c_functions(source, &root, &mut functions)?;
49
50 Ok(functions)
51 }
52
53 pub fn extract_java(source: &str) -> Result<Vec<FunctionInfo>> {
55 use tree_sitter::Parser;
56 use tree_sitter_java;
57
58 let mut parser = Parser::new();
59 parser
60 .set_language(&tree_sitter_java::language())
61 .map_err(|e| ForgeError::DatabaseError(format!("Failed to set Java language: {:?}", e)))?;
62
63 let tree = parser
64 .parse(source, None)
65 .ok_or_else(|| ForgeError::DatabaseError("Failed to parse Java code".to_string()))?;
66
67 let root = tree.root_node();
68 let mut functions = Vec::new();
69
70 Self::extract_java_functions(source, &root, &mut functions)?;
71
72 Ok(functions)
73 }
74
75 pub fn extract_rust(source: &str) -> Result<Vec<FunctionInfo>> {
77 use tree_sitter::Parser;
78 use tree_sitter_rust;
79
80 let mut parser = Parser::new();
81 parser
82 .set_language(&tree_sitter_rust::language())
83 .map_err(|e| ForgeError::DatabaseError(format!("Failed to set Rust language: {:?}", e)))?;
84
85 let tree = parser
86 .parse(source, None)
87 .ok_or_else(|| ForgeError::DatabaseError("Failed to parse Rust code".to_string()))?;
88
89 let root = tree.root_node();
90 let mut functions = Vec::new();
91
92 Self::extract_rust_functions(source, &root, &mut functions)?;
93
94 Ok(functions)
95 }
96
97 pub fn detect_language(path: &std::path::Path) -> Option<SupportedLanguage> {
99 match path.extension()?.to_str()? {
100 "c" | "h" => Some(SupportedLanguage::C),
101 "java" => Some(SupportedLanguage::Java),
102 "rs" => Some(SupportedLanguage::Rust),
103 _ => None,
104 }
105 }
106
107 pub fn extract(source: &str, lang: SupportedLanguage) -> Result<Vec<FunctionInfo>> {
109 match lang {
110 SupportedLanguage::C => Self::extract_c(source),
111 SupportedLanguage::Java => Self::extract_java(source),
112 SupportedLanguage::Rust => Self::extract_rust(source),
113 }
114 }
115
116 fn extract_c_functions(
117 source: &str,
118 node: &tree_sitter::Node,
119 functions: &mut Vec<FunctionInfo>,
120 ) -> Result<()> {
121 let kind = node.kind();
122
123 if kind == "function_definition" {
125 if let Some(func) = Self::parse_c_function(source, node)? {
126 functions.push(func);
127 }
128 }
129
130 let mut cursor = node.walk();
132 for child in node.children(&mut cursor) {
133 Self::extract_c_functions(source, &child, functions)?;
134 }
135
136 Ok(())
137 }
138
139 fn parse_c_function(source: &str, node: &tree_sitter::Node) -> Result<Option<FunctionInfo>> {
140 let start_byte = node.start_byte();
141 let end_byte = node.end_byte();
142
143 let mut name = "unknown".to_string();
145 let mut cursor = node.walk();
146 for child in node.children(&mut cursor) {
147 if child.kind() == "identifier" {
149 name = Self::node_text(source, &child);
150 break;
151 }
152 if child.kind() == "function_declarator" {
154 let mut inner_cursor = child.walk();
155 for inner in child.children(&mut inner_cursor) {
156 if inner.kind() == "identifier" {
157 name = Self::node_text(source, &inner);
158 break;
159 }
160 if inner.kind() == "pointer_declarator" || inner.kind() == "function_declarator" {
162 let mut ptr_cursor = inner.walk();
163 for ptr_child in inner.children(&mut ptr_cursor) {
164 if ptr_child.kind() == "identifier" {
165 name = Self::node_text(source, &ptr_child);
166 break;
167 }
168 }
169 }
170 }
171 break;
172 }
173 if child.kind() == "pointer_declarator" {
175 let mut inner_cursor = child.walk();
176 for inner in child.children(&mut inner_cursor) {
177 if inner.kind() == "function_declarator" {
178 let mut fn_cursor = inner.walk();
179 for fn_child in inner.children(&mut fn_cursor) {
180 if fn_child.kind() == "identifier" {
181 name = Self::node_text(source, &fn_child);
182 break;
183 }
184 }
185 }
186 }
187 break;
188 }
189 }
190
191 let mut body = None;
193 let mut cursor = node.walk();
194 for child in node.children(&mut cursor) {
195 if child.kind() == "compound_statement" {
196 body = Some(child);
197 break;
198 }
199 }
200
201 let cfg = if let Some(body) = body {
202 Self::build_cfg_from_body(source, &body, SupportedLanguage::C)?
203 } else {
204 TestCfg::new(BlockId(0))
206 };
207
208 Ok(Some(FunctionInfo {
209 name,
210 start_byte,
211 end_byte,
212 cfg,
213 }))
214 }
215
216 fn extract_java_functions(
217 source: &str,
218 node: &tree_sitter::Node,
219 functions: &mut Vec<FunctionInfo>,
220 ) -> Result<()> {
221 let kind = node.kind();
222
223 if kind == "method_declaration" {
225 if let Some(func) = Self::parse_java_function(source, node)? {
226 functions.push(func);
227 }
228 }
229
230 let mut cursor = node.walk();
232 for child in node.children(&mut cursor) {
233 Self::extract_java_functions(source, &child, functions)?;
234 }
235
236 Ok(())
237 }
238
239 fn parse_java_function(source: &str, node: &tree_sitter::Node) -> Result<Option<FunctionInfo>> {
240 let start_byte = node.start_byte();
241 let end_byte = node.end_byte();
242
243 let mut name = "unknown".to_string();
245 let mut cursor = node.walk();
246 for child in node.children(&mut cursor) {
247 if child.kind() == "identifier" {
248 name = Self::node_text(source, &child);
249 break;
250 }
251 }
252
253 let mut body = None;
255 let mut cursor = node.walk();
256 for child in node.children(&mut cursor) {
257 if child.kind() == "block" {
258 body = Some(child);
259 break;
260 }
261 }
262
263 let cfg = if let Some(body) = body {
264 Self::build_cfg_from_body(source, &body, SupportedLanguage::Java)?
265 } else {
266 TestCfg::new(BlockId(0))
268 };
269
270 Ok(Some(FunctionInfo {
271 name,
272 start_byte,
273 end_byte,
274 cfg,
275 }))
276 }
277
278 fn extract_rust_functions(
279 source: &str,
280 node: &tree_sitter::Node,
281 functions: &mut Vec<FunctionInfo>,
282 ) -> Result<()> {
283 let kind = node.kind();
284
285 if kind == "function_item" || kind == "method_declaration" {
287 if let Some(func) = Self::parse_rust_function(source, node)? {
288 functions.push(func);
289 }
290 }
291
292 let mut cursor = node.walk();
294 for child in node.children(&mut cursor) {
295 Self::extract_rust_functions(source, &child, functions)?;
296 }
297
298 Ok(())
299 }
300
301 fn parse_rust_function(source: &str, node: &tree_sitter::Node) -> Result<Option<FunctionInfo>> {
302 let start_byte = node.start_byte();
303 let end_byte = node.end_byte();
304
305 let mut name = "unknown".to_string();
307 let mut found_fn = false;
308 let mut cursor = node.walk();
309
310 for child in node.children(&mut cursor) {
311 if child.kind() == "fn" {
312 found_fn = true;
313 continue;
314 }
315 if found_fn && child.kind() == "identifier" {
316 name = Self::node_text(source, &child);
317 break;
318 }
319 }
320
321 let mut body = None;
323 let mut cursor = node.walk();
324 for child in node.children(&mut cursor) {
325 if child.kind() == "block" {
326 body = Some(child);
327 break;
328 }
329 }
330
331 let cfg = if let Some(body) = body {
332 Self::build_cfg_from_body(source, &body, SupportedLanguage::Rust)?
333 } else {
334 TestCfg::new(BlockId(0))
336 };
337
338 Ok(Some(FunctionInfo {
339 name,
340 start_byte,
341 end_byte,
342 cfg,
343 }))
344 }
345
346 fn build_cfg_from_body(
347 source: &str,
348 body_node: &tree_sitter::Node,
349 lang: SupportedLanguage,
350 ) -> Result<TestCfg> {
351 let mut cfg = TestCfg::new(BlockId(0));
352 let mut block_counter = 1i64;
353 let mut block_stack: Vec<BlockId> = vec![BlockId(0)];
354 let mut loop_stack: Vec<BlockId> = Vec::new();
355
356 Self::process_cfg_node(
357 source,
358 body_node,
359 &mut cfg,
360 &mut block_counter,
361 &mut block_stack,
362 &mut loop_stack,
363 lang,
364 )?;
365
366 if let Some(last) = block_stack.last() {
368 cfg.add_exit(*last);
369 }
370
371 Ok(cfg)
372 }
373
374 fn process_cfg_node(
375 source: &str,
376 node: &tree_sitter::Node,
377 cfg: &mut TestCfg,
378 counter: &mut i64,
379 block_stack: &mut Vec<BlockId>,
380 loop_stack: &mut Vec<BlockId>,
381 lang: SupportedLanguage,
382 ) -> Result<()> {
383 let kind = node.kind();
384
385 match kind {
386 "if_statement" | "if_expression" | "if_let_expression" => {
388 Self::process_if_statement(source, node, cfg, counter, block_stack, loop_stack, lang)?;
389 }
390
391 "for_statement" | "while_statement" | "do_statement" => {
393 Self::process_loop(source, node, cfg, counter, block_stack, loop_stack, lang)?;
394 }
395
396 "loop_expression" => {
398 Self::process_rust_loop(source, node, cfg, counter, block_stack, loop_stack, lang)?;
400 }
401
402 "while_expression" | "while_let_expression" => {
403 Self::process_rust_while(source, node, cfg, counter, block_stack, loop_stack, lang)?;
405 }
406
407 "for_expression" => {
408 Self::process_rust_for(source, node, cfg, counter, block_stack, loop_stack, lang)?;
410 }
411
412 "match_expression" | "match_block" => {
414 Self::process_rust_match(source, node, cfg, counter, block_stack, loop_stack, lang)?;
415 }
416
417 "switch_statement" => {
419 Self::process_switch(source, node, cfg, counter, block_stack, loop_stack, lang)?;
420 }
421
422 "return_statement" | "return_expression" => {
424 if let Some(current) = block_stack.last() {
425 cfg.add_exit(*current);
426 }
427 }
428
429 "break_statement" | "break_expression" => {
431 if let Some(loop_header) = loop_stack.last() {
432 if let Some(current) = block_stack.last() {
433 cfg.add_edge(*current, *loop_header);
434 }
435 }
436 }
437
438 "continue_statement" => {
440 if let Some(loop_header) = loop_stack.last() {
441 if let Some(current) = block_stack.last() {
442 cfg.add_edge(*current, *loop_header);
443 }
444 }
445 }
446
447 "compound_statement" | "block" => {
449 let mut cursor = node.walk();
450 for child in node.children(&mut cursor) {
451 Self::process_cfg_node(source, &child, cfg, counter, block_stack, loop_stack, lang)?;
452 }
453 }
454
455 "expression_statement" | "declaration" | "local_variable_declaration"
457 | "let_declaration" | "call_expression" => {
458 }
460
461 _ => {
462 let mut cursor = node.walk();
464 for child in node.children(&mut cursor) {
465 Self::process_cfg_node(source, &child, cfg, counter, block_stack, loop_stack, lang)?;
466 }
467 }
468 }
469
470 Ok(())
471 }
472
473 fn process_if_statement(
474 source: &str,
475 node: &tree_sitter::Node,
476 cfg: &mut TestCfg,
477 counter: &mut i64,
478 block_stack: &mut Vec<BlockId>,
479 loop_stack: &mut Vec<BlockId>,
480 lang: SupportedLanguage,
481 ) -> Result<()> {
482 let cond_block = block_stack.last().copied().unwrap_or(BlockId(0));
483
484 let then_block = BlockId(*counter);
486 *counter += 1;
487 cfg.add_edge(cond_block, then_block);
488
489 let else_block = BlockId(*counter);
491 *counter += 1;
492 let merge_block = BlockId(*counter);
493 *counter += 1;
494
495 cfg.add_edge(cond_block, else_block);
496
497 let mut then_body = None;
499 let mut else_body = None;
500 let mut cursor = node.walk();
501
502 for child in node.children(&mut cursor) {
503 match child.kind() {
504 "compound_statement" | "block" | "expression_statement" => {
505 if then_body.is_none() {
506 then_body = Some(child);
507 } else {
508 else_body = Some(child);
509 }
510 }
511 "if_statement" => {
512 else_body = Some(child);
514 }
515 _ => {}
516 }
517 }
518
519 block_stack.push(then_block);
521 if let Some(then) = then_body {
522 Self::process_cfg_node(source, &then, cfg, counter, block_stack, loop_stack, lang)?;
523 }
524 if let Some(current) = block_stack.pop() {
525 cfg.add_edge(current, merge_block);
526 }
527
528 block_stack.push(else_block);
530 if let Some(else_) = else_body {
531 Self::process_cfg_node(source, &else_, cfg, counter, block_stack, loop_stack, lang)?;
532 }
533 if let Some(current) = block_stack.pop() {
534 cfg.add_edge(current, merge_block);
535 }
536
537 block_stack.push(merge_block);
539
540 Ok(())
541 }
542
543 fn process_loop(
544 source: &str,
545 node: &tree_sitter::Node,
546 cfg: &mut TestCfg,
547 counter: &mut i64,
548 block_stack: &mut Vec<BlockId>,
549 loop_stack: &mut Vec<BlockId>,
550 lang: SupportedLanguage,
551 ) -> Result<()> {
552 let pre_block = block_stack.last().copied().unwrap_or(BlockId(0));
553
554 let header_block = BlockId(*counter);
556 *counter += 1;
557 cfg.add_edge(pre_block, header_block);
558
559 let body_block = BlockId(*counter);
561 *counter += 1;
562 cfg.add_edge(header_block, body_block);
563
564 let exit_block = BlockId(*counter);
566 *counter += 1;
567 cfg.add_edge(header_block, exit_block);
568
569 loop_stack.push(header_block);
571
572 let mut cursor = node.walk();
574 for child in node.children(&mut cursor) {
575 if child.kind() == "compound_statement" || child.kind() == "block" {
576 block_stack.push(body_block);
577 Self::process_cfg_node(source, &child, cfg, counter, block_stack, loop_stack, lang)?;
578 if let Some(current) = block_stack.pop() {
579 cfg.add_edge(current, header_block);
581 }
582 break;
583 }
584 }
585
586 loop_stack.pop();
587
588 block_stack.push(exit_block);
590
591 Ok(())
592 }
593
594 fn process_switch(
595 source: &str,
596 node: &tree_sitter::Node,
597 cfg: &mut TestCfg,
598 counter: &mut i64,
599 block_stack: &mut Vec<BlockId>,
600 loop_stack: &mut Vec<BlockId>,
601 lang: SupportedLanguage,
602 ) -> Result<()> {
603 let switch_block = block_stack.last().copied().unwrap_or(BlockId(0));
604 let merge_block = BlockId(*counter);
605 *counter += 1;
606
607 let mut cursor = node.walk();
609 for child in node.children(&mut cursor) {
610 if child.kind() == "compound_statement" {
611 let mut case_cursor = child.walk();
613 for case in child.children(&mut case_cursor) {
614 if case.kind() == "case_statement" || case.kind() == "labeled_statement" {
615 let case_block = BlockId(*counter);
616 *counter += 1;
617 cfg.add_edge(switch_block, case_block);
618
619 block_stack.push(case_block);
620 Self::process_cfg_node(source, &case, cfg, counter, block_stack, loop_stack, lang)?;
621 if let Some(current) = block_stack.pop() {
622 cfg.add_edge(current, merge_block);
623 }
624 }
625 }
626 }
627 }
628
629 block_stack.push(merge_block);
630 Ok(())
631 }
632
633 fn process_rust_loop(
634 source: &str,
635 node: &tree_sitter::Node,
636 cfg: &mut TestCfg,
637 counter: &mut i64,
638 block_stack: &mut Vec<BlockId>,
639 loop_stack: &mut Vec<BlockId>,
640 lang: SupportedLanguage,
641 ) -> Result<()> {
642 let pre_block = block_stack.last().copied().unwrap_or(BlockId(0));
644
645 let header_block = BlockId(*counter);
647 *counter += 1;
648 cfg.add_edge(pre_block, header_block);
649
650 let body_block = BlockId(*counter);
652 *counter += 1;
653 cfg.add_edge(header_block, body_block);
654
655 let exit_block = BlockId(*counter);
657 *counter += 1;
658
659 loop_stack.push(header_block);
661
662 let mut cursor = node.walk();
664 for child in node.children(&mut cursor) {
665 if child.kind() == "block" {
666 block_stack.push(body_block);
667 Self::process_cfg_node(source, &child, cfg, counter, block_stack, loop_stack, lang)?;
668 if let Some(current) = block_stack.pop() {
669 cfg.add_edge(current, header_block);
671 }
672 break;
673 }
674 }
675
676 loop_stack.pop();
677
678 block_stack.push(exit_block);
680 cfg.add_edge(header_block, exit_block);
681
682 Ok(())
683 }
684
685 fn process_rust_while(
686 source: &str,
687 node: &tree_sitter::Node,
688 cfg: &mut TestCfg,
689 counter: &mut i64,
690 block_stack: &mut Vec<BlockId>,
691 loop_stack: &mut Vec<BlockId>,
692 lang: SupportedLanguage,
693 ) -> Result<()> {
694 let pre_block = block_stack.last().copied().unwrap_or(BlockId(0));
696
697 let header_block = BlockId(*counter);
699 *counter += 1;
700 cfg.add_edge(pre_block, header_block);
701
702 let body_block = BlockId(*counter);
704 *counter += 1;
705 cfg.add_edge(header_block, body_block);
706
707 let exit_block = BlockId(*counter);
709 *counter += 1;
710 cfg.add_edge(header_block, exit_block);
711
712 loop_stack.push(header_block);
714
715 let mut cursor = node.walk();
717 for child in node.children(&mut cursor) {
718 if child.kind() == "block" {
719 block_stack.push(body_block);
720 Self::process_cfg_node(source, &child, cfg, counter, block_stack, loop_stack, lang)?;
721 if let Some(current) = block_stack.pop() {
722 cfg.add_edge(current, header_block);
724 }
725 break;
726 }
727 }
728
729 loop_stack.pop();
730
731 block_stack.push(exit_block);
733
734 Ok(())
735 }
736
737 fn process_rust_for(
738 source: &str,
739 node: &tree_sitter::Node,
740 cfg: &mut TestCfg,
741 counter: &mut i64,
742 block_stack: &mut Vec<BlockId>,
743 loop_stack: &mut Vec<BlockId>,
744 lang: SupportedLanguage,
745 ) -> Result<()> {
746 let pre_block = block_stack.last().copied().unwrap_or(BlockId(0));
748
749 let header_block = BlockId(*counter);
751 *counter += 1;
752 cfg.add_edge(pre_block, header_block);
753
754 let body_block = BlockId(*counter);
756 *counter += 1;
757 cfg.add_edge(header_block, body_block);
758
759 let exit_block = BlockId(*counter);
761 *counter += 1;
762 cfg.add_edge(header_block, exit_block);
763
764 loop_stack.push(header_block);
766
767 let mut cursor = node.walk();
769 for child in node.children(&mut cursor) {
770 if child.kind() == "block" {
771 block_stack.push(body_block);
772 Self::process_cfg_node(source, &child, cfg, counter, block_stack, loop_stack, lang)?;
773 if let Some(current) = block_stack.pop() {
774 cfg.add_edge(current, header_block);
776 }
777 break;
778 }
779 }
780
781 loop_stack.pop();
782
783 block_stack.push(exit_block);
785
786 Ok(())
787 }
788
789 fn process_rust_match(
790 source: &str,
791 node: &tree_sitter::Node,
792 cfg: &mut TestCfg,
793 counter: &mut i64,
794 block_stack: &mut Vec<BlockId>,
795 loop_stack: &mut Vec<BlockId>,
796 lang: SupportedLanguage,
797 ) -> Result<()> {
798 let match_block = block_stack.last().copied().unwrap_or(BlockId(0));
800 let merge_block = BlockId(*counter);
801 *counter += 1;
802
803 let mut cursor = node.walk();
805 for child in node.children(&mut cursor) {
806 if child.kind() == "block" {
807 let mut arm_cursor = child.walk();
809 for arm in child.children(&mut arm_cursor) {
810 if arm.kind() == "match_arm" {
811 let arm_block = BlockId(*counter);
812 *counter += 1;
813 cfg.add_edge(match_block, arm_block);
814
815 block_stack.push(arm_block);
816 Self::process_cfg_node(source, &arm, cfg, counter, block_stack, loop_stack, lang)?;
817 if let Some(current) = block_stack.pop() {
818 cfg.add_edge(current, merge_block);
819 }
820 }
821 }
822 }
823 }
824
825 block_stack.push(merge_block);
826 Ok(())
827 }
828
829 fn node_text(source: &str, node: &tree_sitter::Node) -> String {
830 source[node.start_byte()..node.end_byte()].to_string()
831 }
832}
833
834#[cfg(test)]
835mod tests {
836 use super::*;
837
838 #[test]
839 fn test_language_detection() {
840 use std::path::Path;
841
842 assert_eq!(
843 CfgExtractor::detect_language(Path::new("test.c")),
844 Some(SupportedLanguage::C)
845 );
846 assert_eq!(
847 CfgExtractor::detect_language(Path::new("test.h")),
848 Some(SupportedLanguage::C)
849 );
850 assert_eq!(
851 CfgExtractor::detect_language(Path::new("Test.java")),
852 Some(SupportedLanguage::Java)
853 );
854 assert_eq!(
855 CfgExtractor::detect_language(Path::new("test.rs")),
856 Some(SupportedLanguage::Rust)
857 );
858 }
859
860 #[test]
861 fn test_extract_c_simple_function() {
862 let source = r#"
863 int add(int a, int b) {
864 return a + b;
865 }
866 "#;
867
868 let funcs = CfgExtractor::extract_c(source).unwrap();
869 assert_eq!(funcs.len(), 1);
870 assert_eq!(funcs[0].name, "add");
871 }
872
873 #[test]
874 fn test_extract_c_with_if() {
875 let source = r#"
876 int max(int a, int b) {
877 if (a > b) {
878 return a;
879 } else {
880 return b;
881 }
882 }
883 "#;
884
885 let funcs = CfgExtractor::extract_c(source).unwrap();
886 assert_eq!(funcs.len(), 1);
887
888 let cfg = &funcs[0].cfg;
889 assert!(cfg.successors.len() >= 2);
891 }
892
893 #[test]
894 fn test_extract_java_simple_method() {
895 let source = r#"
896 public class Test {
897 public int add(int a, int b) {
898 return a + b;
899 }
900 }
901 "#;
902
903 let funcs = CfgExtractor::extract_java(source).unwrap();
904 assert_eq!(funcs.len(), 1);
905 assert_eq!(funcs[0].name, "add");
906 }
907
908 #[test]
909 fn test_extract_java_with_loop() {
910 let source = r#"
911 public class Test {
912 public int sum(int n) {
913 int total = 0;
914 for (int i = 0; i < n; i++) {
915 total += i;
916 }
917 return total;
918 }
919 }
920 "#;
921
922 let funcs = CfgExtractor::extract_java(source).unwrap();
923 assert_eq!(funcs.len(), 1);
924
925 let cfg = &funcs[0].cfg;
927 let loops = cfg.detect_loops();
928 assert!(!loops.is_empty(), "Should detect at least one loop");
929 }
930
931 #[test]
932 fn test_extract_rust_simple_function() {
933 let source = r#"
934 fn add(a: i32, b: i32) -> i32 {
935 a + b
936 }
937 "#;
938
939 let funcs = CfgExtractor::extract_rust(source).unwrap();
940 assert_eq!(funcs.len(), 1);
941 assert_eq!(funcs[0].name, "add");
942 }
943
944 #[test]
945 fn test_extract_rust_if_expression() {
946 let source = r#"
947 fn max(a: i32, b: i32) -> i32 {
948 if a > b {
949 a
950 } else {
951 b
952 }
953 }
954 "#;
955
956 let funcs = CfgExtractor::extract_rust(source).unwrap();
957 assert_eq!(funcs.len(), 1);
958 assert_eq!(funcs[0].name, "max");
959
960 let cfg = &funcs[0].cfg;
962 assert!(cfg.entry == BlockId(0));
963 }
964
965 #[test]
966 fn test_extract_rust_loop() {
967 let source = r#"
968 fn countdown(mut n: i32) -> i32 {
969 loop {
970 if n <= 0 {
971 break;
972 }
973 n -= 1;
974 }
975 n
976 }
977 "#;
978
979 let funcs = CfgExtractor::extract_rust(source).unwrap();
980 assert_eq!(funcs.len(), 1);
981 assert_eq!(funcs[0].name, "countdown");
982
983 let cfg = &funcs[0].cfg;
985 assert!(cfg.entry == BlockId(0));
986 }
987
988 #[test]
989 fn test_extract_rust_for_loop() {
990 let source = r#"
991 fn sum(n: i32) -> i32 {
992 let mut total = 0;
993 for i in 0..n {
994 total += i;
995 }
996 total
997 }
998 "#;
999
1000 let funcs = CfgExtractor::extract_rust(source).unwrap();
1001 assert_eq!(funcs.len(), 1);
1002 assert_eq!(funcs[0].name, "sum");
1003
1004 let cfg = &funcs[0].cfg;
1006 assert!(cfg.entry == BlockId(0));
1007 }
1008
1009 #[test]
1010 fn test_extract_rust_match_expression() {
1011 let source = r#"
1012 fn classify(n: i32) -> &'static str {
1013 match n {
1014 0 => "zero",
1015 1..=9 => "single digit",
1016 _ => "other",
1017 }
1018 }
1019 "#;
1020
1021 let funcs = CfgExtractor::extract_rust(source).unwrap();
1022 assert_eq!(funcs.len(), 1);
1023 assert_eq!(funcs[0].name, "classify");
1024 }
1025}