Skip to main content

mir_extractor/
dataflow.rs

1use std::collections::HashSet;
2
3use crate::MirFunction;
4
5pub mod cfg;
6pub mod closure;
7pub mod field;
8pub mod path_sensitive;
9pub mod taint;
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum TaintPropagation {
13    ParamToReturn(usize),
14    ParamToParam { from: usize, to: usize },
15    ParamToSink { param: usize, sink_type: String },
16    ParamSanitized(usize),
17}
18
19#[derive(Debug, Clone)]
20pub struct DataflowSummary {
21    pub name: String,
22    pub propagation: Vec<TaintPropagation>,
23    pub returns_tainted: bool,
24}
25
26#[derive(Debug, Clone)]
27pub struct Assignment {
28    pub target: String,
29    pub sources: Vec<String>,
30    pub rhs: String,
31    pub line: String,
32}
33
34pub struct MirDataflow {
35    assignments: Vec<Assignment>,
36}
37
38impl MirDataflow {
39    pub fn new(function: &MirFunction) -> Self {
40        let assignments = function
41            .body
42            .iter()
43            .flat_map(|line| parse_assignment_line(line))
44            .collect();
45        Self { assignments }
46    }
47
48    pub fn assignments(&self) -> &[Assignment] {
49        &self.assignments
50    }
51
52    pub fn taint_from<F>(&self, mut predicate: F) -> HashSet<String>
53    where
54        F: FnMut(&Assignment) -> bool,
55    {
56        let mut tainted: HashSet<String> = HashSet::new();
57
58        for assignment in &self.assignments {
59            if predicate(assignment) {
60                tainted.insert(assignment.target.clone());
61            }
62        }
63
64        let mut changed = true;
65        while changed {
66            changed = false;
67            for assignment in &self.assignments {
68                if tainted.contains(&assignment.target) {
69                    continue;
70                }
71
72                if assignment
73                    .sources
74                    .iter()
75                    .any(|source| tainted.contains(source))
76                {
77                    tainted.insert(assignment.target.clone());
78                    changed = true;
79                }
80            }
81        }
82
83        tainted
84    }
85}
86
87fn parse_assignment_line(line: &str) -> Vec<Assignment> {
88    let trimmed = line.trim();
89    if trimmed.is_empty() {
90        return Vec::new();
91    }
92
93    let mut parts = trimmed.splitn(2, '=');
94    let lhs = match parts.next() {
95        Some(value) => value.trim(),
96        None => return Vec::new(),
97    };
98    let rhs = match parts.next() {
99        Some(value) => value.trim(),
100        None => return Vec::new(),
101    };
102
103    let rhs = rhs.trim_end_matches(';').trim();
104    if rhs.is_empty() {
105        return Vec::new();
106    }
107
108    let mut targets = extract_variables(lhs);
109    if targets.is_empty() {
110        return Vec::new();
111    }
112
113    targets.sort();
114    targets.dedup();
115
116    let sources = extract_variables(rhs);
117
118    targets
119        .into_iter()
120        .map(|target| Assignment {
121            target,
122            sources: sources.clone(),
123            rhs: rhs.to_string(),
124            line: trimmed.to_string(),
125        })
126        .collect()
127}
128
129pub(crate) fn extract_variables(input: &str) -> Vec<String> {
130    let mut vars = Vec::new();
131    let mut chars = input.char_indices().peekable();
132
133    while let Some((idx, ch)) = chars.next() {
134        if ch == '_' {
135            let mut end = idx + ch.len_utf8();
136            while let Some((next_idx, next_ch)) = chars.peek().copied() {
137                if next_ch.is_ascii_digit() {
138                    chars.next();
139                    end = next_idx + next_ch.len_utf8();
140                } else {
141                    break;
142                }
143            }
144            if end > idx + ch.len_utf8() {
145                vars.push(input[idx..end].to_string());
146            }
147        }
148    }
149
150    vars
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    fn make_function(lines: &[&str]) -> MirFunction {
158        MirFunction {
159            name: "demo".to_string(),
160            signature: "fn demo()".to_string(),
161            body: lines.iter().map(|line| line.to_string()).collect(),
162            span: None,
163            ..Default::default()
164        }
165    }
166
167    #[test]
168    fn parses_simple_assignment() {
169        let function = make_function(&["    _1 = copy _2;"]);
170        let dataflow = MirDataflow::new(&function);
171        assert_eq!(dataflow.assignments().len(), 1);
172        let assignment = &dataflow.assignments()[0];
173        assert_eq!(assignment.target, "_1");
174        assert_eq!(assignment.sources, vec!["_2".to_string()]);
175    }
176
177    #[test]
178    fn taint_propagates_transitively() {
179        let function = make_function(&[
180            "    _1 = std::http::HeaderMap::get(move _0);",
181            "    _2 = copy _1;",
182            "    _3 = Vec::<u8>::with_capacity(move _2);",
183        ]);
184
185        let dataflow = MirDataflow::new(&function);
186        let tainted = dataflow.taint_from(|assignment| assignment.rhs.contains("HeaderMap::get"));
187        assert!(tainted.contains("_1"));
188        assert!(tainted.contains("_2"));
189        assert!(tainted.contains("_3"));
190    }
191
192    #[test]
193    fn skip_non_assignments() {
194        let function = make_function(&[
195            "    assert(!const false) -> [success: bb1, unwind: bb2];",
196            "    _1 = Vec::<u8>::with_capacity(const 1024_usize);",
197        ]);
198
199        let dataflow = MirDataflow::new(&function);
200        assert_eq!(dataflow.assignments().len(), 1);
201    }
202
203    #[test]
204    fn tuple_destructuring_creates_assignments_for_each_slot() {
205        let function = make_function(&["    (_1, _2) = move _3;", "    _4 = copy _2;"]);
206
207        let dataflow = MirDataflow::new(&function);
208        assert_eq!(dataflow.assignments().len(), 3);
209
210        let tainted = dataflow.taint_from(|assignment| assignment.rhs.contains("_3"));
211        assert!(tainted.contains("_1"));
212        assert!(tainted.contains("_2"));
213        assert!(tainted.contains("_4"));
214    }
215
216    #[test]
217    fn option_projections_propagate_through_fields() {
218        let function = make_function(&[
219            "    _4 = reqwest::Response::content_length(move _1);",
220            "    (_5.0: core::option::Option<usize>) = move _4;",
221            "    _6 = move (_5.0: core::option::Option<usize>);",
222            "    _7 = Vec::<u8>::with_capacity(move _6);",
223        ]);
224
225        let dataflow = MirDataflow::new(&function);
226        let tainted = dataflow.taint_from(|assignment| assignment.rhs.contains("content_length"));
227
228        assert!(tainted.contains("_4"));
229        assert!(tainted.contains("_5"));
230        assert!(tainted.contains("_6"));
231        assert!(tainted.contains("_7"));
232    }
233}