Skip to main content

mdx_rust_analysis/
hardening.rs

1//! Conservative Rust hardening analysis for ordinary Rust modules.
2//!
3//! This module intentionally starts with high-confidence static patterns. It
4//! can inspect normal Rust crates without requiring agent registration.
5
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use std::path::{Path, PathBuf};
9
10#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
11pub struct HardeningAnalysis {
12    pub root: PathBuf,
13    pub target: Option<PathBuf>,
14    pub files_scanned: usize,
15    pub findings: Vec<HardeningFinding>,
16    pub changes: Vec<HardeningFileChange>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
20pub struct HardeningFinding {
21    pub id: String,
22    pub title: String,
23    pub description: String,
24    pub file: PathBuf,
25    pub line: usize,
26    pub strategy: HardeningStrategy,
27    pub patchable: bool,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
31pub enum HardeningStrategy {
32    ResultUnwrapContext,
33    ProcessExecutionReview,
34    UnsafeReview,
35    EnvAccessReview,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
39pub struct HardeningFileChange {
40    pub file: PathBuf,
41    pub old_content: String,
42    pub new_content: String,
43    pub strategy: HardeningStrategy,
44    pub finding_ids: Vec<String>,
45    pub description: String,
46}
47
48#[derive(Debug, Clone, Copy)]
49pub struct HardeningAnalyzeConfig<'a> {
50    pub target: Option<&'a Path>,
51    pub max_files: usize,
52}
53
54pub fn analyze_hardening(
55    root: &Path,
56    config: HardeningAnalyzeConfig<'_>,
57) -> anyhow::Result<HardeningAnalysis> {
58    let files = collect_rust_files(root, config.target)?;
59    let mut findings = Vec::new();
60    let mut changes = Vec::new();
61
62    for file in files.iter().take(config.max_files) {
63        let content = std::fs::read_to_string(file)?;
64        let rel = relative_path(root, file);
65        let function_ranges = find_function_ranges(&content);
66
67        for (index, line) in content.lines().enumerate() {
68            let line_no = index + 1;
69            let pattern_line = line_without_comments_or_strings(line);
70            let trimmed = pattern_line.trim();
71
72            if trimmed.contains("Command::new(") || trimmed.contains("std::process::Command") {
73                findings.push(HardeningFinding {
74                    id: format!("process-execution:{}:{line_no}", rel.display()),
75                    title: "Process execution surface".to_string(),
76                    description:
77                        "External process execution should have explicit input validation or allowlisting."
78                            .to_string(),
79                    file: rel.clone(),
80                    line: line_no,
81                    strategy: HardeningStrategy::ProcessExecutionReview,
82                    patchable: false,
83                });
84            }
85
86            if trimmed.contains("unsafe ") || trimmed == "unsafe" || trimmed.contains("unsafe{") {
87                findings.push(HardeningFinding {
88                    id: format!("unsafe-rust:{}:{line_no}", rel.display()),
89                    title: "Unsafe Rust requires review".to_string(),
90                    description:
91                        "Unsafe code should be isolated and documented before automated edits touch it."
92                            .to_string(),
93                    file: rel.clone(),
94                    line: line_no,
95                    strategy: HardeningStrategy::UnsafeReview,
96                    patchable: false,
97                });
98            }
99
100            if trimmed.contains("std::env::var(") || trimmed.contains("env::var(") {
101                findings.push(HardeningFinding {
102                    id: format!("env-access:{}:{line_no}", rel.display()),
103                    title: "Environment variable access".to_string(),
104                    description:
105                        "Environment-derived configuration should return contextual errors at boundaries."
106                            .to_string(),
107                    file: rel.clone(),
108                    line: line_no,
109                    strategy: HardeningStrategy::EnvAccessReview,
110                    patchable: false,
111                });
112            }
113        }
114
115        if let Some(change) = build_result_context_change(root, file, &content, &function_ranges)? {
116            for id in &change.finding_ids {
117                if !findings.iter().any(|finding| &finding.id == id) {
118                    let line = id
119                        .rsplit(':')
120                        .next()
121                        .and_then(|line| line.parse::<usize>().ok())
122                        .unwrap_or(1);
123                    findings.push(HardeningFinding {
124                        id: id.clone(),
125                        title: "Panic-prone unwrap in anyhow Result function".to_string(),
126                        description: "Replace unwrap/expect with anyhow Context and ? so failure is reported instead of panicking.".to_string(),
127                        file: rel.clone(),
128                        line,
129                        strategy: HardeningStrategy::ResultUnwrapContext,
130                        patchable: true,
131                    });
132                }
133            }
134            changes.push(change);
135        }
136    }
137
138    Ok(HardeningAnalysis {
139        root: root.to_path_buf(),
140        target: config.target.map(Path::to_path_buf),
141        files_scanned: files.len().min(config.max_files),
142        findings,
143        changes,
144    })
145}
146
147fn build_result_context_change(
148    root: &Path,
149    file: &Path,
150    content: &str,
151    function_ranges: &[FunctionRange],
152) -> anyhow::Result<Option<HardeningFileChange>> {
153    let rel = relative_path(root, file);
154    let mut lines: Vec<String> = content.lines().map(ToString::to_string).collect();
155    let mut changed = false;
156    let mut finding_ids = Vec::new();
157
158    for range in function_ranges {
159        if !range.returns_anyhow_result {
160            continue;
161        }
162
163        for line_index in range.start_line.saturating_sub(1)..range.end_line.min(lines.len()) {
164            let original = lines[line_index].clone();
165            if original.trim_start().starts_with("//") {
166                continue;
167            }
168
169            let mut rewritten = original.clone();
170            if rewritten.contains(".unwrap()") {
171                rewritten = rewritten.replace(
172                    ".unwrap()",
173                    &format!(".context(\"{} failed instead of panicking\")?", range.name),
174                );
175            }
176            rewritten = replace_expect_calls(&rewritten);
177
178            if rewritten != original {
179                changed = true;
180                lines[line_index] = rewritten;
181                finding_ids.push(format!(
182                    "unwrap-in-result:{}:{}",
183                    rel.display(),
184                    line_index + 1
185                ));
186            }
187        }
188    }
189
190    if !changed {
191        return Ok(None);
192    }
193
194    let mut new_content = lines.join("\n");
195    if content.ends_with('\n') {
196        new_content.push('\n');
197    }
198    new_content = ensure_anyhow_context_import(&new_content);
199    if syn::parse_file(&new_content).is_err() {
200        return Ok(None);
201    }
202
203    Ok(Some(HardeningFileChange {
204        file: rel,
205        old_content: content.to_string(),
206        new_content,
207        strategy: HardeningStrategy::ResultUnwrapContext,
208        finding_ids,
209        description:
210            "Replace panic-prone unwrap/expect calls in anyhow Result functions with Context and ?."
211                .to_string(),
212    }))
213}
214
215fn replace_expect_calls(line: &str) -> String {
216    let mut output = String::new();
217    let mut rest = line;
218    while let Some(start) = rest.find(".expect(\"") {
219        let (before, after_start) = rest.split_at(start);
220        output.push_str(before);
221        let msg_start = ".expect(\"".len();
222        let after_msg_start = &after_start[msg_start..];
223        if let Some(end) = after_msg_start.find("\")") {
224            let message = &after_msg_start[..end];
225            output.push_str(&format!(".context(\"{}\")?", escape_string(message)));
226            rest = &after_msg_start[end + 2..];
227        } else {
228            output.push_str(after_start);
229            rest = "";
230        }
231    }
232    output.push_str(rest);
233    output
234}
235
236fn escape_string(value: &str) -> String {
237    value.replace('\\', "\\\\").replace('"', "\\\"")
238}
239
240fn line_without_comments_or_strings(line: &str) -> String {
241    let mut output = String::with_capacity(line.len());
242    let mut chars = line.chars().peekable();
243    let mut in_string = false;
244    let mut escaped = false;
245
246    while let Some(ch) = chars.next() {
247        if !in_string && ch == '/' && chars.peek() == Some(&'/') {
248            break;
249        }
250
251        if ch == '"' && !escaped {
252            in_string = !in_string;
253            output.push(' ');
254            continue;
255        }
256
257        if in_string {
258            escaped = ch == '\\' && !escaped;
259            output.push(' ');
260            continue;
261        }
262
263        escaped = false;
264        output.push(ch);
265    }
266
267    output
268}
269
270fn ensure_anyhow_context_import(content: &str) -> String {
271    if content.contains("anyhow::Context") || content.contains("Context,") {
272        return content.to_string();
273    }
274
275    let mut lines: Vec<&str> = content.lines().collect();
276    let insert_at = lines
277        .iter()
278        .position(|line| !line.starts_with("#![") && !line.trim().is_empty())
279        .unwrap_or(0);
280    lines.insert(insert_at, "use anyhow::Context;");
281    let mut result = lines.join("\n");
282    if content.ends_with('\n') {
283        result.push('\n');
284    }
285    result
286}
287
288#[derive(Debug)]
289struct FunctionRange {
290    name: String,
291    start_line: usize,
292    end_line: usize,
293    returns_anyhow_result: bool,
294}
295
296fn find_function_ranges(content: &str) -> Vec<FunctionRange> {
297    let lines: Vec<&str> = content.lines().collect();
298    let has_anyhow_result_alias =
299        content.contains("use anyhow::Result") || content.contains("use anyhow::{Result");
300    let mut ranges = Vec::new();
301    let mut index = 0;
302    while index < lines.len() {
303        let line = lines[index];
304        if !line.contains("fn ") {
305            index += 1;
306            continue;
307        }
308
309        let mut signature = line.to_string();
310        let start_line = index + 1;
311        let mut open_line = index;
312        while !signature.contains('{') && open_line + 1 < lines.len() {
313            open_line += 1;
314            signature.push(' ');
315            signature.push_str(lines[open_line]);
316        }
317
318        if !signature.contains('{') {
319            index += 1;
320            continue;
321        }
322
323        let Some(name) = function_name(&signature) else {
324            index += 1;
325            continue;
326        };
327
328        let mut depth = 0isize;
329        let mut end_line = open_line + 1;
330        for (body_index, body_line) in lines.iter().enumerate().skip(open_line) {
331            depth += body_line.matches('{').count() as isize;
332            depth -= body_line.matches('}').count() as isize;
333            end_line = body_index + 1;
334            if depth == 0 {
335                break;
336            }
337        }
338
339        let returns_anyhow_result = signature.contains("-> anyhow::Result")
340            || (has_anyhow_result_alias && signature.contains("-> Result<"));
341        ranges.push(FunctionRange {
342            name,
343            start_line,
344            end_line,
345            returns_anyhow_result,
346        });
347        index = end_line;
348    }
349    ranges
350}
351
352fn function_name(signature: &str) -> Option<String> {
353    let rest = signature.split_once("fn ")?.1;
354    let name = rest
355        .split(|c: char| !(c.is_alphanumeric() || c == '_'))
356        .next()?;
357    if name.is_empty() {
358        None
359    } else {
360        Some(name.to_string())
361    }
362}
363
364fn collect_rust_files(root: &Path, target: Option<&Path>) -> anyhow::Result<Vec<PathBuf>> {
365    let scan_root = target
366        .map(|path| {
367            if path.is_absolute() {
368                path.to_path_buf()
369            } else {
370                root.join(path)
371            }
372        })
373        .unwrap_or_else(|| root.to_path_buf());
374    if !scan_root.starts_with(root) {
375        anyhow::bail!("hardening target is outside root: {}", scan_root.display());
376    }
377
378    if scan_root.is_file() {
379        return Ok(if scan_root.extension().is_some_and(|ext| ext == "rs") {
380            vec![scan_root]
381        } else {
382            Vec::new()
383        });
384    }
385
386    let mut files = Vec::new();
387    for result in ignore::WalkBuilder::new(scan_root)
388        .hidden(false)
389        .filter_entry(|entry| {
390            let name = entry.file_name().to_string_lossy();
391            !matches!(
392                name.as_ref(),
393                "target" | ".git" | ".worktrees" | ".mdx-rust"
394            )
395        })
396        .build()
397    {
398        let entry = result?;
399        let path = entry.path();
400        if path.is_file() && path.extension().is_some_and(|ext| ext == "rs") {
401            files.push(path.to_path_buf());
402        }
403    }
404    files.sort();
405    Ok(files)
406}
407
408fn relative_path(root: &Path, path: &Path) -> PathBuf {
409    path.strip_prefix(root).unwrap_or(path).to_path_buf()
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use tempfile::tempdir;
416
417    #[test]
418    fn hardening_rewrites_unwrap_in_anyhow_result_function() {
419        let dir = tempdir().unwrap();
420        let src = dir.path().join("src");
421        std::fs::create_dir_all(&src).unwrap();
422        std::fs::write(
423            src.join("lib.rs"),
424            r#"pub fn load() -> anyhow::Result<String> {
425    let value = std::fs::read_to_string("config.toml").unwrap();
426    Ok(value)
427}
428"#,
429        )
430        .unwrap();
431
432        let analysis = analyze_hardening(
433            dir.path(),
434            HardeningAnalyzeConfig {
435                target: None,
436                max_files: 10,
437            },
438        )
439        .unwrap();
440
441        assert_eq!(analysis.changes.len(), 1);
442        let change = &analysis.changes[0];
443        assert!(change.new_content.contains("use anyhow::Context;"));
444        assert!(change
445            .new_content
446            .contains(".context(\"load failed instead of panicking\")?"));
447        assert!(syn::parse_file(&change.new_content).is_ok());
448    }
449
450    #[test]
451    fn hardening_does_not_rewrite_plain_result_without_anyhow_alias() {
452        let dir = tempdir().unwrap();
453        let src = dir.path().join("src");
454        std::fs::create_dir_all(&src).unwrap();
455        std::fs::write(
456            src.join("lib.rs"),
457            r#"pub fn load() -> Result<String, std::io::Error> {
458    let value = std::fs::read_to_string("config.toml").unwrap();
459    Ok(value)
460}
461"#,
462        )
463        .unwrap();
464
465        let analysis = analyze_hardening(
466            dir.path(),
467            HardeningAnalyzeConfig {
468                target: None,
469                max_files: 10,
470            },
471        )
472        .unwrap();
473
474        assert!(analysis.changes.is_empty());
475    }
476
477    #[test]
478    fn hardening_does_not_flag_patterns_inside_strings_or_comments() {
479        let dir = tempdir().unwrap();
480        let src = dir.path().join("src");
481        std::fs::create_dir_all(&src).unwrap();
482        std::fs::write(
483            src.join("lib.rs"),
484            r#"pub fn describe() -> &'static str {
485    // Command::new("ignored")
486    "unsafe std::process::Command env::var("
487}
488"#,
489        )
490        .unwrap();
491
492        let analysis = analyze_hardening(
493            dir.path(),
494            HardeningAnalyzeConfig {
495                target: None,
496                max_files: 10,
497            },
498        )
499        .unwrap();
500
501        assert!(analysis.findings.is_empty(), "{:?}", analysis.findings);
502    }
503}