1use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
2
3use crate::{dataflow::extract_variables, MirDataflow, MirFunction};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct PrototypeOptions {
7 pub guard_markers: Vec<String>,
8 pub unsync_markers: Vec<String>,
9 pub narrow_cast_targets: Vec<String>,
10 pub try_into_targets: Vec<String>,
11 pub serialization_sinks: Vec<String>,
12 pub length_identifiers: Vec<String>,
13}
14
15impl Default for PrototypeOptions {
16 fn default() -> Self {
17 Self {
18 guard_markers: vec![
19 "::min".to_string(),
21 ".min(".to_string(),
22 "cmp::min".to_string(),
23 "::clamp".to_string(),
24 ".clamp(".to_string(),
25 "::saturating_sub".to_string(),
26 "::checked_sub".to_string(),
27 "::min_by".to_string(),
28 "::min_by_key".to_string(),
29 "max_request_bytes".to_string(),
31 "max_request_size".to_string(),
32 "max_body_size".to_string(),
33 "max_http_request_size".to_string(),
34 "body_limit".to_string(),
35 "DefaultBodyLimit".to_string(),
36 "RequestBodyLimit".to_string(),
37 "content_length_limit".to_string(),
38 "PayloadConfig".to_string(),
39 "ContentLengthLimit".to_string(),
40 "MAX_SIZE".to_string(),
42 "MAX_LEN".to_string(),
43 "MAX_LENGTH".to_string(),
44 "SIZE_LIMIT".to_string(),
45 "max_size".to_string(),
46 "size_limit".to_string(),
47 "max_capacity".to_string(),
48 "limit_bytes".to_string(),
49 ],
50 unsync_markers: vec![
51 "::rc<".to_string(),
52 "::refcell<".to_string(),
53 "::<rc<".to_string(), "::<refcell<".to_string(),
55 "std::rc::rc<".to_string(),
56 "alloc::rc::rc<".to_string(),
57 "std::cell::refcell<".to_string(),
58 "core::cell::refcell<".to_string(),
59 "std::cell::cell<".to_string(),
60 "core::cell::cell<".to_string(),
61 ],
62 narrow_cast_targets: vec![
63 " as i32".to_string(),
64 " as u32".to_string(),
65 " as i16".to_string(),
66 " as u16".to_string(),
67 " as i8".to_string(),
68 " as u8".to_string(),
69 ],
70 try_into_targets: vec![
71 "<i32>".to_string(),
72 "<u32>".to_string(),
73 "<i16>".to_string(),
74 "<u16>".to_string(),
75 "<i8>".to_string(),
76 "<u8>".to_string(),
77 ],
78 serialization_sinks: vec![
79 "put_i32".to_string(),
80 "put_u32".to_string(),
81 "put_u64".to_string(),
82 "put_i16".to_string(),
83 "put_u16".to_string(),
84 "put_u8".to_string(),
85 "write_i32".to_string(),
86 "write_u32".to_string(),
87 "write_u64".to_string(),
88 "write_i16".to_string(),
89 "write_u16".to_string(),
90 "write_u8".to_string(),
91 "unwrap(".to_string(),
92 "::unwrap(".to_string(),
93 "expect(".to_string(),
94 "::expect(".to_string(),
95 ],
96 length_identifiers: vec![
97 "len".to_string(),
98 "length".to_string(),
99 "payload".to_string(),
100 "size".to_string(),
101 ],
102 }
103 }
104}
105
106#[derive(Debug, Clone, PartialEq, Eq)]
107pub struct ContentLengthAllocation {
108 pub allocation_line: String,
109 pub capacity_var: String,
110 pub tainted_vars: HashSet<String>,
111}
112
113#[derive(Debug, Clone, PartialEq, Eq)]
114pub struct LengthTruncationCast {
115 pub cast_line: String,
116 pub target_var: String,
117 pub source_vars: Vec<String>,
118 pub sink_lines: Vec<String>,
119}
120
121#[derive(Debug, Clone, PartialEq, Eq)]
122pub struct BroadcastUnsyncUsage {
123 pub line: String,
124}
125
126#[derive(Debug, Clone, PartialEq, Eq)]
127pub struct CommandInvocation {
128 pub command_line: String,
129 pub tainted_args: Vec<String>,
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
133pub struct OpensslVerifyNoneInvocation {
134 pub call_line: String,
135 pub supporting_lines: Vec<String>,
136}
137
138pub fn detect_content_length_allocations(function: &MirFunction) -> Vec<ContentLengthAllocation> {
139 detect_content_length_allocations_with_options(function, &PrototypeOptions::default())
140}
141
142pub fn detect_content_length_allocations_with_options(
143 function: &MirFunction,
144 options: &PrototypeOptions,
145) -> Vec<ContentLengthAllocation> {
146 let dataflow = MirDataflow::new(function);
147 let tainted = dataflow.taint_from(|assignment| rhs_mentions_content_length(&assignment.rhs));
148
149 if tainted.is_empty() {
150 return Vec::new();
151 }
152
153 let mut findings = Vec::new();
154
155 for line in &function.body {
156 if let Some(capacity_var) = extract_capacity_variable(line) {
157 if tainted.contains(&capacity_var)
158 && !is_guarded_capacity(function, &dataflow, &capacity_var, options)
159 {
160 findings.push(ContentLengthAllocation {
161 allocation_line: line.trim().to_string(),
162 capacity_var,
163 tainted_vars: tainted.clone(),
164 });
165 }
166 }
167 }
168
169 findings
170}
171
172pub fn detect_unbounded_allocations(function: &MirFunction) -> Vec<ContentLengthAllocation> {
173 detect_unbounded_allocations_with_options(function, &PrototypeOptions::default())
174}
175
176pub fn detect_unbounded_allocations_with_options(
177 function: &MirFunction,
178 options: &PrototypeOptions,
179) -> Vec<ContentLengthAllocation> {
180 let dataflow = MirDataflow::new(function);
181 let seeds = collect_length_seed_vars(function, &dataflow, options);
182 if seeds.is_empty() {
183 return Vec::new();
184 }
185
186 let tainted = propagate_length_seeds(&dataflow, seeds);
187 if tainted.is_empty() {
188 return Vec::new();
189 }
190
191 let mut findings = Vec::new();
192
193 for line in &function.body {
194 if let Some(capacity_var) = extract_capacity_variable(line) {
195 if tainted.contains(&capacity_var)
196 && !is_guarded_capacity(function, &dataflow, &capacity_var, options)
197 {
198 findings.push(ContentLengthAllocation {
199 allocation_line: line.trim().to_string(),
200 capacity_var,
201 tainted_vars: tainted.clone(),
202 });
203 }
204 }
205 }
206
207 findings
208}
209
210fn extract_capacity_variable(line: &str) -> Option<String> {
211 let lowered = line.to_lowercase();
212 let is_reserve_method = lowered.contains("reserve_exact") || lowered.contains("reserve");
213 let keyword = if lowered.contains("with_capacity") {
214 "with_capacity"
215 } else if lowered.contains("reserve_exact") {
216 "reserve_exact"
217 } else if lowered.contains("reserve") {
218 "reserve"
219 } else {
220 return None;
221 };
222
223 let start = line.find(keyword)? + keyword.len();
224 let remainder = line[start..].trim_start();
225 if !remainder.starts_with('(') {
226 return None;
227 }
228
229 let closing = remainder.find(')')?;
230 let inside = &remainder[1..closing];
231 let vars = extract_variables(inside);
232
233 if is_reserve_method {
236 vars.into_iter().nth(1)
237 } else {
238 vars.into_iter().next()
239 }
240}
241
242fn rhs_mentions_content_length(rhs: &str) -> bool {
243 let lower = rhs.to_lowercase();
244 if lower.contains("content_length")
245 || lower.contains("\"content-length\"")
246 || lower.contains("header::content-length")
247 || lower.contains("headername::from_static(\"content-length\"")
248 || lower.contains("headervalue::from_static(\"content-length\"")
249 || lower.contains("from_bytes(b\"content-length\")")
250 {
251 return true;
252 }
253
254 rhs.contains("CONTENT_LENGTH")
255}
256
257fn is_guarded_capacity(
258 function: &MirFunction,
259 dataflow: &MirDataflow,
260 capacity_var: &str,
261 options: &PrototypeOptions,
262) -> bool {
263 let mut queue = vec![capacity_var.to_string()];
264 let mut visited = HashSet::new();
265
266 while let Some(var) = queue.pop() {
267 if !visited.insert(var.clone()) {
268 continue;
269 }
270
271 if assert_mentions_var(function, &var) {
273 return true;
274 }
275
276 if has_comparison_guard(function, &var) {
278 return true;
279 }
280
281 if used_in_guard_check(function, &var, options) {
283 return true;
284 }
285
286 for assignment in dataflow.assignments() {
288 if assignment.target != var {
289 continue;
290 }
291
292 if rhs_contains_upper_bound_guard(&assignment.rhs, options) {
293 return true;
294 }
295
296 for source in &assignment.sources {
297 queue.push(source.clone());
298 }
299 }
300 }
301
302 false
303}
304
305fn has_comparison_guard(function: &MirFunction, var: &str) -> bool {
308 for (i, line) in function.body.iter().enumerate() {
309 if (line.contains("= Le(")
311 || line.contains("= Lt(")
312 || line.contains("= Ge(")
313 || line.contains("= Gt("))
314 && line.contains(var)
315 {
316 if i + 1 < function.body.len() {
318 let next_line = &function.body[i + 1];
319 if next_line.contains("switchInt") {
320 return true;
321 }
322 }
323 }
324 }
325 false
326}
327
328fn used_in_guard_check(function: &MirFunction, var: &str, options: &PrototypeOptions) -> bool {
331 for (i, line) in function.body.iter().enumerate() {
332 let lowered = line.to_lowercase();
334 let has_guard = options
335 .guard_markers
336 .iter()
337 .any(|marker| lowered.contains(marker));
338
339 if has_guard && line.contains(var) {
340 if let Some(eq_pos) = line.find('=') {
342 if let Some(result_var) = line[..eq_pos].trim().split_whitespace().last() {
343 for j in (i + 1)..function.body.len().min(i + 10) {
345 let future_line = &function.body[j];
346 if future_line.contains("discriminant") && future_line.contains(result_var)
347 {
348 return true;
349 }
350 if future_line.contains("switchInt") && future_line.contains(result_var) {
352 return true;
353 }
354 }
355 }
356 }
357 }
358 }
359 false
360}
361
362fn rhs_contains_upper_bound_guard(rhs: &str, options: &PrototypeOptions) -> bool {
363 let lowered = rhs.to_lowercase();
364 options
365 .guard_markers
366 .iter()
367 .any(|pattern| lowered.contains(pattern))
368}
369
370fn assert_mentions_var(function: &MirFunction, var: &str) -> bool {
371 function.body.iter().any(|line| {
372 if !line.contains("assert") || !line.contains(var) {
373 return false;
374 }
375 let lowered = line.to_lowercase();
376 lowered.contains(" <= ")
377 || lowered.contains(" < ")
378 || lowered.contains(" >= ")
379 || lowered.contains(" > ")
380 })
381}
382
383pub fn detect_truncating_len_casts(function: &MirFunction) -> Vec<LengthTruncationCast> {
384 detect_truncating_len_casts_with_options(function, &PrototypeOptions::default())
385}
386
387pub fn detect_truncating_len_casts_with_options(
388 function: &MirFunction,
389 options: &PrototypeOptions,
390) -> Vec<LengthTruncationCast> {
391 let dataflow = MirDataflow::new(function);
392 let seeds = collect_length_seed_vars(function, &dataflow, options);
393 let tainted = propagate_length_seeds(&dataflow, seeds);
394
395 if tainted.is_empty() {
396 return Vec::new();
397 }
398
399 let mut findings = Vec::new();
400
401 for assignment in dataflow.assignments() {
402 let rhs_lower = assignment.rhs.to_lowercase();
403 if !is_narrow_cast(&rhs_lower, options) && !is_try_into_narrow(&rhs_lower, options) {
404 continue;
405 }
406
407 if assignment
408 .sources
409 .iter()
410 .any(|source| tainted.contains(source))
411 {
412 let sink_lines = collect_sink_lines(function, &dataflow, &assignment.target, options);
413 findings.push(LengthTruncationCast {
414 cast_line: assignment.line.clone(),
415 target_var: assignment.target.clone(),
416 source_vars: assignment.sources.clone(),
417 sink_lines,
418 });
419 }
420 }
421
422 findings
423}
424
425pub fn detect_broadcast_unsync_payloads(function: &MirFunction) -> Vec<BroadcastUnsyncUsage> {
426 detect_broadcast_unsync_payloads_with_options(function, &PrototypeOptions::default())
427}
428
429pub fn detect_broadcast_unsync_payloads_with_options(
430 function: &MirFunction,
431 options: &PrototypeOptions,
432) -> Vec<BroadcastUnsyncUsage> {
433 let dataflow = MirDataflow::new(function);
434 let mut seed_lines = Vec::new();
435
436 let unsync_vars = dataflow.taint_from(|assignment| {
437 if is_broadcast_constructor(&assignment.rhs)
438 && payload_looks_unsync(&assignment.rhs, options)
439 {
440 seed_lines.push(assignment.line.trim().to_string());
441 return true;
442 }
443
444 false
445 });
446
447 if unsync_vars.is_empty() {
448 return seed_lines
449 .into_iter()
450 .map(|line| BroadcastUnsyncUsage { line })
451 .collect();
452 }
453
454 let mut lines: BTreeSet<String> = seed_lines.into_iter().collect();
455
456 for raw_line in &function.body {
457 let trimmed = raw_line.trim();
458 if trimmed.is_empty() {
459 continue;
460 }
461
462 let lower = trimmed.to_lowercase();
463 let references_unsync_var = unsync_vars.iter().any(|var| trimmed.contains(var));
464
465 if payload_looks_unsync(trimmed, options) && lower.contains("tokio::sync::broadcast") {
466 lines.insert(trimmed.to_string());
467 continue;
468 }
469
470 if references_unsync_var && line_mentions_broadcast_usage(&lower) {
471 lines.insert(trimmed.to_string());
472 }
473 }
474
475 lines
476 .into_iter()
477 .map(|line| BroadcastUnsyncUsage { line })
478 .collect()
479}
480
481pub fn detect_command_invocations(function: &MirFunction) -> Vec<CommandInvocation> {
482 let dataflow = MirDataflow::new(function);
483 let taint_sources = dataflow.taint_from(|assignment| {
484 let lowered = assignment.rhs.to_lowercase();
485 lowered.contains("env::var")
486 || lowered.contains("env::args")
487 || lowered.contains("env::var_os")
488 || lowered.contains("env::args_os")
489 || lowered.contains("std::env::args")
490 || lowered.contains("std::env::vars")
491 });
492
493 let mut findings = Vec::new();
494
495 fn pattern_outside_quotes(text: &str, idx: usize) -> bool {
496 let bytes = text.as_bytes();
497 let mut in_quotes = false;
498 let mut escaped = false;
499
500 for (pos, byte) in bytes.iter().enumerate() {
501 if pos >= idx {
502 break;
503 }
504
505 if escaped {
506 escaped = false;
507 continue;
508 }
509
510 match byte {
511 b'\\' => escaped = true,
512 b'"' => in_quotes = !in_quotes,
513 _ => {}
514 }
515 }
516
517 !in_quotes
518 }
519
520 for assignment in dataflow.assignments() {
521 let lowered = assignment.rhs.to_lowercase();
522 let first_paren = lowered.find('(').unwrap_or(lowered.len());
523
524 let is_process_command = [
525 "::std::process::command::new",
526 "std::process::command::new",
527 "::tokio::process::command::new",
528 "tokio::process::command::new",
529 "::async_process::command::new",
530 "async_process::command::new",
531 ]
532 .into_iter()
533 .any(|pattern| {
534 lowered.find(pattern).map_or(false, |idx| {
535 idx < first_paren && pattern_outside_quotes(&assignment.rhs, idx)
536 })
537 });
538
539 if !is_process_command {
540 continue;
541 }
542
543 let mut tainted_args = Vec::new();
544 let mut visited = HashSet::new();
545 let mut queue = VecDeque::new();
546 queue.push_back(assignment.target.clone());
547
548 while let Some(current) = queue.pop_front() {
549 if !visited.insert(current.clone()) {
550 continue;
551 }
552
553 let mut arg_sites = Vec::new();
554
555 for other in dataflow.assignments() {
556 if other.target == current {
557 for src in &other.sources {
558 queue.push_back(src.clone());
559 }
560 }
561
562 let rhs_lower = other.rhs.to_lowercase();
563 if other.sources.iter().any(|src| src == ¤t)
564 && (rhs_lower.contains("command::arg(")
565 || rhs_lower.contains("command::args(")
566 || rhs_lower.contains("command::env(")
567 || rhs_lower.contains("command::arg_os(")
568 || rhs_lower.contains("command::args_os("))
569 {
570 arg_sites.push(other.clone());
571 }
572 }
573
574 for site in arg_sites {
575 let mut queue_inputs = Vec::new();
576 for src in &site.sources {
577 queue_inputs.push(src.clone());
578 if taint_sources.contains(src) {
579 tainted_args.push(src.clone());
580 }
581 }
582 for src in queue_inputs {
583 queue.push_back(src);
584 }
585 }
586 }
587
588 tainted_args.sort();
589 tainted_args.dedup();
590
591 findings.push(CommandInvocation {
592 command_line: assignment.line.trim().to_string(),
593 tainted_args,
594 });
595 }
596
597 findings
598}
599
600pub fn detect_openssl_verify_none(function: &MirFunction) -> Vec<OpensslVerifyNoneInvocation> {
601 let dataflow = MirDataflow::new(function);
602 let mut var_to_lines: HashMap<String, Vec<String>> = HashMap::new();
603
604 for assignment in dataflow.assignments() {
605 var_to_lines
606 .entry(assignment.target.clone())
607 .or_default()
608 .push(assignment.line.trim().to_string());
609 }
610
611 let tainted_modes =
612 dataflow.taint_from(|assignment| rhs_disables_verification(&assignment.rhs));
613 let mut findings: HashMap<String, Vec<String>> = HashMap::new();
614
615 for assignment in dataflow.assignments() {
616 if !is_verify_configuration_call(&assignment.rhs) {
617 continue;
618 }
619
620 let mut supporting = Vec::new();
621 let mut disables = rhs_disables_verification(&assignment.rhs);
622
623 for source in &assignment.sources {
624 if tainted_modes.contains(source) {
625 disables = true;
626 if let Some(lines) = var_to_lines.get(source) {
627 for line in lines {
628 if !supporting.contains(line) {
629 supporting.push(line.clone());
630 }
631 }
632 }
633 }
634 }
635
636 if !disables {
637 continue;
638 }
639
640 let entry = findings
641 .entry(assignment.line.trim().to_string())
642 .or_insert_with(Vec::new);
643 for line in supporting {
644 if !entry.contains(&line) {
645 entry.push(line);
646 }
647 }
648 }
649
650 for raw_line in &function.body {
651 let trimmed = raw_line.trim().to_string();
652 if !is_verify_configuration_call(&trimmed) {
653 continue;
654 }
655 if !rhs_disables_verification(&trimmed) {
656 continue;
657 }
658 findings.entry(trimmed).or_insert_with(Vec::new);
659 }
660
661 let mut result: Vec<_> = findings
662 .into_iter()
663 .map(
664 |(call_line, supporting_lines)| OpensslVerifyNoneInvocation {
665 call_line,
666 supporting_lines,
667 },
668 )
669 .collect();
670
671 result.sort_by(|a, b| a.call_line.cmp(&b.call_line));
672 result
673}
674
675fn is_verify_configuration_call(text: &str) -> bool {
676 let lowered = text.to_lowercase();
677 lowered.contains("set_verify(") || lowered.contains("set_verify_callback(")
678}
679
680fn rhs_disables_verification(rhs: &str) -> bool {
681 let rhs_lower = rhs.to_lowercase();
682
683 rhs_lower.contains("sslverifymode::none")
684 || rhs_lower.contains("ssl_verify_none")
685 || rhs_lower.contains("verify_none")
686 || rhs_lower.contains("sslverifymode::empty(")
687 || rhs_lower.contains("verify_mode::empty(")
688 || rhs_lower.contains("sslverifymode::from_bits_truncate(0")
689 || rhs_lower.contains("sslverifymode::from_bits(0")
690 || rhs_lower.contains("sslverifymode::from_bits_truncate(const 0")
691 || rhs_lower.contains("sslverifymode::from_bits(const 0")
692 || rhs_lower.contains("sslverifymode::bits(0")
693}
694
695fn collect_length_seed_vars(
696 function: &MirFunction,
697 dataflow: &MirDataflow,
698 options: &PrototypeOptions,
699) -> HashSet<String> {
700 let mut seeds = HashSet::new();
701
702 for line in &function.body {
703 let trimmed = line.trim();
704 if let Some(rest) = trimmed.strip_prefix("debug ") {
705 if let Some((name_part, var_part)) = rest.split_once("=>") {
706 let name = name_part.trim();
707 let var = var_part.trim().trim_end_matches(';');
708 if is_length_identifier(name, options) && var.starts_with('_') {
709 seeds.insert(var.to_string());
710 }
711 }
712 }
713 }
714
715 for assignment in dataflow.assignments() {
716 let lowered = assignment.rhs.to_lowercase();
717 if lowered.contains(".len(")
718 || lowered.contains("::len(")
719 || lowered.contains("len()")
720 || options
721 .length_identifiers
722 .iter()
723 .any(|marker| lowered.contains(marker))
724 {
725 seeds.insert(assignment.target.clone());
726 }
727 }
728
729 seeds
730}
731
732fn propagate_length_seeds(dataflow: &MirDataflow, seeds: HashSet<String>) -> HashSet<String> {
733 let mut tainted = seeds;
734 if tainted.is_empty() {
735 return tainted;
736 }
737
738 let mut changed = true;
739 while changed {
740 changed = false;
741 for assignment in dataflow.assignments() {
742 if tainted.contains(&assignment.target) {
743 continue;
744 }
745
746 if assignment
747 .sources
748 .iter()
749 .any(|source| tainted.contains(source))
750 {
751 tainted.insert(assignment.target.clone());
752 changed = true;
753 }
754 }
755 }
756
757 tainted
758}
759
760fn collect_sink_lines(
761 function: &MirFunction,
762 dataflow: &MirDataflow,
763 root_var: &str,
764 options: &PrototypeOptions,
765) -> Vec<String> {
766 let related_vars = collect_related_vars(dataflow, root_var);
767 let mut sinks = Vec::new();
768
769 for line in &function.body {
770 let lowered = line.to_lowercase();
771 if !options
772 .serialization_sinks
773 .iter()
774 .any(|marker| lowered.contains(marker))
775 {
776 continue;
777 }
778
779 if related_vars.iter().any(|var| line.contains(var)) {
780 let trimmed = line.trim().to_string();
781 if !sinks.contains(&trimmed) {
782 sinks.push(trimmed);
783 }
784 }
785 }
786
787 sinks
788}
789
790fn collect_related_vars(dataflow: &MirDataflow, root_var: &str) -> HashSet<String> {
791 let mut related = HashSet::new();
792 related.insert(root_var.to_string());
793 let mut changed = true;
794
795 while changed {
796 changed = false;
797 for assignment in dataflow.assignments() {
798 if related.contains(&assignment.target) {
799 continue;
800 }
801
802 if assignment
803 .sources
804 .iter()
805 .any(|source| related.contains(source))
806 {
807 related.insert(assignment.target.clone());
808 changed = true;
809 }
810 }
811 }
812
813 related
814}
815
816fn is_length_identifier(name: &str, options: &PrototypeOptions) -> bool {
817 let lowered = name.to_lowercase();
818 options
819 .length_identifiers
820 .iter()
821 .any(|marker| lowered.contains(marker))
822}
823
824fn is_narrow_cast(rhs_lower: &str, options: &PrototypeOptions) -> bool {
825 rhs_lower.contains("inttoint")
826 && options
827 .narrow_cast_targets
828 .iter()
829 .any(|target| rhs_lower.contains(target))
830}
831
832fn is_try_into_narrow(rhs_lower: &str, options: &PrototypeOptions) -> bool {
833 if !rhs_lower.contains("try_into") {
834 return false;
835 }
836
837 options
838 .try_into_targets
839 .iter()
840 .any(|target| rhs_lower.contains(target))
841}
842
843fn is_broadcast_constructor(rhs: &str) -> bool {
844 let lower = rhs.to_lowercase();
845 lower.contains("tokio::sync::broadcast::channel")
846 || lower.contains("tokio::sync::broadcast::sender::")
847 || lower.contains("tokio::sync::broadcast::receiver::")
848}
849
850fn line_mentions_broadcast_usage(lower: &str) -> bool {
851 [
852 "tokio::sync::broadcast::",
853 ".send(",
854 "::send(",
855 "::send_ref(",
856 ".subscribe(",
857 "::subscribe(",
858 ]
859 .iter()
860 .any(|marker| lower.contains(marker))
861}
862
863fn payload_looks_unsync(line: &str, options: &PrototypeOptions) -> bool {
864 let lowered = line.to_lowercase();
865 options
866 .unsync_markers
867 .iter()
868 .any(|marker| lowered.contains(marker))
869}
870
871#[cfg(test)]
872mod tests {
873 use super::*;
874
875 fn function_from_lines(lines: &[&str]) -> MirFunction {
876 MirFunction {
877 name: "demo".to_string(),
878 signature: "fn demo()".to_string(),
879 body: lines.iter().map(|l| l.to_string()).collect(),
880 span: None,
881 ..Default::default()
882 }
883 }
884
885 #[test]
886 fn detects_simple_content_length_allocation() {
887 let function = function_from_lines(&[
888 " _5 = reqwest::Response::content_length(move _1);",
889 " _6 = copy _5;",
890 " _7 = Vec::<u8>::with_capacity(move _6);",
891 ]);
892
893 let findings = detect_content_length_allocations(&function);
894 assert_eq!(findings.len(), 1);
895 let finding = &findings[0];
896 assert_eq!(finding.capacity_var, "_6");
897 assert!(finding.tainted_vars.contains("_5"));
898 }
899
900 #[test]
901 fn ignores_min_guard() {
902 let function = function_from_lines(&[
903 " _2 = reqwest::Response::content_length(move _1);",
904 " _3 = copy _2;",
905 " _4 = core::cmp::min(move _3, const 1048576_usize);",
906 " _5 = Vec::<u8>::with_capacity(move _4);",
907 ]);
908
909 let findings = detect_content_length_allocations(&function);
910 assert!(findings.is_empty());
911 }
912
913 #[test]
914 fn ignores_clamp_guard() {
915 let function = function_from_lines(&[
916 " _2 = reqwest::Response::content_length(move _1);",
917 " _3 = copy _2;",
918 " _4 = core::cmp::Ord::clamp(copy _3, const 0_usize, const 65536_usize);",
919 " _5 = Vec::<u8>::with_capacity(move _4);",
920 ]);
921
922 let findings = detect_content_length_allocations(&function);
923 assert!(findings.is_empty());
924 }
925
926 #[test]
927 fn ignores_assert_guard() {
928 let function = function_from_lines(&[
929 " _2 = reqwest::Response::content_length(move _1);",
930 " assert(move _2 <= const 1048576_usize, ...);",
931 " _3 = copy _2;",
932 " _4 = Vec::<u8>::with_capacity(move _3);",
933 ]);
934
935 let findings = detect_content_length_allocations(&function);
936 assert!(findings.is_empty());
937 }
938
939 #[test]
940 fn detects_unbounded_allocation_from_len() {
941 let function = function_from_lines(&[
942 " debug body_len => _1;",
943 " _2 = copy _1;",
944 " _3 = Vec::<u8>::with_capacity(move _2);",
945 ]);
946
947 let findings = detect_unbounded_allocations(&function);
948 assert_eq!(findings.len(), 1);
949 assert_eq!(findings[0].capacity_var, "_2");
950 }
951
952 #[test]
953 fn unbounded_allocation_respects_guard() {
954 let function = function_from_lines(&[
955 " debug payload_size => _1;",
956 " _2 = core::cmp::min(move _1, const 65536_usize);",
957 " _3 = Vec::<u8>::with_capacity(move _2);",
958 ]);
959
960 let findings = detect_unbounded_allocations(&function);
961 assert!(findings.is_empty());
962 }
963
964 #[test]
965 fn detects_tainted_command_arg() {
966 let function = function_from_lines(&[
967 " _1 = std::env::var(const \"USER\");",
968 " _2 = std::process::Command::new(const \"/bin/echo\");",
969 " _3 = std::process::Command::arg(move _2, move _1);",
970 ]);
971
972 let invocations = detect_command_invocations(&function);
973 assert_eq!(invocations.len(), 1);
974 assert!(invocations[0].tainted_args.contains(&"_1".to_string()));
975 }
976
977 #[test]
978 fn ignores_constant_command_args() {
979 let function = function_from_lines(&[
980 " _1 = std::process::Command::new(const \"git\");",
981 " _2 = std::process::Command::arg(move _1, const \"status\");",
982 ]);
983
984 let invocations = detect_command_invocations(&function);
985 assert_eq!(invocations.len(), 1);
986 assert!(invocations[0].tainted_args.is_empty());
987 }
988
989 #[test]
990 fn ignores_clap_command_builder() {
991 let function = function_from_lines(&[
992 " _1 = clap::Command::new(const \"cargo-cola\");",
993 " _2 = clap::Command::arg(move _1, const \"--help\");",
994 ]);
995
996 let invocations = detect_command_invocations(&function);
997 assert!(invocations.is_empty());
998 }
999
1000 #[test]
1001 fn ignores_command_string_literal_checks() {
1002 let function = function_from_lines(&[
1003 " _1 = core::str::<impl str>::contains::<&str>(copy _0, const \"std::process::Command::new\");",
1004 ]);
1005
1006 let invocations = detect_command_invocations(&function);
1007 assert!(invocations.is_empty());
1008 }
1009
1010 #[test]
1011 fn detects_openssl_verify_none_inline() {
1012 let function = function_from_lines(&[
1013 " _1 = openssl::ssl::SslContextBuilder::set_verify(move _0, openssl::ssl::SslVerifyMode::NONE);",
1014 ]);
1015
1016 let findings = detect_openssl_verify_none(&function);
1017 assert_eq!(findings.len(), 1);
1018 assert!(findings[0].call_line.contains("set_verify"));
1019 assert!(findings[0].supporting_lines.is_empty());
1020 }
1021
1022 #[test]
1023 fn detects_openssl_verify_none_via_empty_mode() {
1024 let function = function_from_lines(&[
1025 " _1 = openssl::ssl::SslVerifyMode::empty();",
1026 " _2 = openssl::ssl::SslContextBuilder::set_verify(move _0, move _1);",
1027 ]);
1028
1029 let findings = detect_openssl_verify_none(&function);
1030 assert_eq!(findings.len(), 1);
1031 assert_eq!(findings[0].supporting_lines.len(), 1);
1032 assert!(findings[0].supporting_lines[0].contains("SslVerifyMode::empty"));
1033 }
1034
1035 #[test]
1036 fn detects_openssl_verify_none_callback() {
1037 let function = function_from_lines(&[
1038 " _2 = openssl::ssl::SslContextBuilder::set_verify_callback(move _0, openssl::ssl::SslVerifyMode::NONE, move _1);",
1039 ]);
1040
1041 let findings = detect_openssl_verify_none(&function);
1042 assert_eq!(findings.len(), 1);
1043 assert!(findings[0].call_line.contains("set_verify_callback"));
1044 }
1045
1046 #[test]
1047 fn detects_truncating_len_cast() {
1048 let function = function_from_lines(&[
1049 " debug payload_len => _1;",
1050 " _2 = copy _1;",
1051 " _3 = move _2 as i32 (IntToInt);",
1052 " _4 = Vec::<u8>::with_capacity(move _3);",
1053 ]);
1054
1055 let casts = detect_truncating_len_casts(&function);
1056 assert_eq!(casts.len(), 1);
1057 assert_eq!(casts[0].target_var, "_3");
1058 assert_eq!(casts[0].source_vars, vec!["_2".to_string()]);
1059 assert!(casts[0].sink_lines.is_empty());
1060 }
1061
1062 #[test]
1063 fn detects_try_into_len_cast() {
1064 let function = function_from_lines(&[
1065 " debug payload_len => _1;",
1066 " _2 = copy _1;",
1067 " _3 = core::convert::TryInto::try_into::<i16>(move _2);",
1068 ]);
1069
1070 let casts = detect_truncating_len_casts(&function);
1071 assert_eq!(casts.len(), 1);
1072 assert_eq!(casts[0].target_var, "_3");
1073 }
1074
1075 #[test]
1076 fn captures_serialization_sinks() {
1077 let function = function_from_lines(&[
1078 " debug payload_len => _1;",
1079 " _2 = copy _1;",
1080 " _3 = move _2 as u16 (IntToInt);",
1081 " _4 = copy _3;",
1082 " _5 = byteorder::WriteBytesExt::write_u16::<byteorder::BigEndian>(move _0, move _4);",
1083 ]);
1084
1085 let casts = detect_truncating_len_casts(&function);
1086 assert_eq!(casts.len(), 1);
1087 assert_eq!(casts[0].sink_lines.len(), 1);
1088 assert!(casts[0].sink_lines[0].contains("WriteBytesExt::write_u16"));
1089 }
1090
1091 #[test]
1092 fn ignores_wide_len_cast() {
1093 let function = function_from_lines(&[
1094 " debug payload_len => _1;",
1095 " _2 = copy _1;",
1096 " _3 = move _2 as i64 (IntToInt);",
1097 " _4 = Vec::<u8>::with_capacity(move _3);",
1098 ]);
1099
1100 let casts = detect_truncating_len_casts(&function);
1101 assert!(casts.is_empty());
1102 }
1103
1104 #[test]
1105 fn respects_custom_guard_markers() {
1106 let function = function_from_lines(&[
1107 " _2 = reqwest::Response::content_length(move _1);",
1108 " _3 = copy _2;",
1109 " _4 = my_crate::ensure_capacity(move _3);",
1110 " _5 = Vec::<u8>::with_capacity(move _4);",
1111 ]);
1112
1113 let mut options = PrototypeOptions::default();
1114 options.guard_markers.push("ensure_capacity".to_string());
1115
1116 let findings = detect_content_length_allocations_with_options(&function, &options);
1117 assert!(findings.is_empty());
1118 }
1119
1120 #[test]
1121 fn detects_broadcast_rc_payload() {
1122 let function = function_from_lines(&[
1123 " _5 = tokio::sync::broadcast::channel::<std::rc::Rc<String>>(const 16_usize);",
1124 ]);
1125
1126 let findings = detect_broadcast_unsync_payloads(&function);
1127 assert_eq!(findings.len(), 1);
1128 assert!(findings[0].line.contains("std::rc::Rc"));
1129 }
1130
1131 #[test]
1132 fn ignores_broadcast_arc_payload() {
1133 let function = function_from_lines(&[
1134 " _5 = tokio::sync::broadcast::channel::<std::sync::Arc<String>>(const 16_usize);",
1135 ]);
1136
1137 let findings = detect_broadcast_unsync_payloads(&function);
1138 assert!(findings.is_empty());
1139 }
1140
1141 #[test]
1142 fn no_findings_without_taint() {
1143 let function =
1144 function_from_lines(&[" _3 = Vec::<u8>::with_capacity(const 4096_usize);"]);
1145
1146 let findings = detect_content_length_allocations(&function);
1147 assert!(findings.is_empty());
1148 }
1149}