use std::fs;
use std::path::{Path, PathBuf};
fn main() {
println!("cargo:rerun-if-changed=build.rs");
let manifest_dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap());
let src = manifest_dir.join("src");
let mut violations: Vec<String> = Vec::new();
scan(&src, Path::new(""), &mut violations);
for v in &violations {
println!("cargo:warning=progress-debug-formatter: {v}");
}
if !violations.is_empty() {
eprintln!(
"\nerror(progress-debug-formatter): {} `progress!` callsite(s) use Display formatter (`{{}}` or `.display()`) on potentially-attacker-controllable bytes.\n\
Fix: rewrite to `{{:?}}` Debug formatter — escapes `\\n` so attacker-controlled strings cannot forge a sentinel-shaped stderr line.\n\
If the argument is provably safe (compile-time literal, integer count), add a `// SAFE: <reason>` comment on the prior non-blank line.\n",
violations.len(),
);
std::process::exit(1);
}
}
fn scan(dir: &Path, rel: &Path, violations: &mut Vec<String>) {
let Ok(entries) = fs::read_dir(dir) else { return };
for e in entries.flatten() {
let path = e.path();
let Some(name) = path.file_name().and_then(|n| n.to_str()) else { continue };
let new_rel = rel.join(name);
if path.is_dir() {
scan(&path, &new_rel, violations);
continue;
}
if path.extension().and_then(|e| e.to_str()) != Some("rs") {
continue;
}
println!("cargo:rerun-if-changed={}", path.display());
let Ok(content) = fs::read_to_string(&path) else { continue };
scan_file(&content, &new_rel, violations);
}
}
fn scan_file(content: &str, rel: &Path, violations: &mut Vec<String>) {
let lines: Vec<&str> = content.lines().collect();
let mut i = 0usize;
while let Some(line) = lines.get(i).copied() {
let line_no = i + 1;
let stripped = strip_line_comment(line);
if !stripped.contains("progress!(") {
i = i.saturating_add(1);
continue;
}
let mut callsite = String::new();
let mut paren_depth: i32 = 0;
let mut j = i;
let mut saw_open = false;
while let Some(jline) = lines.get(j).copied() {
let l = strip_line_comment(jline);
callsite.push_str(&l);
callsite.push('\n');
for c in l.chars() {
match c {
'(' => {
paren_depth = paren_depth.saturating_add(1);
saw_open = true;
}
')' => paren_depth = paren_depth.saturating_sub(1),
_ => {}
}
}
if saw_open && paren_depth == 0 {
break;
}
j = j.saturating_add(1);
}
if has_safe_marker_above(&lines, i) {
i = j.saturating_add(1);
continue;
}
let Some(fmt) = extract_format_string(&callsite) else {
violations.push(format!(
"{}:{} — `progress!` called without a string-literal format arg; first arg must be a literal",
rel.display(),
line_no,
));
i = j.saturating_add(1);
continue;
};
if let Some(spec) = find_display_spec(&fmt) {
violations.push(format!(
"{}:{} — `progress!` uses Display formatter {spec:?} in format string {fmt:?}; rewrite to `{{:?}}` or add `// SAFE: <reason>`",
rel.display(),
line_no,
));
}
if callsite.contains(".display()") {
violations.push(format!(
"{}:{} — `progress!` callsite contains `.display()`; use `{{:?}}` on the Path directly (Debug on Path escapes newlines; Display does not)",
rel.display(),
line_no,
));
}
i = j.saturating_add(1);
}
}
fn strip_line_comment(line: &str) -> String {
let bytes = line.as_bytes();
let mut out = String::with_capacity(line.len());
let mut i = 0;
let mut in_string = false;
while let Some(&b) = bytes.get(i) {
if in_string {
if b == b'\\'
&& let Some(&next) = bytes.get(i + 1)
{
out.push(b as char);
out.push(next as char);
i += 2;
continue;
}
if b == b'"' {
in_string = false;
}
out.push(b as char);
i += 1;
continue;
}
if b == b'"' {
in_string = true;
out.push(b as char);
i += 1;
continue;
}
if b == b'/' && bytes.get(i + 1).copied() == Some(b'/') {
break;
}
out.push(b as char);
i += 1;
}
out
}
fn has_safe_marker_above(lines: &[&str], line_idx: usize) -> bool {
let mut k = line_idx;
while k > 0 {
k = k.saturating_sub(1);
let Some(line_k) = lines.get(k) else { break };
let t = line_k.trim_start();
if t.is_empty() {
continue;
}
if t.starts_with('#') {
continue;
}
return t.contains("// SAFE:") || t.contains("//SAFE:");
}
false
}
fn extract_format_string(callsite: &str) -> Option<String> {
let mut start = callsite.find("progress!(")?;
start = start.saturating_add("progress!(".len());
let rest = callsite.get(start..)?;
let bytes = rest.as_bytes();
let mut i = 0;
while bytes.get(i).is_some_and(u8::is_ascii_whitespace) {
i += 1;
}
if bytes.get(i).copied() != Some(b'"') {
return None;
}
i += 1;
let mut out = String::new();
while let Some(&b) = bytes.get(i) {
if b == b'\\'
&& let Some(&next) = bytes.get(i + 1)
{
out.push(b as char);
out.push(next as char);
i += 2;
continue;
}
if b == b'"' {
return Some(out);
}
out.push(b as char);
i += 1;
}
None
}
fn find_display_spec(fmt: &str) -> Option<String> {
let bytes = fmt.as_bytes();
let mut i = 0;
while let Some(&b) = bytes.get(i) {
if b == b'{' && bytes.get(i + 1).copied() == Some(b'{') {
i += 2; continue;
}
if b == b'{' {
let mut j = i + 1;
while let Some(&bj) = bytes.get(j) {
if bj == b'}' {
break;
}
j += 1;
}
if bytes.get(j).copied() != Some(b'}') {
return None;
}
let spec = fmt.get(i..=j)?.to_string();
if spec.contains('?') {
i = j + 1;
continue;
}
if let Some(colon_pos) = spec.find(':')
&& let Some(spec_body) = spec.get(colon_pos + 1..spec.len().saturating_sub(1))
&& spec_body
.chars()
.all(|c| matches!(c, 'b' | 'o' | 'x' | 'X' | 'e' | 'E' | '.' | '0'..='9'))
{
i = j + 1;
continue;
}
return Some(spec);
}
i += 1;
}
None
}