1use std::collections::{BTreeSet, HashMap};
2
3use streaming_iterator::StreamingIterator;
4use tree_sitter::{Node, Query, QueryCursor};
5
6use crate::rules::RuleId;
7use crate::suppress::parse_suppression;
8
9pub fn count_captures(query: &Query, capture_name: &str, node: Node, source: &[u8]) -> usize {
10 let idx = match query.capture_index_for_name(capture_name) {
11 Some(i) => i,
12 None => return 0,
13 };
14 let mut cursor = QueryCursor::new();
15 let mut matches = cursor.matches(query, node, source);
16 let mut count = 0;
17 while let Some(m) = matches.next() {
18 count += m.captures.iter().filter(|c| c.index == idx).count();
19 }
20 count
21}
22
23pub fn has_any_match(query: &Query, capture_name: &str, node: Node, source: &[u8]) -> bool {
24 let idx = match query.capture_index_for_name(capture_name) {
25 Some(i) => i,
26 None => return false,
27 };
28 let mut cursor = QueryCursor::new();
29 let mut matches = cursor.matches(query, node, source);
30 while let Some(m) = matches.next() {
31 if m.captures.iter().any(|c| c.index == idx) {
32 return true;
33 }
34 }
35 false
36}
37
38pub fn collect_mock_class_names<F>(
39 query: &Query,
40 node: Node,
41 source: &[u8],
42 extract_name: F,
43) -> Vec<String>
44where
45 F: Fn(&str) -> String,
46{
47 let var_idx = match query.capture_index_for_name("var_name") {
48 Some(i) => i,
49 None => return Vec::new(),
50 };
51 let mut cursor = QueryCursor::new();
52 let mut matches = cursor.matches(query, node, source);
53 let mut names = BTreeSet::new();
54 while let Some(m) = matches.next() {
55 for c in m.captures.iter().filter(|c| c.index == var_idx) {
56 if let Ok(var) = c.node.utf8_text(source) {
57 names.insert(extract_name(var));
58 }
59 }
60 }
61 names.into_iter().collect()
62}
63
64fn collect_capture_ranges(
66 query: &Query,
67 capture_name: &str,
68 node: Node,
69 source: &[u8],
70) -> Vec<(usize, usize)> {
71 let idx = match query.capture_index_for_name(capture_name) {
72 Some(i) => i,
73 None => return Vec::new(),
74 };
75 let mut ranges = Vec::new();
76 let mut cursor = QueryCursor::new();
77 let mut matches = cursor.matches(query, node, source);
78 while let Some(m) = matches.next() {
79 for c in m.captures.iter().filter(|c| c.index == idx) {
80 ranges.push((c.node.start_byte(), c.node.end_byte()));
81 }
82 }
83 ranges
84}
85
86pub fn count_captures_within_context(
89 outer_query: &Query,
90 outer_capture: &str,
91 inner_query: &Query,
92 inner_capture: &str,
93 node: Node,
94 source: &[u8],
95) -> usize {
96 let ranges = collect_capture_ranges(outer_query, outer_capture, node, source);
97 if ranges.is_empty() {
98 return 0;
99 }
100
101 let inner_idx = match inner_query.capture_index_for_name(inner_capture) {
102 Some(i) => i,
103 None => return 0,
104 };
105
106 let mut count = 0;
107 let mut cursor = QueryCursor::new();
108 let mut matches = cursor.matches(inner_query, node, source);
109 while let Some(m) = matches.next() {
110 for c in m.captures.iter().filter(|c| c.index == inner_idx) {
111 let start = c.node.start_byte();
112 let end = c.node.end_byte();
113 if ranges.iter().any(|(rs, re)| start >= *rs && end <= *re) {
114 count += 1;
115 }
116 }
117 }
118
119 count
120}
121
122const TRIVIAL_LITERALS: &[&str] = &[
125 "0",
126 "1",
127 "2",
128 "true",
129 "false",
130 "True",
131 "False",
132 "None",
133 "null",
134 "undefined",
135 "nil",
136 "\"\"",
137 "''",
138 "0.0",
139 "1.0",
140];
141
142pub fn count_duplicate_literals(
149 assertion_query: &Query,
150 node: Node,
151 source: &[u8],
152 literal_kinds: &[&str],
153) -> usize {
154 let ranges = collect_capture_ranges(assertion_query, "assertion", node, source);
155 if ranges.is_empty() {
156 return 0;
157 }
158
159 let mut counts: HashMap<String, usize> = HashMap::new();
161 let mut stack = vec![node];
162 while let Some(n) = stack.pop() {
163 let start = n.start_byte();
164 let end = n.end_byte();
165
166 let overlaps_any = ranges.iter().any(|(rs, re)| end > *rs && start < *re);
168 if !overlaps_any {
169 continue;
170 }
171
172 if literal_kinds.contains(&n.kind()) {
173 let in_assertion = ranges.iter().any(|(rs, re)| start >= *rs && end <= *re);
174 if in_assertion {
175 if let Ok(text) = n.utf8_text(source) {
176 if !TRIVIAL_LITERALS.contains(&text) {
177 *counts.entry(text.to_string()).or_insert(0) += 1;
178 }
179 }
180 }
181 }
182
183 for i in 0..n.child_count() {
184 if let Some(child) = n.child(i) {
185 stack.push(child);
186 }
187 }
188 }
189
190 counts.values().copied().max().unwrap_or(0)
191}
192
193pub fn count_custom_assertion_lines(source_lines: &[&str], patterns: &[String]) -> usize {
197 if patterns.is_empty() {
198 return 0;
199 }
200 source_lines
201 .iter()
202 .filter(|line| patterns.iter().any(|p| line.contains(p.as_str())))
203 .count()
204}
205
206pub fn apply_custom_assertion_fallback(
209 analysis: &mut crate::extractor::FileAnalysis,
210 source: &str,
211 patterns: &[String],
212) {
213 if patterns.is_empty() {
214 return;
215 }
216 let lines: Vec<&str> = source.lines().collect();
217 for func in &mut analysis.functions {
218 if func.analysis.assertion_count > 0 {
219 continue;
220 }
221 let start = func.line.saturating_sub(1);
223 let end = func.end_line.min(lines.len());
224 if start >= end {
225 continue;
226 }
227 let body_lines = &lines[start..end];
228 let count = count_custom_assertion_lines(body_lines, patterns);
229 func.analysis.assertion_count += count;
230 }
231}
232
233pub fn extract_suppression_from_previous_line(source: &str, start_row: usize) -> Vec<RuleId> {
234 if start_row == 0 {
235 return Vec::new();
236 }
237 let lines: Vec<&str> = source.lines().collect();
238 let prev_line = lines.get(start_row - 1).unwrap_or(&"");
239 parse_suppression(prev_line)
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn suppression_from_first_line_returns_empty() {
248 assert!(extract_suppression_from_previous_line("any source", 0).is_empty());
249 }
250
251 #[test]
252 fn suppression_from_previous_line_parses_comment() {
253 let source = "// exspec-ignore: T001\nfn test_foo() {}";
254 let result = extract_suppression_from_previous_line(source, 1);
255 assert_eq!(result.len(), 1);
256 assert_eq!(result[0].0, "T001");
257 }
258
259 #[test]
260 fn suppression_from_previous_line_no_comment() {
261 let source = "// normal comment\nfn test_foo() {}";
262 let result = extract_suppression_from_previous_line(source, 1);
263 assert!(result.is_empty());
264 }
265
266 #[test]
267 fn suppression_out_of_bounds_returns_empty() {
268 let source = "single line";
269 let result = extract_suppression_from_previous_line(source, 5);
270 assert!(result.is_empty());
271 }
272
273 fn python_language() -> tree_sitter::Language {
276 tree_sitter_python::LANGUAGE.into()
277 }
278
279 #[test]
280 fn count_captures_within_context_basic() {
281 let source = "def test_foo():\n assert obj._count == 1\n";
283 let mut parser = tree_sitter::Parser::new();
284 parser.set_language(&python_language()).unwrap();
285 let tree = parser.parse(source, None).unwrap();
286 let root = tree.root_node();
287
288 let assertion_query =
289 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
290 let private_query = Query::new(
291 &python_language(),
292 "(attribute attribute: (identifier) @private_access (#match? @private_access \"^_[^_]\"))",
293 )
294 .unwrap();
295
296 let count = count_captures_within_context(
297 &assertion_query,
298 "assertion",
299 &private_query,
300 "private_access",
301 root,
302 source.as_bytes(),
303 );
304 assert_eq!(count, 1, "should detect _count inside assert statement");
305 }
306
307 #[test]
308 fn count_captures_within_context_outside() {
309 let source = "def test_foo():\n x = obj._count\n assert x == 1\n";
311 let mut parser = tree_sitter::Parser::new();
312 parser.set_language(&python_language()).unwrap();
313 let tree = parser.parse(source, None).unwrap();
314 let root = tree.root_node();
315
316 let assertion_query =
317 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
318 let private_query = Query::new(
319 &python_language(),
320 "(attribute attribute: (identifier) @private_access (#match? @private_access \"^_[^_]\"))",
321 )
322 .unwrap();
323
324 let count = count_captures_within_context(
325 &assertion_query,
326 "assertion",
327 &private_query,
328 "private_access",
329 root,
330 source.as_bytes(),
331 );
332 assert_eq!(count, 0, "_count is outside assert, should not count");
333 }
334
335 #[test]
336 fn count_captures_within_context_no_outer() {
337 let source = "def test_foo():\n x = obj._count\n";
339 let mut parser = tree_sitter::Parser::new();
340 parser.set_language(&python_language()).unwrap();
341 let tree = parser.parse(source, None).unwrap();
342 let root = tree.root_node();
343
344 let assertion_query =
345 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
346 let private_query = Query::new(
347 &python_language(),
348 "(attribute attribute: (identifier) @private_access (#match? @private_access \"^_[^_]\"))",
349 )
350 .unwrap();
351
352 let count = count_captures_within_context(
353 &assertion_query,
354 "assertion",
355 &private_query,
356 "private_access",
357 root,
358 source.as_bytes(),
359 );
360 assert_eq!(count, 0, "no assertions, should return 0");
361 }
362
363 #[test]
364 fn count_captures_missing_capture_returns_zero() {
365 let lang = python_language();
366 let query = Query::new(&lang, "(assert_statement) @assertion").unwrap();
368 let source = "def test_foo():\n assert True\n";
369 let mut parser = tree_sitter::Parser::new();
370 parser.set_language(&lang).unwrap();
371 let tree = parser.parse(source, None).unwrap();
372 let root = tree.root_node();
373
374 let count = count_captures(&query, "nonexistent", root, source.as_bytes());
375 assert_eq!(count, 0, "missing capture name should return 0, not panic");
376 }
377
378 #[test]
379 fn collect_mock_class_names_missing_capture_returns_empty() {
380 let lang = python_language();
381 let query = Query::new(&lang, "(assert_statement) @assertion").unwrap();
383 let source = "def test_foo():\n assert True\n";
384 let mut parser = tree_sitter::Parser::new();
385 parser.set_language(&lang).unwrap();
386 let tree = parser.parse(source, None).unwrap();
387 let root = tree.root_node();
388
389 let names = collect_mock_class_names(&query, root, source.as_bytes(), |s| s.to_string());
390 assert!(
391 names.is_empty(),
392 "missing @var_name capture should return empty vec, not panic"
393 );
394 }
395
396 #[test]
397 fn count_captures_within_context_missing_capture() {
398 let source = "def test_foo():\n assert obj._count == 1\n";
400 let mut parser = tree_sitter::Parser::new();
401 parser.set_language(&python_language()).unwrap();
402 let tree = parser.parse(source, None).unwrap();
403 let root = tree.root_node();
404
405 let assertion_query =
406 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
407 let private_query = Query::new(
408 &python_language(),
409 "(attribute attribute: (identifier) @private_access (#match? @private_access \"^_[^_]\"))",
410 )
411 .unwrap();
412
413 let count = count_captures_within_context(
415 &assertion_query,
416 "nonexistent",
417 &private_query,
418 "private_access",
419 root,
420 source.as_bytes(),
421 );
422 assert_eq!(count, 0, "missing outer capture should return 0");
423
424 let count = count_captures_within_context(
426 &assertion_query,
427 "assertion",
428 &private_query,
429 "nonexistent",
430 root,
431 source.as_bytes(),
432 );
433 assert_eq!(count, 0, "missing inner capture should return 0");
434 }
435
436 #[test]
439 fn count_duplicate_literals_detects_repeated_value() {
440 let source = "def test_foo():\n assert calc(1) == 42\n assert calc(2) == 42\n assert calc(3) == 42\n";
441 let mut parser = tree_sitter::Parser::new();
442 parser.set_language(&python_language()).unwrap();
443 let tree = parser.parse(source, None).unwrap();
444 let root = tree.root_node();
445
446 let assertion_query =
447 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
448 let count = count_duplicate_literals(
449 &assertion_query,
450 root,
451 source.as_bytes(),
452 &["integer", "float", "string"],
453 );
454 assert_eq!(count, 3, "42 appears 3 times in assertions");
455 }
456
457 #[test]
458 fn count_duplicate_literals_trivial_excluded() {
459 let source =
461 "def test_foo():\n assert calc(1) == 0\n assert calc(2) == 0\n assert calc(1) == 0\n";
462 let mut parser = tree_sitter::Parser::new();
463 parser.set_language(&python_language()).unwrap();
464 let tree = parser.parse(source, None).unwrap();
465 let root = tree.root_node();
466
467 let assertion_query =
468 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
469 let count = count_duplicate_literals(
470 &assertion_query,
471 root,
472 source.as_bytes(),
473 &["integer", "float", "string"],
474 );
475 assert_eq!(count, 0, "0, 1, 2 are all trivial and should be excluded");
476 }
477
478 #[test]
479 fn count_duplicate_literals_no_assertions() {
480 let source = "def test_foo():\n x = 42\n y = 42\n z = 42\n";
481 let mut parser = tree_sitter::Parser::new();
482 parser.set_language(&python_language()).unwrap();
483 let tree = parser.parse(source, None).unwrap();
484 let root = tree.root_node();
485
486 let assertion_query =
487 Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
488 let count = count_duplicate_literals(
489 &assertion_query,
490 root,
491 source.as_bytes(),
492 &["integer", "float", "string"],
493 );
494 assert_eq!(count, 0, "no assertions, should return 0");
495 }
496
497 #[test]
501 fn count_custom_assertion_lines_empty_patterns() {
502 let lines = vec!["util.assertEqual(x, 1)", "assert True"];
503 assert_eq!(count_custom_assertion_lines(&lines, &[]), 0);
504 }
505
506 #[test]
508 fn count_custom_assertion_lines_matching() {
509 let lines = vec![
510 " util.assertEqual(x, 1)",
511 " util.assertEqual(y, 2)",
512 " print(result)",
513 ];
514 let patterns = vec!["util.assertEqual(".to_string()];
515 assert_eq!(count_custom_assertion_lines(&lines, &patterns), 2);
516 }
517
518 #[test]
520 fn count_custom_assertion_lines_in_comment() {
521 let lines = vec![" # util.assertEqual(x, 1)", " pass"];
522 let patterns = vec!["util.assertEqual(".to_string()];
523 assert_eq!(count_custom_assertion_lines(&lines, &patterns), 1);
524 }
525
526 #[test]
528 fn count_custom_assertion_lines_no_match() {
529 let lines = vec![" result = compute(42)", " print(result)"];
530 let patterns = vec!["util.assertEqual(".to_string()];
531 assert_eq!(count_custom_assertion_lines(&lines, &patterns), 0);
532 }
533
534 #[test]
536 fn count_custom_assertion_lines_multiple_occurrences() {
537 let lines = vec![" myAssert(a) and myAssert(b)", " myAssert(c)"];
538 let patterns = vec!["myAssert(".to_string()];
539 assert_eq!(count_custom_assertion_lines(&lines, &patterns), 2);
541 }
542
543 #[test]
545 fn count_custom_assertion_lines_multiple_patterns() {
546 let lines = vec![" customCheck(x)"];
547 let patterns = vec!["util.assertEqual(".to_string(), "customCheck(".to_string()];
548 assert_eq!(count_custom_assertion_lines(&lines, &patterns), 1);
549 }
550
551 #[test]
555 fn apply_fallback_skips_functions_with_assertions() {
556 use crate::extractor::{FileAnalysis, TestAnalysis, TestFunction};
557
558 let source = "def test_foo():\n util.assertEqual(x, 1)\n assert True\n";
559 let mut analysis = FileAnalysis {
560 file: "test.py".to_string(),
561 functions: vec![TestFunction {
562 name: "test_foo".to_string(),
563 file: "test.py".to_string(),
564 line: 1,
565 end_line: 3,
566 analysis: TestAnalysis {
567 assertion_count: 1,
568 ..Default::default()
569 },
570 }],
571 has_pbt_import: false,
572 has_contract_import: false,
573 has_error_test: false,
574 has_relational_assertion: false,
575 parameterized_count: 0,
576 };
577 let patterns = vec!["util.assertEqual(".to_string()];
578 apply_custom_assertion_fallback(&mut analysis, source, &patterns);
579 assert_eq!(analysis.functions[0].analysis.assertion_count, 1);
580 }
581
582 #[test]
584 fn apply_fallback_increments_assertion_count() {
585 use crate::extractor::{FileAnalysis, TestAnalysis, TestFunction};
586
587 let source = "def test_foo():\n util.assertEqual(x, 1)\n util.assertEqual(y, 2)\n";
588 let mut analysis = FileAnalysis {
589 file: "test.py".to_string(),
590 functions: vec![TestFunction {
591 name: "test_foo".to_string(),
592 file: "test.py".to_string(),
593 line: 1,
594 end_line: 3,
595 analysis: TestAnalysis {
596 assertion_count: 0,
597 ..Default::default()
598 },
599 }],
600 has_pbt_import: false,
601 has_contract_import: false,
602 has_error_test: false,
603 has_relational_assertion: false,
604 parameterized_count: 0,
605 };
606 let patterns = vec!["util.assertEqual(".to_string()];
607 apply_custom_assertion_fallback(&mut analysis, source, &patterns);
608 assert_eq!(analysis.functions[0].analysis.assertion_count, 2);
609 }
610
611 #[test]
613 fn apply_fallback_empty_patterns_noop() {
614 use crate::extractor::{FileAnalysis, TestAnalysis, TestFunction};
615
616 let source = "def test_foo():\n util.assertEqual(x, 1)\n";
617 let mut analysis = FileAnalysis {
618 file: "test.py".to_string(),
619 functions: vec![TestFunction {
620 name: "test_foo".to_string(),
621 file: "test.py".to_string(),
622 line: 1,
623 end_line: 2,
624 analysis: TestAnalysis {
625 assertion_count: 0,
626 ..Default::default()
627 },
628 }],
629 has_pbt_import: false,
630 has_contract_import: false,
631 has_error_test: false,
632 has_relational_assertion: false,
633 parameterized_count: 0,
634 };
635 apply_custom_assertion_fallback(&mut analysis, source, &[]);
636 assert_eq!(analysis.functions[0].analysis.assertion_count, 0);
637 }
638
639 #[test]
640 fn count_duplicate_literals_missing_capture() {
641 let source = "def test_foo():\n assert 42 == 42\n";
642 let mut parser = tree_sitter::Parser::new();
643 parser.set_language(&python_language()).unwrap();
644 let tree = parser.parse(source, None).unwrap();
645 let root = tree.root_node();
646
647 let query = Query::new(&python_language(), "(assert_statement) @something_else").unwrap();
649 let count = count_duplicate_literals(&query, root, source.as_bytes(), &["integer"]);
650 assert_eq!(count, 0, "missing @assertion capture should return 0");
651 }
652}