Skip to main content

mir_extractor/rules/
utils.rs

1//! Shared utilities for rule implementations.
2//!
3//! This module contains helper functions and types used across multiple rules,
4//! particularly for source code analysis that needs to handle string literals correctly.
5
6use walkdir::DirEntry;
7
8/// Filter function for WalkDir to skip common non-source directories.
9///
10/// Excludes: target, .git, .cola-cache, out, node_modules
11///
12/// # Example
13/// ```ignore
14/// for entry in WalkDir::new(crate_root)
15///     .into_iter()
16///     .filter_entry(|e| filter_entry(e))
17/// {
18///     // ...
19/// }
20/// ```
21pub fn filter_entry(entry: &DirEntry) -> bool {
22    if entry.depth() == 0 {
23        return true;
24    }
25
26    let name = entry.file_name().to_string_lossy();
27    if entry.file_type().is_dir()
28        && matches!(
29            name.as_ref(),
30            "target" | ".git" | ".cola-cache" | "out" | "node_modules"
31        )
32    {
33        return false;
34    }
35    true
36}
37
38/// State machine for tracking string literal boundaries across lines.
39///
40/// Used by `strip_string_literals` to correctly handle multi-line strings
41/// and avoid false positives from pattern matching inside string content.
42#[derive(Clone, Copy, Default)]
43pub struct StringLiteralState {
44    /// Currently inside a regular `"..."` string
45    pub in_normal_string: bool,
46    /// Currently inside a raw string `r#"..."#` with this many hashes
47    pub raw_hashes: Option<usize>,
48}
49
50const STRIP_STRING_INITIAL_CAPACITY: usize = 256;
51
52/// Replaces string literal content with spaces while preserving line length.
53///
54/// This function is essential for source-level rules that need to search for
55/// patterns without matching inside string literals. It handles:
56/// - Regular strings: `"..."`
57/// - Raw strings: `r#"..."#` with any number of hashes
58/// - Character literals: `'x'`
59/// - Lifetimes: `'a` (preserved, not stripped)
60///
61/// # Arguments
62/// * `state` - Current parsing state from previous line
63/// * `line` - The source line to process
64///
65/// # Returns
66/// A tuple of (sanitized line with string content replaced by spaces, new state)
67///
68/// # Example
69/// ```ignore
70/// let (sanitized, state) = strip_string_literals(StringLiteralState::default(), r#"let x = "hello world";"#);
71/// assert!(sanitized.contains("let x ="));
72/// assert!(!sanitized.contains("hello"));
73/// ```
74pub fn strip_string_literals(
75    mut state: StringLiteralState,
76    line: &str,
77) -> (String, StringLiteralState) {
78    let bytes = line.as_bytes();
79    let mut result = String::with_capacity(STRIP_STRING_INITIAL_CAPACITY);
80    let mut i = 0usize;
81
82    while i < bytes.len() {
83        // Handle raw string content
84        if let Some(hashes) = state.raw_hashes {
85            result.push(' ');
86            if bytes[i] == b'"' {
87                let mut matched = true;
88                for k in 0..hashes {
89                    if i + 1 + k >= bytes.len() || bytes[i + 1 + k] != b'#' {
90                        matched = false;
91                        break;
92                    }
93                }
94                if matched {
95                    for _ in 0..hashes {
96                        result.push(' ');
97                    }
98                    state.raw_hashes = None;
99                    i += 1 + hashes;
100                    continue;
101                }
102            }
103            i += 1;
104            continue;
105        }
106
107        // Handle regular string content
108        if state.in_normal_string {
109            result.push(' ');
110            if bytes[i] == b'\\' {
111                i += 1;
112                if i < bytes.len() {
113                    result.push(' ');
114                    i += 1;
115                    continue;
116                } else {
117                    break;
118                }
119            }
120            if bytes[i] == b'"' {
121                state.in_normal_string = false;
122            }
123            i += 1;
124            continue;
125        }
126
127        let ch = bytes[i];
128
129        // Start of regular string
130        if ch == b'"' {
131            state.in_normal_string = true;
132            result.push(' ');
133            i += 1;
134            continue;
135        }
136
137        // Check for raw string start
138        if ch == b'r' {
139            let mut j = i + 1;
140            let mut hashes = 0usize;
141            while j < bytes.len() && bytes[j] == b'#' {
142                hashes += 1;
143                j += 1;
144            }
145            if j < bytes.len() && bytes[j] == b'"' {
146                state.raw_hashes = Some(hashes);
147                result.push(' ');
148                for _ in 0..hashes {
149                    result.push(' ');
150                }
151                result.push(' ');
152                i = j + 1;
153                continue;
154            }
155        }
156
157        // Handle character literals vs lifetimes
158        if ch == b'\'' {
159            // Check if this looks like a lifetime ('a, 'static, etc.)
160            if i + 1 < bytes.len() {
161                let next = bytes[i + 1];
162                let looks_like_lifetime = next == b'_' || next.is_ascii_alphabetic();
163                let following = bytes.get(i + 2).copied();
164                if looks_like_lifetime && following != Some(b'\'') {
165                    // It's a lifetime, preserve it
166                    result.push('\'');
167                    i += 1;
168                    continue;
169                }
170            }
171
172            // It's a character literal, find the closing quote
173            let mut j = i + 1;
174            let mut escaped = false;
175            let mut found_closing = false;
176
177            while j < bytes.len() {
178                let current = bytes[j];
179                if escaped {
180                    escaped = false;
181                } else if current == b'\\' {
182                    escaped = true;
183                } else if current == b'\'' {
184                    found_closing = true;
185                    break;
186                }
187
188                j += 1;
189            }
190
191            if found_closing {
192                // Replace entire character literal with spaces
193                result.push(' ');
194                i += 1;
195                while i <= j {
196                    result.push(' ');
197                    i += 1;
198                }
199                continue;
200            } else {
201                // Unclosed, treat as regular quote
202                result.push('\'');
203                i += 1;
204                continue;
205            }
206        }
207
208        result.push(ch as char);
209        i += 1;
210    }
211
212    (result, state)
213}
214
215/// Strip comments from a line of code.
216///
217/// Handles both single-line (`//`) and block (`/* */`) comments, tracking
218/// multi-line block comment state across calls.
219///
220/// # Arguments
221/// * `line` - The source line to process
222/// * `in_block_comment` - Mutable state tracking if we're inside a block comment
223///
224/// # Returns
225/// The line with comment content removed
226pub fn strip_comments(line: &str, in_block_comment: &mut bool) -> String {
227    let mut result = String::with_capacity(line.len());
228    let bytes = line.as_bytes();
229    let mut idx = 0usize;
230
231    while idx < bytes.len() {
232        if *in_block_comment {
233            if bytes[idx] == b'*' && idx + 1 < bytes.len() && bytes[idx + 1] == b'/' {
234                *in_block_comment = false;
235                idx += 2;
236            } else {
237                idx += 1;
238            }
239            continue;
240        }
241
242        if bytes[idx] == b'/' && idx + 1 < bytes.len() {
243            match bytes[idx + 1] {
244                b'/' => break,
245                b'*' => {
246                    *in_block_comment = true;
247                    idx += 2;
248                    continue;
249                }
250                _ => {}
251            }
252        }
253
254        result.push(bytes[idx] as char);
255        idx += 1;
256    }
257    result
258}
259
260/// Collect lines that match any of the given patterns after sanitizing string literals.
261///
262/// This is useful for rules that need to find code patterns while ignoring
263/// matches that occur inside string literals.
264#[allow(dead_code)]
265pub fn collect_sanitized_matches(lines: &[String], patterns: &[&str]) -> Vec<String> {
266    let mut state = StringLiteralState::default();
267
268    lines
269        .iter()
270        .filter_map(|line| {
271            let (sanitized, next_state) = strip_string_literals(state, line);
272            state = next_state;
273
274            if patterns.iter().any(|needle| sanitized.contains(needle)) {
275                Some(line.trim().to_string())
276            } else {
277                None
278            }
279        })
280        .collect()
281}
282
283/// Check if a function should be skipped for command injection rules.
284///
285/// This excludes mir-extractor's own build infrastructure functions that
286/// legitimately spawn external processes (rustc, cargo, etc.).
287pub fn command_rule_should_skip(
288    function: &crate::MirFunction,
289    package: &crate::MirPackage,
290) -> bool {
291    if package.crate_name == "mir-extractor" {
292        matches!(
293            function.name.as_str(),
294            "detect_rustc_version"
295                | "run_cargo_rustc"
296                | "discover_rustc_targets"
297                | "detect_crate_name"
298        )
299    } else {
300        false
301    }
302}
303
304/// Log sink patterns in MIR (desugarings of print/log macros).
305/// Used by both CleartextLoggingRule and LogInjectionRule.
306pub const LOG_SINK_PATTERNS: &[&str] = &[
307    "_print(",        // println!, print! desugaring
308    "eprint",         // eprintln!, eprint!
309    "::fmt(",         // format! and Debug/Display impl calls
310    "Arguments::new", // format_args! macro
311    "panic_fmt",      // panic! with formatting
312    "begin_panic",    // older panic desugaring
313    "::log(",         // log crate macros
314    "::info(",
315    "::warn(",
316    "::error(",
317    "::debug(",
318    "::trace(",
319];
320
321/// Input source patterns for untrusted data origins.
322/// Used by multiple injection rules.
323pub const INPUT_SOURCE_PATTERNS: &[&str] = &[
324    "= var::<",        // env::var::<T> - generic call (MIR format)
325    "= var(",          // env::var - standard call
326    "var_os(",         // env::var_os
327    "::args(",         // env::args
328    "args_os(",        // env::args_os
329    "::nth(",          // iterator nth (often on args)
330    "read_line(",      // stdin
331    "read_to_string(", // file/stdin reads
332];
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    #[test]
339    fn test_strip_string_literals_basic() {
340        let (sanitized, _) =
341            strip_string_literals(StringLiteralState::default(), r#"let x = "hello";"#);
342        assert!(sanitized.contains("let x ="));
343        assert!(!sanitized.contains("hello"));
344        assert!(sanitized.ends_with(";"));
345    }
346
347    #[test]
348    fn test_strip_string_literals_preserves_lifetimes() {
349        let input = "fn foo<'a>(x: &'a str) -> &'a str";
350        let (sanitized, _) = strip_string_literals(StringLiteralState::default(), input);
351        assert!(sanitized.contains("'a"));
352        assert_eq!(sanitized.matches("'a").count(), 3);
353    }
354
355    #[test]
356    fn test_strip_string_literals_raw_string() {
357        let (sanitized, _) = strip_string_literals(
358            StringLiteralState::default(),
359            r##"let x = r#"raw string"#;"##,
360        );
361        assert!(sanitized.contains("let x ="));
362        assert!(!sanitized.contains("raw string"));
363    }
364
365    #[test]
366    fn test_strip_string_literals_multiline_state() {
367        let (_, state1) = strip_string_literals(StringLiteralState::default(), r#"let x = "start"#);
368        assert!(state1.in_normal_string);
369
370        let (sanitized, state2) = strip_string_literals(state1, r#"end of string";"#);
371        assert!(!state2.in_normal_string);
372        assert!(!sanitized.contains("end of string"));
373    }
374
375    #[test]
376    fn test_strip_char_literal() {
377        let (sanitized, _) = strip_string_literals(StringLiteralState::default(), "let c = 'x';");
378        assert!(sanitized.contains("let c ="));
379        assert!(!sanitized.contains("'x'"));
380    }
381
382    #[test]
383    fn test_strip_comments_line_comment() {
384        let mut in_block = false;
385        let result = strip_comments("let x = 1; // comment", &mut in_block);
386        assert_eq!(result, "let x = 1; ");
387        assert!(!in_block);
388    }
389
390    #[test]
391    fn test_strip_comments_block_comment() {
392        let mut in_block = false;
393        let result = strip_comments("let x = /* inline */ 2;", &mut in_block);
394        assert_eq!(result, "let x =  2;");
395        assert!(!in_block);
396    }
397
398    #[test]
399    fn test_strip_comments_multiline_block() {
400        let mut in_block = false;
401        let line1 = strip_comments("let x = /* start", &mut in_block);
402        assert_eq!(line1, "let x = ");
403        assert!(in_block);
404
405        let line2 = strip_comments("still in comment */ done;", &mut in_block);
406        assert_eq!(line2, " done;");
407        assert!(!in_block);
408    }
409}