1use crate::MirFunction;
2use std::collections::HashMap;
10
11#[derive(Debug, Clone, PartialEq)]
13pub enum CaptureMode {
14 ByValue,
16 ByRef,
18 ByMutRef,
20}
21
22#[derive(Debug, Clone, PartialEq)]
24pub enum TaintState {
25 Clean,
26 Tainted {
27 source_type: String,
28 source_location: String,
29 },
30 Sanitized {
31 sanitizer: String,
32 },
33}
34
35#[derive(Debug, Clone)]
37pub struct CapturedVariable {
38 pub field_index: usize,
40
41 pub parent_var: String,
43
44 pub capture_mode: CaptureMode,
46
47 pub taint_state: TaintState,
49}
50
51#[derive(Debug, Clone)]
53pub struct ClosureInfo {
54 pub name: String,
56
57 pub parent_function: String,
59
60 pub closure_index: usize,
62
63 pub captured_vars: Vec<CapturedVariable>,
65
66 pub source_location: Option<String>,
68}
69
70impl ClosureInfo {
71 pub fn new(name: String, parent: String, index: usize) -> Self {
73 ClosureInfo {
74 name,
75 parent_function: parent,
76 closure_index: index,
77 captured_vars: Vec::new(),
78 source_location: None,
79 }
80 }
81
82 pub fn has_tainted_captures(&self) -> bool {
84 self.captured_vars
85 .iter()
86 .any(|cap| matches!(cap.taint_state, TaintState::Tainted { .. }))
87 }
88}
89
90pub struct ClosureRegistry {
92 closures: HashMap<String, ClosureInfo>,
94
95 parent_to_closures: HashMap<String, Vec<String>>,
97
98 closure_bindings: HashMap<(String, String), String>,
101}
102
103impl ClosureRegistry {
104 pub fn new() -> Self {
106 ClosureRegistry {
107 closures: HashMap::new(),
108 parent_to_closures: HashMap::new(),
109 closure_bindings: HashMap::new(),
110 }
111 }
112
113 pub fn register_closure(&mut self, info: ClosureInfo) {
115 let name = info.name.clone();
116 let parent = info.parent_function.clone();
117
118 self.closures.insert(name.clone(), info);
120
121 self.parent_to_closures
123 .entry(parent)
124 .or_insert_with(Vec::new)
125 .push(name);
126 }
127
128 pub fn get_closure(&self, name: &str) -> Option<&ClosureInfo> {
130 self.closures.get(name)
131 }
132
133 pub fn get_closures_for_parent(&self, parent: &str) -> Vec<&ClosureInfo> {
135 if let Some(closure_names) = self.parent_to_closures.get(parent) {
136 closure_names
137 .iter()
138 .filter_map(|name| self.closures.get(name))
139 .collect()
140 } else {
141 Vec::new()
142 }
143 }
144
145 pub fn bind_closure(&mut self, parent: String, var: String, closure_name: String) {
147 self.closure_bindings.insert((parent, var), closure_name);
148 }
149
150 pub fn get_closure_binding(&self, parent: &str, var: &str) -> Option<&String> {
152 self.closure_bindings
153 .get(&(parent.to_string(), var.to_string()))
154 }
155
156 pub fn get_all_parents(&self) -> Vec<String> {
158 self.parent_to_closures.keys().cloned().collect()
159 }
160
161 pub fn get_all_closures(&self) -> Vec<&ClosureInfo> {
163 self.closures.values().collect()
164 }
165}
166
167impl Default for ClosureRegistry {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173pub struct ClosureRegistryBuilder {
175 registry: ClosureRegistry,
176}
177
178impl ClosureRegistryBuilder {
179 pub fn new() -> Self {
181 ClosureRegistryBuilder {
182 registry: ClosureRegistry::new(),
183 }
184 }
185
186 pub fn build(functions: &[MirFunction]) -> ClosureRegistry {
188 let mut builder = Self::new();
189
190 for function in functions {
192 if let Some((parent, index)) = parse_closure_name(&function.name) {
193 let info = ClosureInfo::new(function.name.clone(), parent.clone(), index);
194 builder.registry.register_closure(info);
195 }
196 }
197
198 for function in functions {
200 builder.process_function(function);
201 }
202
203 for function in functions {
205 builder.analyze_taint_for_function(function);
206 }
207
208 builder.registry
209 }
210
211 pub fn build_from_package(package: &crate::MirPackage) -> ClosureRegistry {
213 Self::build(&package.functions)
214 }
215
216 fn analyze_taint_for_function(&mut self, function: &MirFunction) {
218 let mut taint_map: std::collections::HashMap<String, TaintState> =
220 std::collections::HashMap::new();
221 let mut var_aliases: std::collections::HashMap<String, String> =
222 std::collections::HashMap::new();
223
224 for line in &function.body {
225 let trimmed = line.trim();
226
227 if let Some(eq_pos) = trimmed.find(" = ") {
229 let lhs = trimmed[..eq_pos].trim();
230 let rhs = trimmed[eq_pos + 3..].trim().trim_end_matches(';');
231
232 if rhs.contains("args()") || rhs.contains("env::args") || rhs.contains("env::var") {
234 taint_map.insert(
235 lhs.to_string(),
236 TaintState::Tainted {
237 source_type: "environment".to_string(),
238 source_location: rhs.to_string(),
239 },
240 );
241 }
242 else if rhs.contains("(") && rhs.contains("move ") {
244 let mut tainted_in_args = false;
246 for word in rhs.split_whitespace() {
247 if word.starts_with('_') {
248 let var = word.trim_end_matches(|c: char| !c.is_numeric() && c != '_');
249 if taint_map.contains_key(var) {
250 tainted_in_args = true;
251 break;
252 }
253 }
254 }
255
256 if tainted_in_args {
257 taint_map.insert(
259 lhs.to_string(),
260 TaintState::Tainted {
261 source_type: "propagated".to_string(),
262 source_location: "function_call".to_string(),
263 },
264 );
265 }
266 }
267 else if rhs.starts_with("&")
269 || rhs.starts_with("copy ")
270 || rhs.starts_with("move ")
271 {
272 let source_var = if rhs.starts_with("&mut ") {
274 rhs[5..].trim()
275 } else if rhs.starts_with("&") {
276 rhs[1..].trim()
277 } else if rhs.starts_with("copy ") {
278 rhs[5..].trim()
279 } else if rhs.starts_with("move ") {
280 rhs[5..].trim()
281 } else {
282 rhs
283 };
284
285 let source_var = source_var
287 .split(|c: char| !c.is_numeric() && c != '_')
288 .next()
289 .unwrap_or(source_var);
290
291 var_aliases.insert(lhs.to_string(), source_var.to_string());
293
294 if let Some(taint) = taint_map.get(source_var) {
296 taint_map.insert(lhs.to_string(), taint.clone());
297 }
298 }
299 }
300 }
301
302 let mut changed = true;
304 while changed {
305 changed = false;
306 for (var, alias) in &var_aliases {
307 if taint_map.contains_key(var) {
308 continue;
309 }
310 if let Some(taint) = taint_map.get(alias) {
311 taint_map.insert(var.clone(), taint.clone());
312 changed = true;
313 }
314 }
315 }
316
317 let closures_for_this_function = self.registry.get_closures_for_parent(&function.name);
319 let closure_names: Vec<String> = closures_for_this_function
320 .iter()
321 .map(|c| c.name.clone())
322 .collect();
323
324 for closure_name in closure_names {
325 if let Some(info) = self.registry.closures.get_mut(&closure_name) {
326 for capture in &mut info.captured_vars {
327 let mut resolved_var = capture.parent_var.clone();
329
330 for _ in 0..10 {
332 if let Some(alias) = var_aliases.get(&resolved_var) {
334 resolved_var = alias.clone();
335 } else {
336 break;
337 }
338 }
339
340 if let Some(taint) = taint_map.get(&resolved_var) {
342 capture.taint_state = taint.clone();
343 }
344 }
345 }
346 }
347 }
348
349 fn process_function(&mut self, function: &MirFunction) {
351 if function.name.contains("execute_async") {
352 }
354 for line in &function.body {
355 if function.name.contains("execute_async") {
356 }
358 if let Some((closure_var, location, captures)) = parse_closure_creation(line) {
360 if function.name.contains("execute_async") {
361 }
363 let closure_name = self.find_closure_for_parent(&function.name, &location);
367
368 if let Some(closure_name) = closure_name {
369 self.registry.bind_closure(
371 function.name.clone(),
372 closure_var.clone(),
373 closure_name.clone(),
374 );
375
376 if let Some(info) = self.registry.closures.get_mut(&closure_name) {
378 info.source_location = Some(location.clone());
380
381 for (field_index, (_capture_name, capture_value)) in
383 captures.iter().enumerate()
384 {
385 let capture_mode = if capture_value.starts_with("move ") {
387 CaptureMode::ByValue
388 } else if capture_value.starts_with("&mut ") {
389 CaptureMode::ByMutRef
390 } else if capture_value.starts_with('&') {
391 CaptureMode::ByRef
392 } else {
393 CaptureMode::ByValue };
395
396 let parent_var = Self::extract_var_from_capture(capture_value);
398
399 let captured = CapturedVariable {
401 field_index,
402 parent_var: parent_var.clone(),
403 capture_mode,
404 taint_state: TaintState::Clean, };
406
407 info.captured_vars.push(captured);
408 }
409 }
410 }
411 }
412 }
413 }
414
415 fn find_closure_for_parent(&self, parent: &str, _location: &str) -> Option<String> {
418 let closures_for_parent: Vec<_> = self
420 .registry
421 .closures
422 .values()
423 .filter(|info| info.parent_function == parent)
424 .collect();
425
426 if closures_for_parent.is_empty() {
427 return None;
428 }
429
430 if closures_for_parent.len() == 1 {
432 return Some(closures_for_parent[0].name.clone());
433 }
434
435 for info in &closures_for_parent {
438 if info.source_location.is_none() {
439 return Some(info.name.clone());
440 }
441 }
442
443 Some(closures_for_parent[0].name.clone())
445 }
446
447 #[allow(dead_code)]
449 fn find_closure_by_location(&self, location: &str) -> Option<String> {
450 for (name, info) in &self.registry.closures {
457 if let Some(ref loc) = info.source_location {
458 if loc == location {
459 return Some(name.clone());
460 }
461 }
462 if let Some(_start) = location.rfind(':') {
465 if let Some(_line_start) = location[.._start].rfind(':') {
466 }
469 }
470 }
471
472 None
475 }
476
477 fn extract_var_from_capture(capture_value: &str) -> String {
482 let trimmed = capture_value.trim();
483
484 if trimmed.starts_with("move ") {
485 trimmed[5..].trim().to_string()
486 } else if trimmed.starts_with("&mut ") {
487 trimmed[5..].trim().to_string()
488 } else if trimmed.starts_with('&') {
489 trimmed[1..].trim().to_string()
490 } else {
491 trimmed
493 .split_whitespace()
494 .last()
495 .unwrap_or(trimmed)
496 .to_string()
497 }
498 }
499}
500
501impl Default for ClosureRegistryBuilder {
502 fn default() -> Self {
503 Self::new()
504 }
505}
506
507pub fn is_closure_function(name: &str) -> bool {
509 name.contains("::{closure#")
510}
511
512pub fn parse_closure_name(name: &str) -> Option<(String, usize)> {
522 if let Some(pos) = name.find("::{closure#") {
523 let parent = name[..pos].to_string();
524 let rest = &name[pos + 11..]; if let Some(end) = rest.find('}') {
528 if let Ok(index) = rest[..end].parse::<usize>() {
529 return Some((parent, index));
530 }
531 }
532 }
533 None
534}
535
536pub fn parse_closure_creation(statement: &str) -> Option<(String, String, Vec<(String, String)>)> {
541 if !statement.contains("{closure@") && !statement.contains("{coroutine@") {
544 return None;
545 }
546
547 let lhs = if let Some(eq_pos) = statement.find(" = ") {
549 statement[..eq_pos].trim().to_string()
550 } else {
551 return None;
552 };
553
554 let location = if let Some(start) = statement.find("{closure@") {
556 if let Some(end) = statement[start..].find('}') {
557 statement[start..start + end + 1].to_string()
558 } else {
559 return None;
560 }
561 } else if let Some(start) = statement.find("{coroutine@") {
562 if let Some(end) = statement[start..].find('}') {
563 statement[start..start + end + 1].to_string()
564 } else {
565 return None;
566 }
567 } else {
568 return None;
569 };
570
571 let mut captures = Vec::new();
573
574 if let Some(capture_start) = statement.rfind(" { ") {
576 if let Some(capture_end) = statement[capture_start..].rfind('}') {
577 let capture_str = &statement[capture_start + 3..capture_start + capture_end];
578
579 for capture in capture_str.split(',') {
581 let capture = capture.trim();
582 if let Some(colon_pos) = capture.find(": ") {
583 let var_name = capture[..colon_pos].trim().to_string();
584 let value = capture[colon_pos + 2..].trim().to_string();
585 captures.push((var_name, value));
586 }
587 }
588 }
589 }
590
591 Some((lhs, location, captures))
592}
593
594pub fn is_closure_call(statement: &str) -> bool {
599 statement.contains(" as Fn<") && statement.contains(">::call(")
600 || statement.contains(" as FnMut<") && statement.contains(">::call_mut(")
601 || statement.contains(" as FnOnce<") && statement.contains(">::call_once(")
602}
603
604pub fn parse_closure_call(statement: &str) -> Option<(String, String)> {
606 if !is_closure_call(statement) {
607 return None;
608 }
609
610 let result_var = if let Some(eq_pos) = statement.find(" = ") {
612 statement[..eq_pos].trim().to_string()
613 } else {
614 return None;
615 };
616
617 if let Some(call_start) = statement
619 .find("::call(")
620 .or_else(|| statement.find("::call_mut("))
621 .or_else(|| statement.find("::call_once("))
622 {
623 if let Some(paren_end) = statement[call_start..].find(')') {
624 let args = &statement[call_start + 7..call_start + paren_end];
625
626 if let Some(comma_pos) = args.find(',') {
628 let closure_arg = args[..comma_pos].trim();
629 let closure_var = if closure_arg.starts_with("move ") {
631 closure_arg[5..].trim().to_string()
632 } else {
633 closure_arg.to_string()
634 };
635 return Some((result_var, closure_var));
636 }
637 }
638 }
639
640 None
641}
642
643pub fn parse_env_field_access(statement: &str) -> Option<(String, usize)> {
648 if !statement.contains("(*_") {
650 return None;
651 }
652
653 let lhs = if let Some(eq_pos) = statement.find(" = ") {
655 statement[..eq_pos].trim().to_string()
656 } else {
657 return None;
658 };
659
660 if let Some(start) = statement.find("((*_") {
662 if let Some(dot_start) = statement[start..].find(").") {
664 let after_dot = &statement[start + dot_start + 2..];
665
666 let field_str = after_dot
668 .chars()
669 .take_while(|c| c.is_numeric())
670 .collect::<String>();
671
672 if let Ok(field_index) = field_str.parse::<usize>() {
673 return Some((lhs, field_index));
674 }
675 }
676 }
677
678 None
679}
680
681#[cfg(test)]
682mod tests {
683 use super::*;
684
685 #[test]
686 fn test_is_closure_function() {
687 assert!(is_closure_function("test_func::{closure#0}"));
688 assert!(is_closure_function("module::test::{closure#5}"));
689 assert!(!is_closure_function("regular_function"));
690 }
691
692 #[test]
693 fn test_parse_closure_name() {
694 let (parent, index) = parse_closure_name("test_func::{closure#0}").unwrap();
695 assert_eq!(parent, "test_func");
696 assert_eq!(index, 0);
697
698 let (parent, index) = parse_closure_name("module::nested::func::{closure#3}").unwrap();
699 assert_eq!(parent, "module::nested::func");
700 assert_eq!(index, 3);
701
702 assert!(parse_closure_name("not_a_closure").is_none());
703 }
704
705 #[test]
706 fn test_parse_closure_creation() {
707 let stmt = "_5 = {closure@examples/interprocedural/src/lib.rs:278:19: 278:21} { tainted: move _6 };";
708 let (lhs, location, captures) = parse_closure_creation(stmt).unwrap();
709
710 assert_eq!(lhs, "_5");
711 assert!(location.starts_with("{closure@"));
712 assert_eq!(captures.len(), 1);
713 assert_eq!(captures[0].0, "tainted");
714 assert_eq!(captures[0].1, "move _6");
715 }
716
717 #[test]
718 fn test_is_closure_call() {
719 assert!(is_closure_call(
720 "<{closure@...} as Fn<()>>::call(move _8, const ())"
721 ));
722 assert!(is_closure_call(
723 "<{closure@...} as FnMut<()>>::call_mut(move _8, const ())"
724 ));
725 assert!(is_closure_call(
726 "<{closure@...} as FnOnce<()>>::call_once(move _8, const ())"
727 ));
728 assert!(!is_closure_call("regular_function_call()"));
729 }
730
731 #[test]
732 fn test_parse_closure_call() {
733 let stmt = "_7 = <{closure@...} as Fn<()>>::call(move _8, const ()) -> [return: bb5, unwind: bb7];";
734 let (result, closure_var) = parse_closure_call(stmt).unwrap();
735
736 assert_eq!(result, "_7");
737 assert_eq!(closure_var, "_8");
738 }
739
740 #[test]
741 fn test_parse_env_field_access() {
742 let stmt = "_7 = deref_copy ((*_1).0: &std::string::String);";
743 let (lhs, field) = parse_env_field_access(stmt).unwrap();
744
745 assert_eq!(lhs, "_7");
746 assert_eq!(field, 0);
747 }
748}