1use std::collections::{BTreeSet, HashMap};
2
3use streaming_iterator::StreamingIterator;
4use tree_sitter::{Node, Query, QueryCursor};
5
6use crate::rules::RuleId;
7use crate::suppress::parse_suppression;
8
9pub fn count_captures(query: &Query, capture_name: &str, node: Node, source: &[u8]) -> usize {
10 let idx = match query.capture_index_for_name(capture_name) {
11 Some(i) => i,
12 None => return 0,
13 };
14 let mut cursor = QueryCursor::new();
15 let mut matches = cursor.matches(query, node, source);
16 let mut count = 0;
17 while let Some(m) = matches.next() {
18 count += m.captures.iter().filter(|c| c.index == idx).count();
19 }
20 count
21}
22
23pub fn has_any_match(query: &Query, capture_name: &str, node: Node, source: &[u8]) -> bool {
24 let idx = match query.capture_index_for_name(capture_name) {
25 Some(i) => i,
26 None => return false,
27 };
28 let mut cursor = QueryCursor::new();
29 let mut matches = cursor.matches(query, node, source);
30 while let Some(m) = matches.next() {
31 if m.captures.iter().any(|c| c.index == idx) {
32 return true;
33 }
34 }
35 false
36}
37
38pub fn collect_mock_class_names<F>(
39 query: &Query,
40 node: Node,
41 source: &[u8],
42 extract_name: F,
43) -> Vec<String>
44where
45 F: Fn(&str) -> String,
46{
47 let var_idx = match query.capture_index_for_name("var_name") {
48 Some(i) => i,
49 None => return Vec::new(),
50 };
51 let mut cursor = QueryCursor::new();
52 let mut matches = cursor.matches(query, node, source);
53 let mut names = BTreeSet::new();
54 while let Some(m) = matches.next() {
55 for c in m.captures.iter().filter(|c| c.index == var_idx) {
56 if let Ok(var) = c.node.utf8_text(source) {
57 names.insert(extract_name(var));
58 }
59 }
60 }
61 names.into_iter().collect()
62}
63
64fn collect_capture_ranges(
66 query: &Query,
67 capture_name: &str,
68 node: Node,
69 source: &[u8],
70) -> Vec<(usize, usize)> {
71 let idx = match query.capture_index_for_name(capture_name) {
72 Some(i) => i,
73 None => return Vec::new(),
74 };
75 let mut ranges = Vec::new();
76 let mut cursor = QueryCursor::new();
77 let mut matches = cursor.matches(query, node, source);
78 while let Some(m) = matches.next() {
79 for c in m.captures.iter().filter(|c| c.index == idx) {
80 ranges.push((c.node.start_byte(), c.node.end_byte()));
81 }
82 }
83 ranges
84}
85
86pub fn count_captures_within_context(
89 outer_query: &Query,
90 outer_capture: &str,
91 inner_query: &Query,
92 inner_capture: &str,
93 node: Node,
94 source: &[u8],
95) -> usize {
96 let ranges = collect_capture_ranges(outer_query, outer_capture, node, source);
97 if ranges.is_empty() {
98 return 0;
99 }
100
101 let inner_idx = match inner_query.capture_index_for_name(inner_capture) {
102 Some(i) => i,
103 None => return 0,
104 };
105
106 let mut count = 0;
107 let mut cursor = QueryCursor::new();
108 let mut matches = cursor.matches(inner_query, node, source);
109 while let Some(m) = matches.next() {
110 for c in m.captures.iter().filter(|c| c.index == inner_idx) {
111 let start = c.node.start_byte();
112 let end = c.node.end_byte();
113 if ranges.iter().any(|(rs, re)| start >= *rs && end <= *re) {
114 count += 1;
115 }
116 }
117 }
118
119 count
120}
121
122const TRIVIAL_LITERALS: &[&str] = &[
125 "0",
126 "1",
127 "2",
128 "true",
129 "false",
130 "True",
131 "False",
132 "None",
133 "null",
134 "undefined",
135 "nil",
136 "\"\"",
137 "''",
138 "0.0",
139 "1.0",
140];
141
142pub fn count_duplicate_literals(
149 assertion_query: &Query,
150 node: Node,
151 source: &[u8],
152 literal_kinds: &[&str],
153) -> usize {
154 let ranges = collect_capture_ranges(assertion_query, "assertion", node, source);
155 if ranges.is_empty() {
156 return 0;
157 }
158
159 let mut counts: HashMap<String, usize> = HashMap::new();
161 let mut stack = vec![node];
162 while let Some(n) = stack.pop() {
163 let start = n.start_byte();
164 let end = n.end_byte();
165
166 let overlaps_any = ranges.iter().any(|(rs, re)| end > *rs && start < *re);
168 if !overlaps_any {
169 continue;
170 }
171
172 if literal_kinds.contains(&n.kind()) {
173 let in_assertion = ranges.iter().any(|(rs, re)| start >= *rs && end <= *re);
174 if in_assertion {
175 if let Ok(text) = n.utf8_text(source) {
176 if !TRIVIAL_LITERALS.contains(&text) {
177 *counts.entry(text.to_string()).or_insert(0) += 1;
178 }
179 }
180 }
181 }
182
183 for i in 0..n.child_count() {
184 if let Some(child) = n.child(i) {
185 stack.push(child);
186 }
187 }
188 }
189
190 counts.values().copied().max().unwrap_or(0)
191}
192
193pub fn count_custom_assertion_lines(source_lines: &[&str], patterns: &[String]) -> usize {
197 if patterns.is_empty() {
198 return 0;
199 }
200 source_lines
201 .iter()
202 .filter(|line| {
203 patterns
204 .iter()
205 .any(|p| !p.is_empty() && line.contains(p.as_str()))
206 })
207 .count()
208}
209
210pub fn apply_custom_assertion_fallback(
213 analysis: &mut crate::extractor::FileAnalysis,
214 source: &str,
215 patterns: &[String],
216) {
217 if patterns.is_empty() {
218 return;
219 }
220 let lines: Vec<&str> = source.lines().collect();
221 for func in &mut analysis.functions {
222 if func.analysis.assertion_count > 0 {
223 continue;
224 }
225 let start = func.line.saturating_sub(1);
227 let end = func.end_line.min(lines.len());
228 if start >= end {
229 continue;
230 }
231 let body_lines = &lines[start..end];
232 let count = count_custom_assertion_lines(body_lines, patterns);
233 func.analysis.assertion_count += count;
234 }
235}
236
237pub fn extract_suppression_from_previous_line(source: &str, start_row: usize) -> Vec<RuleId> {
238 if start_row == 0 {
239 return Vec::new();
240 }
241 let lines: Vec<&str> = source.lines().collect();
242 let prev_line = lines.get(start_row - 1).unwrap_or(&"");
243 parse_suppression(prev_line)
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn suppression_from_first_line_returns_empty() {
252 assert!(extract_suppression_from_previous_line("any source", 0).is_empty());
253 }
254
255 #[test]
256 fn suppression_from_previous_line_parses_comment() {
257 let source = "// exspec-ignore: T001\nfn test_foo() {}";
258 let result = extract_suppression_from_previous_line(source, 1);
259 assert_eq!(result.len(), 1);
260 assert_eq!(result[0].0, "T001");
261 }
262
263 #[test]
264 fn suppression_from_previous_line_no_comment() {
265 let source = "// normal comment\nfn test_foo() {}";
266 let result = extract_suppression_from_previous_line(source, 1);
267 assert!(result.is_empty());
268 }
269
270 #[test]
271 fn suppression_out_of_bounds_returns_empty() {
272 let source = "single line";
273 let result = extract_suppression_from_previous_line(source, 5);
274 assert!(result.is_empty());
275 }
276
277 fn python_language() -> tree_sitter::Language {
280 tree_sitter_python::LANGUAGE.into()
281 }
282
283 #[test]
284 fn count_captures_within_context_basic() {
285 let source = "def test_foo():\n assert obj._count == 1\n";
287 let mut parser = tree_sitter::Parser::new();
288 parser.set_language(&python_language()).unwrap();
289 let tree = parser.parse(source, None).unwrap();
290 let root = tree.root_node();
291
292 let assertion_query =
293 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
294 let private_query = Query::new(
295 &python_language(),
296 "(attribute attribute: (identifier) @private_access (#match? @private_access \"^_[^_]\"))",
297 )
298 .unwrap();
299
300 let count = count_captures_within_context(
301 &assertion_query,
302 "assertion",
303 &private_query,
304 "private_access",
305 root,
306 source.as_bytes(),
307 );
308 assert_eq!(count, 1, "should detect _count inside assert statement");
309 }
310
311 #[test]
312 fn count_captures_within_context_outside() {
313 let source = "def test_foo():\n x = obj._count\n assert x == 1\n";
315 let mut parser = tree_sitter::Parser::new();
316 parser.set_language(&python_language()).unwrap();
317 let tree = parser.parse(source, None).unwrap();
318 let root = tree.root_node();
319
320 let assertion_query =
321 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
322 let private_query = Query::new(
323 &python_language(),
324 "(attribute attribute: (identifier) @private_access (#match? @private_access \"^_[^_]\"))",
325 )
326 .unwrap();
327
328 let count = count_captures_within_context(
329 &assertion_query,
330 "assertion",
331 &private_query,
332 "private_access",
333 root,
334 source.as_bytes(),
335 );
336 assert_eq!(count, 0, "_count is outside assert, should not count");
337 }
338
339 #[test]
340 fn count_captures_within_context_no_outer() {
341 let source = "def test_foo():\n x = obj._count\n";
343 let mut parser = tree_sitter::Parser::new();
344 parser.set_language(&python_language()).unwrap();
345 let tree = parser.parse(source, None).unwrap();
346 let root = tree.root_node();
347
348 let assertion_query =
349 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
350 let private_query = Query::new(
351 &python_language(),
352 "(attribute attribute: (identifier) @private_access (#match? @private_access \"^_[^_]\"))",
353 )
354 .unwrap();
355
356 let count = count_captures_within_context(
357 &assertion_query,
358 "assertion",
359 &private_query,
360 "private_access",
361 root,
362 source.as_bytes(),
363 );
364 assert_eq!(count, 0, "no assertions, should return 0");
365 }
366
367 #[test]
368 fn count_captures_missing_capture_returns_zero() {
369 let lang = python_language();
370 let query = Query::new(&lang, "(assert_statement) @assertion").unwrap();
372 let source = "def test_foo():\n assert True\n";
373 let mut parser = tree_sitter::Parser::new();
374 parser.set_language(&lang).unwrap();
375 let tree = parser.parse(source, None).unwrap();
376 let root = tree.root_node();
377
378 let count = count_captures(&query, "nonexistent", root, source.as_bytes());
379 assert_eq!(count, 0, "missing capture name should return 0, not panic");
380 }
381
382 #[test]
383 fn collect_mock_class_names_missing_capture_returns_empty() {
384 let lang = python_language();
385 let query = Query::new(&lang, "(assert_statement) @assertion").unwrap();
387 let source = "def test_foo():\n assert True\n";
388 let mut parser = tree_sitter::Parser::new();
389 parser.set_language(&lang).unwrap();
390 let tree = parser.parse(source, None).unwrap();
391 let root = tree.root_node();
392
393 let names = collect_mock_class_names(&query, root, source.as_bytes(), |s| s.to_string());
394 assert!(
395 names.is_empty(),
396 "missing @var_name capture should return empty vec, not panic"
397 );
398 }
399
400 #[test]
401 fn count_captures_within_context_missing_capture() {
402 let source = "def test_foo():\n assert obj._count == 1\n";
404 let mut parser = tree_sitter::Parser::new();
405 parser.set_language(&python_language()).unwrap();
406 let tree = parser.parse(source, None).unwrap();
407 let root = tree.root_node();
408
409 let assertion_query =
410 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
411 let private_query = Query::new(
412 &python_language(),
413 "(attribute attribute: (identifier) @private_access (#match? @private_access \"^_[^_]\"))",
414 )
415 .unwrap();
416
417 let count = count_captures_within_context(
419 &assertion_query,
420 "nonexistent",
421 &private_query,
422 "private_access",
423 root,
424 source.as_bytes(),
425 );
426 assert_eq!(count, 0, "missing outer capture should return 0");
427
428 let count = count_captures_within_context(
430 &assertion_query,
431 "assertion",
432 &private_query,
433 "nonexistent",
434 root,
435 source.as_bytes(),
436 );
437 assert_eq!(count, 0, "missing inner capture should return 0");
438 }
439
440 #[test]
443 fn count_duplicate_literals_detects_repeated_value() {
444 let source = "def test_foo():\n assert calc(1) == 42\n assert calc(2) == 42\n assert calc(3) == 42\n";
445 let mut parser = tree_sitter::Parser::new();
446 parser.set_language(&python_language()).unwrap();
447 let tree = parser.parse(source, None).unwrap();
448 let root = tree.root_node();
449
450 let assertion_query =
451 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
452 let count = count_duplicate_literals(
453 &assertion_query,
454 root,
455 source.as_bytes(),
456 &["integer", "float", "string"],
457 );
458 assert_eq!(count, 3, "42 appears 3 times in assertions");
459 }
460
461 #[test]
462 fn count_duplicate_literals_trivial_excluded() {
463 let source =
465 "def test_foo():\n assert calc(1) == 0\n assert calc(2) == 0\n assert calc(1) == 0\n";
466 let mut parser = tree_sitter::Parser::new();
467 parser.set_language(&python_language()).unwrap();
468 let tree = parser.parse(source, None).unwrap();
469 let root = tree.root_node();
470
471 let assertion_query =
472 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
473 let count = count_duplicate_literals(
474 &assertion_query,
475 root,
476 source.as_bytes(),
477 &["integer", "float", "string"],
478 );
479 assert_eq!(count, 0, "0, 1, 2 are all trivial and should be excluded");
480 }
481
482 #[test]
483 fn count_duplicate_literals_no_assertions() {
484 let source = "def test_foo():\n x = 42\n y = 42\n z = 42\n";
485 let mut parser = tree_sitter::Parser::new();
486 parser.set_language(&python_language()).unwrap();
487 let tree = parser.parse(source, None).unwrap();
488 let root = tree.root_node();
489
490 let assertion_query =
491 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
492 let count = count_duplicate_literals(
493 &assertion_query,
494 root,
495 source.as_bytes(),
496 &["integer", "float", "string"],
497 );
498 assert_eq!(count, 0, "no assertions, should return 0");
499 }
500
501 #[test]
505 fn count_custom_assertion_lines_empty_patterns() {
506 let lines = vec!["util.assertEqual(x, 1)", "assert True"];
507 assert_eq!(count_custom_assertion_lines(&lines, &[]), 0);
508 }
509
510 #[test]
512 fn count_custom_assertion_lines_matching() {
513 let lines = vec![
514 " util.assertEqual(x, 1)",
515 " util.assertEqual(y, 2)",
516 " print(result)",
517 ];
518 let patterns = vec!["util.assertEqual(".to_string()];
519 assert_eq!(count_custom_assertion_lines(&lines, &patterns), 2);
520 }
521
522 #[test]
524 fn count_custom_assertion_lines_in_comment() {
525 let lines = vec![" # util.assertEqual(x, 1)", " pass"];
526 let patterns = vec!["util.assertEqual(".to_string()];
527 assert_eq!(count_custom_assertion_lines(&lines, &patterns), 1);
528 }
529
530 #[test]
532 fn count_custom_assertion_lines_no_match() {
533 let lines = vec![" result = compute(42)", " print(result)"];
534 let patterns = vec!["util.assertEqual(".to_string()];
535 assert_eq!(count_custom_assertion_lines(&lines, &patterns), 0);
536 }
537
538 #[test]
540 fn count_custom_assertion_lines_multiple_occurrences() {
541 let lines = vec![" myAssert(a) and myAssert(b)", " myAssert(c)"];
542 let patterns = vec!["myAssert(".to_string()];
543 assert_eq!(count_custom_assertion_lines(&lines, &patterns), 2);
545 }
546
547 #[test]
549 fn count_custom_assertion_lines_multiple_patterns() {
550 let lines = vec![" customCheck(x)"];
551 let patterns = vec!["util.assertEqual(".to_string(), "customCheck(".to_string()];
552 assert_eq!(count_custom_assertion_lines(&lines, &patterns), 1);
553 }
554
555 #[test]
559 fn apply_fallback_skips_functions_with_assertions() {
560 use crate::extractor::{FileAnalysis, TestAnalysis, TestFunction};
561
562 let source = "def test_foo():\n util.assertEqual(x, 1)\n assert True\n";
563 let mut analysis = FileAnalysis {
564 file: "test.py".to_string(),
565 functions: vec![TestFunction {
566 name: "test_foo".to_string(),
567 file: "test.py".to_string(),
568 line: 1,
569 end_line: 3,
570 analysis: TestAnalysis {
571 assertion_count: 1,
572 ..Default::default()
573 },
574 }],
575 has_pbt_import: false,
576 has_contract_import: false,
577 has_error_test: false,
578 has_relational_assertion: false,
579 parameterized_count: 0,
580 };
581 let patterns = vec!["util.assertEqual(".to_string()];
582 apply_custom_assertion_fallback(&mut analysis, source, &patterns);
583 assert_eq!(analysis.functions[0].analysis.assertion_count, 1);
584 }
585
586 #[test]
588 fn apply_fallback_increments_assertion_count() {
589 use crate::extractor::{FileAnalysis, TestAnalysis, TestFunction};
590
591 let source = "def test_foo():\n util.assertEqual(x, 1)\n util.assertEqual(y, 2)\n";
592 let mut analysis = FileAnalysis {
593 file: "test.py".to_string(),
594 functions: vec![TestFunction {
595 name: "test_foo".to_string(),
596 file: "test.py".to_string(),
597 line: 1,
598 end_line: 3,
599 analysis: TestAnalysis {
600 assertion_count: 0,
601 ..Default::default()
602 },
603 }],
604 has_pbt_import: false,
605 has_contract_import: false,
606 has_error_test: false,
607 has_relational_assertion: false,
608 parameterized_count: 0,
609 };
610 let patterns = vec!["util.assertEqual(".to_string()];
611 apply_custom_assertion_fallback(&mut analysis, source, &patterns);
612 assert_eq!(analysis.functions[0].analysis.assertion_count, 2);
613 }
614
615 #[test]
617 fn apply_fallback_empty_patterns_noop() {
618 use crate::extractor::{FileAnalysis, TestAnalysis, TestFunction};
619
620 let source = "def test_foo():\n util.assertEqual(x, 1)\n";
621 let mut analysis = FileAnalysis {
622 file: "test.py".to_string(),
623 functions: vec![TestFunction {
624 name: "test_foo".to_string(),
625 file: "test.py".to_string(),
626 line: 1,
627 end_line: 2,
628 analysis: TestAnalysis {
629 assertion_count: 0,
630 ..Default::default()
631 },
632 }],
633 has_pbt_import: false,
634 has_contract_import: false,
635 has_error_test: false,
636 has_relational_assertion: false,
637 parameterized_count: 0,
638 };
639 apply_custom_assertion_fallback(&mut analysis, source, &[]);
640 assert_eq!(analysis.functions[0].analysis.assertion_count, 0);
641 }
642
643 #[test]
646 fn empty_string_pattern_ignored() {
647 let lines = vec!["assert True", "x = 1", "print(result)"];
648 let patterns = vec!["".to_string()];
649 assert_eq!(
650 count_custom_assertion_lines(&lines, &patterns),
651 0,
652 "empty string pattern should not match any line"
653 );
654 }
655
656 #[test]
657 fn mixed_empty_and_valid_patterns() {
658 let lines = vec![" assert_custom(x)", " print(result)"];
659 let patterns = vec!["".to_string(), "assert_custom".to_string()];
660 assert_eq!(
661 count_custom_assertion_lines(&lines, &patterns),
662 1,
663 "only valid patterns should match"
664 );
665 }
666
667 #[test]
668 fn whitespace_only_pattern_matches() {
669 let lines = vec!["assert_true", "no_space_here"];
671 let patterns = vec![" ".to_string()];
672 assert_eq!(
673 count_custom_assertion_lines(&lines, &patterns),
674 0,
675 "whitespace pattern should not match lines without spaces"
676 );
677 let lines_with_space = vec!["assert true", "nospace"];
678 assert_eq!(
679 count_custom_assertion_lines(&lines_with_space, &patterns),
680 1,
681 "whitespace pattern should match lines containing spaces"
682 );
683 }
684
685 #[test]
688 fn apply_fallback_end_line_exceeds_source() {
689 use crate::extractor::{FileAnalysis, TestAnalysis, TestFunction};
690
691 let source = "def test_foo():\n custom_assert(x)\n";
692 let mut analysis = FileAnalysis {
693 file: "test.py".to_string(),
694 functions: vec![TestFunction {
695 name: "test_foo".to_string(),
696 file: "test.py".to_string(),
697 line: 1,
698 end_line: 12, analysis: TestAnalysis {
700 assertion_count: 0,
701 ..Default::default()
702 },
703 }],
704 has_pbt_import: false,
705 has_contract_import: false,
706 has_error_test: false,
707 has_relational_assertion: false,
708 parameterized_count: 0,
709 };
710 let patterns = vec!["custom_assert".to_string()];
711 apply_custom_assertion_fallback(&mut analysis, source, &patterns);
712 assert_eq!(
713 analysis.functions[0].analysis.assertion_count, 1,
714 "should handle end_line > source length without panic"
715 );
716 }
717
718 #[test]
719 fn apply_fallback_empty_string_pattern_noop() {
720 use crate::extractor::{FileAnalysis, TestAnalysis, TestFunction};
721
722 let source = "def test_foo():\n some_call(x)\n another_call(y)\n";
723 let mut analysis = FileAnalysis {
724 file: "test.py".to_string(),
725 functions: vec![TestFunction {
726 name: "test_foo".to_string(),
727 file: "test.py".to_string(),
728 line: 1,
729 end_line: 3,
730 analysis: TestAnalysis {
731 assertion_count: 0,
732 ..Default::default()
733 },
734 }],
735 has_pbt_import: false,
736 has_contract_import: false,
737 has_error_test: false,
738 has_relational_assertion: false,
739 parameterized_count: 0,
740 };
741 let patterns = vec!["".to_string()];
742 apply_custom_assertion_fallback(&mut analysis, source, &patterns);
743 assert_eq!(
744 analysis.functions[0].analysis.assertion_count, 0,
745 "empty-string-only patterns should not increment assertion_count"
746 );
747 }
748
749 #[test]
750 fn count_duplicate_literals_missing_capture() {
751 let source = "def test_foo():\n assert 42 == 42\n";
752 let mut parser = tree_sitter::Parser::new();
753 parser.set_language(&python_language()).unwrap();
754 let tree = parser.parse(source, None).unwrap();
755 let root = tree.root_node();
756
757 let query = Query::new(&python_language(), "(assert_statement) @something_else").unwrap();
759 let count = count_duplicate_literals(&query, root, source.as_bytes(), &["integer"]);
760 assert_eq!(count, 0, "missing @assertion capture should return 0");
761 }
762}