use anyhow::{Context, Result};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
use syn::{
parse_file,
spanned::Spanned,
visit::{self, Visit},
ExprMethodCall, Lit,
};
use tracing::{info, warn};
use crate::types::Receipt;
use crate::verifier::verify;
pub struct AutoRemediator {
pub receipt_path: PathBuf,
pub source_dir: PathBuf,
}
impl AutoRemediator {
pub fn new(receipt_path: impl AsRef<Path>, source_dir: impl AsRef<Path>) -> Self {
Self {
receipt_path: receipt_path.as_ref().to_path_buf(),
source_dir: source_dir.as_ref().to_path_buf(),
}
}
pub fn remediate(&self) -> Result<String> {
let receipt_text = fs::read_to_string(&self.receipt_path)
.with_context(|| format!("reading receipt from {:?}", self.receipt_path))?;
let receipt: Receipt =
serde_json::from_str(&receipt_text).with_context(|| "parsing receipt JSON")?;
let verdict = verify(&receipt);
if verdict.accepted {
return Ok("Receipt is valid. No remediation needed.".to_string());
}
info!("Diagnosed failure: {}", verdict.reason);
let rs_files = self.find_rs_files()?;
for file in rs_files {
let content = fs::read_to_string(&file)?;
if !content.contains("ChainAssembler") && !content.contains("build_event") {
continue;
}
let ast = match parse_file(&content) {
Ok(ast) => ast,
Err(e) => {
warn!("Warning: could not parse {:?}: {}", file, e);
continue;
}
};
let mut visitor = ProvenanceVisitor::new(&content);
visitor.visit_file(&ast);
if let Some(fix) = self.find_fix(&file, &content, &visitor, &receipt) {
return Ok(fix);
}
}
Ok("Could not find a deterministic fix for the diagnosed failure.".to_string())
}
fn find_rs_files(&self) -> Result<Vec<PathBuf>> {
let mut files = Vec::new();
self.collect_rs_files(&self.source_dir, &mut files)?;
Ok(files)
}
fn collect_rs_files(&self, dir: &Path, files: &mut Vec<PathBuf>) -> Result<()> {
if dir.is_dir() {
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
let name = path.file_name().and_then(|s| s.to_str()).unwrap_or("");
if name != "target" && name != ".git" && !name.starts_with(".") {
self.collect_rs_files(&path, files)?;
}
} else if path.extension().and_then(|s| s.to_str()) == Some("rs") {
files.push(path);
}
}
}
Ok(())
}
fn find_fix(
&self,
file: &Path,
content: &str,
visitor: &ProvenanceVisitor,
receipt: &Receipt,
) -> Option<String> {
let mut appended_types = Vec::new();
for append in &visitor.appends {
if let Some(event_var) = &append.event_var {
if let Some(build) = visitor.builds.get(event_var) {
appended_types.push(build.event_type.clone());
}
}
}
let receipt_types: Vec<String> = receipt
.events
.iter()
.map(|e| e.event_type.clone())
.collect();
for (var, build) in &visitor.builds {
if !appended_types.contains(&build.event_type) {
if !receipt_types.contains(&build.event_type) {
return self.generate_append_patch(file, content, build, var).ok();
}
}
}
None
}
fn generate_append_patch(
&self,
file: &Path,
content: &str,
build: &BuildInfo,
var_name: &str,
) -> Result<String> {
let lines: Vec<&str> = content.lines().collect();
let build_line_idx = build.line - 1;
if build_line_idx + 1 < lines.len() {
let next_line = lines[build_line_idx + 1];
let indent = next_line
.chars()
.take_while(|c| c.is_whitespace())
.collect::<String>();
let mut new_content_lines = lines.clone();
let suffix = if content.contains("Result<") || content.contains("fn main() -> Result") {
"?"
} else {
".expect(\"auto-remediated append\")"
};
let new_line = format!("{}asm.append({}){};", indent, var_name, suffix);
new_content_lines.insert(build_line_idx + 1, &new_line);
let fixed = new_content_lines.join("\n");
self.make_git_diff(file, content, &fixed)
} else {
anyhow::bail!("Build line at end of file")
}
}
fn make_git_diff(&self, file: &Path, old: &str, new: &str) -> Result<String> {
let temp_old = tempfile::NamedTempFile::new()?;
let temp_new = tempfile::NamedTempFile::new()?;
fs::write(temp_old.path(), old)?;
fs::write(temp_new.path(), new)?;
let output = Command::new("git")
.args([
"diff",
"--no-index",
"--",
temp_old.path().to_str().unwrap(),
temp_new.path().to_str().unwrap(),
])
.output()?;
let mut diff = String::from_utf8_lossy(&output.stdout).to_string();
let old_path = temp_old.path().to_str().unwrap();
let new_path = temp_new.path().to_str().unwrap();
let final_path = file.to_str().unwrap();
diff = diff.replace(old_path, &format!("a/{}", final_path));
diff = diff.replace(new_path, &format!("b/{}", final_path));
Ok(diff)
}
}
struct BuildInfo {
line: usize,
event_type: String,
}
struct AppendInfo {
line: usize,
event_var: Option<String>,
}
struct ProvenanceVisitor<'a> {
_content: &'a str,
builds: HashMap<String, BuildInfo>,
appends: Vec<AppendInfo>,
}
impl<'a> ProvenanceVisitor<'a> {
fn new(content: &'a str) -> Self {
Self {
_content: content,
builds: HashMap::new(),
appends: Vec::new(),
}
}
}
impl<'ast, 'a> Visit<'ast> for ProvenanceVisitor<'a> {
fn visit_stmt(&mut self, i: &'ast syn::Stmt) {
if let syn::Stmt::Local(ref local) = i {
if let Some(ref init) = local.init {
let expr = match &*init.expr {
syn::Expr::Call(c) => Some(c),
syn::Expr::Try(t) => {
if let syn::Expr::Call(c) = &*t.expr {
Some(c)
} else {
None
}
}
syn::Expr::MethodCall(m) => {
if m.method == "expect" || m.method == "unwrap" {
if let syn::Expr::Call(c) = &*m.receiver {
Some(c)
} else {
None
}
} else {
None
}
}
_ => None,
};
if let Some(call) = expr {
if let syn::Expr::Path(ref p) = *call.func {
if p.path
.segments
.last()
.map(|s| s.ident == "build_event")
.unwrap_or(false)
{
if let Some(syn::Expr::Lit(syn::ExprLit {
lit: Lit::Str(ref s),
..
})) = call.args.first()
{
if let syn::Pat::Ident(ref id) = local.pat {
self.builds.insert(
id.ident.to_string(),
BuildInfo {
line: call.span().start().line,
event_type: s.value(),
},
);
}
}
}
}
}
}
}
visit::visit_stmt(self, i);
}
fn visit_expr_method_call(&mut self, i: &'ast ExprMethodCall) {
if i.method == "append" {
let mut event_var = None;
if let Some(arg) = i.args.first() {
if let syn::Expr::Path(ref p) = arg {
event_var = p.path.get_ident().map(|id| id.to_string());
}
}
self.appends.push(AppendInfo {
line: i.span().start().line,
event_var,
});
}
visit::visit_expr_method_call(self, i);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_visitor_collects_builds() {
let code = r#"
let e0 = build_event("seeded", vec![], b"p", &mut c)?;
asm.append(e0)?;
"#;
let ast = parse_file(code).unwrap();
let mut v = ProvenanceVisitor::new(code);
v.visit_file(&ast);
assert!(v.builds.contains_key("e0"));
assert_eq!(v.builds["e0"].event_type, "seeded");
assert_eq!(v.appends.len(), 1);
assert_eq!(v.appends[0].event_var.as_deref(), Some("e0"));
}
}