1#![deny(missing_docs, unsafe_code)]
48
49use std::{
50 fs,
51 io::{self, BufRead as _, Write as _},
52 iter::FromIterator,
53 path, thread,
54};
55
56pub mod license;
57
58#[derive(Clone)]
60pub struct Header<C: HeaderChecker> {
61 checker: C,
63 header: String,
65}
66
67impl<C: HeaderChecker> Header<C> {
68 pub fn new(checker: C, header: String) -> Self {
74 Self { checker, header }
75 }
76
77 pub fn header_present(&self, input: &mut impl io::Read) -> io::Result<bool> {
79 self.checker.check(input)
80 }
81
82 pub fn add_header_if_missing(&self, p: &path::Path) -> Result<bool, AddHeaderError> {
86 let err_mapper = |e| AddHeaderError::IoError(p.to_path_buf(), e);
87 let contents = fs::read_to_string(p).map_err(err_mapper)?;
88 if self
89 .header_present(&mut contents.as_bytes())
90 .map_err(err_mapper)?
91 {
92 return Ok(false);
93 }
94 let mut effective_header = header_delimiters(p)
95 .ok_or_else(|| AddHeaderError::UnrecognizedExtension(p.to_path_buf()))
96 .map(|d| wrap_header(&self.header, d))?;
97 let mut after_header = contents.as_str();
98 if let Some((first_line, rest)) = contents.split_once('\n') {
100 if MAGIC_FIRST_LINES.iter().any(|l| first_line.contains(l)) {
101 let mut first_line = first_line.to_string();
102 first_line.push('\n');
103 effective_header.insert_str(0, &first_line);
104 after_header = rest;
105 }
106 }
107 let mut f = fs::OpenOptions::new()
109 .write(true)
110 .truncate(true)
111 .open(p)
112 .map_err(err_mapper)?;
113 f.write_all(effective_header.as_bytes())
114 .map_err(err_mapper)?;
115 f.write_all("\n".as_bytes()).map_err(err_mapper)?;
117 f.write_all(after_header.as_bytes()).map_err(err_mapper)?;
118 Ok(true)
119 }
120
121 pub fn delete_header_if_present(&self, p: &path::Path) -> Result<bool, DeleteHeaderError> {
125 let err_mapper = |e| DeleteHeaderError::IoError(p.to_path_buf(), e);
126 let contents = fs::read_to_string(p).map_err(err_mapper)?;
127 if !self
128 .header_present(&mut contents.as_bytes())
129 .map_err(err_mapper)?
130 {
131 return Ok(false);
132 }
133 let mut effective_header = header_delimiters(p)
134 .ok_or_else(|| DeleteHeaderError::UnrecognizedExtension(p.to_path_buf()))
135 .map(|d| wrap_header(&self.header, d))?;
136 effective_header.push('\n');
138
139 if !contents.contains(&effective_header) {
142 return Ok(false);
143 }
144
145 let remainder = contents.replacen(&effective_header, "", 1);
148 let mut f = fs::OpenOptions::new()
150 .write(true)
151 .truncate(true)
152 .open(p)
153 .map_err(err_mapper)?;
154 f.write_all(remainder.as_bytes()).map_err(err_mapper)?;
155 Ok(true)
156 }
157}
158
159#[derive(Debug, thiserror::Error)]
161pub enum AddHeaderError {
162 #[error("I/O error at {0:?}: {1}")]
164 IoError(path::PathBuf, io::Error),
165 #[error("Unknown file extension: {0:?}")]
167 UnrecognizedExtension(path::PathBuf),
168}
169
170#[derive(Debug, thiserror::Error)]
172pub enum DeleteHeaderError {
173 #[error("I/O error at {0:?}: {1}")]
175 IoError(path::PathBuf, io::Error),
176 #[error("Unknown file extension: {0:?}")]
178 UnrecognizedExtension(path::PathBuf),
179}
180
181pub trait HeaderChecker: Send + Clone {
185 fn check(&self, file: &mut impl io::Read) -> io::Result<bool>;
187}
188
189#[derive(Clone)]
191pub struct SingleLineChecker {
192 pattern: String,
194 max_lines: usize,
196}
197
198impl SingleLineChecker {
199 pub fn new(pattern: String, max_lines: usize) -> Self {
201 Self { pattern, max_lines }
202 }
203}
204
205impl HeaderChecker for SingleLineChecker {
206 fn check(&self, input: &mut impl io::Read) -> io::Result<bool> {
207 let mut reader = io::BufReader::new(input);
208 let mut lines_read = 0;
209 let mut line = String::new();
211 while lines_read < self.max_lines {
213 line.clear();
214 let bytes = reader.read_line(&mut line)?;
215 if bytes == 0 {
216 return Ok(false);
218 }
219 lines_read += 1;
220 if line.contains(&self.pattern) {
221 return Ok(true);
222 }
223 }
224 Ok(false)
225 }
226}
227
228#[derive(Copy, Clone)]
230enum CheckStatus {
231 HeaderNotFound,
233 BinaryFile,
235}
236
237#[derive(Clone)]
239struct FileResult {
240 path: path::PathBuf,
241 status: CheckStatus,
242}
243
244#[derive(Clone, Default, PartialEq, Debug)]
246pub struct FileResults {
247 pub no_header_files: Vec<path::PathBuf>,
249 pub binary_files: Vec<path::PathBuf>,
251}
252
253impl FileResults {
254 pub fn has_failure(&self) -> bool {
256 !self.no_header_files.is_empty() || !self.binary_files.is_empty()
257 }
258}
259
260impl FromIterator<FileResult> for FileResults {
261 fn from_iter<I>(iter: I) -> FileResults
262 where
263 I: IntoIterator<Item = FileResult>,
264 {
265 let mut results = FileResults::default();
266 for result in iter {
267 match result.status {
268 CheckStatus::HeaderNotFound => results.no_header_files.push(result.path),
269 CheckStatus::BinaryFile => results.binary_files.push(result.path),
270 }
271 }
272 results
273 }
274}
275
276pub fn check_headers_recursively(
286 root: &path::Path,
287 path_predicate: impl Fn(&path::Path) -> bool,
288 header: Header<impl HeaderChecker + 'static>,
289 num_threads: usize,
290) -> Result<FileResults, CheckHeadersRecursivelyError> {
291 let (path_tx, path_rx) = crossbeam::channel::unbounded::<path::PathBuf>();
292 let (result_tx, result_rx) = crossbeam::channel::unbounded();
293 let handles = (0..num_threads)
295 .map(|_| {
296 let path_rx = path_rx.clone();
297 let result_tx = result_tx.clone();
298 let header = header.clone();
299 thread::spawn(move || {
300 for p in path_rx {
301 match fs::File::open(&p).and_then(|mut f| header.header_present(&mut f)) {
302 Ok(header_present) => {
303 if header_present {
304 } else {
306 let res = FileResult {
307 path: p,
308 status: CheckStatus::HeaderNotFound,
309 };
310 result_tx.send(Ok(res)).unwrap();
311 }
312 }
313 Err(e) if e.kind() == io::ErrorKind::InvalidData => {
314 let res = FileResult {
315 path: p,
316 status: CheckStatus::BinaryFile,
317 };
318 result_tx.send(Ok(res)).unwrap();
319 }
320 Err(e) => result_tx
321 .send(Err(CheckHeadersRecursivelyError::IoError(p, e)))
322 .unwrap(),
323 }
324 }
325 })
327 })
328 .collect::<Vec<thread::JoinHandle<()>>>();
329 drop(result_tx);
331 find_files(root, path_predicate, path_tx)?;
332 let res: FileResults = result_rx.into_iter().collect::<Result<_, _>>()?;
333 for h in handles {
334 h.join().unwrap();
335 }
336 Ok(res)
337}
338
339#[derive(Debug, thiserror::Error)]
341pub enum CheckHeadersRecursivelyError {
342 #[error("I/O error at {0:?}: {1}")]
344 IoError(path::PathBuf, io::Error),
345 #[error("Walkdir error: {0}")]
347 WalkdirError(#[from] walkdir::Error),
348}
349
350pub fn add_headers_recursively(
355 root: &path::Path,
356 path_predicate: impl Fn(&path::Path) -> bool,
357 header: Header<impl HeaderChecker>,
358) -> Result<Vec<path::PathBuf>, AddHeadersRecursivelyError> {
359 recursive_optional_operation(root, path_predicate, |p| {
361 header.add_header_if_missing(p).map_err(|e| e.into())
362 })
363}
364
365#[derive(Debug, thiserror::Error)]
367pub enum AddHeadersRecursivelyError {
368 #[error("I/O error at {0:?}: {1}")]
370 IoError(path::PathBuf, io::Error),
371 #[error("Walkdir error: {0}")]
373 WalkdirError(#[from] walkdir::Error),
374 #[error("Unknown file extension: {0:?}")]
376 UnrecognizedExtension(path::PathBuf),
377}
378
379impl From<AddHeaderError> for AddHeadersRecursivelyError {
380 fn from(value: AddHeaderError) -> Self {
381 match value {
382 AddHeaderError::IoError(p, e) => Self::IoError(p, e),
383 AddHeaderError::UnrecognizedExtension(p) => Self::UnrecognizedExtension(p),
384 }
385 }
386}
387
388pub fn delete_headers_recursively(
393 root: &path::Path,
394 path_predicate: impl Fn(&path::Path) -> bool,
395 header: Header<impl HeaderChecker>,
396) -> Result<Vec<path::PathBuf>, DeleteHeadersRecursivelyError> {
397 recursive_optional_operation(root, path_predicate, |p| {
398 header.delete_header_if_present(p).map_err(|e| e.into())
399 })
400}
401
402#[derive(Debug, thiserror::Error)]
404pub enum DeleteHeadersRecursivelyError {
405 #[error("I/O error at {0:?}: {1}")]
407 IoError(path::PathBuf, io::Error),
408 #[error("Walkdir error: {0}")]
410 WalkdirError(#[from] walkdir::Error),
411 #[error("Unknown file extension: {0:?}")]
413 UnrecognizedExtension(path::PathBuf),
414}
415
416impl From<DeleteHeaderError> for DeleteHeadersRecursivelyError {
417 fn from(value: DeleteHeaderError) -> Self {
418 match value {
419 DeleteHeaderError::IoError(p, e) => Self::IoError(p, e),
420 DeleteHeaderError::UnrecognizedExtension(p) => Self::UnrecognizedExtension(p),
421 }
422 }
423}
424
425fn find_files(
428 root: &path::Path,
429 path_predicate: impl Fn(&path::Path) -> bool,
430 dest: crossbeam::channel::Sender<path::PathBuf>,
431) -> Result<(), walkdir::Error> {
432 for r in walkdir::WalkDir::new(root).into_iter() {
433 let entry = r?;
434 if entry.path().is_dir() || !path_predicate(entry.path()) {
435 continue;
436 }
437 dest.send(entry.into_path()).unwrap()
438 }
439 Ok(())
440}
441
442fn wrap_header(orig_header: &str, delim: HeaderDelimiters) -> String {
447 let mut out = String::new();
448 if !delim.first_line.is_empty() {
449 out.push_str(delim.first_line);
450 out.push('\n');
451 }
452 for line in orig_header.split('\n') {
454 out.push_str(delim.content_line_prefix);
455 out.push_str(line);
456 out.truncate(out.trim_end_matches([' ', '\t']).len());
460 out.push('\n');
461 }
462 if !delim.last_line.is_empty() {
463 out.push_str(delim.last_line);
464 out.push('\n');
465 }
466 out
467}
468
469fn header_delimiters(p: &path::Path) -> Option<HeaderDelimiters> {
472 match p
473 .extension()
474 .and_then(|os_str| os_str.to_str())
476 .unwrap_or("")
477 {
478 "c" | "h" | "gv" | "java" | "scala" | "kt" | "kts" => Some(("/*", " * ", " */")),
479 "js" | "mjs" | "cjs" | "jsx" | "tsx" | "css" | "scss" | "sass" | "ts" => {
480 Some(("/**", " * ", " */"))
481 }
482 "cc" | "cpp" | "cs" | "go" | "hcl" | "hh" | "hpp" | "m" | "mm" | "proto" | "rs"
483 | "swift" | "dart" | "groovy" | "v" | "sv" => Some(("", "// ", "")),
484 "py" | "sh" | "yaml" | "yml" | "dockerfile" | "rb" | "gemfile" | "tcl" | "tf" | "bzl"
485 | "pl" | "pp" | "build" => Some(("", "# ", "")),
486 "el" | "lisp" => Some(("", ";; ", "")),
487 "erl" => Some(("", "% ", "")),
488 "hs" | "lua" | "sql" | "sdl" => Some(("", "-- ", "")),
489 "html" | "xml" | "vue" | "wxi" | "wxl" | "wxs" => Some(("<!--", " ", "-->")),
490 "php" => Some(("", "// ", "")),
491 "ml" | "mli" | "mll" | "mly" => Some(("(**", " ", "*)")),
492 _ => match p
494 .file_name()
495 .and_then(|os_str| os_str.to_str())
496 .unwrap_or("")
497 {
498 "Dockerfile" => Some(("", "# ", "")),
499 _ => None,
500 },
501 }
502 .map(
503 |(first_line, content_line_prefix, last_line)| HeaderDelimiters {
504 first_line,
505 content_line_prefix,
506 last_line,
507 },
508 )
509}
510
511#[derive(Clone, Copy)]
513struct HeaderDelimiters {
514 first_line: &'static str,
516 content_line_prefix: &'static str,
518 last_line: &'static str,
520}
521
522const MAGIC_FIRST_LINES: [&str; 8] = [
524 "#!", "<?xml", "<!doctype", "# encoding:", "# frozen_string_literal:", "<?php", "# escape", "# syntax", ];
533
534fn recursive_optional_operation<E>(
539 root: &path::Path,
540 path_predicate: impl Fn(&path::Path) -> bool,
541 operation: impl Fn(&path::Path) -> Result<bool, E>,
542) -> Result<Vec<path::PathBuf>, E>
543where
544 E: From<walkdir::Error>,
545{
546 let (path_tx, path_rx) = crossbeam::channel::unbounded::<path::PathBuf>();
547 find_files(root, path_predicate, path_tx)?;
548 path_rx
549 .into_iter()
550 .filter_map(|p| match operation(&p) {
552 Ok(operation_applied) => {
553 if operation_applied {
554 Some(Ok(p))
555 } else {
556 None
557 }
558 }
559 Err(e) => Some(Err(e)),
560 })
561 .collect::<Result<Vec<_>, _>>()
562}