Skip to main content

exspec_core/
query_utils.rs

1use std::collections::{BTreeSet, HashMap};
2
3use streaming_iterator::StreamingIterator;
4use tree_sitter::{Node, Query, QueryCursor};
5
6use crate::rules::RuleId;
7use crate::suppress::parse_suppression;
8
9pub fn count_captures(query: &Query, capture_name: &str, node: Node, source: &[u8]) -> usize {
10    let idx = match query.capture_index_for_name(capture_name) {
11        Some(i) => i,
12        None => return 0,
13    };
14    let mut cursor = QueryCursor::new();
15    let mut matches = cursor.matches(query, node, source);
16    let mut count = 0;
17    while let Some(m) = matches.next() {
18        count += m.captures.iter().filter(|c| c.index == idx).count();
19    }
20    count
21}
22
23pub fn has_any_match(query: &Query, capture_name: &str, node: Node, source: &[u8]) -> bool {
24    let idx = match query.capture_index_for_name(capture_name) {
25        Some(i) => i,
26        None => return false,
27    };
28    let mut cursor = QueryCursor::new();
29    let mut matches = cursor.matches(query, node, source);
30    while let Some(m) = matches.next() {
31        if m.captures.iter().any(|c| c.index == idx) {
32            return true;
33        }
34    }
35    false
36}
37
38pub fn collect_mock_class_names<F>(
39    query: &Query,
40    node: Node,
41    source: &[u8],
42    extract_name: F,
43) -> Vec<String>
44where
45    F: Fn(&str) -> String,
46{
47    let var_idx = match query.capture_index_for_name("var_name") {
48        Some(i) => i,
49        None => return Vec::new(),
50    };
51    let mut cursor = QueryCursor::new();
52    let mut matches = cursor.matches(query, node, source);
53    let mut names = BTreeSet::new();
54    while let Some(m) = matches.next() {
55        for c in m.captures.iter().filter(|c| c.index == var_idx) {
56            if let Ok(var) = c.node.utf8_text(source) {
57                names.insert(extract_name(var));
58            }
59        }
60    }
61    names.into_iter().collect()
62}
63
64/// Collect byte ranges of all captures matching `capture_name` in `query`.
65fn collect_capture_ranges(
66    query: &Query,
67    capture_name: &str,
68    node: Node,
69    source: &[u8],
70) -> Vec<(usize, usize)> {
71    let idx = match query.capture_index_for_name(capture_name) {
72        Some(i) => i,
73        None => return Vec::new(),
74    };
75    let mut ranges = Vec::new();
76    let mut cursor = QueryCursor::new();
77    let mut matches = cursor.matches(query, node, source);
78    while let Some(m) = matches.next() {
79        for c in m.captures.iter().filter(|c| c.index == idx) {
80            ranges.push((c.node.start_byte(), c.node.end_byte()));
81        }
82    }
83    ranges
84}
85
86/// Count captures of `inner_capture` from `inner_query` that fall within
87/// byte ranges of `outer_capture` from `outer_query`.
88pub fn count_captures_within_context(
89    outer_query: &Query,
90    outer_capture: &str,
91    inner_query: &Query,
92    inner_capture: &str,
93    node: Node,
94    source: &[u8],
95) -> usize {
96    let ranges = collect_capture_ranges(outer_query, outer_capture, node, source);
97    if ranges.is_empty() {
98        return 0;
99    }
100
101    let inner_idx = match inner_query.capture_index_for_name(inner_capture) {
102        Some(i) => i,
103        None => return 0,
104    };
105
106    let mut count = 0;
107    let mut cursor = QueryCursor::new();
108    let mut matches = cursor.matches(inner_query, node, source);
109    while let Some(m) = matches.next() {
110        for c in m.captures.iter().filter(|c| c.index == inner_idx) {
111            let start = c.node.start_byte();
112            let end = c.node.end_byte();
113            if ranges.iter().any(|(rs, re)| start >= *rs && end <= *re) {
114                count += 1;
115            }
116        }
117    }
118
119    count
120}
121
122// Literals considered too common to flag as duplicates.
123// Cross-language superset: Python (True/False/None), JS (null/undefined), PHP/Ruby (nil).
124const TRIVIAL_LITERALS: &[&str] = &[
125    "0",
126    "1",
127    "2",
128    "true",
129    "false",
130    "True",
131    "False",
132    "None",
133    "null",
134    "undefined",
135    "nil",
136    "\"\"",
137    "''",
138    "0.0",
139    "1.0",
140];
141
142/// Count the maximum number of times any non-trivial literal appears
143/// within assertion nodes of the given function node.
144///
145/// `assertion_query` must have an `@assertion` capture.
146/// `literal_kinds` lists the tree-sitter node kind names that represent literals
147/// for the target language (e.g., `["integer", "float", "string"]` for Python).
148pub fn count_duplicate_literals(
149    assertion_query: &Query,
150    node: Node,
151    source: &[u8],
152    literal_kinds: &[&str],
153) -> usize {
154    let ranges = collect_capture_ranges(assertion_query, "assertion", node, source);
155    if ranges.is_empty() {
156        return 0;
157    }
158
159    // Walk tree, collect literals within assertion ranges
160    let mut counts: HashMap<String, usize> = HashMap::new();
161    let mut stack = vec![node];
162    while let Some(n) = stack.pop() {
163        let start = n.start_byte();
164        let end = n.end_byte();
165
166        // Prune subtrees that don't overlap with any assertion range
167        let overlaps_any = ranges.iter().any(|(rs, re)| end > *rs && start < *re);
168        if !overlaps_any {
169            continue;
170        }
171
172        if literal_kinds.contains(&n.kind()) {
173            let in_assertion = ranges.iter().any(|(rs, re)| start >= *rs && end <= *re);
174            if in_assertion {
175                if let Ok(text) = n.utf8_text(source) {
176                    if !TRIVIAL_LITERALS.contains(&text) {
177                        *counts.entry(text.to_string()).or_insert(0) += 1;
178                    }
179                }
180            }
181        }
182
183        for i in 0..n.child_count() {
184            if let Some(child) = n.child(i) {
185                stack.push(child);
186            }
187        }
188    }
189
190    counts.values().copied().max().unwrap_or(0)
191}
192
193/// Text-based fallback for T001 escape hatch. Patterns are literal substrings, not regex.
194/// Matches in comments, strings, and imports are included by design.
195/// Returns the number of source lines that contain any pattern as a substring.
196pub fn count_custom_assertion_lines(source_lines: &[&str], patterns: &[String]) -> usize {
197    if patterns.is_empty() {
198        return 0;
199    }
200    source_lines
201        .iter()
202        .filter(|line| patterns.iter().any(|p| line.contains(p.as_str())))
203        .count()
204}
205
206/// Apply custom assertion pattern fallback to functions with assertion_count == 0.
207/// Only functions with no detected assertions are augmented; others are untouched.
208pub fn apply_custom_assertion_fallback(
209    analysis: &mut crate::extractor::FileAnalysis,
210    source: &str,
211    patterns: &[String],
212) {
213    if patterns.is_empty() {
214        return;
215    }
216    let lines: Vec<&str> = source.lines().collect();
217    for func in &mut analysis.functions {
218        if func.analysis.assertion_count > 0 {
219            continue;
220        }
221        // line/end_line are 1-based
222        let start = func.line.saturating_sub(1);
223        let end = func.end_line.min(lines.len());
224        if start >= end {
225            continue;
226        }
227        let body_lines = &lines[start..end];
228        let count = count_custom_assertion_lines(body_lines, patterns);
229        func.analysis.assertion_count += count;
230    }
231}
232
233pub fn extract_suppression_from_previous_line(source: &str, start_row: usize) -> Vec<RuleId> {
234    if start_row == 0 {
235        return Vec::new();
236    }
237    let lines: Vec<&str> = source.lines().collect();
238    let prev_line = lines.get(start_row - 1).unwrap_or(&"");
239    parse_suppression(prev_line)
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn suppression_from_first_line_returns_empty() {
248        assert!(extract_suppression_from_previous_line("any source", 0).is_empty());
249    }
250
251    #[test]
252    fn suppression_from_previous_line_parses_comment() {
253        let source = "// exspec-ignore: T001\nfn test_foo() {}";
254        let result = extract_suppression_from_previous_line(source, 1);
255        assert_eq!(result.len(), 1);
256        assert_eq!(result[0].0, "T001");
257    }
258
259    #[test]
260    fn suppression_from_previous_line_no_comment() {
261        let source = "// normal comment\nfn test_foo() {}";
262        let result = extract_suppression_from_previous_line(source, 1);
263        assert!(result.is_empty());
264    }
265
266    #[test]
267    fn suppression_out_of_bounds_returns_empty() {
268        let source = "single line";
269        let result = extract_suppression_from_previous_line(source, 5);
270        assert!(result.is_empty());
271    }
272
273    // --- count_captures_within_context ---
274
275    fn python_language() -> tree_sitter::Language {
276        tree_sitter_python::LANGUAGE.into()
277    }
278
279    #[test]
280    fn count_captures_within_context_basic() {
281        // assert obj._count == 1 -> _count is inside assert_statement (@assertion)
282        let source = "def test_foo():\n    assert obj._count == 1\n";
283        let mut parser = tree_sitter::Parser::new();
284        parser.set_language(&python_language()).unwrap();
285        let tree = parser.parse(source, None).unwrap();
286        let root = tree.root_node();
287
288        let assertion_query =
289            Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
290        let private_query = Query::new(
291            &python_language(),
292            "(attribute attribute: (identifier) @private_access (#match? @private_access \"^_[^_]\"))",
293        )
294        .unwrap();
295
296        let count = count_captures_within_context(
297            &assertion_query,
298            "assertion",
299            &private_query,
300            "private_access",
301            root,
302            source.as_bytes(),
303        );
304        assert_eq!(count, 1, "should detect _count inside assert statement");
305    }
306
307    #[test]
308    fn count_captures_within_context_outside() {
309        // _count is outside assert -> should not count
310        let source = "def test_foo():\n    x = obj._count\n    assert x == 1\n";
311        let mut parser = tree_sitter::Parser::new();
312        parser.set_language(&python_language()).unwrap();
313        let tree = parser.parse(source, None).unwrap();
314        let root = tree.root_node();
315
316        let assertion_query =
317            Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
318        let private_query = Query::new(
319            &python_language(),
320            "(attribute attribute: (identifier) @private_access (#match? @private_access \"^_[^_]\"))",
321        )
322        .unwrap();
323
324        let count = count_captures_within_context(
325            &assertion_query,
326            "assertion",
327            &private_query,
328            "private_access",
329            root,
330            source.as_bytes(),
331        );
332        assert_eq!(count, 0, "_count is outside assert, should not count");
333    }
334
335    #[test]
336    fn count_captures_within_context_no_outer() {
337        // No assert statement at all
338        let source = "def test_foo():\n    x = obj._count\n";
339        let mut parser = tree_sitter::Parser::new();
340        parser.set_language(&python_language()).unwrap();
341        let tree = parser.parse(source, None).unwrap();
342        let root = tree.root_node();
343
344        let assertion_query =
345            Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
346        let private_query = Query::new(
347            &python_language(),
348            "(attribute attribute: (identifier) @private_access (#match? @private_access \"^_[^_]\"))",
349        )
350        .unwrap();
351
352        let count = count_captures_within_context(
353            &assertion_query,
354            "assertion",
355            &private_query,
356            "private_access",
357            root,
358            source.as_bytes(),
359        );
360        assert_eq!(count, 0, "no assertions, should return 0");
361    }
362
363    #[test]
364    fn count_captures_missing_capture_returns_zero() {
365        let lang = python_language();
366        // Query with capture @assertion, but we ask for nonexistent name
367        let query = Query::new(&lang, "(assert_statement) @assertion").unwrap();
368        let source = "def test_foo():\n    assert True\n";
369        let mut parser = tree_sitter::Parser::new();
370        parser.set_language(&lang).unwrap();
371        let tree = parser.parse(source, None).unwrap();
372        let root = tree.root_node();
373
374        let count = count_captures(&query, "nonexistent", root, source.as_bytes());
375        assert_eq!(count, 0, "missing capture name should return 0, not panic");
376    }
377
378    #[test]
379    fn collect_mock_class_names_missing_capture_returns_empty() {
380        let lang = python_language();
381        // Query without @var_name capture
382        let query = Query::new(&lang, "(assert_statement) @assertion").unwrap();
383        let source = "def test_foo():\n    assert True\n";
384        let mut parser = tree_sitter::Parser::new();
385        parser.set_language(&lang).unwrap();
386        let tree = parser.parse(source, None).unwrap();
387        let root = tree.root_node();
388
389        let names = collect_mock_class_names(&query, root, source.as_bytes(), |s| s.to_string());
390        assert!(
391            names.is_empty(),
392            "missing @var_name capture should return empty vec, not panic"
393        );
394    }
395
396    #[test]
397    fn count_captures_within_context_missing_capture() {
398        // Capture name doesn't exist in query -> defensive 0
399        let source = "def test_foo():\n    assert obj._count == 1\n";
400        let mut parser = tree_sitter::Parser::new();
401        parser.set_language(&python_language()).unwrap();
402        let tree = parser.parse(source, None).unwrap();
403        let root = tree.root_node();
404
405        let assertion_query =
406            Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
407        let private_query = Query::new(
408            &python_language(),
409            "(attribute attribute: (identifier) @private_access (#match? @private_access \"^_[^_]\"))",
410        )
411        .unwrap();
412
413        // Wrong capture name for outer
414        let count = count_captures_within_context(
415            &assertion_query,
416            "nonexistent",
417            &private_query,
418            "private_access",
419            root,
420            source.as_bytes(),
421        );
422        assert_eq!(count, 0, "missing outer capture should return 0");
423
424        // Wrong capture name for inner
425        let count = count_captures_within_context(
426            &assertion_query,
427            "assertion",
428            &private_query,
429            "nonexistent",
430            root,
431            source.as_bytes(),
432        );
433        assert_eq!(count, 0, "missing inner capture should return 0");
434    }
435
436    // --- count_duplicate_literals ---
437
438    #[test]
439    fn count_duplicate_literals_detects_repeated_value() {
440        let source = "def test_foo():\n    assert calc(1) == 42\n    assert calc(2) == 42\n    assert calc(3) == 42\n";
441        let mut parser = tree_sitter::Parser::new();
442        parser.set_language(&python_language()).unwrap();
443        let tree = parser.parse(source, None).unwrap();
444        let root = tree.root_node();
445
446        let assertion_query =
447            Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
448        let count = count_duplicate_literals(
449            &assertion_query,
450            root,
451            source.as_bytes(),
452            &["integer", "float", "string"],
453        );
454        assert_eq!(count, 3, "42 appears 3 times in assertions");
455    }
456
457    #[test]
458    fn count_duplicate_literals_trivial_excluded() {
459        // All literals are trivial (0, 1, 2) - should return 0
460        let source =
461            "def test_foo():\n    assert calc(1) == 0\n    assert calc(2) == 0\n    assert calc(1) == 0\n";
462        let mut parser = tree_sitter::Parser::new();
463        parser.set_language(&python_language()).unwrap();
464        let tree = parser.parse(source, None).unwrap();
465        let root = tree.root_node();
466
467        let assertion_query =
468            Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
469        let count = count_duplicate_literals(
470            &assertion_query,
471            root,
472            source.as_bytes(),
473            &["integer", "float", "string"],
474        );
475        assert_eq!(count, 0, "0, 1, 2 are all trivial and should be excluded");
476    }
477
478    #[test]
479    fn count_duplicate_literals_no_assertions() {
480        let source = "def test_foo():\n    x = 42\n    y = 42\n    z = 42\n";
481        let mut parser = tree_sitter::Parser::new();
482        parser.set_language(&python_language()).unwrap();
483        let tree = parser.parse(source, None).unwrap();
484        let root = tree.root_node();
485
486        let assertion_query =
487            Query::new(&python_language(), "(assert_statement) @assertion").unwrap();
488        let count = count_duplicate_literals(
489            &assertion_query,
490            root,
491            source.as_bytes(),
492            &["integer", "float", "string"],
493        );
494        assert_eq!(count, 0, "no assertions, should return 0");
495    }
496
497    // --- count_custom_assertion_lines ---
498
499    // TC-04: empty patterns -> 0
500    #[test]
501    fn count_custom_assertion_lines_empty_patterns() {
502        let lines = vec!["util.assertEqual(x, 1)", "assert True"];
503        assert_eq!(count_custom_assertion_lines(&lines, &[]), 0);
504    }
505
506    // TC-05: matching pattern returns correct count
507    #[test]
508    fn count_custom_assertion_lines_matching() {
509        let lines = vec![
510            "    util.assertEqual(x, 1)",
511            "    util.assertEqual(y, 2)",
512            "    print(result)",
513        ];
514        let patterns = vec!["util.assertEqual(".to_string()];
515        assert_eq!(count_custom_assertion_lines(&lines, &patterns), 2);
516    }
517
518    // TC-06: pattern in comment still counts (by design)
519    #[test]
520    fn count_custom_assertion_lines_in_comment() {
521        let lines = vec!["    # util.assertEqual(x, 1)", "    pass"];
522        let patterns = vec!["util.assertEqual(".to_string()];
523        assert_eq!(count_custom_assertion_lines(&lines, &patterns), 1);
524    }
525
526    // TC-07: no matches -> 0
527    #[test]
528    fn count_custom_assertion_lines_no_match() {
529        let lines = vec!["    result = compute(42)", "    print(result)"];
530        let patterns = vec!["util.assertEqual(".to_string()];
531        assert_eq!(count_custom_assertion_lines(&lines, &patterns), 0);
532    }
533
534    // TC-08: same pattern on multiple lines returns line count
535    #[test]
536    fn count_custom_assertion_lines_multiple_occurrences() {
537        let lines = vec!["    myAssert(a) and myAssert(b)", "    myAssert(c)"];
538        let patterns = vec!["myAssert(".to_string()];
539        // Line count, not occurrence count: line 1 has 2 but counts as 1
540        assert_eq!(count_custom_assertion_lines(&lines, &patterns), 2);
541    }
542
543    // TC-16: multiple patterns, one matches
544    #[test]
545    fn count_custom_assertion_lines_multiple_patterns() {
546        let lines = vec!["    customCheck(x)"];
547        let patterns = vec!["util.assertEqual(".to_string(), "customCheck(".to_string()];
548        assert_eq!(count_custom_assertion_lines(&lines, &patterns), 1);
549    }
550
551    // --- apply_custom_assertion_fallback ---
552
553    // TC-09: assertion_count > 0 -> unchanged
554    #[test]
555    fn apply_fallback_skips_functions_with_assertions() {
556        use crate::extractor::{FileAnalysis, TestAnalysis, TestFunction};
557
558        let source = "def test_foo():\n    util.assertEqual(x, 1)\n    assert True\n";
559        let mut analysis = FileAnalysis {
560            file: "test.py".to_string(),
561            functions: vec![TestFunction {
562                name: "test_foo".to_string(),
563                file: "test.py".to_string(),
564                line: 1,
565                end_line: 3,
566                analysis: TestAnalysis {
567                    assertion_count: 1,
568                    ..Default::default()
569                },
570            }],
571            has_pbt_import: false,
572            has_contract_import: false,
573            has_error_test: false,
574            has_relational_assertion: false,
575            parameterized_count: 0,
576        };
577        let patterns = vec!["util.assertEqual(".to_string()];
578        apply_custom_assertion_fallback(&mut analysis, source, &patterns);
579        assert_eq!(analysis.functions[0].analysis.assertion_count, 1);
580    }
581
582    // TC-10: assertion_count == 0 + custom match -> incremented
583    #[test]
584    fn apply_fallback_increments_assertion_count() {
585        use crate::extractor::{FileAnalysis, TestAnalysis, TestFunction};
586
587        let source = "def test_foo():\n    util.assertEqual(x, 1)\n    util.assertEqual(y, 2)\n";
588        let mut analysis = FileAnalysis {
589            file: "test.py".to_string(),
590            functions: vec![TestFunction {
591                name: "test_foo".to_string(),
592                file: "test.py".to_string(),
593                line: 1,
594                end_line: 3,
595                analysis: TestAnalysis {
596                    assertion_count: 0,
597                    ..Default::default()
598                },
599            }],
600            has_pbt_import: false,
601            has_contract_import: false,
602            has_error_test: false,
603            has_relational_assertion: false,
604            parameterized_count: 0,
605        };
606        let patterns = vec!["util.assertEqual(".to_string()];
607        apply_custom_assertion_fallback(&mut analysis, source, &patterns);
608        assert_eq!(analysis.functions[0].analysis.assertion_count, 2);
609    }
610
611    // Empty patterns -> no-op
612    #[test]
613    fn apply_fallback_empty_patterns_noop() {
614        use crate::extractor::{FileAnalysis, TestAnalysis, TestFunction};
615
616        let source = "def test_foo():\n    util.assertEqual(x, 1)\n";
617        let mut analysis = FileAnalysis {
618            file: "test.py".to_string(),
619            functions: vec![TestFunction {
620                name: "test_foo".to_string(),
621                file: "test.py".to_string(),
622                line: 1,
623                end_line: 2,
624                analysis: TestAnalysis {
625                    assertion_count: 0,
626                    ..Default::default()
627                },
628            }],
629            has_pbt_import: false,
630            has_contract_import: false,
631            has_error_test: false,
632            has_relational_assertion: false,
633            parameterized_count: 0,
634        };
635        apply_custom_assertion_fallback(&mut analysis, source, &[]);
636        assert_eq!(analysis.functions[0].analysis.assertion_count, 0);
637    }
638
639    #[test]
640    fn count_duplicate_literals_missing_capture() {
641        let source = "def test_foo():\n    assert 42 == 42\n";
642        let mut parser = tree_sitter::Parser::new();
643        parser.set_language(&python_language()).unwrap();
644        let tree = parser.parse(source, None).unwrap();
645        let root = tree.root_node();
646
647        // Query without @assertion capture
648        let query = Query::new(&python_language(), "(assert_statement) @something_else").unwrap();
649        let count = count_duplicate_literals(&query, root, source.as_bytes(), &["integer"]);
650        assert_eq!(count, 0, "missing @assertion capture should return 0");
651    }
652}