rust_covfix/
rule.rs

1use proc_macro2::TokenTree;
2use regex::Regex;
3use std::fs;
4use std::marker::PhantomData;
5use std::path::Path;
6use syn::visit::Visit;
7use syn::{
8    ExprForLoop, ExprMacro, Fields, File, ItemEnum, ItemFn, ItemMod, ItemStruct, ItemUnion,
9    MacroDelimiter,
10};
11
12use crate::error::*;
13use crate::{BranchCoverage, FileCoverage, LineCoverage};
14
15pub struct SourceCode {
16    pub content: String,
17    pub ast: File,
18}
19
20impl SourceCode {
21    pub fn new(filename: &Path) -> Result<SourceCode, Error> {
22        let content = fs::read_to_string(filename)
23            .chain_err(|| ErrorKind::SourceFileNotFound(filename.to_owned()))?;
24        let ast =
25            syn::parse_file(&content).chain_err(|| format!("Failed to parse {:?}", filename))?;
26        Ok(SourceCode { content, ast })
27    }
28}
29
30pub trait Rule {
31    fn fix_file_coverage(&self, source: &SourceCode, file_cov: &mut FileCoverage);
32}
33
34pub struct CloseBlockRule {
35    reg: Regex,
36}
37
38impl CloseBlockRule {
39    pub fn new() -> Self {
40        Self {
41            reg: Regex::new(
42                r"^(?:\s*\}(?:\s*\))*(?:\s*;)?|\s*(?:\}\s*)?else(?:\s*\{)?)?\s*(?://.*)?$",
43            )
44            .unwrap(),
45        }
46    }
47}
48
49impl Rule for CloseBlockRule {
50    fn fix_file_coverage(&self, source: &SourceCode, file_cov: &mut FileCoverage) {
51        for entry in PerLineIterator::new(&source.content, file_cov) {
52            if entry.line_cov.is_none() && entry.branch_covs.is_empty() {
53                continue;
54            }
55
56            if self.reg.is_match(entry.line) {
57                if let Some(line_cov) = entry.line_cov {
58                    line_cov.count = None;
59                }
60
61                entry.branch_covs.iter_mut().for_each(|v| v.taken = None);
62            }
63        }
64    }
65}
66
67pub struct TestRule;
68
69impl TestRule {
70    pub fn new() -> Self {
71        Self
72    }
73}
74
75impl Rule for TestRule {
76    fn fix_file_coverage(&self, source: &SourceCode, file_cov: &mut FileCoverage) {
77        let mut inner = TestRuleInner { file_cov };
78        inner.visit_file(&source.ast);
79    }
80}
81
82struct TestRuleInner<'a> {
83    file_cov: &'a mut FileCoverage,
84}
85
86impl<'a> TestRuleInner<'a> {
87    fn ignore_range(&mut self, start: usize, end: usize) {
88        for line_cov in self
89            .file_cov
90            .line_coverages
91            .iter_mut()
92            .skip_while(|e| e.line_number < start)
93            .take_while(|e| e.line_number <= end)
94        {
95            line_cov.count = None;
96        }
97
98        for branch_cov in self
99            .file_cov
100            .branch_coverages
101            .iter_mut()
102            .skip_while(|e| e.line_number < start)
103            .take_while(|e| e.line_number <= end)
104        {
105            branch_cov.taken = None;
106        }
107    }
108}
109
110impl<'ast, 'a> Visit<'ast> for TestRuleInner<'a> {
111    fn visit_item_fn(&mut self, item: &'ast ItemFn) {
112        let start = match item.attrs.get(0) {
113            Some(attr) => attr.pound_token.spans[0].start().line,
114            None => return,
115        };
116        let end = item.block.brace_token.span.end().line;
117
118        for attr in item.attrs.iter() {
119            if attr.path.segments.len() == 1 && attr.path.segments[0].ident == "test" {
120                self.ignore_range(start, end);
121                return;
122            }
123        }
124
125        syn::visit::visit_item_fn(self, item);
126    }
127
128    fn visit_item_mod(&mut self, item: &'ast ItemMod) {
129        let span = match item.content {
130            Some((ref brace, _)) => brace.span,
131            None => return,
132        };
133
134        for attr in item.attrs.iter() {
135            if attr.path.segments.len() == 1 && attr.path.segments[0].ident == "cfg" {
136                for token in attr.tokens.clone() {
137                    if let TokenTree::Group(g) = token {
138                        let token = match g.stream().into_iter().next() {
139                            Some(t) => t,
140                            None => continue,
141                        };
142
143                        if let TokenTree::Ident(ident) = token {
144                            if ident == "test" {
145                                self.ignore_range(span.start().line, span.end().line);
146                                return;
147                            }
148                        }
149                    }
150                }
151            }
152        }
153
154        syn::visit::visit_item_mod(self, item);
155    }
156}
157
158pub struct LoopRule;
159
160impl LoopRule {
161    pub fn new() -> Self {
162        Self
163    }
164}
165
166impl Rule for LoopRule {
167    fn fix_file_coverage(&self, source: &SourceCode, file_cov: &mut FileCoverage) {
168        let mut inner = LoopRuleInner {
169            it: PerLineIterator::new(&source.content, file_cov),
170            current_line: 0,
171        };
172        inner.visit_file(&source.ast);
173    }
174}
175
176struct LoopRuleInner<'a, 'b> {
177    it: PerLineIterator<'a, 'b>,
178    current_line: usize,
179}
180
181impl<'ast, 'a, 'b> Visit<'ast> for LoopRuleInner<'a, 'b> {
182    fn visit_expr_for_loop(&mut self, expr: &'ast ExprForLoop) {
183        let line = expr.for_token.span.start().line;
184
185        if let Some(entry) = self.it.nth(line - self.current_line - 1) {
186            self.current_line = line;
187
188            let should_be_fixed = entry
189                .line_cov
190                .map_or(false, |v| v.count.map_or(false, |c| c > 0));
191
192            if should_be_fixed {
193                for branch_cov in entry.branch_covs {
194                    if branch_cov.taken == Some(false) {
195                        branch_cov.taken = None;
196                        break;
197                    }
198                }
199            }
200        }
201
202        syn::visit::visit_expr_for_loop(self, expr);
203    }
204}
205
206pub struct DeriveRule;
207
208impl DeriveRule {
209    pub fn new() -> Self {
210        Self
211    }
212}
213
214impl Rule for DeriveRule {
215    fn fix_file_coverage(&self, source: &SourceCode, file_cov: &mut FileCoverage) {
216        let mut inner = DeriveLoopInner { file_cov };
217        inner.visit_file(&source.ast);
218    }
219}
220
221struct DeriveLoopInner<'a> {
222    file_cov: &'a mut FileCoverage,
223}
224
225impl<'a> DeriveLoopInner<'a> {
226    fn ignore_range(&mut self, start: usize, end: usize) {
227        for line_cov in self
228            .file_cov
229            .line_coverages
230            .iter_mut()
231            .skip_while(|e| e.line_number < start)
232            .take_while(|e| e.line_number <= end)
233        {
234            line_cov.count = None;
235        }
236
237        for branch_cov in self
238            .file_cov
239            .branch_coverages
240            .iter_mut()
241            .skip_while(|e| e.line_number < start)
242            .take_while(|e| e.line_number <= end)
243        {
244            branch_cov.taken = None;
245        }
246    }
247}
248
249impl<'ast, 'a> Visit<'ast> for DeriveLoopInner<'a> {
250    fn visit_item_struct(&mut self, item: &'ast ItemStruct) {
251        let start = match item.attrs.get(0) {
252            Some(attr) => attr.pound_token.spans[0].start().line,
253            None => return,
254        };
255        let end = match item.fields {
256            Fields::Named(ref f) => f.brace_token.span.end().line,
257            Fields::Unnamed(ref f) => f.paren_token.span.end().line,
258            Fields::Unit => item.ident.span().end().line,
259        };
260
261        for attr in item.attrs.iter() {
262            if attr.path.segments.len() == 1 && attr.path.segments[0].ident == "derive" {
263                self.ignore_range(start, end);
264                return;
265            }
266        }
267    }
268
269    fn visit_item_enum(&mut self, item: &'ast ItemEnum) {
270        let start = match item.attrs.get(0) {
271            Some(attr) => attr.pound_token.spans[0].start().line,
272            None => return,
273        };
274        let end = item.brace_token.span.end().line;
275
276        for attr in item.attrs.iter() {
277            if attr.path.segments.len() == 1 && attr.path.segments[0].ident == "derive" {
278                self.ignore_range(start, end);
279                break;
280            }
281        }
282    }
283
284    fn visit_item_union(&mut self, item: &'ast ItemUnion) {
285        let start = match item.attrs.get(0) {
286            Some(attr) => attr.pound_token.spans[0].start().line,
287            None => return,
288        };
289        let end = item.fields.brace_token.span.end().line;
290
291        for attr in item.attrs.iter() {
292            if attr.path.segments.len() == 1 && attr.path.segments[0].ident == "derive" {
293                self.ignore_range(start, end);
294                break;
295            }
296        }
297    }
298}
299
300pub struct UnreachableRule;
301
302impl UnreachableRule {
303    pub fn new() -> Self {
304        Self
305    }
306}
307
308impl Rule for UnreachableRule {
309    fn fix_file_coverage(&self, source: &SourceCode, file_cov: &mut FileCoverage) {
310        let mut inner = UnreachableRuleInner { file_cov };
311        inner.visit_file(&source.ast);
312    }
313}
314
315struct UnreachableRuleInner<'a> {
316    file_cov: &'a mut FileCoverage,
317}
318
319impl<'a> UnreachableRuleInner<'a> {
320    fn ignore_range(&mut self, start: usize, end: usize) {
321        for line_cov in self
322            .file_cov
323            .line_coverages
324            .iter_mut()
325            .skip_while(|e| e.line_number < start)
326            .take_while(|e| e.line_number <= end)
327        {
328            line_cov.count = None;
329        }
330
331        for branch_cov in self
332            .file_cov
333            .branch_coverages
334            .iter_mut()
335            .skip_while(|e| e.line_number < start)
336            .take_while(|e| e.line_number <= end)
337        {
338            branch_cov.taken = None;
339        }
340    }
341}
342
343impl<'ast, 'a> Visit<'ast> for UnreachableRuleInner<'a> {
344    fn visit_expr_macro(&mut self, expr: &'ast ExprMacro) {
345        if let Some(ident) = expr.mac.path.get_ident() {
346            if ident == "unreachable" {
347                let start = ident.span().start().line;
348                let end = match expr.mac.delimiter {
349                    MacroDelimiter::Paren(ref p) => p.span.end().line,
350                    MacroDelimiter::Brace(ref b) => b.span.end().line,
351                    MacroDelimiter::Bracket(ref b) => b.span.end().line,
352                };
353                self.ignore_range(start, end);
354                return;
355            }
356        }
357
358        syn::visit::visit_expr_macro(self, expr);
359    }
360}
361
362pub struct CommentRule;
363
364impl CommentRule {
365    fn new() -> Self {
366        Self
367    }
368}
369
370impl Rule for CommentRule {
371    fn fix_file_coverage(&self, source: &SourceCode, file_cov: &mut FileCoverage) {
372        fn ignore_line(entry: &mut CoverageEntry) {
373            if let Some(&mut ref mut line_cov) = entry.line_cov {
374                line_cov.count = None;
375            }
376        }
377
378        fn ignore_branch(entry: &mut CoverageEntry) {
379            entry.branch_covs.iter_mut().for_each(|v| v.taken = None);
380        }
381
382        fn ignore_both(entry: &mut CoverageEntry) {
383            ignore_line(entry);
384            ignore_branch(entry);
385        }
386
387        let mut inside_ignore_line = false;
388        let mut inside_ignore_branch = false;
389        let mut inside_ignore_both = false;
390
391        for mut entry in PerLineIterator::new(&source.content, file_cov) {
392            use CommentMarker::*;
393
394            let marker = extract_marker(entry.line);
395
396            if inside_ignore_line {
397                ignore_line(&mut entry);
398
399                if marker == Some(EndIgnoreLine) {
400                    inside_ignore_line = false;
401                }
402
403                continue;
404            }
405
406            if inside_ignore_branch {
407                ignore_branch(&mut entry);
408
409                if marker == Some(EndIgnoreBranch) {
410                    inside_ignore_branch = false;
411                }
412
413                continue;
414            }
415
416            if inside_ignore_both {
417                ignore_both(&mut entry);
418
419                if marker == Some(EndIgnoreBoth) {
420                    inside_ignore_both = false;
421                }
422
423                continue;
424            }
425
426            match marker {
427                Some(IgnoreLine) => ignore_line(&mut entry),
428                Some(IgnoreBranch) => ignore_branch(&mut entry),
429                Some(IgnoreBoth) => ignore_both(&mut entry),
430                Some(BeginIgnoreLine) => {
431                    ignore_line(&mut entry);
432                    inside_ignore_line = true;
433                }
434                Some(BeginIgnoreBranch) => {
435                    ignore_branch(&mut entry);
436                    inside_ignore_branch = true;
437                }
438                Some(BeginIgnoreBoth) => {
439                    ignore_both(&mut entry);
440                    inside_ignore_both = true;
441                }
442                _ => {}
443            }
444        }
445    }
446}
447
448pub fn default_rules() -> Vec<Box<dyn Rule>> {
449    vec![
450        Box::new(CloseBlockRule::new()),
451        Box::new(TestRule::new()),
452        Box::new(LoopRule::new()),
453        Box::new(DeriveRule::new()),
454        Box::new(UnreachableRule::new()),
455        Box::new(CommentRule::new()),
456    ]
457}
458
459pub fn from_str(s: &str) -> Result<Box<dyn Rule>, Error> {
460    if s == "close" {
461        return Ok(Box::new(CloseBlockRule::new()));
462    }
463    if s == "test" {
464        return Ok(Box::new(TestRule::new()));
465    }
466    if s == "loop" {
467        return Ok(Box::new(LoopRule::new()));
468    }
469    if s == "derive" {
470        return Ok(Box::new(DeriveRule::new()));
471    }
472    if s == "unreachable" {
473        return Ok(Box::new(UnreachableRule::new()));
474    }
475    if s == "comment" {
476        return Ok(Box::new(CommentRule::new()));
477    }
478
479    Err(ErrorKind::InvalidRuleName(s.to_owned()).into())
480}
481
482// ---------- Utilities ----------
483
484struct CoverageEntry<'a, 'b> {
485    line: &'a str,
486    line_cov: Option<&'b mut LineCoverage>,
487    branch_covs: &'b mut [BranchCoverage],
488}
489
490/// A coverage iterator over the lines of a source files.
491#[derive(Clone)]
492struct PerLineIterator<'a, 'b> {
493    line_number: usize,
494    lines: Vec<&'a str>,
495    lp: *mut LineCoverage,
496    lp_end: *mut LineCoverage,
497    bp: *mut BranchCoverage,
498    bp_end: *mut BranchCoverage,
499    _borrow: PhantomData<&'b FileCoverage>,
500}
501
502impl<'a, 'b> PerLineIterator<'a, 'b> {
503    fn new(source: &'a str, file_cov: &'b mut FileCoverage) -> PerLineIterator<'a, 'b> {
504        let lp = file_cov.line_coverages.as_mut_ptr();
505        let bp = file_cov.branch_coverages.as_mut_ptr();
506        let lp_end = unsafe { lp.add(file_cov.line_coverages.len()) };
507        let bp_end = unsafe { bp.add(file_cov.branch_coverages.len()) };
508
509        Self {
510            line_number: 1,
511            lines: source.lines().collect(),
512            lp,
513            bp,
514            lp_end,
515            bp_end,
516            _borrow: PhantomData,
517        }
518    }
519}
520
521impl<'a, 'b> Iterator for PerLineIterator<'a, 'b> {
522    type Item = CoverageEntry<'a, 'b>;
523
524    fn next(&mut self) -> Option<Self::Item> {
525        if self.line_number > self.lines.len() {
526            return None;
527        }
528
529        unsafe {
530            let line = self.lines.get_unchecked_mut(self.line_number - 1);
531
532            // line coverage at current line
533            let line_cov = if self.lp < self.lp_end && (*self.lp).line_number == self.line_number {
534                let val = Some(&mut *self.lp);
535                self.lp = self.lp.add(1);
536                val
537            } else {
538                None
539            };
540
541            // branch coverages at current line
542            let branch_covs = if self.bp < self.bp_end && (*self.bp).line_number == self.line_number
543            {
544                let start = self.bp;
545                self.bp = self.bp.add(1);
546                let mut count = 1;
547                while self.bp < self.bp_end && (*self.bp).line_number == self.line_number {
548                    self.bp = self.bp.add(1);
549                    count += 1;
550                }
551                ::std::slice::from_raw_parts_mut(start, count)
552            } else {
553                &mut []
554            };
555
556            self.line_number += 1;
557
558            Some(CoverageEntry {
559                line,
560                line_cov,
561                branch_covs,
562            })
563        }
564    }
565}
566
567#[derive(Debug, PartialEq)]
568enum CommentMarker {
569    IgnoreLine,
570    IgnoreBranch,
571    IgnoreBoth,
572    BeginIgnoreLine,
573    BeginIgnoreBranch,
574    BeginIgnoreBoth,
575    EndIgnoreLine,
576    EndIgnoreBranch,
577    EndIgnoreBoth,
578}
579
580fn extract_marker(line: &str) -> Option<CommentMarker> {
581    fn is_character(byte: u8) -> bool {
582        (0x41 <= byte && byte <= 0x5a)
583            || (0x61 <= byte && byte <= 0x7a)
584            || byte == b'_'
585            || byte == b'-'
586    }
587
588    let bytes = line.as_bytes();
589    let imax = bytes.len().saturating_sub(9);
590
591    for i in 0..imax {
592        unsafe {
593            if !bytes.get_unchecked(i..).starts_with(b"cov:") {
594                continue;
595            }
596
597            if i != 0 && !b" \t".contains(bytes.get_unchecked(i - 1)) {
598                continue;
599            }
600
601            let mut pos = i + 4;
602            while pos < bytes.len() && b" \t".contains(bytes.get_unchecked(pos)) {
603                pos += 1;
604            }
605
606            let mut end_pos = pos + 1;
607            while end_pos < bytes.len() && is_character(*bytes.get_unchecked(end_pos)) {
608                end_pos += 1;
609            }
610
611            let key = std::str::from_utf8_unchecked(bytes.get_unchecked(pos..end_pos));
612
613            return parse_marker(key);
614        }
615    }
616
617    None
618}
619
620fn parse_marker(key: &str) -> Option<CommentMarker> {
621    use CommentMarker::*;
622
623    let mut splits = key.split(|v| v == '-' || v == '_');
624    let mut segments = [""; 3];
625
626    segments[0] = splits.next().unwrap_or("");
627    segments[1] = splits.next().unwrap_or("");
628    segments[2] = splits.next().unwrap_or("");
629
630    match segments {
631        ["ignore", "line", ""] => Some(IgnoreLine),
632        ["ignore", "branch", ""] => Some(IgnoreBranch),
633        ["ignore", "", ""] => Some(IgnoreBoth),
634        ["begin", "ignore", "line"] => Some(BeginIgnoreLine),
635        ["begin", "ignore", "branch"] => Some(BeginIgnoreBranch),
636        ["begin", "ignore", ""] => Some(BeginIgnoreBoth),
637        ["end", "ignore", "line"] => Some(EndIgnoreLine),
638        ["end", "ignore", "branch"] => Some(EndIgnoreBranch),
639        ["end", "ignore", ""] => Some(EndIgnoreBoth),
640        _ => {
641            warnln!("Warning: Invalid marker detected: {:?}", key);
642            None
643        }
644    }
645}
646
647// cov:begin-ignore
648macro_rules! impl_default {
649    ($name:ident) => {
650        impl Default for $name {
651            fn default() -> Self {
652                Self::new()
653            }
654        }
655    };
656}
657
658impl_default!(CloseBlockRule);
659impl_default!(TestRule);
660impl_default!(LoopRule);
661impl_default!(DeriveRule);
662impl_default!(CommentRule);
663// cov:end-ignore
664
665#[cfg(test)]
666mod tests {
667    #[test]
668    fn extract_marker() {
669        use super::CommentMarker::*;
670        assert_eq!(super::extract_marker("ccov:ignore"), None);
671        assert_eq!(super::extract_marker("cov:ignore-linee"), None);
672        assert_eq!(super::extract_marker("cov:ignore--branch"), None);
673        assert_eq!(super::extract_marker("cov:ignore"), Some(IgnoreBoth));
674        assert_eq!(super::extract_marker("cov:ignore-line-begin"), None);
675        assert_eq!(super::extract_marker("cov:ignore-end"), None);
676        assert_eq!(
677            super::extract_marker("//\tcov:ignore-line"),
678            Some(IgnoreLine)
679        );
680        assert_eq!(
681            super::extract_marker("// cov:ignore_branch"),
682            Some(IgnoreBranch)
683        );
684        assert_eq!(
685            super::extract_marker("cov:ignore-branch"),
686            Some(IgnoreBranch)
687        );
688        assert_eq!(
689            super::extract_marker("cov:begin-ignore"),
690            Some(BeginIgnoreBoth)
691        );
692        assert_eq!(
693            super::extract_marker("cov: begin-ignore-line"),
694            Some(BeginIgnoreLine)
695        );
696        assert_eq!(
697            super::extract_marker("cov: \tbegin-ignore-branch"),
698            Some(BeginIgnoreBranch)
699        );
700        assert_eq!(
701            super::extract_marker("cov:\t end_ignore-line"),
702            Some(EndIgnoreLine)
703        );
704        assert_eq!(
705            super::extract_marker("cov:end-ignore_branch\t"),
706            Some(EndIgnoreBranch)
707        );
708        assert_eq!(
709            super::extract_marker("cov:end-ignore "),
710            Some(EndIgnoreBoth)
711        );
712    }
713
714    #[test]
715    fn from_str() {
716        assert!(super::from_str("close").is_ok());
717        assert!(super::from_str("test").is_ok());
718        assert!(super::from_str("loop").is_ok());
719        assert!(super::from_str("derive").is_ok());
720        assert!(super::from_str("unreachable").is_ok());
721        assert!(super::from_str("comment").is_ok());
722        assert!(super::from_str("").is_err());
723        assert!(super::from_str("derives").is_err());
724        assert!(super::from_str("forloop").is_err());
725    }
726}