1use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use std::path::{Path, PathBuf};
9
10#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
11pub struct HardeningAnalysis {
12 pub root: PathBuf,
13 pub target: Option<PathBuf>,
14 pub files_scanned: usize,
15 pub findings: Vec<HardeningFinding>,
16 pub changes: Vec<HardeningFileChange>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
20pub struct HardeningFinding {
21 pub id: String,
22 pub title: String,
23 pub description: String,
24 pub file: PathBuf,
25 pub line: usize,
26 pub strategy: HardeningStrategy,
27 pub patchable: bool,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
31pub enum HardeningStrategy {
32 ResultUnwrapContext,
33 ProcessExecutionReview,
34 UnsafeReview,
35 EnvAccessReview,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
39pub struct HardeningFileChange {
40 pub file: PathBuf,
41 pub old_content: String,
42 pub new_content: String,
43 pub strategy: HardeningStrategy,
44 pub finding_ids: Vec<String>,
45 pub description: String,
46}
47
48#[derive(Debug, Clone, Copy)]
49pub struct HardeningAnalyzeConfig<'a> {
50 pub target: Option<&'a Path>,
51 pub max_files: usize,
52}
53
54pub fn analyze_hardening(
55 root: &Path,
56 config: HardeningAnalyzeConfig<'_>,
57) -> anyhow::Result<HardeningAnalysis> {
58 let files = collect_rust_files(root, config.target)?;
59 let mut findings = Vec::new();
60 let mut changes = Vec::new();
61
62 for file in files.iter().take(config.max_files) {
63 let content = std::fs::read_to_string(file)?;
64 let rel = relative_path(root, file);
65 let function_ranges = find_function_ranges(&content);
66
67 for (index, line) in content.lines().enumerate() {
68 let line_no = index + 1;
69 let pattern_line = line_without_comments_or_strings(line);
70 let trimmed = pattern_line.trim();
71
72 if trimmed.contains("Command::new(") || trimmed.contains("std::process::Command") {
73 findings.push(HardeningFinding {
74 id: format!("process-execution:{}:{line_no}", rel.display()),
75 title: "Process execution surface".to_string(),
76 description:
77 "External process execution should have explicit input validation or allowlisting."
78 .to_string(),
79 file: rel.clone(),
80 line: line_no,
81 strategy: HardeningStrategy::ProcessExecutionReview,
82 patchable: false,
83 });
84 }
85
86 if trimmed.contains("unsafe ") || trimmed == "unsafe" || trimmed.contains("unsafe{") {
87 findings.push(HardeningFinding {
88 id: format!("unsafe-rust:{}:{line_no}", rel.display()),
89 title: "Unsafe Rust requires review".to_string(),
90 description:
91 "Unsafe code should be isolated and documented before automated edits touch it."
92 .to_string(),
93 file: rel.clone(),
94 line: line_no,
95 strategy: HardeningStrategy::UnsafeReview,
96 patchable: false,
97 });
98 }
99
100 if trimmed.contains("std::env::var(") || trimmed.contains("env::var(") {
101 findings.push(HardeningFinding {
102 id: format!("env-access:{}:{line_no}", rel.display()),
103 title: "Environment variable access".to_string(),
104 description:
105 "Environment-derived configuration should return contextual errors at boundaries."
106 .to_string(),
107 file: rel.clone(),
108 line: line_no,
109 strategy: HardeningStrategy::EnvAccessReview,
110 patchable: false,
111 });
112 }
113 }
114
115 if let Some(change) = build_result_context_change(root, file, &content, &function_ranges)? {
116 for id in &change.finding_ids {
117 if !findings.iter().any(|finding| &finding.id == id) {
118 let line = id
119 .rsplit(':')
120 .next()
121 .and_then(|line| line.parse::<usize>().ok())
122 .unwrap_or(1);
123 findings.push(HardeningFinding {
124 id: id.clone(),
125 title: "Panic-prone unwrap in anyhow Result function".to_string(),
126 description: "Replace unwrap/expect with anyhow Context and ? so failure is reported instead of panicking.".to_string(),
127 file: rel.clone(),
128 line,
129 strategy: HardeningStrategy::ResultUnwrapContext,
130 patchable: true,
131 });
132 }
133 }
134 changes.push(change);
135 }
136 }
137
138 Ok(HardeningAnalysis {
139 root: root.to_path_buf(),
140 target: config.target.map(Path::to_path_buf),
141 files_scanned: files.len().min(config.max_files),
142 findings,
143 changes,
144 })
145}
146
147fn build_result_context_change(
148 root: &Path,
149 file: &Path,
150 content: &str,
151 function_ranges: &[FunctionRange],
152) -> anyhow::Result<Option<HardeningFileChange>> {
153 let rel = relative_path(root, file);
154 let mut lines: Vec<String> = content.lines().map(ToString::to_string).collect();
155 let mut changed = false;
156 let mut finding_ids = Vec::new();
157
158 for range in function_ranges {
159 if !range.returns_anyhow_result {
160 continue;
161 }
162
163 for line_index in range.start_line.saturating_sub(1)..range.end_line.min(lines.len()) {
164 let original = lines[line_index].clone();
165 if original.trim_start().starts_with("//") {
166 continue;
167 }
168
169 let mut rewritten = original.clone();
170 if rewritten.contains(".unwrap()") {
171 rewritten = rewritten.replace(
172 ".unwrap()",
173 &format!(".context(\"{} failed instead of panicking\")?", range.name),
174 );
175 }
176 rewritten = replace_expect_calls(&rewritten);
177
178 if rewritten != original {
179 changed = true;
180 lines[line_index] = rewritten;
181 finding_ids.push(format!(
182 "unwrap-in-result:{}:{}",
183 rel.display(),
184 line_index + 1
185 ));
186 }
187 }
188 }
189
190 if !changed {
191 return Ok(None);
192 }
193
194 let mut new_content = lines.join("\n");
195 if content.ends_with('\n') {
196 new_content.push('\n');
197 }
198 new_content = ensure_anyhow_context_import(&new_content);
199 if syn::parse_file(&new_content).is_err() {
200 return Ok(None);
201 }
202
203 Ok(Some(HardeningFileChange {
204 file: rel,
205 old_content: content.to_string(),
206 new_content,
207 strategy: HardeningStrategy::ResultUnwrapContext,
208 finding_ids,
209 description:
210 "Replace panic-prone unwrap/expect calls in anyhow Result functions with Context and ?."
211 .to_string(),
212 }))
213}
214
215fn replace_expect_calls(line: &str) -> String {
216 let mut output = String::new();
217 let mut rest = line;
218 while let Some(start) = rest.find(".expect(\"") {
219 let (before, after_start) = rest.split_at(start);
220 output.push_str(before);
221 let msg_start = ".expect(\"".len();
222 let after_msg_start = &after_start[msg_start..];
223 if let Some(end) = after_msg_start.find("\")") {
224 let message = &after_msg_start[..end];
225 output.push_str(&format!(".context(\"{}\")?", escape_string(message)));
226 rest = &after_msg_start[end + 2..];
227 } else {
228 output.push_str(after_start);
229 rest = "";
230 }
231 }
232 output.push_str(rest);
233 output
234}
235
236fn escape_string(value: &str) -> String {
237 value.replace('\\', "\\\\").replace('"', "\\\"")
238}
239
240fn line_without_comments_or_strings(line: &str) -> String {
241 let mut output = String::with_capacity(line.len());
242 let mut chars = line.chars().peekable();
243 let mut in_string = false;
244 let mut escaped = false;
245
246 while let Some(ch) = chars.next() {
247 if !in_string && ch == '/' && chars.peek() == Some(&'/') {
248 break;
249 }
250
251 if ch == '"' && !escaped {
252 in_string = !in_string;
253 output.push(' ');
254 continue;
255 }
256
257 if in_string {
258 escaped = ch == '\\' && !escaped;
259 output.push(' ');
260 continue;
261 }
262
263 escaped = false;
264 output.push(ch);
265 }
266
267 output
268}
269
270fn ensure_anyhow_context_import(content: &str) -> String {
271 if content.contains("anyhow::Context") || content.contains("Context,") {
272 return content.to_string();
273 }
274
275 let mut lines: Vec<&str> = content.lines().collect();
276 let insert_at = lines
277 .iter()
278 .position(|line| !line.starts_with("#![") && !line.trim().is_empty())
279 .unwrap_or(0);
280 lines.insert(insert_at, "use anyhow::Context;");
281 let mut result = lines.join("\n");
282 if content.ends_with('\n') {
283 result.push('\n');
284 }
285 result
286}
287
288#[derive(Debug)]
289struct FunctionRange {
290 name: String,
291 start_line: usize,
292 end_line: usize,
293 returns_anyhow_result: bool,
294}
295
296fn find_function_ranges(content: &str) -> Vec<FunctionRange> {
297 let lines: Vec<&str> = content.lines().collect();
298 let has_anyhow_result_alias =
299 content.contains("use anyhow::Result") || content.contains("use anyhow::{Result");
300 let mut ranges = Vec::new();
301 let mut index = 0;
302 while index < lines.len() {
303 let line = lines[index];
304 if !line.contains("fn ") {
305 index += 1;
306 continue;
307 }
308
309 let mut signature = line.to_string();
310 let start_line = index + 1;
311 let mut open_line = index;
312 while !signature.contains('{') && open_line + 1 < lines.len() {
313 open_line += 1;
314 signature.push(' ');
315 signature.push_str(lines[open_line]);
316 }
317
318 if !signature.contains('{') {
319 index += 1;
320 continue;
321 }
322
323 let Some(name) = function_name(&signature) else {
324 index += 1;
325 continue;
326 };
327
328 let mut depth = 0isize;
329 let mut end_line = open_line + 1;
330 for (body_index, body_line) in lines.iter().enumerate().skip(open_line) {
331 depth += body_line.matches('{').count() as isize;
332 depth -= body_line.matches('}').count() as isize;
333 end_line = body_index + 1;
334 if depth == 0 {
335 break;
336 }
337 }
338
339 let returns_anyhow_result = signature.contains("-> anyhow::Result")
340 || (has_anyhow_result_alias && signature.contains("-> Result<"));
341 ranges.push(FunctionRange {
342 name,
343 start_line,
344 end_line,
345 returns_anyhow_result,
346 });
347 index = end_line;
348 }
349 ranges
350}
351
352fn function_name(signature: &str) -> Option<String> {
353 let rest = signature.split_once("fn ")?.1;
354 let name = rest
355 .split(|c: char| !(c.is_alphanumeric() || c == '_'))
356 .next()?;
357 if name.is_empty() {
358 None
359 } else {
360 Some(name.to_string())
361 }
362}
363
364fn collect_rust_files(root: &Path, target: Option<&Path>) -> anyhow::Result<Vec<PathBuf>> {
365 let scan_root = target
366 .map(|path| {
367 if path.is_absolute() {
368 path.to_path_buf()
369 } else {
370 root.join(path)
371 }
372 })
373 .unwrap_or_else(|| root.to_path_buf());
374 if !scan_root.starts_with(root) {
375 anyhow::bail!("hardening target is outside root: {}", scan_root.display());
376 }
377
378 if scan_root.is_file() {
379 return Ok(if scan_root.extension().is_some_and(|ext| ext == "rs") {
380 vec![scan_root]
381 } else {
382 Vec::new()
383 });
384 }
385
386 let mut files = Vec::new();
387 for result in ignore::WalkBuilder::new(scan_root)
388 .hidden(false)
389 .filter_entry(|entry| {
390 let name = entry.file_name().to_string_lossy();
391 !matches!(
392 name.as_ref(),
393 "target" | ".git" | ".worktrees" | ".mdx-rust"
394 )
395 })
396 .build()
397 {
398 let entry = result?;
399 let path = entry.path();
400 if path.is_file() && path.extension().is_some_and(|ext| ext == "rs") {
401 files.push(path.to_path_buf());
402 }
403 }
404 files.sort();
405 Ok(files)
406}
407
408fn relative_path(root: &Path, path: &Path) -> PathBuf {
409 path.strip_prefix(root).unwrap_or(path).to_path_buf()
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use tempfile::tempdir;
416
417 #[test]
418 fn hardening_rewrites_unwrap_in_anyhow_result_function() {
419 let dir = tempdir().unwrap();
420 let src = dir.path().join("src");
421 std::fs::create_dir_all(&src).unwrap();
422 std::fs::write(
423 src.join("lib.rs"),
424 r#"pub fn load() -> anyhow::Result<String> {
425 let value = std::fs::read_to_string("config.toml").unwrap();
426 Ok(value)
427}
428"#,
429 )
430 .unwrap();
431
432 let analysis = analyze_hardening(
433 dir.path(),
434 HardeningAnalyzeConfig {
435 target: None,
436 max_files: 10,
437 },
438 )
439 .unwrap();
440
441 assert_eq!(analysis.changes.len(), 1);
442 let change = &analysis.changes[0];
443 assert!(change.new_content.contains("use anyhow::Context;"));
444 assert!(change
445 .new_content
446 .contains(".context(\"load failed instead of panicking\")?"));
447 assert!(syn::parse_file(&change.new_content).is_ok());
448 }
449
450 #[test]
451 fn hardening_does_not_rewrite_plain_result_without_anyhow_alias() {
452 let dir = tempdir().unwrap();
453 let src = dir.path().join("src");
454 std::fs::create_dir_all(&src).unwrap();
455 std::fs::write(
456 src.join("lib.rs"),
457 r#"pub fn load() -> Result<String, std::io::Error> {
458 let value = std::fs::read_to_string("config.toml").unwrap();
459 Ok(value)
460}
461"#,
462 )
463 .unwrap();
464
465 let analysis = analyze_hardening(
466 dir.path(),
467 HardeningAnalyzeConfig {
468 target: None,
469 max_files: 10,
470 },
471 )
472 .unwrap();
473
474 assert!(analysis.changes.is_empty());
475 }
476
477 #[test]
478 fn hardening_does_not_flag_patterns_inside_strings_or_comments() {
479 let dir = tempdir().unwrap();
480 let src = dir.path().join("src");
481 std::fs::create_dir_all(&src).unwrap();
482 std::fs::write(
483 src.join("lib.rs"),
484 r#"pub fn describe() -> &'static str {
485 // Command::new("ignored")
486 "unsafe std::process::Command env::var("
487}
488"#,
489 )
490 .unwrap();
491
492 let analysis = analyze_hardening(
493 dir.path(),
494 HardeningAnalyzeConfig {
495 target: None,
496 max_files: 10,
497 },
498 )
499 .unwrap();
500
501 assert!(analysis.findings.is_empty(), "{:?}", analysis.findings);
502 }
503}