Skip to main content

mdx_rust_analysis/
hardening.rs

1//! Conservative Rust hardening analysis for ordinary Rust modules.
2//!
3//! This module intentionally starts with high-confidence static patterns. It
4//! can inspect normal Rust crates without requiring agent registration.
5
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use std::path::{Path, PathBuf};
9
10#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
11pub struct HardeningAnalysis {
12    pub root: PathBuf,
13    pub target: Option<PathBuf>,
14    pub files_scanned: usize,
15    pub findings: Vec<HardeningFinding>,
16    pub changes: Vec<HardeningFileChange>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
20pub struct HardeningFinding {
21    pub id: String,
22    pub title: String,
23    pub description: String,
24    pub file: PathBuf,
25    pub line: usize,
26    pub strategy: HardeningStrategy,
27    pub patchable: bool,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
31pub enum HardeningStrategy {
32    BorrowParameterTightening,
33    ErrorContextPropagation,
34    IteratorCloned,
35    MechanicalTier1Cleanup,
36    MustUsePublicReturn,
37    ResultUnwrapContext,
38    ProcessExecutionReview,
39    UnsafeReview,
40    EnvAccessReview,
41    FileIoReview,
42    HttpSurfaceReview,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
46pub struct HardeningFileChange {
47    pub file: PathBuf,
48    pub old_content: String,
49    pub new_content: String,
50    pub strategy: HardeningStrategy,
51    pub finding_ids: Vec<String>,
52    pub description: String,
53}
54
55#[derive(Debug, Clone, Copy)]
56pub struct HardeningAnalyzeConfig<'a> {
57    pub target: Option<&'a Path>,
58    pub max_files: usize,
59}
60
61pub fn analyze_hardening(
62    root: &Path,
63    config: HardeningAnalyzeConfig<'_>,
64) -> anyhow::Result<HardeningAnalysis> {
65    let files = collect_rust_files(root, config.target)?;
66    let mut findings = Vec::new();
67    let mut changes = Vec::new();
68
69    for file in files.iter().take(config.max_files) {
70        let content = std::fs::read_to_string(file)?;
71        let rel = relative_path(root, file);
72        let function_ranges = find_function_ranges(&content);
73
74        for (index, line) in content.lines().enumerate() {
75            let line_no = index + 1;
76            let pattern_line = line_without_comments_or_strings(line);
77            let trimmed = pattern_line.trim();
78
79            if trimmed.contains("Command::new(") || trimmed.contains("std::process::Command") {
80                findings.push(HardeningFinding {
81                    id: format!("process-execution:{}:{line_no}", rel.display()),
82                    title: "Process execution surface".to_string(),
83                    description:
84                        "External process execution should have explicit input validation or allowlisting."
85                            .to_string(),
86                    file: rel.clone(),
87                    line: line_no,
88                    strategy: HardeningStrategy::ProcessExecutionReview,
89                    patchable: false,
90                });
91            }
92
93            if trimmed.contains("unsafe ") || trimmed == "unsafe" || trimmed.contains("unsafe{") {
94                findings.push(HardeningFinding {
95                    id: format!("unsafe-rust:{}:{line_no}", rel.display()),
96                    title: "Unsafe Rust requires review".to_string(),
97                    description:
98                        "Unsafe code should be isolated and documented before automated edits touch it."
99                            .to_string(),
100                    file: rel.clone(),
101                    line: line_no,
102                    strategy: HardeningStrategy::UnsafeReview,
103                    patchable: false,
104                });
105            }
106
107            if trimmed.contains("std::env::var(") || trimmed.contains("env::var(") {
108                findings.push(HardeningFinding {
109                    id: format!("env-access:{}:{line_no}", rel.display()),
110                    title: "Environment variable access".to_string(),
111                    description:
112                        "Environment-derived configuration should return contextual errors at boundaries."
113                            .to_string(),
114                    file: rel.clone(),
115                    line: line_no,
116                    strategy: HardeningStrategy::EnvAccessReview,
117                    patchable: false,
118                });
119            }
120
121            let filesystem_call = trimmed.contains("std::fs::read_to_string(")
122                || trimmed.contains("fs::read_to_string(")
123                || trimmed.contains("std::fs::write(")
124                || trimmed.contains("fs::write(");
125            let has_visible_error_handling = trimmed.contains('?')
126                || trimmed.contains(".unwrap(")
127                || trimmed.contains(".expect(");
128            if filesystem_call && !has_visible_error_handling {
129                findings.push(HardeningFinding {
130                    id: format!("file-io:{}:{line_no}", rel.display()),
131                    title: "Filesystem boundary".to_string(),
132                    description:
133                        "Filesystem access should preserve contextual errors and validated paths."
134                            .to_string(),
135                    file: rel.clone(),
136                    line: line_no,
137                    strategy: HardeningStrategy::FileIoReview,
138                    patchable: false,
139                });
140            }
141
142            if trimmed.contains("Router::new(")
143                || trimmed.contains(".route(")
144                || trimmed.contains("#[get(")
145                || trimmed.contains("#[post(")
146            {
147                findings.push(HardeningFinding {
148                    id: format!("http-surface:{}:{line_no}", rel.display()),
149                    title: "HTTP or route surface".to_string(),
150                    description:
151                        "HTTP-facing surfaces should validate inputs and preserve typed errors."
152                            .to_string(),
153                    file: rel.clone(),
154                    line: line_no,
155                    strategy: HardeningStrategy::HttpSurfaceReview,
156                    patchable: false,
157                });
158            }
159        }
160
161        if let Some(change) = build_tier1_mechanical_change(root, file, &content, &function_ranges)?
162        {
163            findings.extend(change.findings);
164            changes.push(change.change);
165        }
166    }
167
168    Ok(HardeningAnalysis {
169        root: root.to_path_buf(),
170        target: config.target.map(Path::to_path_buf),
171        files_scanned: files.len().min(config.max_files),
172        findings,
173        changes,
174    })
175}
176
177struct Tier1MechanicalChange {
178    change: HardeningFileChange,
179    findings: Vec<HardeningFinding>,
180}
181
182fn build_tier1_mechanical_change(
183    root: &Path,
184    file: &Path,
185    content: &str,
186    function_ranges: &[FunctionRange],
187) -> anyhow::Result<Option<Tier1MechanicalChange>> {
188    let rel = relative_path(root, file);
189    let mut lines: Vec<String> = content.lines().map(ToString::to_string).collect();
190    let mut finding_ids = Vec::new();
191    let mut findings = Vec::new();
192
193    apply_result_context_recipe(
194        &rel,
195        &mut lines,
196        function_ranges,
197        &mut finding_ids,
198        &mut findings,
199    );
200    apply_error_context_recipe(
201        &rel,
202        &mut lines,
203        function_ranges,
204        &mut finding_ids,
205        &mut findings,
206    );
207    apply_borrow_parameter_recipe(
208        &rel,
209        &mut lines,
210        function_ranges,
211        &mut finding_ids,
212        &mut findings,
213    );
214    apply_borrowed_vec_literal_recipe(&rel, &mut lines, &mut finding_ids, &mut findings);
215    apply_iterator_cloned_recipe(&rel, &mut lines, &mut finding_ids, &mut findings);
216    apply_must_use_recipe(
217        &rel,
218        &mut lines,
219        function_ranges,
220        &mut finding_ids,
221        &mut findings,
222    );
223
224    if finding_ids.is_empty() {
225        return Ok(None);
226    }
227
228    let mut new_content = lines.join("\n");
229    if content.ends_with('\n') {
230        new_content.push('\n');
231    }
232    if findings.iter().any(|finding| {
233        matches!(
234            finding.strategy,
235            HardeningStrategy::ErrorContextPropagation | HardeningStrategy::ResultUnwrapContext
236        )
237    }) {
238        new_content = ensure_anyhow_context_import(&new_content);
239    }
240    if syn::parse_file(&new_content).is_err() {
241        return Ok(None);
242    }
243
244    Ok(Some(Tier1MechanicalChange {
245        change: HardeningFileChange {
246            file: rel,
247            old_content: content.to_string(),
248            new_content,
249            strategy: HardeningStrategy::MechanicalTier1Cleanup,
250            finding_ids,
251            description:
252                "Apply Tier 1 mechanical hardening recipes under compile and clippy validation."
253                    .to_string(),
254        },
255        findings,
256    }))
257}
258
259fn apply_result_context_recipe(
260    rel: &Path,
261    lines: &mut [String],
262    function_ranges: &[FunctionRange],
263    finding_ids: &mut Vec<String>,
264    findings: &mut Vec<HardeningFinding>,
265) {
266    for range in function_ranges {
267        if !range.returns_anyhow_result {
268            continue;
269        }
270
271        for line_index in range.start_line.saturating_sub(1)..range.end_line.min(lines.len()) {
272            let original = lines[line_index].clone();
273            if original.trim_start().starts_with("//") {
274                continue;
275            }
276
277            let mut rewritten = original.clone();
278            if rewritten.contains(".unwrap()") {
279                rewritten = rewritten.replace(
280                    ".unwrap()",
281                    &format!(".context(\"{} failed instead of panicking\")?", range.name),
282                );
283            }
284            rewritten = replace_expect_calls(&rewritten);
285
286            if rewritten != original {
287                lines[line_index] = rewritten;
288                let line = line_index + 1;
289                let id = format!("unwrap-in-result:{}:{line}", rel.display());
290                finding_ids.push(id.clone());
291                findings.push(HardeningFinding {
292                    id,
293                    title: "Panic-prone unwrap in anyhow Result function".to_string(),
294                    description: "Replace unwrap/expect with anyhow Context and ? so failure is reported instead of panicking.".to_string(),
295                    file: rel.to_path_buf(),
296                    line,
297                    strategy: HardeningStrategy::ResultUnwrapContext,
298                    patchable: true,
299                });
300            }
301        }
302    }
303}
304
305fn apply_error_context_recipe(
306    rel: &Path,
307    lines: &mut [String],
308    function_ranges: &[FunctionRange],
309    finding_ids: &mut Vec<String>,
310    findings: &mut Vec<HardeningFinding>,
311) {
312    for range in function_ranges {
313        if !range.returns_anyhow_result {
314            continue;
315        }
316
317        for line_index in range.start_line.saturating_sub(1)..range.end_line.min(lines.len()) {
318            let original = lines[line_index].clone();
319            if original.trim_start().starts_with("//")
320                || original.contains(".context(")
321                || original.contains(".with_context(")
322            {
323                continue;
324            }
325
326            let pattern_line = line_without_comments_or_strings(&original);
327            let Some(boundary) = boundary_call_kind(&pattern_line) else {
328                continue;
329            };
330            if !pattern_line.contains('?') {
331                continue;
332            }
333
334            let Some(rewritten) = add_context_before_question_mark(
335                &original,
336                &format!("{} failed at {boundary} boundary", range.name),
337            ) else {
338                continue;
339            };
340            if rewritten == original {
341                continue;
342            }
343
344            lines[line_index] = rewritten;
345            let line = line_index + 1;
346            let id = format!("error-context-propagation:{}:{line}", rel.display());
347            finding_ids.push(id.clone());
348            findings.push(HardeningFinding {
349                id,
350                title: "Propagate boundary errors with context".to_string(),
351                description: "Add anyhow Context to fallible boundary calls that already use ? so failures explain where they came from.".to_string(),
352                file: rel.to_path_buf(),
353                line,
354                strategy: HardeningStrategy::ErrorContextPropagation,
355                patchable: true,
356            });
357        }
358    }
359}
360
361fn boundary_call_kind(line: &str) -> Option<&'static str> {
362    if line.contains("std::fs::")
363        || line.contains("fs::read")
364        || line.contains("fs::write")
365        || line.contains("File::open(")
366    {
367        Some("filesystem")
368    } else if line.contains("std::env::var(") || line.contains("env::var(") {
369        Some("environment")
370    } else {
371        None
372    }
373}
374
375fn add_context_before_question_mark(line: &str, message: &str) -> Option<String> {
376    let question = line.find('?')?;
377    let (before, after) = line.split_at(question);
378    Some(format!(
379        "{}.context(\"{}\"){}",
380        before,
381        escape_string(message),
382        after
383    ))
384}
385
386fn apply_borrow_parameter_recipe(
387    rel: &Path,
388    lines: &mut [String],
389    function_ranges: &[FunctionRange],
390    finding_ids: &mut Vec<String>,
391    findings: &mut Vec<HardeningFinding>,
392) {
393    for range in function_ranges {
394        if range.is_public {
395            continue;
396        }
397
398        let start = range.signature_start_line.saturating_sub(1);
399        let end = range.signature_end_line.min(lines.len());
400        let mut changed = false;
401        for line in &mut lines[start..end] {
402            let original = line.clone();
403            let tightened = tighten_borrow_parameters(&original);
404            if tightened != original {
405                *line = tightened;
406                changed = true;
407            }
408        }
409
410        if changed {
411            let id = format!(
412                "borrow-parameter-tightening:{}:{}",
413                rel.display(),
414                range.signature_start_line
415            );
416            finding_ids.push(id.clone());
417            findings.push(HardeningFinding {
418                id,
419                title: "Tighten private borrowed parameter type".to_string(),
420                description: "Prefer &str and slices over borrowed owned containers in private functions when compile gates prove the change.".to_string(),
421                file: rel.to_path_buf(),
422                line: range.signature_start_line,
423                strategy: HardeningStrategy::BorrowParameterTightening,
424                patchable: true,
425            });
426        }
427    }
428}
429
430fn apply_must_use_recipe(
431    rel: &Path,
432    lines: &mut Vec<String>,
433    function_ranges: &[FunctionRange],
434    finding_ids: &mut Vec<String>,
435    findings: &mut Vec<HardeningFinding>,
436) {
437    let mut inserted = 0usize;
438    for range in function_ranges {
439        if !range.is_public || !range.returns_value || range.returns_common_must_use {
440            continue;
441        }
442        if has_nearby_must_use(lines, range.signature_start_line + inserted) {
443            continue;
444        }
445
446        let insert_at = range.signature_start_line.saturating_sub(1) + inserted;
447        let indent: String = lines
448            .get(insert_at)
449            .map(|line| line.chars().take_while(|ch| ch.is_whitespace()).collect())
450            .unwrap_or_default();
451        lines.insert(insert_at, format!("{indent}#[must_use]"));
452        inserted += 1;
453
454        let id = format!(
455            "must-use-public-return:{}:{}",
456            rel.display(),
457            range.signature_start_line
458        );
459        finding_ids.push(id.clone());
460        findings.push(HardeningFinding {
461            id,
462            title: "Public return value should be marked must_use".to_string(),
463            description: "Add #[must_use] to public value-returning functions so ignored results are visible to callers.".to_string(),
464            file: rel.to_path_buf(),
465            line: range.signature_start_line,
466            strategy: HardeningStrategy::MustUsePublicReturn,
467            patchable: true,
468        });
469    }
470}
471
472fn apply_iterator_cloned_recipe(
473    rel: &Path,
474    lines: &mut [String],
475    finding_ids: &mut Vec<String>,
476    findings: &mut Vec<HardeningFinding>,
477) {
478    for (line_index, line) in lines.iter_mut().enumerate() {
479        if line.trim_start().starts_with("//") {
480            continue;
481        }
482        let original = line.clone();
483        let rewritten = replace_map_clone_calls(&original);
484        if rewritten == original {
485            continue;
486        }
487
488        *line = rewritten;
489        let line_no = line_index + 1;
490        let id = format!("iterator-cloned:{}:{line_no}", rel.display());
491        finding_ids.push(id.clone());
492        findings.push(HardeningFinding {
493            id,
494            title: "Simplify iterator clone collection".to_string(),
495            description: "Replace clone-mapping collection with a simpler form when compile gates prove the iterator item type.".to_string(),
496            file: rel.to_path_buf(),
497            line: line_no,
498            strategy: HardeningStrategy::IteratorCloned,
499            patchable: true,
500        });
501    }
502}
503
504fn apply_borrowed_vec_literal_recipe(
505    rel: &Path,
506    lines: &mut [String],
507    finding_ids: &mut Vec<String>,
508    findings: &mut Vec<HardeningFinding>,
509) {
510    for (line_index, line) in lines.iter_mut().enumerate() {
511        if line.trim_start().starts_with("//") || !line.contains("&vec![") {
512            continue;
513        }
514
515        *line = line.replace("&vec![", "&[");
516        let line_no = line_index + 1;
517        let id = format!("borrowed-vec-literal:{}:{line_no}", rel.display());
518        finding_ids.push(id.clone());
519        findings.push(HardeningFinding {
520            id,
521            title: "Use a borrowed slice literal".to_string(),
522            description: "Replace &vec![..] with a borrowed slice literal when validation proves the callsite.".to_string(),
523            file: rel.to_path_buf(),
524            line: line_no,
525            strategy: HardeningStrategy::BorrowParameterTightening,
526            patchable: true,
527        });
528    }
529}
530
531fn replace_map_clone_calls(line: &str) -> String {
532    let mut output = String::new();
533    let mut rest = line;
534    while let Some(start) = rest.find(".map(|") {
535        let (before, after_start) = rest.split_at(start);
536        output.push_str(before);
537        let Some((variable, after_variable)) = after_start[".map(|".len()..].split_once('|') else {
538            output.push_str(after_start);
539            return output;
540        };
541        let variable = variable.trim();
542        if variable.is_empty()
543            || !variable
544                .chars()
545                .all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
546        {
547            output.push_str(after_start);
548            return output;
549        }
550
551        let expected = format!(" {}.clone())", variable);
552        let trimmed_expected = format!("{}.clone())", variable);
553        if let Some(next) = after_variable.strip_prefix(&expected) {
554            rest = push_clone_replacement(&mut output, next);
555        } else if let Some(next) = after_variable.strip_prefix(&trimmed_expected) {
556            rest = push_clone_replacement(&mut output, next);
557        } else {
558            output.push_str(".map(|");
559            rest = &after_start[".map(|".len()..];
560        }
561    }
562    output.push_str(rest);
563    output
564}
565
566fn push_clone_replacement<'a>(output: &mut String, next: &'a str) -> &'a str {
567    if next.starts_with(".collect()") && output.ends_with(".iter()") {
568        output.truncate(output.len() - ".iter()".len());
569        output.push_str(".to_vec()");
570        &next[".collect()".len()..]
571    } else {
572        output.push_str(".cloned()");
573        next
574    }
575}
576
577fn tighten_borrow_parameters(line: &str) -> String {
578    replace_borrowed_vec(&line.replace("&String", "&str"))
579}
580
581fn replace_borrowed_vec(line: &str) -> String {
582    let mut output = String::new();
583    let mut index = 0usize;
584    while let Some(relative_start) = line[index..].find("&Vec<") {
585        let start = index + relative_start;
586        output.push_str(&line[index..start]);
587        let generic_start = start + "&Vec<".len();
588        let Some(generic_end) = matching_angle_end(line, generic_start) else {
589            output.push_str(&line[start..]);
590            return output;
591        };
592        output.push_str("&[");
593        output.push_str(&line[generic_start..generic_end]);
594        output.push(']');
595        index = generic_end + 1;
596    }
597    output.push_str(&line[index..]);
598    output
599}
600
601fn matching_angle_end(value: &str, start: usize) -> Option<usize> {
602    let mut depth = 1isize;
603    for (offset, ch) in value[start..].char_indices() {
604        match ch {
605            '<' => depth += 1,
606            '>' => {
607                depth -= 1;
608                if depth == 0 {
609                    return Some(start + offset);
610                }
611            }
612            _ => {}
613        }
614    }
615    None
616}
617
618fn has_nearby_must_use(lines: &[String], signature_line: usize) -> bool {
619    let signature_index = signature_line.saturating_sub(1);
620    let start = signature_index.saturating_sub(4);
621    lines[start..signature_index.min(lines.len())]
622        .iter()
623        .any(|line| line.contains("must_use"))
624}
625
626fn replace_expect_calls(line: &str) -> String {
627    let mut output = String::new();
628    let mut rest = line;
629    while let Some(start) = rest.find(".expect(\"") {
630        let (before, after_start) = rest.split_at(start);
631        output.push_str(before);
632        let msg_start = ".expect(\"".len();
633        let after_msg_start = &after_start[msg_start..];
634        if let Some(end) = after_msg_start.find("\")") {
635            let message = &after_msg_start[..end];
636            output.push_str(&format!(".context(\"{}\")?", escape_string(message)));
637            rest = &after_msg_start[end + 2..];
638        } else {
639            output.push_str(after_start);
640            rest = "";
641        }
642    }
643    output.push_str(rest);
644    output
645}
646
647fn escape_string(value: &str) -> String {
648    value.replace('\\', "\\\\").replace('"', "\\\"")
649}
650
651fn line_without_comments_or_strings(line: &str) -> String {
652    let mut output = String::with_capacity(line.len());
653    let mut chars = line.chars().peekable();
654    let mut in_string = false;
655    let mut escaped = false;
656
657    while let Some(ch) = chars.next() {
658        if !in_string && ch == '/' && chars.peek() == Some(&'/') {
659            break;
660        }
661
662        if ch == '"' && !escaped {
663            in_string = !in_string;
664            output.push(' ');
665            continue;
666        }
667
668        if in_string {
669            escaped = ch == '\\' && !escaped;
670            output.push(' ');
671            continue;
672        }
673
674        escaped = false;
675        output.push(ch);
676    }
677
678    output
679}
680
681fn ensure_anyhow_context_import(content: &str) -> String {
682    if content.contains("anyhow::Context") || content.contains("Context,") {
683        return content.to_string();
684    }
685
686    let mut lines: Vec<&str> = content.lines().collect();
687    let insert_at = lines
688        .iter()
689        .position(|line| !line.starts_with("#![") && !line.trim().is_empty())
690        .unwrap_or(0);
691    lines.insert(insert_at, "use anyhow::Context;");
692    let mut result = lines.join("\n");
693    if content.ends_with('\n') {
694        result.push('\n');
695    }
696    result
697}
698
699#[derive(Debug)]
700struct FunctionRange {
701    name: String,
702    start_line: usize,
703    end_line: usize,
704    signature_start_line: usize,
705    signature_end_line: usize,
706    is_public: bool,
707    returns_anyhow_result: bool,
708    returns_value: bool,
709    returns_common_must_use: bool,
710}
711
712fn find_function_ranges(content: &str) -> Vec<FunctionRange> {
713    let lines: Vec<&str> = content.lines().collect();
714    let has_anyhow_result_alias =
715        content.contains("use anyhow::Result") || content.contains("use anyhow::{Result");
716    let mut ranges = Vec::new();
717    let mut index = 0;
718    while index < lines.len() {
719        let line = lines[index];
720        if !line.contains("fn ") {
721            index += 1;
722            continue;
723        }
724
725        let mut signature = line.to_string();
726        let start_line = index + 1;
727        let mut open_line = index;
728        while !signature.contains('{') && open_line + 1 < lines.len() {
729            open_line += 1;
730            signature.push(' ');
731            signature.push_str(lines[open_line]);
732        }
733
734        if !signature.contains('{') {
735            index += 1;
736            continue;
737        }
738
739        let Some(name) = function_name(&signature) else {
740            index += 1;
741            continue;
742        };
743
744        let mut depth = 0isize;
745        let mut end_line = open_line + 1;
746        for (body_index, body_line) in lines.iter().enumerate().skip(open_line) {
747            depth += body_line.matches('{').count() as isize;
748            depth -= body_line.matches('}').count() as isize;
749            end_line = body_index + 1;
750            if depth == 0 {
751                break;
752            }
753        }
754
755        let return_text = signature
756            .split_once("->")
757            .map(|(_, rest)| rest.split('{').next().unwrap_or_default().trim())
758            .unwrap_or_default();
759        let returns_anyhow_result = return_text.starts_with("anyhow::Result")
760            || (has_anyhow_result_alias && return_text.starts_with("Result<"));
761        let returns_value = !return_text.is_empty() && return_text != "()";
762        let returns_common_must_use = return_text.starts_with("Result<")
763            || return_text.starts_with("anyhow::Result")
764            || return_text.starts_with("Option<")
765            || signature.contains("async fn ");
766        ranges.push(FunctionRange {
767            name,
768            start_line,
769            end_line,
770            signature_start_line: start_line,
771            signature_end_line: open_line + 1,
772            is_public: signature.trim_start().starts_with("pub "),
773            returns_anyhow_result,
774            returns_value,
775            returns_common_must_use,
776        });
777        index = end_line;
778    }
779    ranges
780}
781
782fn function_name(signature: &str) -> Option<String> {
783    let rest = signature.split_once("fn ")?.1;
784    let name = rest
785        .split(|c: char| !(c.is_alphanumeric() || c == '_'))
786        .next()?;
787    if name.is_empty() {
788        None
789    } else {
790        Some(name.to_string())
791    }
792}
793
794fn collect_rust_files(root: &Path, target: Option<&Path>) -> anyhow::Result<Vec<PathBuf>> {
795    let scan_root = target
796        .map(|path| {
797            if path.is_absolute() {
798                path.to_path_buf()
799            } else {
800                root.join(path)
801            }
802        })
803        .unwrap_or_else(|| root.to_path_buf());
804    if !scan_root.starts_with(root) {
805        anyhow::bail!("hardening target is outside root: {}", scan_root.display());
806    }
807
808    if scan_root.is_file() {
809        return Ok(if scan_root.extension().is_some_and(|ext| ext == "rs") {
810            vec![scan_root]
811        } else {
812            Vec::new()
813        });
814    }
815
816    let mut files = Vec::new();
817    for result in ignore::WalkBuilder::new(scan_root)
818        .hidden(false)
819        .filter_entry(|entry| {
820            let name = entry.file_name().to_string_lossy();
821            !matches!(
822                name.as_ref(),
823                "target" | ".git" | ".worktrees" | ".mdx-rust"
824            )
825        })
826        .build()
827    {
828        let entry = result?;
829        let path = entry.path();
830        if path.is_file() && path.extension().is_some_and(|ext| ext == "rs") {
831            files.push(path.to_path_buf());
832        }
833    }
834    files.sort();
835    Ok(files)
836}
837
838fn relative_path(root: &Path, path: &Path) -> PathBuf {
839    path.strip_prefix(root).unwrap_or(path).to_path_buf()
840}
841
842#[cfg(test)]
843mod tests {
844    use super::*;
845    use tempfile::tempdir;
846
847    #[test]
848    fn hardening_rewrites_unwrap_in_anyhow_result_function() {
849        let dir = tempdir().unwrap();
850        let src = dir.path().join("src");
851        std::fs::create_dir_all(&src).unwrap();
852        std::fs::write(
853            src.join("lib.rs"),
854            r#"pub fn load() -> anyhow::Result<String> {
855    let value = std::fs::read_to_string("config.toml").unwrap();
856    Ok(value)
857}
858"#,
859        )
860        .unwrap();
861
862        let analysis = analyze_hardening(
863            dir.path(),
864            HardeningAnalyzeConfig {
865                target: None,
866                max_files: 10,
867            },
868        )
869        .unwrap();
870
871        assert_eq!(analysis.changes.len(), 1);
872        let change = &analysis.changes[0];
873        assert!(change.new_content.contains("use anyhow::Context;"));
874        assert!(change
875            .new_content
876            .contains(".context(\"load failed instead of panicking\")?"));
877        assert!(syn::parse_file(&change.new_content).is_ok());
878    }
879
880    #[test]
881    fn hardening_adds_context_to_question_mark_boundaries() {
882        let dir = tempdir().unwrap();
883        let src = dir.path().join("src");
884        std::fs::create_dir_all(&src).unwrap();
885        std::fs::write(
886            src.join("lib.rs"),
887            r#"pub fn load(path: &str) -> anyhow::Result<String> {
888    let value = std::fs::read_to_string(path)?;
889    Ok(value)
890}
891"#,
892        )
893        .unwrap();
894
895        let analysis = analyze_hardening(
896            dir.path(),
897            HardeningAnalyzeConfig {
898                target: None,
899                max_files: 10,
900            },
901        )
902        .unwrap();
903
904        assert_eq!(analysis.changes.len(), 1);
905        let change = &analysis.changes[0];
906        assert!(change.new_content.contains("use anyhow::Context;"));
907        assert!(change
908            .new_content
909            .contains(".context(\"load failed at filesystem boundary\")?"));
910        assert!(change
911            .finding_ids
912            .iter()
913            .any(|id| id.contains("error-context-propagation")));
914        assert!(syn::parse_file(&change.new_content).is_ok());
915    }
916
917    #[test]
918    fn hardening_does_not_rewrite_plain_result_without_anyhow_alias() {
919        let dir = tempdir().unwrap();
920        let src = dir.path().join("src");
921        std::fs::create_dir_all(&src).unwrap();
922        std::fs::write(
923            src.join("lib.rs"),
924            r#"pub fn load() -> Result<String, std::io::Error> {
925    let value = std::fs::read_to_string("config.toml").unwrap();
926    Ok(value)
927}
928"#,
929        )
930        .unwrap();
931
932        let analysis = analyze_hardening(
933            dir.path(),
934            HardeningAnalyzeConfig {
935                target: None,
936                max_files: 10,
937            },
938        )
939        .unwrap();
940
941        assert!(analysis.changes.is_empty());
942    }
943
944    #[test]
945    fn hardening_tightens_private_borrowed_owned_parameters() {
946        let dir = tempdir().unwrap();
947        let src = dir.path().join("src");
948        std::fs::create_dir_all(&src).unwrap();
949        std::fs::write(
950            src.join("lib.rs"),
951            r#"fn score(name: &String, values: &Vec<u8>) -> usize {
952    name.len() + values.len()
953}
954"#,
955        )
956        .unwrap();
957
958        let analysis = analyze_hardening(
959            dir.path(),
960            HardeningAnalyzeConfig {
961                target: None,
962                max_files: 10,
963            },
964        )
965        .unwrap();
966
967        assert_eq!(analysis.changes.len(), 1);
968        let change = &analysis.changes[0];
969        assert!(change
970            .new_content
971            .contains("fn score(name: &str, values: &[u8])"));
972        assert!(change
973            .finding_ids
974            .iter()
975            .any(|id| id.contains("borrow-parameter-tightening")));
976        assert!(syn::parse_file(&change.new_content).is_ok());
977    }
978
979    #[test]
980    fn hardening_marks_public_value_returns_must_use() {
981        let dir = tempdir().unwrap();
982        let src = dir.path().join("src");
983        std::fs::create_dir_all(&src).unwrap();
984        std::fs::write(
985            src.join("lib.rs"),
986            r#"pub fn total(values: &[u8]) -> usize {
987    values.iter().map(|value| *value as usize).sum()
988}
989"#,
990        )
991        .unwrap();
992
993        let analysis = analyze_hardening(
994            dir.path(),
995            HardeningAnalyzeConfig {
996                target: None,
997                max_files: 10,
998            },
999        )
1000        .unwrap();
1001
1002        assert_eq!(analysis.changes.len(), 1);
1003        let change = &analysis.changes[0];
1004        assert!(change.new_content.contains("#[must_use]\npub fn total"));
1005        assert!(change
1006            .finding_ids
1007            .iter()
1008            .any(|id| id.contains("must-use-public-return")));
1009        assert!(syn::parse_file(&change.new_content).is_ok());
1010    }
1011
1012    #[test]
1013    fn hardening_replaces_map_clone_collect_with_to_vec() {
1014        let dir = tempdir().unwrap();
1015        let src = dir.path().join("src");
1016        std::fs::create_dir_all(&src).unwrap();
1017        std::fs::write(
1018            src.join("lib.rs"),
1019            r#"pub fn copy_values(values: &[String]) -> Vec<String> {
1020    values.iter().map(|value| value.clone()).collect()
1021}
1022"#,
1023        )
1024        .unwrap();
1025
1026        let analysis = analyze_hardening(
1027            dir.path(),
1028            HardeningAnalyzeConfig {
1029                target: None,
1030                max_files: 10,
1031            },
1032        )
1033        .unwrap();
1034
1035        assert_eq!(analysis.changes.len(), 1);
1036        let change = &analysis.changes[0];
1037        assert!(change.new_content.contains("values.to_vec()"));
1038        assert!(change
1039            .finding_ids
1040            .iter()
1041            .any(|id| id.contains("iterator-cloned")));
1042        assert!(syn::parse_file(&change.new_content).is_ok());
1043    }
1044
1045    #[test]
1046    fn hardening_does_not_flag_patterns_inside_strings_or_comments() {
1047        let dir = tempdir().unwrap();
1048        let src = dir.path().join("src");
1049        std::fs::create_dir_all(&src).unwrap();
1050        std::fs::write(
1051            src.join("lib.rs"),
1052            r#"fn describe() -> &'static str {
1053    // Command::new("ignored")
1054    "unsafe std::process::Command env::var("
1055}
1056"#,
1057        )
1058        .unwrap();
1059
1060        let analysis = analyze_hardening(
1061            dir.path(),
1062            HardeningAnalyzeConfig {
1063                target: None,
1064                max_files: 10,
1065            },
1066        )
1067        .unwrap();
1068
1069        assert!(analysis.findings.is_empty(), "{:?}", analysis.findings);
1070    }
1071}