Skip to main content

mdx_rust_analysis/
hardening.rs

1//! Conservative Rust hardening analysis for ordinary Rust modules.
2//!
3//! This module intentionally starts with high-confidence static patterns. It
4//! can inspect normal Rust crates without requiring agent registration.
5
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use std::path::{Path, PathBuf};
9
10#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
11pub struct HardeningAnalysis {
12    pub root: PathBuf,
13    pub target: Option<PathBuf>,
14    pub files_scanned: usize,
15    pub findings: Vec<HardeningFinding>,
16    pub changes: Vec<HardeningFileChange>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
20pub struct HardeningFinding {
21    pub id: String,
22    pub title: String,
23    pub description: String,
24    pub file: PathBuf,
25    pub line: usize,
26    pub strategy: HardeningStrategy,
27    pub patchable: bool,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
31pub enum HardeningStrategy {
32    ResultUnwrapContext,
33    ProcessExecutionReview,
34    UnsafeReview,
35    EnvAccessReview,
36    FileIoReview,
37    HttpSurfaceReview,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
41pub struct HardeningFileChange {
42    pub file: PathBuf,
43    pub old_content: String,
44    pub new_content: String,
45    pub strategy: HardeningStrategy,
46    pub finding_ids: Vec<String>,
47    pub description: String,
48}
49
50#[derive(Debug, Clone, Copy)]
51pub struct HardeningAnalyzeConfig<'a> {
52    pub target: Option<&'a Path>,
53    pub max_files: usize,
54}
55
56pub fn analyze_hardening(
57    root: &Path,
58    config: HardeningAnalyzeConfig<'_>,
59) -> anyhow::Result<HardeningAnalysis> {
60    let files = collect_rust_files(root, config.target)?;
61    let mut findings = Vec::new();
62    let mut changes = Vec::new();
63
64    for file in files.iter().take(config.max_files) {
65        let content = std::fs::read_to_string(file)?;
66        let rel = relative_path(root, file);
67        let function_ranges = find_function_ranges(&content);
68
69        for (index, line) in content.lines().enumerate() {
70            let line_no = index + 1;
71            let pattern_line = line_without_comments_or_strings(line);
72            let trimmed = pattern_line.trim();
73
74            if trimmed.contains("Command::new(") || trimmed.contains("std::process::Command") {
75                findings.push(HardeningFinding {
76                    id: format!("process-execution:{}:{line_no}", rel.display()),
77                    title: "Process execution surface".to_string(),
78                    description:
79                        "External process execution should have explicit input validation or allowlisting."
80                            .to_string(),
81                    file: rel.clone(),
82                    line: line_no,
83                    strategy: HardeningStrategy::ProcessExecutionReview,
84                    patchable: false,
85                });
86            }
87
88            if trimmed.contains("unsafe ") || trimmed == "unsafe" || trimmed.contains("unsafe{") {
89                findings.push(HardeningFinding {
90                    id: format!("unsafe-rust:{}:{line_no}", rel.display()),
91                    title: "Unsafe Rust requires review".to_string(),
92                    description:
93                        "Unsafe code should be isolated and documented before automated edits touch it."
94                            .to_string(),
95                    file: rel.clone(),
96                    line: line_no,
97                    strategy: HardeningStrategy::UnsafeReview,
98                    patchable: false,
99                });
100            }
101
102            if trimmed.contains("std::env::var(") || trimmed.contains("env::var(") {
103                findings.push(HardeningFinding {
104                    id: format!("env-access:{}:{line_no}", rel.display()),
105                    title: "Environment variable access".to_string(),
106                    description:
107                        "Environment-derived configuration should return contextual errors at boundaries."
108                            .to_string(),
109                    file: rel.clone(),
110                    line: line_no,
111                    strategy: HardeningStrategy::EnvAccessReview,
112                    patchable: false,
113                });
114            }
115
116            let filesystem_call = trimmed.contains("std::fs::read_to_string(")
117                || trimmed.contains("fs::read_to_string(")
118                || trimmed.contains("std::fs::write(")
119                || trimmed.contains("fs::write(");
120            let has_visible_error_handling = trimmed.contains('?')
121                || trimmed.contains(".unwrap(")
122                || trimmed.contains(".expect(");
123            if filesystem_call && !has_visible_error_handling {
124                findings.push(HardeningFinding {
125                    id: format!("file-io:{}:{line_no}", rel.display()),
126                    title: "Filesystem boundary".to_string(),
127                    description:
128                        "Filesystem access should preserve contextual errors and validated paths."
129                            .to_string(),
130                    file: rel.clone(),
131                    line: line_no,
132                    strategy: HardeningStrategy::FileIoReview,
133                    patchable: false,
134                });
135            }
136
137            if trimmed.contains("Router::new(")
138                || trimmed.contains(".route(")
139                || trimmed.contains("#[get(")
140                || trimmed.contains("#[post(")
141            {
142                findings.push(HardeningFinding {
143                    id: format!("http-surface:{}:{line_no}", rel.display()),
144                    title: "HTTP or route surface".to_string(),
145                    description:
146                        "HTTP-facing surfaces should validate inputs and preserve typed errors."
147                            .to_string(),
148                    file: rel.clone(),
149                    line: line_no,
150                    strategy: HardeningStrategy::HttpSurfaceReview,
151                    patchable: false,
152                });
153            }
154        }
155
156        if let Some(change) = build_result_context_change(root, file, &content, &function_ranges)? {
157            for id in &change.finding_ids {
158                if !findings.iter().any(|finding| &finding.id == id) {
159                    let line = id
160                        .rsplit(':')
161                        .next()
162                        .and_then(|line| line.parse::<usize>().ok())
163                        .unwrap_or(1);
164                    findings.push(HardeningFinding {
165                        id: id.clone(),
166                        title: "Panic-prone unwrap in anyhow Result function".to_string(),
167                        description: "Replace unwrap/expect with anyhow Context and ? so failure is reported instead of panicking.".to_string(),
168                        file: rel.clone(),
169                        line,
170                        strategy: HardeningStrategy::ResultUnwrapContext,
171                        patchable: true,
172                    });
173                }
174            }
175            changes.push(change);
176        }
177    }
178
179    Ok(HardeningAnalysis {
180        root: root.to_path_buf(),
181        target: config.target.map(Path::to_path_buf),
182        files_scanned: files.len().min(config.max_files),
183        findings,
184        changes,
185    })
186}
187
188fn build_result_context_change(
189    root: &Path,
190    file: &Path,
191    content: &str,
192    function_ranges: &[FunctionRange],
193) -> anyhow::Result<Option<HardeningFileChange>> {
194    let rel = relative_path(root, file);
195    let mut lines: Vec<String> = content.lines().map(ToString::to_string).collect();
196    let mut changed = false;
197    let mut finding_ids = Vec::new();
198
199    for range in function_ranges {
200        if !range.returns_anyhow_result {
201            continue;
202        }
203
204        for line_index in range.start_line.saturating_sub(1)..range.end_line.min(lines.len()) {
205            let original = lines[line_index].clone();
206            if original.trim_start().starts_with("//") {
207                continue;
208            }
209
210            let mut rewritten = original.clone();
211            if rewritten.contains(".unwrap()") {
212                rewritten = rewritten.replace(
213                    ".unwrap()",
214                    &format!(".context(\"{} failed instead of panicking\")?", range.name),
215                );
216            }
217            rewritten = replace_expect_calls(&rewritten);
218
219            if rewritten != original {
220                changed = true;
221                lines[line_index] = rewritten;
222                finding_ids.push(format!(
223                    "unwrap-in-result:{}:{}",
224                    rel.display(),
225                    line_index + 1
226                ));
227            }
228        }
229    }
230
231    if !changed {
232        return Ok(None);
233    }
234
235    let mut new_content = lines.join("\n");
236    if content.ends_with('\n') {
237        new_content.push('\n');
238    }
239    new_content = ensure_anyhow_context_import(&new_content);
240    if syn::parse_file(&new_content).is_err() {
241        return Ok(None);
242    }
243
244    Ok(Some(HardeningFileChange {
245        file: rel,
246        old_content: content.to_string(),
247        new_content,
248        strategy: HardeningStrategy::ResultUnwrapContext,
249        finding_ids,
250        description:
251            "Replace panic-prone unwrap/expect calls in anyhow Result functions with Context and ?."
252                .to_string(),
253    }))
254}
255
256fn replace_expect_calls(line: &str) -> String {
257    let mut output = String::new();
258    let mut rest = line;
259    while let Some(start) = rest.find(".expect(\"") {
260        let (before, after_start) = rest.split_at(start);
261        output.push_str(before);
262        let msg_start = ".expect(\"".len();
263        let after_msg_start = &after_start[msg_start..];
264        if let Some(end) = after_msg_start.find("\")") {
265            let message = &after_msg_start[..end];
266            output.push_str(&format!(".context(\"{}\")?", escape_string(message)));
267            rest = &after_msg_start[end + 2..];
268        } else {
269            output.push_str(after_start);
270            rest = "";
271        }
272    }
273    output.push_str(rest);
274    output
275}
276
277fn escape_string(value: &str) -> String {
278    value.replace('\\', "\\\\").replace('"', "\\\"")
279}
280
281fn line_without_comments_or_strings(line: &str) -> String {
282    let mut output = String::with_capacity(line.len());
283    let mut chars = line.chars().peekable();
284    let mut in_string = false;
285    let mut escaped = false;
286
287    while let Some(ch) = chars.next() {
288        if !in_string && ch == '/' && chars.peek() == Some(&'/') {
289            break;
290        }
291
292        if ch == '"' && !escaped {
293            in_string = !in_string;
294            output.push(' ');
295            continue;
296        }
297
298        if in_string {
299            escaped = ch == '\\' && !escaped;
300            output.push(' ');
301            continue;
302        }
303
304        escaped = false;
305        output.push(ch);
306    }
307
308    output
309}
310
311fn ensure_anyhow_context_import(content: &str) -> String {
312    if content.contains("anyhow::Context") || content.contains("Context,") {
313        return content.to_string();
314    }
315
316    let mut lines: Vec<&str> = content.lines().collect();
317    let insert_at = lines
318        .iter()
319        .position(|line| !line.starts_with("#![") && !line.trim().is_empty())
320        .unwrap_or(0);
321    lines.insert(insert_at, "use anyhow::Context;");
322    let mut result = lines.join("\n");
323    if content.ends_with('\n') {
324        result.push('\n');
325    }
326    result
327}
328
329#[derive(Debug)]
330struct FunctionRange {
331    name: String,
332    start_line: usize,
333    end_line: usize,
334    returns_anyhow_result: bool,
335}
336
337fn find_function_ranges(content: &str) -> Vec<FunctionRange> {
338    let lines: Vec<&str> = content.lines().collect();
339    let has_anyhow_result_alias =
340        content.contains("use anyhow::Result") || content.contains("use anyhow::{Result");
341    let mut ranges = Vec::new();
342    let mut index = 0;
343    while index < lines.len() {
344        let line = lines[index];
345        if !line.contains("fn ") {
346            index += 1;
347            continue;
348        }
349
350        let mut signature = line.to_string();
351        let start_line = index + 1;
352        let mut open_line = index;
353        while !signature.contains('{') && open_line + 1 < lines.len() {
354            open_line += 1;
355            signature.push(' ');
356            signature.push_str(lines[open_line]);
357        }
358
359        if !signature.contains('{') {
360            index += 1;
361            continue;
362        }
363
364        let Some(name) = function_name(&signature) else {
365            index += 1;
366            continue;
367        };
368
369        let mut depth = 0isize;
370        let mut end_line = open_line + 1;
371        for (body_index, body_line) in lines.iter().enumerate().skip(open_line) {
372            depth += body_line.matches('{').count() as isize;
373            depth -= body_line.matches('}').count() as isize;
374            end_line = body_index + 1;
375            if depth == 0 {
376                break;
377            }
378        }
379
380        let returns_anyhow_result = signature.contains("-> anyhow::Result")
381            || (has_anyhow_result_alias && signature.contains("-> Result<"));
382        ranges.push(FunctionRange {
383            name,
384            start_line,
385            end_line,
386            returns_anyhow_result,
387        });
388        index = end_line;
389    }
390    ranges
391}
392
393fn function_name(signature: &str) -> Option<String> {
394    let rest = signature.split_once("fn ")?.1;
395    let name = rest
396        .split(|c: char| !(c.is_alphanumeric() || c == '_'))
397        .next()?;
398    if name.is_empty() {
399        None
400    } else {
401        Some(name.to_string())
402    }
403}
404
405fn collect_rust_files(root: &Path, target: Option<&Path>) -> anyhow::Result<Vec<PathBuf>> {
406    let scan_root = target
407        .map(|path| {
408            if path.is_absolute() {
409                path.to_path_buf()
410            } else {
411                root.join(path)
412            }
413        })
414        .unwrap_or_else(|| root.to_path_buf());
415    if !scan_root.starts_with(root) {
416        anyhow::bail!("hardening target is outside root: {}", scan_root.display());
417    }
418
419    if scan_root.is_file() {
420        return Ok(if scan_root.extension().is_some_and(|ext| ext == "rs") {
421            vec![scan_root]
422        } else {
423            Vec::new()
424        });
425    }
426
427    let mut files = Vec::new();
428    for result in ignore::WalkBuilder::new(scan_root)
429        .hidden(false)
430        .filter_entry(|entry| {
431            let name = entry.file_name().to_string_lossy();
432            !matches!(
433                name.as_ref(),
434                "target" | ".git" | ".worktrees" | ".mdx-rust"
435            )
436        })
437        .build()
438    {
439        let entry = result?;
440        let path = entry.path();
441        if path.is_file() && path.extension().is_some_and(|ext| ext == "rs") {
442            files.push(path.to_path_buf());
443        }
444    }
445    files.sort();
446    Ok(files)
447}
448
449fn relative_path(root: &Path, path: &Path) -> PathBuf {
450    path.strip_prefix(root).unwrap_or(path).to_path_buf()
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456    use tempfile::tempdir;
457
458    #[test]
459    fn hardening_rewrites_unwrap_in_anyhow_result_function() {
460        let dir = tempdir().unwrap();
461        let src = dir.path().join("src");
462        std::fs::create_dir_all(&src).unwrap();
463        std::fs::write(
464            src.join("lib.rs"),
465            r#"pub fn load() -> anyhow::Result<String> {
466    let value = std::fs::read_to_string("config.toml").unwrap();
467    Ok(value)
468}
469"#,
470        )
471        .unwrap();
472
473        let analysis = analyze_hardening(
474            dir.path(),
475            HardeningAnalyzeConfig {
476                target: None,
477                max_files: 10,
478            },
479        )
480        .unwrap();
481
482        assert_eq!(analysis.changes.len(), 1);
483        let change = &analysis.changes[0];
484        assert!(change.new_content.contains("use anyhow::Context;"));
485        assert!(change
486            .new_content
487            .contains(".context(\"load failed instead of panicking\")?"));
488        assert!(syn::parse_file(&change.new_content).is_ok());
489    }
490
491    #[test]
492    fn hardening_does_not_rewrite_plain_result_without_anyhow_alias() {
493        let dir = tempdir().unwrap();
494        let src = dir.path().join("src");
495        std::fs::create_dir_all(&src).unwrap();
496        std::fs::write(
497            src.join("lib.rs"),
498            r#"pub fn load() -> Result<String, std::io::Error> {
499    let value = std::fs::read_to_string("config.toml").unwrap();
500    Ok(value)
501}
502"#,
503        )
504        .unwrap();
505
506        let analysis = analyze_hardening(
507            dir.path(),
508            HardeningAnalyzeConfig {
509                target: None,
510                max_files: 10,
511            },
512        )
513        .unwrap();
514
515        assert!(analysis.changes.is_empty());
516    }
517
518    #[test]
519    fn hardening_does_not_flag_patterns_inside_strings_or_comments() {
520        let dir = tempdir().unwrap();
521        let src = dir.path().join("src");
522        std::fs::create_dir_all(&src).unwrap();
523        std::fs::write(
524            src.join("lib.rs"),
525            r#"pub fn describe() -> &'static str {
526    // Command::new("ignored")
527    "unsafe std::process::Command env::var("
528}
529"#,
530        )
531        .unwrap();
532
533        let analysis = analyze_hardening(
534            dir.path(),
535            HardeningAnalyzeConfig {
536                target: None,
537                max_files: 10,
538            },
539        )
540        .unwrap();
541
542        assert!(analysis.findings.is_empty(), "{:?}", analysis.findings);
543    }
544}