Skip to main content

mir_extractor/
prototypes.rs

1use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
2
3use crate::{dataflow::extract_variables, MirDataflow, MirFunction};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct PrototypeOptions {
7    pub guard_markers: Vec<String>,
8    pub unsync_markers: Vec<String>,
9    pub narrow_cast_targets: Vec<String>,
10    pub try_into_targets: Vec<String>,
11    pub serialization_sinks: Vec<String>,
12    pub length_identifiers: Vec<String>,
13}
14
15impl Default for PrototypeOptions {
16    fn default() -> Self {
17        Self {
18            guard_markers: vec![
19                // Numeric clamping/bounding
20                "::min".to_string(),
21                ".min(".to_string(),
22                "cmp::min".to_string(),
23                "::clamp".to_string(),
24                ".clamp(".to_string(),
25                "::saturating_sub".to_string(),
26                "::checked_sub".to_string(),
27                "::min_by".to_string(),
28                "::min_by_key".to_string(),
29                // v1.0.1: HTTP/request size guards
30                "max_request_bytes".to_string(),
31                "max_request_size".to_string(),
32                "max_body_size".to_string(),
33                "max_http_request_size".to_string(),
34                "body_limit".to_string(),
35                "DefaultBodyLimit".to_string(),
36                "RequestBodyLimit".to_string(),
37                "content_length_limit".to_string(),
38                "PayloadConfig".to_string(),
39                "ContentLengthLimit".to_string(),
40                // v1.0.1: Generic limit patterns
41                "MAX_SIZE".to_string(),
42                "MAX_LEN".to_string(),
43                "MAX_LENGTH".to_string(),
44                "SIZE_LIMIT".to_string(),
45                "max_size".to_string(),
46                "size_limit".to_string(),
47                "max_capacity".to_string(),
48                "limit_bytes".to_string(),
49            ],
50            unsync_markers: vec![
51                "::rc<".to_string(),
52                "::refcell<".to_string(),
53                "::<rc<".to_string(), // Turbofish syntax in generics
54                "::<refcell<".to_string(),
55                "std::rc::rc<".to_string(),
56                "alloc::rc::rc<".to_string(),
57                "std::cell::refcell<".to_string(),
58                "core::cell::refcell<".to_string(),
59                "std::cell::cell<".to_string(),
60                "core::cell::cell<".to_string(),
61            ],
62            narrow_cast_targets: vec![
63                " as i32".to_string(),
64                " as u32".to_string(),
65                " as i16".to_string(),
66                " as u16".to_string(),
67                " as i8".to_string(),
68                " as u8".to_string(),
69            ],
70            try_into_targets: vec![
71                "<i32>".to_string(),
72                "<u32>".to_string(),
73                "<i16>".to_string(),
74                "<u16>".to_string(),
75                "<i8>".to_string(),
76                "<u8>".to_string(),
77            ],
78            serialization_sinks: vec![
79                "put_i32".to_string(),
80                "put_u32".to_string(),
81                "put_u64".to_string(),
82                "put_i16".to_string(),
83                "put_u16".to_string(),
84                "put_u8".to_string(),
85                "write_i32".to_string(),
86                "write_u32".to_string(),
87                "write_u64".to_string(),
88                "write_i16".to_string(),
89                "write_u16".to_string(),
90                "write_u8".to_string(),
91                "unwrap(".to_string(),
92                "::unwrap(".to_string(),
93                "expect(".to_string(),
94                "::expect(".to_string(),
95            ],
96            length_identifiers: vec![
97                "len".to_string(),
98                "length".to_string(),
99                "payload".to_string(),
100                "size".to_string(),
101            ],
102        }
103    }
104}
105
106#[derive(Debug, Clone, PartialEq, Eq)]
107pub struct ContentLengthAllocation {
108    pub allocation_line: String,
109    pub capacity_var: String,
110    pub tainted_vars: HashSet<String>,
111}
112
113#[derive(Debug, Clone, PartialEq, Eq)]
114pub struct LengthTruncationCast {
115    pub cast_line: String,
116    pub target_var: String,
117    pub source_vars: Vec<String>,
118    pub sink_lines: Vec<String>,
119}
120
121#[derive(Debug, Clone, PartialEq, Eq)]
122pub struct BroadcastUnsyncUsage {
123    pub line: String,
124}
125
126#[derive(Debug, Clone, PartialEq, Eq)]
127pub struct CommandInvocation {
128    pub command_line: String,
129    pub tainted_args: Vec<String>,
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
133pub struct OpensslVerifyNoneInvocation {
134    pub call_line: String,
135    pub supporting_lines: Vec<String>,
136}
137
138pub fn detect_content_length_allocations(function: &MirFunction) -> Vec<ContentLengthAllocation> {
139    detect_content_length_allocations_with_options(function, &PrototypeOptions::default())
140}
141
142pub fn detect_content_length_allocations_with_options(
143    function: &MirFunction,
144    options: &PrototypeOptions,
145) -> Vec<ContentLengthAllocation> {
146    let dataflow = MirDataflow::new(function);
147    let tainted = dataflow.taint_from(|assignment| rhs_mentions_content_length(&assignment.rhs));
148
149    if tainted.is_empty() {
150        return Vec::new();
151    }
152
153    let mut findings = Vec::new();
154
155    for line in &function.body {
156        if let Some(capacity_var) = extract_capacity_variable(line) {
157            if tainted.contains(&capacity_var)
158                && !is_guarded_capacity(function, &dataflow, &capacity_var, options)
159            {
160                findings.push(ContentLengthAllocation {
161                    allocation_line: line.trim().to_string(),
162                    capacity_var,
163                    tainted_vars: tainted.clone(),
164                });
165            }
166        }
167    }
168
169    findings
170}
171
172pub fn detect_unbounded_allocations(function: &MirFunction) -> Vec<ContentLengthAllocation> {
173    detect_unbounded_allocations_with_options(function, &PrototypeOptions::default())
174}
175
176pub fn detect_unbounded_allocations_with_options(
177    function: &MirFunction,
178    options: &PrototypeOptions,
179) -> Vec<ContentLengthAllocation> {
180    let dataflow = MirDataflow::new(function);
181    let seeds = collect_length_seed_vars(function, &dataflow, options);
182    if seeds.is_empty() {
183        return Vec::new();
184    }
185
186    let tainted = propagate_length_seeds(&dataflow, seeds);
187    if tainted.is_empty() {
188        return Vec::new();
189    }
190
191    let mut findings = Vec::new();
192
193    for line in &function.body {
194        if let Some(capacity_var) = extract_capacity_variable(line) {
195            if tainted.contains(&capacity_var)
196                && !is_guarded_capacity(function, &dataflow, &capacity_var, options)
197            {
198                findings.push(ContentLengthAllocation {
199                    allocation_line: line.trim().to_string(),
200                    capacity_var,
201                    tainted_vars: tainted.clone(),
202                });
203            }
204        }
205    }
206
207    findings
208}
209
210fn extract_capacity_variable(line: &str) -> Option<String> {
211    let lowered = line.to_lowercase();
212    let is_reserve_method = lowered.contains("reserve_exact") || lowered.contains("reserve");
213    let keyword = if lowered.contains("with_capacity") {
214        "with_capacity"
215    } else if lowered.contains("reserve_exact") {
216        "reserve_exact"
217    } else if lowered.contains("reserve") {
218        "reserve"
219    } else {
220        return None;
221    };
222
223    let start = line.find(keyword)? + keyword.len();
224    let remainder = line[start..].trim_start();
225    if !remainder.starts_with('(') {
226        return None;
227    }
228
229    let closing = remainder.find(')')?;
230    let inside = &remainder[1..closing];
231    let vars = extract_variables(inside);
232
233    // For Vec::reserve and Vec::reserve_exact, the capacity is the second argument
234    // (first is self). For with_capacity, it's the first (and only) argument.
235    if is_reserve_method {
236        vars.into_iter().nth(1)
237    } else {
238        vars.into_iter().next()
239    }
240}
241
242fn rhs_mentions_content_length(rhs: &str) -> bool {
243    let lower = rhs.to_lowercase();
244    if lower.contains("content_length")
245        || lower.contains("\"content-length\"")
246        || lower.contains("header::content-length")
247        || lower.contains("headername::from_static(\"content-length\"")
248        || lower.contains("headervalue::from_static(\"content-length\"")
249        || lower.contains("from_bytes(b\"content-length\")")
250    {
251        return true;
252    }
253
254    rhs.contains("CONTENT_LENGTH")
255}
256
257fn is_guarded_capacity(
258    function: &MirFunction,
259    dataflow: &MirDataflow,
260    capacity_var: &str,
261    options: &PrototypeOptions,
262) -> bool {
263    let mut queue = vec![capacity_var.to_string()];
264    let mut visited = HashSet::new();
265
266    while let Some(var) = queue.pop() {
267        if !visited.insert(var.clone()) {
268            continue;
269        }
270
271        // Check for assert! with comparison operators
272        if assert_mentions_var(function, &var) {
273            return true;
274        }
275
276        // Check for comparison operations (Le, Lt, Ge, Gt) in MIR
277        if has_comparison_guard(function, &var) {
278            return true;
279        }
280
281        // Check if variable is used as argument to guarding functions
282        if used_in_guard_check(function, &var, options) {
283            return true;
284        }
285
286        // Walk backward through assignments
287        for assignment in dataflow.assignments() {
288            if assignment.target != var {
289                continue;
290            }
291
292            if rhs_contains_upper_bound_guard(&assignment.rhs, options) {
293                return true;
294            }
295
296            for source in &assignment.sources {
297                queue.push(source.clone());
298            }
299        }
300    }
301
302    false
303}
304
305/// Check if variable is used in a comparison operation (Le, Lt, Ge, Gt)
306/// followed by conditional branching (indicating an assertion or validation)
307fn has_comparison_guard(function: &MirFunction, var: &str) -> bool {
308    for (i, line) in function.body.iter().enumerate() {
309        // Look for comparison operations
310        if (line.contains("= Le(")
311            || line.contains("= Lt(")
312            || line.contains("= Ge(")
313            || line.contains("= Gt("))
314            && line.contains(var)
315        {
316            // Check if next line is a switchInt (conditional branch)
317            if i + 1 < function.body.len() {
318                let next_line = &function.body[i + 1];
319                if next_line.contains("switchInt") {
320                    return true;
321                }
322            }
323        }
324    }
325    false
326}
327
328/// Check if variable is used as argument to a guard function like checked_sub
329/// where the result is checked before use
330fn used_in_guard_check(function: &MirFunction, var: &str, options: &PrototypeOptions) -> bool {
331    for (i, line) in function.body.iter().enumerate() {
332        // Check if line contains guard function and our variable as argument
333        let lowered = line.to_lowercase();
334        let has_guard = options
335            .guard_markers
336            .iter()
337            .any(|marker| lowered.contains(marker));
338
339        if has_guard && line.contains(var) {
340            // Extract the result variable (target of assignment)
341            if let Some(eq_pos) = line.find('=') {
342                if let Some(result_var) = line[..eq_pos].trim().split_whitespace().last() {
343                    // Look ahead for discriminant check on result
344                    for j in (i + 1)..function.body.len().min(i + 10) {
345                        let future_line = &function.body[j];
346                        if future_line.contains("discriminant") && future_line.contains(result_var)
347                        {
348                            return true;
349                        }
350                        // Also check for direct switchInt on the result
351                        if future_line.contains("switchInt") && future_line.contains(result_var) {
352                            return true;
353                        }
354                    }
355                }
356            }
357        }
358    }
359    false
360}
361
362fn rhs_contains_upper_bound_guard(rhs: &str, options: &PrototypeOptions) -> bool {
363    let lowered = rhs.to_lowercase();
364    options
365        .guard_markers
366        .iter()
367        .any(|pattern| lowered.contains(pattern))
368}
369
370fn assert_mentions_var(function: &MirFunction, var: &str) -> bool {
371    function.body.iter().any(|line| {
372        if !line.contains("assert") || !line.contains(var) {
373            return false;
374        }
375        let lowered = line.to_lowercase();
376        lowered.contains(" <= ")
377            || lowered.contains(" < ")
378            || lowered.contains(" >= ")
379            || lowered.contains(" > ")
380    })
381}
382
383pub fn detect_truncating_len_casts(function: &MirFunction) -> Vec<LengthTruncationCast> {
384    detect_truncating_len_casts_with_options(function, &PrototypeOptions::default())
385}
386
387pub fn detect_truncating_len_casts_with_options(
388    function: &MirFunction,
389    options: &PrototypeOptions,
390) -> Vec<LengthTruncationCast> {
391    let dataflow = MirDataflow::new(function);
392    let seeds = collect_length_seed_vars(function, &dataflow, options);
393    let tainted = propagate_length_seeds(&dataflow, seeds);
394
395    if tainted.is_empty() {
396        return Vec::new();
397    }
398
399    let mut findings = Vec::new();
400
401    for assignment in dataflow.assignments() {
402        let rhs_lower = assignment.rhs.to_lowercase();
403        if !is_narrow_cast(&rhs_lower, options) && !is_try_into_narrow(&rhs_lower, options) {
404            continue;
405        }
406
407        if assignment
408            .sources
409            .iter()
410            .any(|source| tainted.contains(source))
411        {
412            let sink_lines = collect_sink_lines(function, &dataflow, &assignment.target, options);
413            findings.push(LengthTruncationCast {
414                cast_line: assignment.line.clone(),
415                target_var: assignment.target.clone(),
416                source_vars: assignment.sources.clone(),
417                sink_lines,
418            });
419        }
420    }
421
422    findings
423}
424
425pub fn detect_broadcast_unsync_payloads(function: &MirFunction) -> Vec<BroadcastUnsyncUsage> {
426    detect_broadcast_unsync_payloads_with_options(function, &PrototypeOptions::default())
427}
428
429pub fn detect_broadcast_unsync_payloads_with_options(
430    function: &MirFunction,
431    options: &PrototypeOptions,
432) -> Vec<BroadcastUnsyncUsage> {
433    let dataflow = MirDataflow::new(function);
434    let mut seed_lines = Vec::new();
435
436    let unsync_vars = dataflow.taint_from(|assignment| {
437        if is_broadcast_constructor(&assignment.rhs)
438            && payload_looks_unsync(&assignment.rhs, options)
439        {
440            seed_lines.push(assignment.line.trim().to_string());
441            return true;
442        }
443
444        false
445    });
446
447    if unsync_vars.is_empty() {
448        return seed_lines
449            .into_iter()
450            .map(|line| BroadcastUnsyncUsage { line })
451            .collect();
452    }
453
454    let mut lines: BTreeSet<String> = seed_lines.into_iter().collect();
455
456    for raw_line in &function.body {
457        let trimmed = raw_line.trim();
458        if trimmed.is_empty() {
459            continue;
460        }
461
462        let lower = trimmed.to_lowercase();
463        let references_unsync_var = unsync_vars.iter().any(|var| trimmed.contains(var));
464
465        if payload_looks_unsync(trimmed, options) && lower.contains("tokio::sync::broadcast") {
466            lines.insert(trimmed.to_string());
467            continue;
468        }
469
470        if references_unsync_var && line_mentions_broadcast_usage(&lower) {
471            lines.insert(trimmed.to_string());
472        }
473    }
474
475    lines
476        .into_iter()
477        .map(|line| BroadcastUnsyncUsage { line })
478        .collect()
479}
480
481pub fn detect_command_invocations(function: &MirFunction) -> Vec<CommandInvocation> {
482    let dataflow = MirDataflow::new(function);
483    let taint_sources = dataflow.taint_from(|assignment| {
484        let lowered = assignment.rhs.to_lowercase();
485        lowered.contains("env::var")
486            || lowered.contains("env::args")
487            || lowered.contains("env::var_os")
488            || lowered.contains("env::args_os")
489            || lowered.contains("std::env::args")
490            || lowered.contains("std::env::vars")
491    });
492
493    let mut findings = Vec::new();
494
495    fn pattern_outside_quotes(text: &str, idx: usize) -> bool {
496        let bytes = text.as_bytes();
497        let mut in_quotes = false;
498        let mut escaped = false;
499
500        for (pos, byte) in bytes.iter().enumerate() {
501            if pos >= idx {
502                break;
503            }
504
505            if escaped {
506                escaped = false;
507                continue;
508            }
509
510            match byte {
511                b'\\' => escaped = true,
512                b'"' => in_quotes = !in_quotes,
513                _ => {}
514            }
515        }
516
517        !in_quotes
518    }
519
520    for assignment in dataflow.assignments() {
521        let lowered = assignment.rhs.to_lowercase();
522        let first_paren = lowered.find('(').unwrap_or(lowered.len());
523
524        let is_process_command = [
525            "::std::process::command::new",
526            "std::process::command::new",
527            "::tokio::process::command::new",
528            "tokio::process::command::new",
529            "::async_process::command::new",
530            "async_process::command::new",
531        ]
532        .into_iter()
533        .any(|pattern| {
534            lowered.find(pattern).map_or(false, |idx| {
535                idx < first_paren && pattern_outside_quotes(&assignment.rhs, idx)
536            })
537        });
538
539        if !is_process_command {
540            continue;
541        }
542
543        let mut tainted_args = Vec::new();
544        let mut visited = HashSet::new();
545        let mut queue = VecDeque::new();
546        queue.push_back(assignment.target.clone());
547
548        while let Some(current) = queue.pop_front() {
549            if !visited.insert(current.clone()) {
550                continue;
551            }
552
553            let mut arg_sites = Vec::new();
554
555            for other in dataflow.assignments() {
556                if other.target == current {
557                    for src in &other.sources {
558                        queue.push_back(src.clone());
559                    }
560                }
561
562                let rhs_lower = other.rhs.to_lowercase();
563                if other.sources.iter().any(|src| src == &current)
564                    && (rhs_lower.contains("command::arg(")
565                        || rhs_lower.contains("command::args(")
566                        || rhs_lower.contains("command::env(")
567                        || rhs_lower.contains("command::arg_os(")
568                        || rhs_lower.contains("command::args_os("))
569                {
570                    arg_sites.push(other.clone());
571                }
572            }
573
574            for site in arg_sites {
575                let mut queue_inputs = Vec::new();
576                for src in &site.sources {
577                    queue_inputs.push(src.clone());
578                    if taint_sources.contains(src) {
579                        tainted_args.push(src.clone());
580                    }
581                }
582                for src in queue_inputs {
583                    queue.push_back(src);
584                }
585            }
586        }
587
588        tainted_args.sort();
589        tainted_args.dedup();
590
591        findings.push(CommandInvocation {
592            command_line: assignment.line.trim().to_string(),
593            tainted_args,
594        });
595    }
596
597    findings
598}
599
600pub fn detect_openssl_verify_none(function: &MirFunction) -> Vec<OpensslVerifyNoneInvocation> {
601    let dataflow = MirDataflow::new(function);
602    let mut var_to_lines: HashMap<String, Vec<String>> = HashMap::new();
603
604    for assignment in dataflow.assignments() {
605        var_to_lines
606            .entry(assignment.target.clone())
607            .or_default()
608            .push(assignment.line.trim().to_string());
609    }
610
611    let tainted_modes =
612        dataflow.taint_from(|assignment| rhs_disables_verification(&assignment.rhs));
613    let mut findings: HashMap<String, Vec<String>> = HashMap::new();
614
615    for assignment in dataflow.assignments() {
616        if !is_verify_configuration_call(&assignment.rhs) {
617            continue;
618        }
619
620        let mut supporting = Vec::new();
621        let mut disables = rhs_disables_verification(&assignment.rhs);
622
623        for source in &assignment.sources {
624            if tainted_modes.contains(source) {
625                disables = true;
626                if let Some(lines) = var_to_lines.get(source) {
627                    for line in lines {
628                        if !supporting.contains(line) {
629                            supporting.push(line.clone());
630                        }
631                    }
632                }
633            }
634        }
635
636        if !disables {
637            continue;
638        }
639
640        let entry = findings
641            .entry(assignment.line.trim().to_string())
642            .or_insert_with(Vec::new);
643        for line in supporting {
644            if !entry.contains(&line) {
645                entry.push(line);
646            }
647        }
648    }
649
650    for raw_line in &function.body {
651        let trimmed = raw_line.trim().to_string();
652        if !is_verify_configuration_call(&trimmed) {
653            continue;
654        }
655        if !rhs_disables_verification(&trimmed) {
656            continue;
657        }
658        findings.entry(trimmed).or_insert_with(Vec::new);
659    }
660
661    let mut result: Vec<_> = findings
662        .into_iter()
663        .map(
664            |(call_line, supporting_lines)| OpensslVerifyNoneInvocation {
665                call_line,
666                supporting_lines,
667            },
668        )
669        .collect();
670
671    result.sort_by(|a, b| a.call_line.cmp(&b.call_line));
672    result
673}
674
675fn is_verify_configuration_call(text: &str) -> bool {
676    let lowered = text.to_lowercase();
677    lowered.contains("set_verify(") || lowered.contains("set_verify_callback(")
678}
679
680fn rhs_disables_verification(rhs: &str) -> bool {
681    let rhs_lower = rhs.to_lowercase();
682
683    rhs_lower.contains("sslverifymode::none")
684        || rhs_lower.contains("ssl_verify_none")
685        || rhs_lower.contains("verify_none")
686        || rhs_lower.contains("sslverifymode::empty(")
687        || rhs_lower.contains("verify_mode::empty(")
688        || rhs_lower.contains("sslverifymode::from_bits_truncate(0")
689        || rhs_lower.contains("sslverifymode::from_bits(0")
690        || rhs_lower.contains("sslverifymode::from_bits_truncate(const 0")
691        || rhs_lower.contains("sslverifymode::from_bits(const 0")
692        || rhs_lower.contains("sslverifymode::bits(0")
693}
694
695fn collect_length_seed_vars(
696    function: &MirFunction,
697    dataflow: &MirDataflow,
698    options: &PrototypeOptions,
699) -> HashSet<String> {
700    let mut seeds = HashSet::new();
701
702    for line in &function.body {
703        let trimmed = line.trim();
704        if let Some(rest) = trimmed.strip_prefix("debug ") {
705            if let Some((name_part, var_part)) = rest.split_once("=>") {
706                let name = name_part.trim();
707                let var = var_part.trim().trim_end_matches(';');
708                if is_length_identifier(name, options) && var.starts_with('_') {
709                    seeds.insert(var.to_string());
710                }
711            }
712        }
713    }
714
715    for assignment in dataflow.assignments() {
716        let lowered = assignment.rhs.to_lowercase();
717        if lowered.contains(".len(")
718            || lowered.contains("::len(")
719            || lowered.contains("len()")
720            || options
721                .length_identifiers
722                .iter()
723                .any(|marker| lowered.contains(marker))
724        {
725            seeds.insert(assignment.target.clone());
726        }
727    }
728
729    seeds
730}
731
732fn propagate_length_seeds(dataflow: &MirDataflow, seeds: HashSet<String>) -> HashSet<String> {
733    let mut tainted = seeds;
734    if tainted.is_empty() {
735        return tainted;
736    }
737
738    let mut changed = true;
739    while changed {
740        changed = false;
741        for assignment in dataflow.assignments() {
742            if tainted.contains(&assignment.target) {
743                continue;
744            }
745
746            if assignment
747                .sources
748                .iter()
749                .any(|source| tainted.contains(source))
750            {
751                tainted.insert(assignment.target.clone());
752                changed = true;
753            }
754        }
755    }
756
757    tainted
758}
759
760fn collect_sink_lines(
761    function: &MirFunction,
762    dataflow: &MirDataflow,
763    root_var: &str,
764    options: &PrototypeOptions,
765) -> Vec<String> {
766    let related_vars = collect_related_vars(dataflow, root_var);
767    let mut sinks = Vec::new();
768
769    for line in &function.body {
770        let lowered = line.to_lowercase();
771        if !options
772            .serialization_sinks
773            .iter()
774            .any(|marker| lowered.contains(marker))
775        {
776            continue;
777        }
778
779        if related_vars.iter().any(|var| line.contains(var)) {
780            let trimmed = line.trim().to_string();
781            if !sinks.contains(&trimmed) {
782                sinks.push(trimmed);
783            }
784        }
785    }
786
787    sinks
788}
789
790fn collect_related_vars(dataflow: &MirDataflow, root_var: &str) -> HashSet<String> {
791    let mut related = HashSet::new();
792    related.insert(root_var.to_string());
793    let mut changed = true;
794
795    while changed {
796        changed = false;
797        for assignment in dataflow.assignments() {
798            if related.contains(&assignment.target) {
799                continue;
800            }
801
802            if assignment
803                .sources
804                .iter()
805                .any(|source| related.contains(source))
806            {
807                related.insert(assignment.target.clone());
808                changed = true;
809            }
810        }
811    }
812
813    related
814}
815
816fn is_length_identifier(name: &str, options: &PrototypeOptions) -> bool {
817    let lowered = name.to_lowercase();
818    options
819        .length_identifiers
820        .iter()
821        .any(|marker| lowered.contains(marker))
822}
823
824fn is_narrow_cast(rhs_lower: &str, options: &PrototypeOptions) -> bool {
825    rhs_lower.contains("inttoint")
826        && options
827            .narrow_cast_targets
828            .iter()
829            .any(|target| rhs_lower.contains(target))
830}
831
832fn is_try_into_narrow(rhs_lower: &str, options: &PrototypeOptions) -> bool {
833    if !rhs_lower.contains("try_into") {
834        return false;
835    }
836
837    options
838        .try_into_targets
839        .iter()
840        .any(|target| rhs_lower.contains(target))
841}
842
843fn is_broadcast_constructor(rhs: &str) -> bool {
844    let lower = rhs.to_lowercase();
845    lower.contains("tokio::sync::broadcast::channel")
846        || lower.contains("tokio::sync::broadcast::sender::")
847        || lower.contains("tokio::sync::broadcast::receiver::")
848}
849
850fn line_mentions_broadcast_usage(lower: &str) -> bool {
851    [
852        "tokio::sync::broadcast::",
853        ".send(",
854        "::send(",
855        "::send_ref(",
856        ".subscribe(",
857        "::subscribe(",
858    ]
859    .iter()
860    .any(|marker| lower.contains(marker))
861}
862
863fn payload_looks_unsync(line: &str, options: &PrototypeOptions) -> bool {
864    let lowered = line.to_lowercase();
865    options
866        .unsync_markers
867        .iter()
868        .any(|marker| lowered.contains(marker))
869}
870
871#[cfg(test)]
872mod tests {
873    use super::*;
874
875    fn function_from_lines(lines: &[&str]) -> MirFunction {
876        MirFunction {
877            name: "demo".to_string(),
878            signature: "fn demo()".to_string(),
879            body: lines.iter().map(|l| l.to_string()).collect(),
880            span: None,
881            ..Default::default()
882        }
883    }
884
885    #[test]
886    fn detects_simple_content_length_allocation() {
887        let function = function_from_lines(&[
888            "    _5 = reqwest::Response::content_length(move _1);",
889            "    _6 = copy _5;",
890            "    _7 = Vec::<u8>::with_capacity(move _6);",
891        ]);
892
893        let findings = detect_content_length_allocations(&function);
894        assert_eq!(findings.len(), 1);
895        let finding = &findings[0];
896        assert_eq!(finding.capacity_var, "_6");
897        assert!(finding.tainted_vars.contains("_5"));
898    }
899
900    #[test]
901    fn ignores_min_guard() {
902        let function = function_from_lines(&[
903            "    _2 = reqwest::Response::content_length(move _1);",
904            "    _3 = copy _2;",
905            "    _4 = core::cmp::min(move _3, const 1048576_usize);",
906            "    _5 = Vec::<u8>::with_capacity(move _4);",
907        ]);
908
909        let findings = detect_content_length_allocations(&function);
910        assert!(findings.is_empty());
911    }
912
913    #[test]
914    fn ignores_clamp_guard() {
915        let function = function_from_lines(&[
916            "    _2 = reqwest::Response::content_length(move _1);",
917            "    _3 = copy _2;",
918            "    _4 = core::cmp::Ord::clamp(copy _3, const 0_usize, const 65536_usize);",
919            "    _5 = Vec::<u8>::with_capacity(move _4);",
920        ]);
921
922        let findings = detect_content_length_allocations(&function);
923        assert!(findings.is_empty());
924    }
925
926    #[test]
927    fn ignores_assert_guard() {
928        let function = function_from_lines(&[
929            "    _2 = reqwest::Response::content_length(move _1);",
930            "    assert(move _2 <= const 1048576_usize, ...);",
931            "    _3 = copy _2;",
932            "    _4 = Vec::<u8>::with_capacity(move _3);",
933        ]);
934
935        let findings = detect_content_length_allocations(&function);
936        assert!(findings.is_empty());
937    }
938
939    #[test]
940    fn detects_unbounded_allocation_from_len() {
941        let function = function_from_lines(&[
942            "    debug body_len => _1;",
943            "    _2 = copy _1;",
944            "    _3 = Vec::<u8>::with_capacity(move _2);",
945        ]);
946
947        let findings = detect_unbounded_allocations(&function);
948        assert_eq!(findings.len(), 1);
949        assert_eq!(findings[0].capacity_var, "_2");
950    }
951
952    #[test]
953    fn unbounded_allocation_respects_guard() {
954        let function = function_from_lines(&[
955            "    debug payload_size => _1;",
956            "    _2 = core::cmp::min(move _1, const 65536_usize);",
957            "    _3 = Vec::<u8>::with_capacity(move _2);",
958        ]);
959
960        let findings = detect_unbounded_allocations(&function);
961        assert!(findings.is_empty());
962    }
963
964    #[test]
965    fn detects_tainted_command_arg() {
966        let function = function_from_lines(&[
967            "    _1 = std::env::var(const \"USER\");",
968            "    _2 = std::process::Command::new(const \"/bin/echo\");",
969            "    _3 = std::process::Command::arg(move _2, move _1);",
970        ]);
971
972        let invocations = detect_command_invocations(&function);
973        assert_eq!(invocations.len(), 1);
974        assert!(invocations[0].tainted_args.contains(&"_1".to_string()));
975    }
976
977    #[test]
978    fn ignores_constant_command_args() {
979        let function = function_from_lines(&[
980            "    _1 = std::process::Command::new(const \"git\");",
981            "    _2 = std::process::Command::arg(move _1, const \"status\");",
982        ]);
983
984        let invocations = detect_command_invocations(&function);
985        assert_eq!(invocations.len(), 1);
986        assert!(invocations[0].tainted_args.is_empty());
987    }
988
989    #[test]
990    fn ignores_clap_command_builder() {
991        let function = function_from_lines(&[
992            "    _1 = clap::Command::new(const \"cargo-cola\");",
993            "    _2 = clap::Command::arg(move _1, const \"--help\");",
994        ]);
995
996        let invocations = detect_command_invocations(&function);
997        assert!(invocations.is_empty());
998    }
999
1000    #[test]
1001    fn ignores_command_string_literal_checks() {
1002        let function = function_from_lines(&[
1003            "    _1 = core::str::<impl str>::contains::<&str>(copy _0, const \"std::process::Command::new\");",
1004        ]);
1005
1006        let invocations = detect_command_invocations(&function);
1007        assert!(invocations.is_empty());
1008    }
1009
1010    #[test]
1011    fn detects_openssl_verify_none_inline() {
1012        let function = function_from_lines(&[
1013            "    _1 = openssl::ssl::SslContextBuilder::set_verify(move _0, openssl::ssl::SslVerifyMode::NONE);",
1014        ]);
1015
1016        let findings = detect_openssl_verify_none(&function);
1017        assert_eq!(findings.len(), 1);
1018        assert!(findings[0].call_line.contains("set_verify"));
1019        assert!(findings[0].supporting_lines.is_empty());
1020    }
1021
1022    #[test]
1023    fn detects_openssl_verify_none_via_empty_mode() {
1024        let function = function_from_lines(&[
1025            "    _1 = openssl::ssl::SslVerifyMode::empty();",
1026            "    _2 = openssl::ssl::SslContextBuilder::set_verify(move _0, move _1);",
1027        ]);
1028
1029        let findings = detect_openssl_verify_none(&function);
1030        assert_eq!(findings.len(), 1);
1031        assert_eq!(findings[0].supporting_lines.len(), 1);
1032        assert!(findings[0].supporting_lines[0].contains("SslVerifyMode::empty"));
1033    }
1034
1035    #[test]
1036    fn detects_openssl_verify_none_callback() {
1037        let function = function_from_lines(&[
1038            "    _2 = openssl::ssl::SslContextBuilder::set_verify_callback(move _0, openssl::ssl::SslVerifyMode::NONE, move _1);",
1039        ]);
1040
1041        let findings = detect_openssl_verify_none(&function);
1042        assert_eq!(findings.len(), 1);
1043        assert!(findings[0].call_line.contains("set_verify_callback"));
1044    }
1045
1046    #[test]
1047    fn detects_truncating_len_cast() {
1048        let function = function_from_lines(&[
1049            "    debug payload_len => _1;",
1050            "    _2 = copy _1;",
1051            "    _3 = move _2 as i32 (IntToInt);",
1052            "    _4 = Vec::<u8>::with_capacity(move _3);",
1053        ]);
1054
1055        let casts = detect_truncating_len_casts(&function);
1056        assert_eq!(casts.len(), 1);
1057        assert_eq!(casts[0].target_var, "_3");
1058        assert_eq!(casts[0].source_vars, vec!["_2".to_string()]);
1059        assert!(casts[0].sink_lines.is_empty());
1060    }
1061
1062    #[test]
1063    fn detects_try_into_len_cast() {
1064        let function = function_from_lines(&[
1065            "    debug payload_len => _1;",
1066            "    _2 = copy _1;",
1067            "    _3 = core::convert::TryInto::try_into::<i16>(move _2);",
1068        ]);
1069
1070        let casts = detect_truncating_len_casts(&function);
1071        assert_eq!(casts.len(), 1);
1072        assert_eq!(casts[0].target_var, "_3");
1073    }
1074
1075    #[test]
1076    fn captures_serialization_sinks() {
1077        let function = function_from_lines(&[
1078            "    debug payload_len => _1;",
1079            "    _2 = copy _1;",
1080            "    _3 = move _2 as u16 (IntToInt);",
1081            "    _4 = copy _3;",
1082            "    _5 = byteorder::WriteBytesExt::write_u16::<byteorder::BigEndian>(move _0, move _4);",
1083        ]);
1084
1085        let casts = detect_truncating_len_casts(&function);
1086        assert_eq!(casts.len(), 1);
1087        assert_eq!(casts[0].sink_lines.len(), 1);
1088        assert!(casts[0].sink_lines[0].contains("WriteBytesExt::write_u16"));
1089    }
1090
1091    #[test]
1092    fn ignores_wide_len_cast() {
1093        let function = function_from_lines(&[
1094            "    debug payload_len => _1;",
1095            "    _2 = copy _1;",
1096            "    _3 = move _2 as i64 (IntToInt);",
1097            "    _4 = Vec::<u8>::with_capacity(move _3);",
1098        ]);
1099
1100        let casts = detect_truncating_len_casts(&function);
1101        assert!(casts.is_empty());
1102    }
1103
1104    #[test]
1105    fn respects_custom_guard_markers() {
1106        let function = function_from_lines(&[
1107            "    _2 = reqwest::Response::content_length(move _1);",
1108            "    _3 = copy _2;",
1109            "    _4 = my_crate::ensure_capacity(move _3);",
1110            "    _5 = Vec::<u8>::with_capacity(move _4);",
1111        ]);
1112
1113        let mut options = PrototypeOptions::default();
1114        options.guard_markers.push("ensure_capacity".to_string());
1115
1116        let findings = detect_content_length_allocations_with_options(&function, &options);
1117        assert!(findings.is_empty());
1118    }
1119
1120    #[test]
1121    fn detects_broadcast_rc_payload() {
1122        let function = function_from_lines(&[
1123            "    _5 = tokio::sync::broadcast::channel::<std::rc::Rc<String>>(const 16_usize);",
1124        ]);
1125
1126        let findings = detect_broadcast_unsync_payloads(&function);
1127        assert_eq!(findings.len(), 1);
1128        assert!(findings[0].line.contains("std::rc::Rc"));
1129    }
1130
1131    #[test]
1132    fn ignores_broadcast_arc_payload() {
1133        let function = function_from_lines(&[
1134            "    _5 = tokio::sync::broadcast::channel::<std::sync::Arc<String>>(const 16_usize);",
1135        ]);
1136
1137        let findings = detect_broadcast_unsync_payloads(&function);
1138        assert!(findings.is_empty());
1139    }
1140
1141    #[test]
1142    fn no_findings_without_taint() {
1143        let function =
1144            function_from_lines(&["    _3 = Vec::<u8>::with_capacity(const 4096_usize);"]);
1145
1146        let findings = detect_content_length_allocations(&function);
1147        assert!(findings.is_empty());
1148    }
1149}