1use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use std::path::{Path, PathBuf};
9
10#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
11pub struct HardeningAnalysis {
12 pub root: PathBuf,
13 pub target: Option<PathBuf>,
14 pub files_scanned: usize,
15 pub findings: Vec<HardeningFinding>,
16 pub changes: Vec<HardeningFileChange>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
20pub struct HardeningFinding {
21 pub id: String,
22 pub title: String,
23 pub description: String,
24 pub file: PathBuf,
25 pub line: usize,
26 pub strategy: HardeningStrategy,
27 pub patchable: bool,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
31pub enum HardeningStrategy {
32 ResultUnwrapContext,
33 ProcessExecutionReview,
34 UnsafeReview,
35 EnvAccessReview,
36 FileIoReview,
37 HttpSurfaceReview,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
41pub struct HardeningFileChange {
42 pub file: PathBuf,
43 pub old_content: String,
44 pub new_content: String,
45 pub strategy: HardeningStrategy,
46 pub finding_ids: Vec<String>,
47 pub description: String,
48}
49
50#[derive(Debug, Clone, Copy)]
51pub struct HardeningAnalyzeConfig<'a> {
52 pub target: Option<&'a Path>,
53 pub max_files: usize,
54}
55
56pub fn analyze_hardening(
57 root: &Path,
58 config: HardeningAnalyzeConfig<'_>,
59) -> anyhow::Result<HardeningAnalysis> {
60 let files = collect_rust_files(root, config.target)?;
61 let mut findings = Vec::new();
62 let mut changes = Vec::new();
63
64 for file in files.iter().take(config.max_files) {
65 let content = std::fs::read_to_string(file)?;
66 let rel = relative_path(root, file);
67 let function_ranges = find_function_ranges(&content);
68
69 for (index, line) in content.lines().enumerate() {
70 let line_no = index + 1;
71 let pattern_line = line_without_comments_or_strings(line);
72 let trimmed = pattern_line.trim();
73
74 if trimmed.contains("Command::new(") || trimmed.contains("std::process::Command") {
75 findings.push(HardeningFinding {
76 id: format!("process-execution:{}:{line_no}", rel.display()),
77 title: "Process execution surface".to_string(),
78 description:
79 "External process execution should have explicit input validation or allowlisting."
80 .to_string(),
81 file: rel.clone(),
82 line: line_no,
83 strategy: HardeningStrategy::ProcessExecutionReview,
84 patchable: false,
85 });
86 }
87
88 if trimmed.contains("unsafe ") || trimmed == "unsafe" || trimmed.contains("unsafe{") {
89 findings.push(HardeningFinding {
90 id: format!("unsafe-rust:{}:{line_no}", rel.display()),
91 title: "Unsafe Rust requires review".to_string(),
92 description:
93 "Unsafe code should be isolated and documented before automated edits touch it."
94 .to_string(),
95 file: rel.clone(),
96 line: line_no,
97 strategy: HardeningStrategy::UnsafeReview,
98 patchable: false,
99 });
100 }
101
102 if trimmed.contains("std::env::var(") || trimmed.contains("env::var(") {
103 findings.push(HardeningFinding {
104 id: format!("env-access:{}:{line_no}", rel.display()),
105 title: "Environment variable access".to_string(),
106 description:
107 "Environment-derived configuration should return contextual errors at boundaries."
108 .to_string(),
109 file: rel.clone(),
110 line: line_no,
111 strategy: HardeningStrategy::EnvAccessReview,
112 patchable: false,
113 });
114 }
115
116 let filesystem_call = trimmed.contains("std::fs::read_to_string(")
117 || trimmed.contains("fs::read_to_string(")
118 || trimmed.contains("std::fs::write(")
119 || trimmed.contains("fs::write(");
120 let has_visible_error_handling = trimmed.contains('?')
121 || trimmed.contains(".unwrap(")
122 || trimmed.contains(".expect(");
123 if filesystem_call && !has_visible_error_handling {
124 findings.push(HardeningFinding {
125 id: format!("file-io:{}:{line_no}", rel.display()),
126 title: "Filesystem boundary".to_string(),
127 description:
128 "Filesystem access should preserve contextual errors and validated paths."
129 .to_string(),
130 file: rel.clone(),
131 line: line_no,
132 strategy: HardeningStrategy::FileIoReview,
133 patchable: false,
134 });
135 }
136
137 if trimmed.contains("Router::new(")
138 || trimmed.contains(".route(")
139 || trimmed.contains("#[get(")
140 || trimmed.contains("#[post(")
141 {
142 findings.push(HardeningFinding {
143 id: format!("http-surface:{}:{line_no}", rel.display()),
144 title: "HTTP or route surface".to_string(),
145 description:
146 "HTTP-facing surfaces should validate inputs and preserve typed errors."
147 .to_string(),
148 file: rel.clone(),
149 line: line_no,
150 strategy: HardeningStrategy::HttpSurfaceReview,
151 patchable: false,
152 });
153 }
154 }
155
156 if let Some(change) = build_result_context_change(root, file, &content, &function_ranges)? {
157 for id in &change.finding_ids {
158 if !findings.iter().any(|finding| &finding.id == id) {
159 let line = id
160 .rsplit(':')
161 .next()
162 .and_then(|line| line.parse::<usize>().ok())
163 .unwrap_or(1);
164 findings.push(HardeningFinding {
165 id: id.clone(),
166 title: "Panic-prone unwrap in anyhow Result function".to_string(),
167 description: "Replace unwrap/expect with anyhow Context and ? so failure is reported instead of panicking.".to_string(),
168 file: rel.clone(),
169 line,
170 strategy: HardeningStrategy::ResultUnwrapContext,
171 patchable: true,
172 });
173 }
174 }
175 changes.push(change);
176 }
177 }
178
179 Ok(HardeningAnalysis {
180 root: root.to_path_buf(),
181 target: config.target.map(Path::to_path_buf),
182 files_scanned: files.len().min(config.max_files),
183 findings,
184 changes,
185 })
186}
187
188fn build_result_context_change(
189 root: &Path,
190 file: &Path,
191 content: &str,
192 function_ranges: &[FunctionRange],
193) -> anyhow::Result<Option<HardeningFileChange>> {
194 let rel = relative_path(root, file);
195 let mut lines: Vec<String> = content.lines().map(ToString::to_string).collect();
196 let mut changed = false;
197 let mut finding_ids = Vec::new();
198
199 for range in function_ranges {
200 if !range.returns_anyhow_result {
201 continue;
202 }
203
204 for line_index in range.start_line.saturating_sub(1)..range.end_line.min(lines.len()) {
205 let original = lines[line_index].clone();
206 if original.trim_start().starts_with("//") {
207 continue;
208 }
209
210 let mut rewritten = original.clone();
211 if rewritten.contains(".unwrap()") {
212 rewritten = rewritten.replace(
213 ".unwrap()",
214 &format!(".context(\"{} failed instead of panicking\")?", range.name),
215 );
216 }
217 rewritten = replace_expect_calls(&rewritten);
218
219 if rewritten != original {
220 changed = true;
221 lines[line_index] = rewritten;
222 finding_ids.push(format!(
223 "unwrap-in-result:{}:{}",
224 rel.display(),
225 line_index + 1
226 ));
227 }
228 }
229 }
230
231 if !changed {
232 return Ok(None);
233 }
234
235 let mut new_content = lines.join("\n");
236 if content.ends_with('\n') {
237 new_content.push('\n');
238 }
239 new_content = ensure_anyhow_context_import(&new_content);
240 if syn::parse_file(&new_content).is_err() {
241 return Ok(None);
242 }
243
244 Ok(Some(HardeningFileChange {
245 file: rel,
246 old_content: content.to_string(),
247 new_content,
248 strategy: HardeningStrategy::ResultUnwrapContext,
249 finding_ids,
250 description:
251 "Replace panic-prone unwrap/expect calls in anyhow Result functions with Context and ?."
252 .to_string(),
253 }))
254}
255
256fn replace_expect_calls(line: &str) -> String {
257 let mut output = String::new();
258 let mut rest = line;
259 while let Some(start) = rest.find(".expect(\"") {
260 let (before, after_start) = rest.split_at(start);
261 output.push_str(before);
262 let msg_start = ".expect(\"".len();
263 let after_msg_start = &after_start[msg_start..];
264 if let Some(end) = after_msg_start.find("\")") {
265 let message = &after_msg_start[..end];
266 output.push_str(&format!(".context(\"{}\")?", escape_string(message)));
267 rest = &after_msg_start[end + 2..];
268 } else {
269 output.push_str(after_start);
270 rest = "";
271 }
272 }
273 output.push_str(rest);
274 output
275}
276
277fn escape_string(value: &str) -> String {
278 value.replace('\\', "\\\\").replace('"', "\\\"")
279}
280
281fn line_without_comments_or_strings(line: &str) -> String {
282 let mut output = String::with_capacity(line.len());
283 let mut chars = line.chars().peekable();
284 let mut in_string = false;
285 let mut escaped = false;
286
287 while let Some(ch) = chars.next() {
288 if !in_string && ch == '/' && chars.peek() == Some(&'/') {
289 break;
290 }
291
292 if ch == '"' && !escaped {
293 in_string = !in_string;
294 output.push(' ');
295 continue;
296 }
297
298 if in_string {
299 escaped = ch == '\\' && !escaped;
300 output.push(' ');
301 continue;
302 }
303
304 escaped = false;
305 output.push(ch);
306 }
307
308 output
309}
310
311fn ensure_anyhow_context_import(content: &str) -> String {
312 if content.contains("anyhow::Context") || content.contains("Context,") {
313 return content.to_string();
314 }
315
316 let mut lines: Vec<&str> = content.lines().collect();
317 let insert_at = lines
318 .iter()
319 .position(|line| !line.starts_with("#![") && !line.trim().is_empty())
320 .unwrap_or(0);
321 lines.insert(insert_at, "use anyhow::Context;");
322 let mut result = lines.join("\n");
323 if content.ends_with('\n') {
324 result.push('\n');
325 }
326 result
327}
328
329#[derive(Debug)]
330struct FunctionRange {
331 name: String,
332 start_line: usize,
333 end_line: usize,
334 returns_anyhow_result: bool,
335}
336
337fn find_function_ranges(content: &str) -> Vec<FunctionRange> {
338 let lines: Vec<&str> = content.lines().collect();
339 let has_anyhow_result_alias =
340 content.contains("use anyhow::Result") || content.contains("use anyhow::{Result");
341 let mut ranges = Vec::new();
342 let mut index = 0;
343 while index < lines.len() {
344 let line = lines[index];
345 if !line.contains("fn ") {
346 index += 1;
347 continue;
348 }
349
350 let mut signature = line.to_string();
351 let start_line = index + 1;
352 let mut open_line = index;
353 while !signature.contains('{') && open_line + 1 < lines.len() {
354 open_line += 1;
355 signature.push(' ');
356 signature.push_str(lines[open_line]);
357 }
358
359 if !signature.contains('{') {
360 index += 1;
361 continue;
362 }
363
364 let Some(name) = function_name(&signature) else {
365 index += 1;
366 continue;
367 };
368
369 let mut depth = 0isize;
370 let mut end_line = open_line + 1;
371 for (body_index, body_line) in lines.iter().enumerate().skip(open_line) {
372 depth += body_line.matches('{').count() as isize;
373 depth -= body_line.matches('}').count() as isize;
374 end_line = body_index + 1;
375 if depth == 0 {
376 break;
377 }
378 }
379
380 let returns_anyhow_result = signature.contains("-> anyhow::Result")
381 || (has_anyhow_result_alias && signature.contains("-> Result<"));
382 ranges.push(FunctionRange {
383 name,
384 start_line,
385 end_line,
386 returns_anyhow_result,
387 });
388 index = end_line;
389 }
390 ranges
391}
392
393fn function_name(signature: &str) -> Option<String> {
394 let rest = signature.split_once("fn ")?.1;
395 let name = rest
396 .split(|c: char| !(c.is_alphanumeric() || c == '_'))
397 .next()?;
398 if name.is_empty() {
399 None
400 } else {
401 Some(name.to_string())
402 }
403}
404
405fn collect_rust_files(root: &Path, target: Option<&Path>) -> anyhow::Result<Vec<PathBuf>> {
406 let scan_root = target
407 .map(|path| {
408 if path.is_absolute() {
409 path.to_path_buf()
410 } else {
411 root.join(path)
412 }
413 })
414 .unwrap_or_else(|| root.to_path_buf());
415 if !scan_root.starts_with(root) {
416 anyhow::bail!("hardening target is outside root: {}", scan_root.display());
417 }
418
419 if scan_root.is_file() {
420 return Ok(if scan_root.extension().is_some_and(|ext| ext == "rs") {
421 vec![scan_root]
422 } else {
423 Vec::new()
424 });
425 }
426
427 let mut files = Vec::new();
428 for result in ignore::WalkBuilder::new(scan_root)
429 .hidden(false)
430 .filter_entry(|entry| {
431 let name = entry.file_name().to_string_lossy();
432 !matches!(
433 name.as_ref(),
434 "target" | ".git" | ".worktrees" | ".mdx-rust"
435 )
436 })
437 .build()
438 {
439 let entry = result?;
440 let path = entry.path();
441 if path.is_file() && path.extension().is_some_and(|ext| ext == "rs") {
442 files.push(path.to_path_buf());
443 }
444 }
445 files.sort();
446 Ok(files)
447}
448
449fn relative_path(root: &Path, path: &Path) -> PathBuf {
450 path.strip_prefix(root).unwrap_or(path).to_path_buf()
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use tempfile::tempdir;
457
458 #[test]
459 fn hardening_rewrites_unwrap_in_anyhow_result_function() {
460 let dir = tempdir().unwrap();
461 let src = dir.path().join("src");
462 std::fs::create_dir_all(&src).unwrap();
463 std::fs::write(
464 src.join("lib.rs"),
465 r#"pub fn load() -> anyhow::Result<String> {
466 let value = std::fs::read_to_string("config.toml").unwrap();
467 Ok(value)
468}
469"#,
470 )
471 .unwrap();
472
473 let analysis = analyze_hardening(
474 dir.path(),
475 HardeningAnalyzeConfig {
476 target: None,
477 max_files: 10,
478 },
479 )
480 .unwrap();
481
482 assert_eq!(analysis.changes.len(), 1);
483 let change = &analysis.changes[0];
484 assert!(change.new_content.contains("use anyhow::Context;"));
485 assert!(change
486 .new_content
487 .contains(".context(\"load failed instead of panicking\")?"));
488 assert!(syn::parse_file(&change.new_content).is_ok());
489 }
490
491 #[test]
492 fn hardening_does_not_rewrite_plain_result_without_anyhow_alias() {
493 let dir = tempdir().unwrap();
494 let src = dir.path().join("src");
495 std::fs::create_dir_all(&src).unwrap();
496 std::fs::write(
497 src.join("lib.rs"),
498 r#"pub fn load() -> Result<String, std::io::Error> {
499 let value = std::fs::read_to_string("config.toml").unwrap();
500 Ok(value)
501}
502"#,
503 )
504 .unwrap();
505
506 let analysis = analyze_hardening(
507 dir.path(),
508 HardeningAnalyzeConfig {
509 target: None,
510 max_files: 10,
511 },
512 )
513 .unwrap();
514
515 assert!(analysis.changes.is_empty());
516 }
517
518 #[test]
519 fn hardening_does_not_flag_patterns_inside_strings_or_comments() {
520 let dir = tempdir().unwrap();
521 let src = dir.path().join("src");
522 std::fs::create_dir_all(&src).unwrap();
523 std::fs::write(
524 src.join("lib.rs"),
525 r#"pub fn describe() -> &'static str {
526 // Command::new("ignored")
527 "unsafe std::process::Command env::var("
528}
529"#,
530 )
531 .unwrap();
532
533 let analysis = analyze_hardening(
534 dir.path(),
535 HardeningAnalyzeConfig {
536 target: None,
537 max_files: 10,
538 },
539 )
540 .unwrap();
541
542 assert!(analysis.findings.is_empty(), "{:?}", analysis.findings);
543 }
544}