Skip to main content

mir_extractor/dataflow/
closure.rs

1use crate::MirFunction;
2/// Closure analysis module for tracking taint through closures and higher-order functions
3///
4/// This module provides:
5/// - Detection of closure creation, invocation, and bodies in MIR
6/// - Tracking of captured variables and their taint states
7/// - Mapping between parent function variables and closure environment fields
8/// - Analysis of taint propagation through closure captures
9use std::collections::HashMap;
10
11/// Capture mode for closure variables
12#[derive(Debug, Clone, PartialEq)]
13pub enum CaptureMode {
14    /// Captured by value (move)
15    ByValue,
16    /// Captured by reference (&)
17    ByRef,
18    /// Captured by mutable reference (&mut)
19    ByMutRef,
20}
21
22/// Taint state for closure analysis
23#[derive(Debug, Clone, PartialEq)]
24pub enum TaintState {
25    Clean,
26    Tainted {
27        source_type: String,
28        source_location: String,
29    },
30    Sanitized {
31        sanitizer: String,
32    },
33}
34
35/// A variable captured by a closure
36#[derive(Debug, Clone)]
37pub struct CapturedVariable {
38    /// Field index in closure environment (.0, .1, .2, etc.)
39    pub field_index: usize,
40
41    /// Original variable name in parent function
42    pub parent_var: String,
43
44    /// How the variable is captured
45    pub capture_mode: CaptureMode,
46
47    /// Taint state of the captured variable
48    pub taint_state: TaintState,
49}
50
51/// Information about a closure definition
52#[derive(Debug, Clone)]
53pub struct ClosureInfo {
54    /// Closure name (e.g., "test_func::{closure#0}")
55    pub name: String,
56
57    /// Parent function name (e.g., "test_func")
58    pub parent_function: String,
59
60    /// Closure number (e.g., 0 for {closure#0})
61    pub closure_index: usize,
62
63    /// Variables captured by this closure
64    pub captured_vars: Vec<CapturedVariable>,
65
66    /// Location in source code
67    pub source_location: Option<String>,
68}
69
70impl ClosureInfo {
71    /// Create a new closure info
72    pub fn new(name: String, parent: String, index: usize) -> Self {
73        ClosureInfo {
74            name,
75            parent_function: parent,
76            closure_index: index,
77            captured_vars: Vec::new(),
78            source_location: None,
79        }
80    }
81
82    /// Check if this closure captures any tainted variables
83    pub fn has_tainted_captures(&self) -> bool {
84        self.captured_vars
85            .iter()
86            .any(|cap| matches!(cap.taint_state, TaintState::Tainted { .. }))
87    }
88}
89
90/// Registry for tracking closures across a codebase
91pub struct ClosureRegistry {
92    /// Maps closure names to their info
93    closures: HashMap<String, ClosureInfo>,
94
95    /// Maps parent function names to their closures
96    parent_to_closures: HashMap<String, Vec<String>>,
97
98    /// Maps closure creation sites to closure names
99    /// Key: (parent_function, closure_variable)
100    closure_bindings: HashMap<(String, String), String>,
101}
102
103impl ClosureRegistry {
104    /// Create a new empty closure registry
105    pub fn new() -> Self {
106        ClosureRegistry {
107            closures: HashMap::new(),
108            parent_to_closures: HashMap::new(),
109            closure_bindings: HashMap::new(),
110        }
111    }
112
113    /// Register a closure
114    pub fn register_closure(&mut self, info: ClosureInfo) {
115        let name = info.name.clone();
116        let parent = info.parent_function.clone();
117
118        // Add to closures map
119        self.closures.insert(name.clone(), info);
120
121        // Add to parent mapping
122        self.parent_to_closures
123            .entry(parent)
124            .or_insert_with(Vec::new)
125            .push(name);
126    }
127
128    /// Get closure info by name
129    pub fn get_closure(&self, name: &str) -> Option<&ClosureInfo> {
130        self.closures.get(name)
131    }
132
133    /// Get all closures for a parent function
134    pub fn get_closures_for_parent(&self, parent: &str) -> Vec<&ClosureInfo> {
135        if let Some(closure_names) = self.parent_to_closures.get(parent) {
136            closure_names
137                .iter()
138                .filter_map(|name| self.closures.get(name))
139                .collect()
140        } else {
141            Vec::new()
142        }
143    }
144
145    /// Bind a closure variable to a closure
146    pub fn bind_closure(&mut self, parent: String, var: String, closure_name: String) {
147        self.closure_bindings.insert((parent, var), closure_name);
148    }
149
150    /// Look up which closure a variable refers to
151    pub fn get_closure_binding(&self, parent: &str, var: &str) -> Option<&String> {
152        self.closure_bindings
153            .get(&(parent.to_string(), var.to_string()))
154    }
155
156    /// Get all parent function names that have closures
157    pub fn get_all_parents(&self) -> Vec<String> {
158        self.parent_to_closures.keys().cloned().collect()
159    }
160
161    /// Get all closures in the registry
162    pub fn get_all_closures(&self) -> Vec<&ClosureInfo> {
163        self.closures.values().collect()
164    }
165}
166
167impl Default for ClosureRegistry {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173/// Builder for constructing a closure registry from a MIR package
174pub struct ClosureRegistryBuilder {
175    registry: ClosureRegistry,
176}
177
178impl ClosureRegistryBuilder {
179    /// Create a new builder
180    pub fn new() -> Self {
181        ClosureRegistryBuilder {
182            registry: ClosureRegistry::new(),
183        }
184    }
185
186    /// Build closure registry from a slice of MIR functions
187    pub fn build(functions: &[MirFunction]) -> ClosureRegistry {
188        let mut builder = Self::new();
189
190        // First pass: identify all closures and their parents
191        for function in functions {
192            if let Some((parent, index)) = parse_closure_name(&function.name) {
193                let info = ClosureInfo::new(function.name.clone(), parent.clone(), index);
194                builder.registry.register_closure(info);
195            }
196        }
197
198        // Second pass: extract captures from parent functions
199        for function in functions {
200            builder.process_function(function);
201        }
202
203        // Third pass: analyze taint in parent functions and propagate to closures
204        for function in functions {
205            builder.analyze_taint_for_function(function);
206        }
207
208        builder.registry
209    }
210
211    /// Build closure registry from a MIR package
212    pub fn build_from_package(package: &crate::MirPackage) -> ClosureRegistry {
213        Self::build(&package.functions)
214    }
215
216    /// Analyze taint in a function and propagate to its closures
217    fn analyze_taint_for_function(&mut self, function: &MirFunction) {
218        // Build a simple taint map for this function
219        let mut taint_map: std::collections::HashMap<String, TaintState> =
220            std::collections::HashMap::new();
221        let mut var_aliases: std::collections::HashMap<String, String> =
222            std::collections::HashMap::new();
223
224        for line in &function.body {
225            let trimmed = line.trim();
226
227            // Parse assignments
228            if let Some(eq_pos) = trimmed.find(" = ") {
229                let lhs = trimmed[..eq_pos].trim();
230                let rhs = trimmed[eq_pos + 3..].trim().trim_end_matches(';');
231
232                // Check if RHS is a source (env::args, args(), etc.)
233                if rhs.contains("args()") || rhs.contains("env::args") || rhs.contains("env::var") {
234                    taint_map.insert(
235                        lhs.to_string(),
236                        TaintState::Tainted {
237                            source_type: "environment".to_string(),
238                            source_location: rhs.to_string(),
239                        },
240                    );
241                }
242                // Check if RHS is a function call - propagate taint from arguments
243                else if rhs.contains("(") && rhs.contains("move ") {
244                    // Extract variables from "move _X" patterns in the RHS
245                    let mut tainted_in_args = false;
246                    for word in rhs.split_whitespace() {
247                        if word.starts_with('_') {
248                            let var = word.trim_end_matches(|c: char| !c.is_numeric() && c != '_');
249                            if taint_map.contains_key(var) {
250                                tainted_in_args = true;
251                                break;
252                            }
253                        }
254                    }
255
256                    if tainted_in_args {
257                        // Propagate taint to LHS
258                        taint_map.insert(
259                            lhs.to_string(),
260                            TaintState::Tainted {
261                                source_type: "propagated".to_string(),
262                                source_location: "function_call".to_string(),
263                            },
264                        );
265                    }
266                }
267                // Check if RHS is a reference or copy
268                else if rhs.starts_with("&")
269                    || rhs.starts_with("copy ")
270                    || rhs.starts_with("move ")
271                {
272                    // Extract the source variable
273                    let source_var = if rhs.starts_with("&mut ") {
274                        rhs[5..].trim()
275                    } else if rhs.starts_with("&") {
276                        rhs[1..].trim()
277                    } else if rhs.starts_with("copy ") {
278                        rhs[5..].trim()
279                    } else if rhs.starts_with("move ") {
280                        rhs[5..].trim()
281                    } else {
282                        rhs
283                    };
284
285                    // Extract just the variable name (e.g., "_1" from "_1;")
286                    let source_var = source_var
287                        .split(|c: char| !c.is_numeric() && c != '_')
288                        .next()
289                        .unwrap_or(source_var);
290
291                    // Create alias mapping
292                    var_aliases.insert(lhs.to_string(), source_var.to_string());
293
294                    // Propagate taint
295                    if let Some(taint) = taint_map.get(source_var) {
296                        taint_map.insert(lhs.to_string(), taint.clone());
297                    }
298                }
299            }
300        }
301
302        // Propagate taint through aliases transitively
303        let mut changed = true;
304        while changed {
305            changed = false;
306            for (var, alias) in &var_aliases {
307                if taint_map.contains_key(var) {
308                    continue;
309                }
310                if let Some(taint) = taint_map.get(alias) {
311                    taint_map.insert(var.clone(), taint.clone());
312                    changed = true;
313                }
314            }
315        }
316
317        // Update closures with taint information
318        let closures_for_this_function = self.registry.get_closures_for_parent(&function.name);
319        let closure_names: Vec<String> = closures_for_this_function
320            .iter()
321            .map(|c| c.name.clone())
322            .collect();
323
324        for closure_name in closure_names {
325            if let Some(info) = self.registry.closures.get_mut(&closure_name) {
326                for capture in &mut info.captured_vars {
327                    // Resolve the parent var through aliases if needed
328                    let mut resolved_var = capture.parent_var.clone();
329
330                    // Follow alias chain
331                    for _ in 0..10 {
332                        // Limit iterations to prevent infinite loop
333                        if let Some(alias) = var_aliases.get(&resolved_var) {
334                            resolved_var = alias.clone();
335                        } else {
336                            break;
337                        }
338                    }
339
340                    // Check if the resolved variable is tainted
341                    if let Some(taint) = taint_map.get(&resolved_var) {
342                        capture.taint_state = taint.clone();
343                    }
344                }
345            }
346        }
347    }
348
349    /// Process a single function to find closure creations
350    fn process_function(&mut self, function: &MirFunction) {
351        if function.name.contains("execute_async") {
352            // eprintln!("[DEBUG] Processing function: {}", function.name);
353        }
354        for line in &function.body {
355            if function.name.contains("execute_async") {
356                // eprintln!("[DEBUG] Line: {}", line);
357            }
358            // Look for closure creation
359            if let Some((closure_var, location, captures)) = parse_closure_creation(line) {
360                if function.name.contains("execute_async") {
361                    // eprintln!("[DEBUG] Found closure creation: var={}, loc={}, captures={:?}", closure_var, location, captures);
362                }
363                // Try to find which closure this refers to based on parent function
364                // The location string contains file:line:col, which we can use to match
365                // For now, we'll use a simpler approach: look for closures with this parent
366                let closure_name = self.find_closure_for_parent(&function.name, &location);
367
368                if let Some(closure_name) = closure_name {
369                    // Bind this variable to the closure
370                    self.registry.bind_closure(
371                        function.name.clone(),
372                        closure_var.clone(),
373                        closure_name.clone(),
374                    );
375
376                    // Process captures
377                    if let Some(info) = self.registry.closures.get_mut(&closure_name) {
378                        // Add source location
379                        info.source_location = Some(location.clone());
380
381                        // Process each captured variable
382                        for (field_index, (_capture_name, capture_value)) in
383                            captures.iter().enumerate()
384                        {
385                            // Determine capture mode
386                            let capture_mode = if capture_value.starts_with("move ") {
387                                CaptureMode::ByValue
388                            } else if capture_value.starts_with("&mut ") {
389                                CaptureMode::ByMutRef
390                            } else if capture_value.starts_with('&') {
391                                CaptureMode::ByRef
392                            } else {
393                                CaptureMode::ByValue // Default
394                            };
395
396                            // Extract the actual variable from capture_value
397                            let parent_var = Self::extract_var_from_capture(capture_value);
398
399                            // Create captured variable (taint state will be filled in later)
400                            let captured = CapturedVariable {
401                                field_index,
402                                parent_var: parent_var.clone(),
403                                capture_mode,
404                                taint_state: TaintState::Clean, // Default, will be updated
405                            };
406
407                            info.captured_vars.push(captured);
408                        }
409                    }
410                }
411            }
412        }
413    }
414
415    /// Find the closure name for a given parent function and location
416    /// When a parent function has multiple closures, we match by location or order
417    fn find_closure_for_parent(&self, parent: &str, _location: &str) -> Option<String> {
418        // Get all closures for this parent
419        let closures_for_parent: Vec<_> = self
420            .registry
421            .closures
422            .values()
423            .filter(|info| info.parent_function == parent)
424            .collect();
425
426        if closures_for_parent.is_empty() {
427            return None;
428        }
429
430        // If there's only one closure for this parent, return it
431        if closures_for_parent.len() == 1 {
432            return Some(closures_for_parent[0].name.clone());
433        }
434
435        // For multiple closures, we'd need to match by location
436        // For now, find the first one that doesn't have a source_location set yet
437        for info in &closures_for_parent {
438            if info.source_location.is_none() {
439                return Some(info.name.clone());
440            }
441        }
442
443        // If all have locations, we need to parse and match - for now return first
444        Some(closures_for_parent[0].name.clone())
445    }
446
447    /// Find closure by its source location (old implementation, keeping for reference)
448    #[allow(dead_code)]
449    fn find_closure_by_location(&self, location: &str) -> Option<String> {
450        // Location format: {closure@examples/interprocedural/src/lib.rs:278:19: 278:21}
451        // We need to match this against registered closures
452        // For now, we'll use a simple heuristic: extract parent from current analysis context
453        // and match by index if the location matches
454
455        // This is a simplified approach - in production, we'd parse the location more carefully
456        for (name, info) in &self.registry.closures {
457            if let Some(ref loc) = info.source_location {
458                if loc == location {
459                    return Some(name.clone());
460                }
461            }
462            // Also try to match if we haven't set source_location yet
463            // Extract numbers from location
464            if let Some(_start) = location.rfind(':') {
465                if let Some(_line_start) = location[.._start].rfind(':') {
466                    // This is a new closure, try to match by parent function name
467                    // which should be in the current context
468                }
469            }
470        }
471
472        // If no exact match, try to infer from the closures we know about
473        // For a more robust implementation, we could extract line numbers and match
474        None
475    }
476
477    /// Extract variable name from capture value
478    /// "move _6" -> "_6"
479    /// "&_3" -> "_3"
480    /// "&mut _4" -> "_4"
481    fn extract_var_from_capture(capture_value: &str) -> String {
482        let trimmed = capture_value.trim();
483
484        if trimmed.starts_with("move ") {
485            trimmed[5..].trim().to_string()
486        } else if trimmed.starts_with("&mut ") {
487            trimmed[5..].trim().to_string()
488        } else if trimmed.starts_with('&') {
489            trimmed[1..].trim().to_string()
490        } else {
491            // Handle "copy _X" or just "_X"
492            trimmed
493                .split_whitespace()
494                .last()
495                .unwrap_or(trimmed)
496                .to_string()
497        }
498    }
499}
500
501impl Default for ClosureRegistryBuilder {
502    fn default() -> Self {
503        Self::new()
504    }
505}
506
507/// Check if a function name represents a closure
508pub fn is_closure_function(name: &str) -> bool {
509    name.contains("::{closure#")
510}
511
512/// Parse closure name to extract parent and index
513///
514/// # Examples
515/// ```
516/// use mir_extractor::dataflow::closure::parse_closure_name;
517/// let (parent, index) = parse_closure_name("test_func::{closure#0}").unwrap();
518/// assert_eq!(parent, "test_func");
519/// assert_eq!(index, 0);
520/// ```
521pub fn parse_closure_name(name: &str) -> Option<(String, usize)> {
522    if let Some(pos) = name.find("::{closure#") {
523        let parent = name[..pos].to_string();
524        let rest = &name[pos + 11..]; // Skip "::{closure#"
525
526        // Extract number before closing '}'
527        if let Some(end) = rest.find('}') {
528            if let Ok(index) = rest[..end].parse::<usize>() {
529                return Some((parent, index));
530            }
531        }
532    }
533    None
534}
535
536/// Extract closure creation from MIR statement
537///
538/// Looks for patterns like:
539/// `_5 = {closure@examples/interprocedural/src/lib.rs:278:19: 278:21} { tainted: move _6 };`
540pub fn parse_closure_creation(statement: &str) -> Option<(String, String, Vec<(String, String)>)> {
541    // Pattern: _X = {closure@<location>} { <captures> }
542    //      OR: _X = {coroutine@<location>} { <captures> }
543    if !statement.contains("{closure@") && !statement.contains("{coroutine@") {
544        return None;
545    }
546
547    // Extract LHS variable
548    let lhs = if let Some(eq_pos) = statement.find(" = ") {
549        statement[..eq_pos].trim().to_string()
550    } else {
551        return None;
552    };
553
554    // Extract closure/coroutine location
555    let location = if let Some(start) = statement.find("{closure@") {
556        if let Some(end) = statement[start..].find('}') {
557            statement[start..start + end + 1].to_string()
558        } else {
559            return None;
560        }
561    } else if let Some(start) = statement.find("{coroutine@") {
562        if let Some(end) = statement[start..].find('}') {
563            statement[start..start + end + 1].to_string()
564        } else {
565            return None;
566        }
567    } else {
568        return None;
569    };
570
571    // Extract captures
572    let mut captures = Vec::new();
573
574    // Look for capture list after location: { var: value, ... }
575    if let Some(capture_start) = statement.rfind(" { ") {
576        if let Some(capture_end) = statement[capture_start..].rfind('}') {
577            let capture_str = &statement[capture_start + 3..capture_start + capture_end];
578
579            // Parse comma-separated captures
580            for capture in capture_str.split(',') {
581                let capture = capture.trim();
582                if let Some(colon_pos) = capture.find(": ") {
583                    let var_name = capture[..colon_pos].trim().to_string();
584                    let value = capture[colon_pos + 2..].trim().to_string();
585                    captures.push((var_name, value));
586                }
587            }
588        }
589    }
590
591    Some((lhs, location, captures))
592}
593
594/// Detect closure invocation in MIR
595///
596/// Looks for patterns like:
597/// `_7 = <{closure@...} as Fn<()>>::call(move _8, const ());`
598pub fn is_closure_call(statement: &str) -> bool {
599    statement.contains(" as Fn<") && statement.contains(">::call(")
600        || statement.contains(" as FnMut<") && statement.contains(">::call_mut(")
601        || statement.contains(" as FnOnce<") && statement.contains(">::call_once(")
602}
603
604/// Extract closure variable from invocation
605pub fn parse_closure_call(statement: &str) -> Option<(String, String)> {
606    if !is_closure_call(statement) {
607        return None;
608    }
609
610    // Extract result variable
611    let result_var = if let Some(eq_pos) = statement.find(" = ") {
612        statement[..eq_pos].trim().to_string()
613    } else {
614        return None;
615    };
616
617    // Extract closure variable from call(move _X, ...)
618    if let Some(call_start) = statement
619        .find("::call(")
620        .or_else(|| statement.find("::call_mut("))
621        .or_else(|| statement.find("::call_once("))
622    {
623        if let Some(paren_end) = statement[call_start..].find(')') {
624            let args = &statement[call_start + 7..call_start + paren_end];
625
626            // First argument is the closure
627            if let Some(comma_pos) = args.find(',') {
628                let closure_arg = args[..comma_pos].trim();
629                // Remove "move " prefix if present
630                let closure_var = if closure_arg.starts_with("move ") {
631                    closure_arg[5..].trim().to_string()
632                } else {
633                    closure_arg.to_string()
634                };
635                return Some((result_var, closure_var));
636            }
637        }
638    }
639
640    None
641}
642
643/// Detect environment field access in closure body
644///
645/// Looks for patterns like:
646/// `_7 = deref_copy ((*_1).0: &std::string::String);`
647pub fn parse_env_field_access(statement: &str) -> Option<(String, usize)> {
648    // Pattern: ((*_X).N: <type>)
649    if !statement.contains("(*_") {
650        return None;
651    }
652
653    // Extract LHS variable first
654    let lhs = if let Some(eq_pos) = statement.find(" = ") {
655        statement[..eq_pos].trim().to_string()
656    } else {
657        return None;
658    };
659
660    // Find the pattern: ((*_1).0: ...)
661    if let Some(start) = statement.find("((*_") {
662        // Find the dot after the closing parenthesis
663        if let Some(dot_start) = statement[start..].find(").") {
664            let after_dot = &statement[start + dot_start + 2..];
665
666            // Extract field number
667            let field_str = after_dot
668                .chars()
669                .take_while(|c| c.is_numeric())
670                .collect::<String>();
671
672            if let Ok(field_index) = field_str.parse::<usize>() {
673                return Some((lhs, field_index));
674            }
675        }
676    }
677
678    None
679}
680
681#[cfg(test)]
682mod tests {
683    use super::*;
684
685    #[test]
686    fn test_is_closure_function() {
687        assert!(is_closure_function("test_func::{closure#0}"));
688        assert!(is_closure_function("module::test::{closure#5}"));
689        assert!(!is_closure_function("regular_function"));
690    }
691
692    #[test]
693    fn test_parse_closure_name() {
694        let (parent, index) = parse_closure_name("test_func::{closure#0}").unwrap();
695        assert_eq!(parent, "test_func");
696        assert_eq!(index, 0);
697
698        let (parent, index) = parse_closure_name("module::nested::func::{closure#3}").unwrap();
699        assert_eq!(parent, "module::nested::func");
700        assert_eq!(index, 3);
701
702        assert!(parse_closure_name("not_a_closure").is_none());
703    }
704
705    #[test]
706    fn test_parse_closure_creation() {
707        let stmt = "_5 = {closure@examples/interprocedural/src/lib.rs:278:19: 278:21} { tainted: move _6 };";
708        let (lhs, location, captures) = parse_closure_creation(stmt).unwrap();
709
710        assert_eq!(lhs, "_5");
711        assert!(location.starts_with("{closure@"));
712        assert_eq!(captures.len(), 1);
713        assert_eq!(captures[0].0, "tainted");
714        assert_eq!(captures[0].1, "move _6");
715    }
716
717    #[test]
718    fn test_is_closure_call() {
719        assert!(is_closure_call(
720            "<{closure@...} as Fn<()>>::call(move _8, const ())"
721        ));
722        assert!(is_closure_call(
723            "<{closure@...} as FnMut<()>>::call_mut(move _8, const ())"
724        ));
725        assert!(is_closure_call(
726            "<{closure@...} as FnOnce<()>>::call_once(move _8, const ())"
727        ));
728        assert!(!is_closure_call("regular_function_call()"));
729    }
730
731    #[test]
732    fn test_parse_closure_call() {
733        let stmt = "_7 = <{closure@...} as Fn<()>>::call(move _8, const ()) -> [return: bb5, unwind: bb7];";
734        let (result, closure_var) = parse_closure_call(stmt).unwrap();
735
736        assert_eq!(result, "_7");
737        assert_eq!(closure_var, "_8");
738    }
739
740    #[test]
741    fn test_parse_env_field_access() {
742        let stmt = "_7 = deref_copy ((*_1).0: &std::string::String);";
743        let (lhs, field) = parse_env_field_access(stmt).unwrap();
744
745        assert_eq!(lhs, "_7");
746        assert_eq!(field, 0);
747    }
748}