#![deny(missing_docs, unsafe_code)]
use std::{
fs,
io::{self, BufRead as _, Write as _},
iter::FromIterator,
path, thread,
};
pub mod license;
#[derive(Clone)]
pub struct Header<C: HeaderChecker> {
checker: C,
header: String,
}
impl<C: HeaderChecker> Header<C> {
pub fn new(checker: C, header: String) -> Self {
Self { checker, header }
}
pub fn header_present(&self, input: &mut impl io::Read) -> io::Result<bool> {
self.checker.check(input)
}
pub fn add_header_if_missing(&self, p: &path::Path) -> Result<bool, AddHeaderError> {
let err_mapper = |e| AddHeaderError::IoError(p.to_path_buf(), e);
let contents = fs::read_to_string(p).map_err(err_mapper)?;
if self
.header_present(&mut contents.as_bytes())
.map_err(err_mapper)?
{
return Ok(false);
}
let mut effective_header = header_delimiters(p)
.ok_or_else(|| AddHeaderError::UnrecognizedExtension(p.to_path_buf()))
.map(|d| wrap_header(&self.header, d))?;
let mut after_header = contents.as_str();
if let Some((first_line, rest)) = contents.split_once('\n') {
if MAGIC_FIRST_LINES.iter().any(|l| first_line.contains(l)) {
let mut first_line = first_line.to_string();
first_line.push('\n');
effective_header.insert_str(0, &first_line);
after_header = rest;
}
}
let mut f = fs::OpenOptions::new()
.write(true)
.truncate(true)
.open(p)
.map_err(err_mapper)?;
f.write_all(effective_header.as_bytes())
.map_err(err_mapper)?;
f.write_all("\n".as_bytes()).map_err(err_mapper)?;
f.write_all(after_header.as_bytes()).map_err(err_mapper)?;
Ok(true)
}
pub fn delete_header_if_present(&self, p: &path::Path) -> Result<bool, DeleteHeaderError> {
let err_mapper = |e| DeleteHeaderError::IoError(p.to_path_buf(), e);
let contents = fs::read_to_string(p).map_err(err_mapper)?;
if !self
.header_present(&mut contents.as_bytes())
.map_err(err_mapper)?
{
return Ok(false);
}
let mut effective_header = header_delimiters(p)
.ok_or_else(|| DeleteHeaderError::UnrecognizedExtension(p.to_path_buf()))
.map(|d| wrap_header(&self.header, d))?;
effective_header.push('\n');
if !contents.contains(&effective_header) {
return Ok(false);
}
let remainder = contents.replacen(&effective_header, "", 1);
let mut f = fs::OpenOptions::new()
.write(true)
.truncate(true)
.open(p)
.map_err(err_mapper)?;
f.write_all(remainder.as_bytes()).map_err(err_mapper)?;
Ok(true)
}
}
#[derive(Debug, thiserror::Error)]
pub enum AddHeaderError {
#[error("I/O error at {0:?}: {1}")]
IoError(path::PathBuf, io::Error),
#[error("Unknown file extension: {0:?}")]
UnrecognizedExtension(path::PathBuf),
}
#[derive(Debug, thiserror::Error)]
pub enum DeleteHeaderError {
#[error("I/O error at {0:?}: {1}")]
IoError(path::PathBuf, io::Error),
#[error("Unknown file extension: {0:?}")]
UnrecognizedExtension(path::PathBuf),
}
pub trait HeaderChecker: Send + Clone {
fn check(&self, file: &mut impl io::Read) -> io::Result<bool>;
}
#[derive(Clone)]
pub struct SingleLineChecker {
pattern: String,
max_lines: usize,
}
impl SingleLineChecker {
pub fn new(pattern: String, max_lines: usize) -> Self {
Self { pattern, max_lines }
}
}
impl HeaderChecker for SingleLineChecker {
fn check(&self, input: &mut impl io::Read) -> io::Result<bool> {
let mut reader = io::BufReader::new(input);
let mut lines_read = 0;
let mut line = String::new();
while lines_read < self.max_lines {
line.clear();
let bytes = reader.read_line(&mut line)?;
if bytes == 0 {
return Ok(false);
}
lines_read += 1;
if line.contains(&self.pattern) {
return Ok(true);
}
}
Ok(false)
}
}
#[derive(Copy, Clone)]
enum CheckStatus {
HeaderNotFound,
BinaryFile,
}
#[derive(Clone)]
struct FileResult {
path: path::PathBuf,
status: CheckStatus,
}
#[derive(Clone, Default, PartialEq, Debug)]
pub struct FileResults {
pub no_header_files: Vec<path::PathBuf>,
pub binary_files: Vec<path::PathBuf>,
}
impl FileResults {
pub fn has_failure(&self) -> bool {
!self.no_header_files.is_empty() || !self.binary_files.is_empty()
}
}
impl FromIterator<FileResult> for FileResults {
fn from_iter<I>(iter: I) -> FileResults
where
I: IntoIterator<Item = FileResult>,
{
let mut results = FileResults::default();
for result in iter {
match result.status {
CheckStatus::HeaderNotFound => results.no_header_files.push(result.path),
CheckStatus::BinaryFile => results.binary_files.push(result.path),
}
}
results
}
}
pub fn check_headers_recursively(
root: &path::Path,
path_predicate: impl Fn(&path::Path) -> bool,
header: Header<impl HeaderChecker + 'static>,
num_threads: usize,
) -> Result<FileResults, CheckHeadersRecursivelyError> {
let (path_tx, path_rx) = crossbeam::channel::unbounded::<path::PathBuf>();
let (result_tx, result_rx) = crossbeam::channel::unbounded();
let handles = (0..num_threads)
.map(|_| {
let path_rx = path_rx.clone();
let result_tx = result_tx.clone();
let header = header.clone();
thread::spawn(move || {
for p in path_rx {
match fs::File::open(&p).and_then(|mut f| header.header_present(&mut f)) {
Ok(header_present) => {
if header_present {
} else {
let res = FileResult {
path: p,
status: CheckStatus::HeaderNotFound,
};
result_tx.send(Ok(res)).unwrap();
}
}
Err(e) if e.kind() == io::ErrorKind::InvalidData => {
let res = FileResult {
path: p,
status: CheckStatus::BinaryFile,
};
result_tx.send(Ok(res)).unwrap();
}
Err(e) => result_tx
.send(Err(CheckHeadersRecursivelyError::IoError(p, e)))
.unwrap(),
}
}
})
})
.collect::<Vec<thread::JoinHandle<()>>>();
drop(result_tx);
find_files(root, path_predicate, path_tx)?;
let res: FileResults = result_rx.into_iter().collect::<Result<_, _>>()?;
for h in handles {
h.join().unwrap();
}
Ok(res)
}
#[derive(Debug, thiserror::Error)]
pub enum CheckHeadersRecursivelyError {
#[error("I/O error at {0:?}: {1}")]
IoError(path::PathBuf, io::Error),
#[error("Walkdir error: {0}")]
WalkdirError(#[from] walkdir::Error),
}
pub fn add_headers_recursively(
root: &path::Path,
path_predicate: impl Fn(&path::Path) -> bool,
header: Header<impl HeaderChecker>,
) -> Result<Vec<path::PathBuf>, AddHeadersRecursivelyError> {
recursive_optional_operation(root, path_predicate, |p| {
header.add_header_if_missing(p).map_err(|e| e.into())
})
}
#[derive(Debug, thiserror::Error)]
pub enum AddHeadersRecursivelyError {
#[error("I/O error at {0:?}: {1}")]
IoError(path::PathBuf, io::Error),
#[error("Walkdir error: {0}")]
WalkdirError(#[from] walkdir::Error),
#[error("Unknown file extension: {0:?}")]
UnrecognizedExtension(path::PathBuf),
}
impl From<AddHeaderError> for AddHeadersRecursivelyError {
fn from(value: AddHeaderError) -> Self {
match value {
AddHeaderError::IoError(p, e) => Self::IoError(p, e),
AddHeaderError::UnrecognizedExtension(p) => Self::UnrecognizedExtension(p),
}
}
}
pub fn delete_headers_recursively(
root: &path::Path,
path_predicate: impl Fn(&path::Path) -> bool,
header: Header<impl HeaderChecker>,
) -> Result<Vec<path::PathBuf>, DeleteHeadersRecursivelyError> {
recursive_optional_operation(root, path_predicate, |p| {
header.delete_header_if_present(p).map_err(|e| e.into())
})
}
#[derive(Debug, thiserror::Error)]
pub enum DeleteHeadersRecursivelyError {
#[error("I/O error at {0:?}: {1}")]
IoError(path::PathBuf, io::Error),
#[error("Walkdir error: {0}")]
WalkdirError(#[from] walkdir::Error),
#[error("Unknown file extension: {0:?}")]
UnrecognizedExtension(path::PathBuf),
}
impl From<DeleteHeaderError> for DeleteHeadersRecursivelyError {
fn from(value: DeleteHeaderError) -> Self {
match value {
DeleteHeaderError::IoError(p, e) => Self::IoError(p, e),
DeleteHeaderError::UnrecognizedExtension(p) => Self::UnrecognizedExtension(p),
}
}
}
fn find_files(
root: &path::Path,
path_predicate: impl Fn(&path::Path) -> bool,
dest: crossbeam::channel::Sender<path::PathBuf>,
) -> Result<(), walkdir::Error> {
for r in walkdir::WalkDir::new(root).into_iter() {
let entry = r?;
if entry.path().is_dir() || !path_predicate(entry.path()) {
continue;
}
dest.send(entry.into_path()).unwrap()
}
Ok(())
}
fn wrap_header(orig_header: &str, delim: HeaderDelimiters) -> String {
let mut out = String::new();
if !delim.first_line.is_empty() {
out.push_str(delim.first_line);
out.push('\n');
}
for line in orig_header.split('\n') {
out.push_str(delim.content_line_prefix);
out.push_str(line);
out.truncate(out.trim_end_matches([' ', '\t']).len());
out.push('\n');
}
if !delim.last_line.is_empty() {
out.push_str(delim.last_line);
out.push('\n');
}
out
}
fn header_delimiters(p: &path::Path) -> Option<HeaderDelimiters> {
match p
.extension()
.and_then(|os_str| os_str.to_str())
.unwrap_or("")
{
"c" | "h" | "gv" | "java" | "scala" | "kt" | "kts" => Some(("/*", " * ", " */")),
"js" | "mjs" | "cjs" | "jsx" | "tsx" | "css" | "scss" | "sass" | "ts" => {
Some(("/**", " * ", " */"))
}
"cc" | "cpp" | "cs" | "go" | "hcl" | "hh" | "hpp" | "m" | "mm" | "proto" | "rs"
| "swift" | "dart" | "groovy" | "v" | "sv" => Some(("", "// ", "")),
"py" | "sh" | "yaml" | "yml" | "dockerfile" | "rb" | "gemfile" | "tcl" | "tf" | "bzl"
| "pl" | "pp" | "build" => Some(("", "# ", "")),
"el" | "lisp" => Some(("", ";; ", "")),
"erl" => Some(("", "% ", "")),
"hs" | "lua" | "sql" | "sdl" => Some(("", "-- ", "")),
"html" | "xml" | "vue" | "wxi" | "wxl" | "wxs" => Some(("<!--", " ", "-->")),
"php" => Some(("", "// ", "")),
"ml" | "mli" | "mll" | "mly" => Some(("(**", " ", "*)")),
_ => match p
.file_name()
.and_then(|os_str| os_str.to_str())
.unwrap_or("")
{
"Dockerfile" => Some(("", "# ", "")),
_ => None,
},
}
.map(
|(first_line, content_line_prefix, last_line)| HeaderDelimiters {
first_line,
content_line_prefix,
last_line,
},
)
}
#[derive(Clone, Copy)]
struct HeaderDelimiters {
first_line: &'static str,
content_line_prefix: &'static str,
last_line: &'static str,
}
const MAGIC_FIRST_LINES: [&str; 8] = [
"#!", "<?xml", "<!doctype", "# encoding:", "# frozen_string_literal:", "<?php", "# escape", "# syntax", ];
fn recursive_optional_operation<E>(
root: &path::Path,
path_predicate: impl Fn(&path::Path) -> bool,
operation: impl Fn(&path::Path) -> Result<bool, E>,
) -> Result<Vec<path::PathBuf>, E>
where
E: From<walkdir::Error>,
{
let (path_tx, path_rx) = crossbeam::channel::unbounded::<path::PathBuf>();
find_files(root, path_predicate, path_tx)?;
path_rx
.into_iter()
.filter_map(|p| match operation(&p) {
Ok(operation_applied) => {
if operation_applied {
Some(Ok(p))
} else {
None
}
}
Err(e) => Some(Err(e)),
})
.collect::<Result<Vec<_>, _>>()
}