file_header/
lib.rs

1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Tools for checking for, or adding, headers (e.g. licenses, etc) in files.
16//!
17//! See the [license::spdx] module for more on using licenses as headers.
18//!
19//! # Examples
20//!
21//! Checking for a header:
22//!
23//! ```
24//! // Copyright 2023 Google LLC.
25//! // SPDX-License-Identifier: Apache-2.0
26//! use file_header::*;
27//! use std::path::Path;
28//!
29//! let checker = SingleLineChecker::new("Foo License".to_string(), 10);
30//! let header = Header::new(checker, "Foo License\nmore license text".to_string());
31//!
32//! match check_headers_recursively(
33//!     Path::new("/some/dir"),
34//!     // check every file -- see `globset` crate for path patterns
35//!     |p| true,
36//!     header,
37//!     // check with 4 threads
38//!     4
39//! ) {
40//!     Ok(fr) => { println!("files without the header: {:?}", fr.no_header_files) }
41//!     Err(e) => { println!("got an error: {:?}", e) }
42//! }
43//!
44//! ```
45//!
46
47#![deny(missing_docs, unsafe_code)]
48
49use std::{
50    fs,
51    io::{self, BufRead as _, Write as _},
52    iter::FromIterator,
53    path, thread,
54};
55
56pub mod license;
57
58/// A file header to check for, or add to, files.
59#[derive(Clone)]
60pub struct Header<C: HeaderChecker> {
61    /// A checker to determine if the desired header is already present.
62    checker: C,
63    /// The header text to add, without comments or other filetype-specific framing.
64    header: String,
65}
66
67impl<C: HeaderChecker> Header<C> {
68    /// Construct a new `Header` with the `checker` used to determine if the header is already
69    /// present, and the plain `header` text to add.
70    ///
71    /// `header` does not need to have applicable comment syntax, etc, as that will be added for
72    /// each file type encountered.
73    pub fn new(checker: C, header: String) -> Self {
74        Self { checker, header }
75    }
76
77    /// Return `true` if the file has the desired header, false otherwise.
78    pub fn header_present(&self, input: &mut impl io::Read) -> io::Result<bool> {
79        self.checker.check(input)
80    }
81
82    /// Add the header, with appropriate formatting for the type of file indicated by `p`'s
83    /// extension, if the header is not already present.
84    /// Returns `true` if the header was added.
85    pub fn add_header_if_missing(&self, p: &path::Path) -> Result<bool, AddHeaderError> {
86        let err_mapper = |e| AddHeaderError::IoError(p.to_path_buf(), e);
87        let contents = fs::read_to_string(p).map_err(err_mapper)?;
88        if self
89            .header_present(&mut contents.as_bytes())
90            .map_err(err_mapper)?
91        {
92            return Ok(false);
93        }
94        let mut effective_header = header_delimiters(p)
95            .ok_or_else(|| AddHeaderError::UnrecognizedExtension(p.to_path_buf()))
96            .map(|d| wrap_header(&self.header, d))?;
97        let mut after_header = contents.as_str();
98        // check for a magic first line and if present, add the license after the first line
99        if let Some((first_line, rest)) = contents.split_once('\n') {
100            if MAGIC_FIRST_LINES.iter().any(|l| first_line.contains(l)) {
101                let mut first_line = first_line.to_string();
102                first_line.push('\n');
103                effective_header.insert_str(0, &first_line);
104                after_header = rest;
105            }
106        }
107        // write the license
108        let mut f = fs::OpenOptions::new()
109            .write(true)
110            .truncate(true)
111            .open(p)
112            .map_err(err_mapper)?;
113        f.write_all(effective_header.as_bytes())
114            .map_err(err_mapper)?;
115        // newline to separate the header from previous contents
116        f.write_all("\n".as_bytes()).map_err(err_mapper)?;
117        f.write_all(after_header.as_bytes()).map_err(err_mapper)?;
118        Ok(true)
119    }
120
121    /// Delete the header, with appropriate formatting for the type of file indicated by `p`'s
122    /// extension, if the header is already present.
123    /// Returns `true` if the header was deleted.
124    pub fn delete_header_if_present(&self, p: &path::Path) -> Result<bool, DeleteHeaderError> {
125        let err_mapper = |e| DeleteHeaderError::IoError(p.to_path_buf(), e);
126        let contents = fs::read_to_string(p).map_err(err_mapper)?;
127        if !self
128            .header_present(&mut contents.as_bytes())
129            .map_err(err_mapper)?
130        {
131            return Ok(false);
132        }
133        let mut effective_header = header_delimiters(p)
134            .ok_or_else(|| DeleteHeaderError::UnrecognizedExtension(p.to_path_buf()))
135            .map(|d| wrap_header(&self.header, d))?;
136        // include the newline separator appended by add_header_if_missing()
137        effective_header.push('\n');
138
139        // the checker is conservative: it may look for only a substring of the license, but
140        // deletion will only have an effect if the entire wrapped header is present.
141        if !contents.contains(&effective_header) {
142            return Ok(false);
143        }
144
145        // remove the first copy of the header to avoid touching the license text in a string
146        // literal, etc.
147        let remainder = contents.replacen(&effective_header, "", 1);
148        // write the remainder
149        let mut f = fs::OpenOptions::new()
150            .write(true)
151            .truncate(true)
152            .open(p)
153            .map_err(err_mapper)?;
154        f.write_all(remainder.as_bytes()).map_err(err_mapper)?;
155        Ok(true)
156    }
157}
158
159/// Errors that can occur when adding a header
160#[derive(Debug, thiserror::Error)]
161pub enum AddHeaderError {
162    /// IO error while adding the header to the path
163    #[error("I/O error at {0:?}: {1}")]
164    IoError(path::PathBuf, io::Error),
165    /// The file at the path had an unrecognized extension
166    #[error("Unknown file extension: {0:?}")]
167    UnrecognizedExtension(path::PathBuf),
168}
169
170/// Errors that can occur when deleting a header
171#[derive(Debug, thiserror::Error)]
172pub enum DeleteHeaderError {
173    /// IO error while deleting the header from the path
174    #[error("I/O error at {0:?}: {1}")]
175    IoError(path::PathBuf, io::Error),
176    /// The file at the path had an unrecognized extension
177    #[error("Unknown file extension: {0:?}")]
178    UnrecognizedExtension(path::PathBuf),
179}
180
181/// Checks for headers in files, like licenses or author attribution.
182///
183/// This is intended to be used via [`Header`], not called directly.
184pub trait HeaderChecker: Send + Clone {
185    /// Return `true` if the file has the desired header, `false` otherwise.
186    fn check(&self, file: &mut impl io::Read) -> io::Result<bool>;
187}
188
189/// Checks for a pattern in the first several lines of each file.
190#[derive(Clone)]
191pub struct SingleLineChecker {
192    /// Pattern to do a substring match on in each of the first `max_lines` lines of the file
193    pattern: String,
194    /// Number of lines to search through
195    max_lines: usize,
196}
197
198impl SingleLineChecker {
199    /// Construct a `SingleLineChecker` that looks for `pattern` in the first `max_lines` of a file.
200    pub fn new(pattern: String, max_lines: usize) -> Self {
201        Self { pattern, max_lines }
202    }
203}
204
205impl HeaderChecker for SingleLineChecker {
206    fn check(&self, input: &mut impl io::Read) -> io::Result<bool> {
207        let mut reader = io::BufReader::new(input);
208        let mut lines_read = 0;
209        // reuse buffer to minimize allocation
210        let mut line = String::new();
211        // only read the first bit of the file
212        while lines_read < self.max_lines {
213            line.clear();
214            let bytes = reader.read_line(&mut line)?;
215            if bytes == 0 {
216                // EOF
217                return Ok(false);
218            }
219            lines_read += 1;
220            if line.contains(&self.pattern) {
221                return Ok(true);
222            }
223        }
224        Ok(false)
225    }
226}
227
228/// Reasons why a file may not have a header
229#[derive(Copy, Clone)]
230enum CheckStatus {
231    /// The header was not found in the file
232    HeaderNotFound,
233    /// A file appears to be binary
234    BinaryFile,
235}
236
237/// The output of checking a single file
238#[derive(Clone)]
239struct FileResult {
240    path: path::PathBuf,
241    status: CheckStatus,
242}
243
244/// Aggregated results for recursively checking a directory tree of files.
245#[derive(Clone, Default, PartialEq, Debug)]
246pub struct FileResults {
247    /// Paths that did not have a header
248    pub no_header_files: Vec<path::PathBuf>,
249    /// Paths that appeared to be binary, not UTF-8 text
250    pub binary_files: Vec<path::PathBuf>,
251}
252
253impl FileResults {
254    /// Returns `true` if any files scanned did not have a header
255    pub fn has_failure(&self) -> bool {
256        !self.no_header_files.is_empty() || !self.binary_files.is_empty()
257    }
258}
259
260impl FromIterator<FileResult> for FileResults {
261    fn from_iter<I>(iter: I) -> FileResults
262    where
263        I: IntoIterator<Item = FileResult>,
264    {
265        let mut results = FileResults::default();
266        for result in iter {
267            match result.status {
268                CheckStatus::HeaderNotFound => results.no_header_files.push(result.path),
269                CheckStatus::BinaryFile => results.binary_files.push(result.path),
270            }
271        }
272        results
273    }
274}
275
276/// Recursively check for `header` in every file in `root` that matches `path_predicate`.
277///
278/// Checking the discovered files is parallelized across `num_threads` threads.
279///
280/// [`globset`](https://crates.io/crates/globset) is a useful crate for ignoring unwanted files in
281/// `path_predicate`.
282///
283/// Returns a [`FileResults`] object containing the paths without headers detected, and the paths
284/// which were not UTF-8 text.
285pub fn check_headers_recursively(
286    root: &path::Path,
287    path_predicate: impl Fn(&path::Path) -> bool,
288    header: Header<impl HeaderChecker + 'static>,
289    num_threads: usize,
290) -> Result<FileResults, CheckHeadersRecursivelyError> {
291    let (path_tx, path_rx) = crossbeam::channel::unbounded::<path::PathBuf>();
292    let (result_tx, result_rx) = crossbeam::channel::unbounded();
293    // spawn a few threads to handle files in parallel
294    let handles = (0..num_threads)
295        .map(|_| {
296            let path_rx = path_rx.clone();
297            let result_tx = result_tx.clone();
298            let header = header.clone();
299            thread::spawn(move || {
300                for p in path_rx {
301                    match fs::File::open(&p).and_then(|mut f| header.header_present(&mut f)) {
302                        Ok(header_present) => {
303                            if header_present {
304                                // no op
305                            } else {
306                                let res = FileResult {
307                                    path: p,
308                                    status: CheckStatus::HeaderNotFound,
309                                };
310                                result_tx.send(Ok(res)).unwrap();
311                            }
312                        }
313                        Err(e) if e.kind() == io::ErrorKind::InvalidData => {
314                            let res = FileResult {
315                                path: p,
316                                status: CheckStatus::BinaryFile,
317                            };
318                            result_tx.send(Ok(res)).unwrap();
319                        }
320                        Err(e) => result_tx
321                            .send(Err(CheckHeadersRecursivelyError::IoError(p, e)))
322                            .unwrap(),
323                    }
324                }
325                // no more files
326            })
327        })
328        .collect::<Vec<thread::JoinHandle<()>>>();
329    // make sure result channel closes when threads complete
330    drop(result_tx);
331    find_files(root, path_predicate, path_tx)?;
332    let res: FileResults = result_rx.into_iter().collect::<Result<_, _>>()?;
333    for h in handles {
334        h.join().unwrap();
335    }
336    Ok(res)
337}
338
339/// Errors that can occur when checking for headers recursively
340#[derive(Debug, thiserror::Error)]
341pub enum CheckHeadersRecursivelyError {
342    /// An I/O error occurred while checking the path
343    #[error("I/O error at {0:?}: {1}")]
344    IoError(path::PathBuf, io::Error),
345    /// `walkdir` could not navigate the directory structure
346    #[error("Walkdir error: {0}")]
347    WalkdirError(#[from] walkdir::Error),
348}
349
350/// Add the provided `header` to any file in `root` that matches `path_predicate` and that doesn't
351/// already have a header as determined by `checker`.
352///
353/// Returns a list of paths that had headers added.
354pub fn add_headers_recursively(
355    root: &path::Path,
356    path_predicate: impl Fn(&path::Path) -> bool,
357    header: Header<impl HeaderChecker>,
358) -> Result<Vec<path::PathBuf>, AddHeadersRecursivelyError> {
359    // likely no need for threading since adding headers is only done occasionally
360    recursive_optional_operation(root, path_predicate, |p| {
361        header.add_header_if_missing(p).map_err(|e| e.into())
362    })
363}
364
365/// Errors that can occur when adding a header recursively
366#[derive(Debug, thiserror::Error)]
367pub enum AddHeadersRecursivelyError {
368    /// An I/O error occurred while adding the header to the path
369    #[error("I/O error at {0:?}: {1}")]
370    IoError(path::PathBuf, io::Error),
371    /// `walkdir` could not navigate the directory structure
372    #[error("Walkdir error: {0}")]
373    WalkdirError(#[from] walkdir::Error),
374    /// A file with an unrecognized extension was encountered at the path
375    #[error("Unknown file extension: {0:?}")]
376    UnrecognizedExtension(path::PathBuf),
377}
378
379impl From<AddHeaderError> for AddHeadersRecursivelyError {
380    fn from(value: AddHeaderError) -> Self {
381        match value {
382            AddHeaderError::IoError(p, e) => Self::IoError(p, e),
383            AddHeaderError::UnrecognizedExtension(p) => Self::UnrecognizedExtension(p),
384        }
385    }
386}
387
388/// Delete the provided `header` from any file in `root` that matches `path_predicate` and that
389/// already has a header as determined by `header`'s checker.
390///
391/// Returns a list of paths that had headers removed.
392pub fn delete_headers_recursively(
393    root: &path::Path,
394    path_predicate: impl Fn(&path::Path) -> bool,
395    header: Header<impl HeaderChecker>,
396) -> Result<Vec<path::PathBuf>, DeleteHeadersRecursivelyError> {
397    recursive_optional_operation(root, path_predicate, |p| {
398        header.delete_header_if_present(p).map_err(|e| e.into())
399    })
400}
401
402/// Errors that can occur when adding a header recursively
403#[derive(Debug, thiserror::Error)]
404pub enum DeleteHeadersRecursivelyError {
405    /// An I/O error occurred while removing the header from the path
406    #[error("I/O error at {0:?}: {1}")]
407    IoError(path::PathBuf, io::Error),
408    /// `walkdir` could not navigate the directory structure
409    #[error("Walkdir error: {0}")]
410    WalkdirError(#[from] walkdir::Error),
411    /// A file with an unrecognized extension was encountered at the path
412    #[error("Unknown file extension: {0:?}")]
413    UnrecognizedExtension(path::PathBuf),
414}
415
416impl From<DeleteHeaderError> for DeleteHeadersRecursivelyError {
417    fn from(value: DeleteHeaderError) -> Self {
418        match value {
419            DeleteHeaderError::IoError(p, e) => Self::IoError(p, e),
420            DeleteHeaderError::UnrecognizedExtension(p) => Self::UnrecognizedExtension(p),
421        }
422    }
423}
424
425/// Find all files starting from `root` that do not match the globs in `ignore`, publishing the
426/// resulting paths into `dest`.
427fn find_files(
428    root: &path::Path,
429    path_predicate: impl Fn(&path::Path) -> bool,
430    dest: crossbeam::channel::Sender<path::PathBuf>,
431) -> Result<(), walkdir::Error> {
432    for r in walkdir::WalkDir::new(root).into_iter() {
433        let entry = r?;
434        if entry.path().is_dir() || !path_predicate(entry.path()) {
435            continue;
436        }
437        dest.send(entry.into_path()).unwrap()
438    }
439    Ok(())
440}
441
442/// Prepare a header for inclusion in a particular file syntax by wrapping it with
443/// comment characters as per the provided `delim`.
444///
445/// Trailing whitespace will be removed to avoid linters disliking the resulting text.
446fn wrap_header(orig_header: &str, delim: HeaderDelimiters) -> String {
447    let mut out = String::new();
448    if !delim.first_line.is_empty() {
449        out.push_str(delim.first_line);
450        out.push('\n');
451    }
452    // assumes header uses \n
453    for line in orig_header.split('\n') {
454        out.push_str(delim.content_line_prefix);
455        out.push_str(line);
456        // Remove any trailing whitespaces (excluding newlines) from `content_line_prefix + line`.
457        // For example, if `content_line_prefix` is `// ` and `line` is empty, the resulting string
458        // should be truncated to `//`.
459        out.truncate(out.trim_end_matches([' ', '\t']).len());
460        out.push('\n');
461    }
462    if !delim.last_line.is_empty() {
463        out.push_str(delim.last_line);
464        out.push('\n');
465    }
466    out
467}
468
469/// Returns the header prefix line, content line prefix, and suffix line for the extension of the
470/// provided path, or `None` if the extension is not recognized.
471fn header_delimiters(p: &path::Path) -> Option<HeaderDelimiters> {
472    match p
473        .extension()
474        // if the extension isn't UTF-8, oh well
475        .and_then(|os_str| os_str.to_str())
476        .unwrap_or("")
477    {
478        "c" | "h" | "gv" | "java" | "scala" | "kt" | "kts" => Some(("/*", " * ", " */")),
479        "js" | "mjs" | "cjs" | "jsx" | "tsx" | "css" | "scss" | "sass" | "ts" => {
480            Some(("/**", " * ", " */"))
481        }
482        "cc" | "cpp" | "cs" | "go" | "hcl" | "hh" | "hpp" | "m" | "mm" | "proto" | "rs"
483        | "swift" | "dart" | "groovy" | "v" | "sv" => Some(("", "// ", "")),
484        "py" | "sh" | "yaml" | "yml" | "dockerfile" | "rb" | "gemfile" | "tcl" | "tf" | "bzl"
485        | "pl" | "pp" | "build" => Some(("", "# ", "")),
486        "el" | "lisp" => Some(("", ";; ", "")),
487        "erl" => Some(("", "% ", "")),
488        "hs" | "lua" | "sql" | "sdl" => Some(("", "-- ", "")),
489        "html" | "xml" | "vue" | "wxi" | "wxl" | "wxs" => Some(("<!--", " ", "-->")),
490        "php" => Some(("", "// ", "")),
491        "ml" | "mli" | "mll" | "mly" => Some(("(**", "   ", "*)")),
492        // also handle whole filenames if extensions didn't match
493        _ => match p
494            .file_name()
495            .and_then(|os_str| os_str.to_str())
496            .unwrap_or("")
497        {
498            "Dockerfile" => Some(("", "# ", "")),
499            _ => None,
500        },
501    }
502    .map(
503        |(first_line, content_line_prefix, last_line)| HeaderDelimiters {
504            first_line,
505            content_line_prefix,
506            last_line,
507        },
508    )
509}
510
511/// Delimiters to use around and inside a header for a particular file syntax.
512#[derive(Clone, Copy)]
513struct HeaderDelimiters {
514    /// Line to prepend before the header
515    first_line: &'static str,
516    /// Prefix before each line of the header itself
517    content_line_prefix: &'static str,
518    /// Line to append after the header
519    last_line: &'static str,
520}
521
522/// Magic first lines that we need to check for before adding the license text to a file
523const MAGIC_FIRST_LINES: [&str; 8] = [
524    "#!",                       // shell script
525    "<?xml",                    // XML declaratioon
526    "<!doctype",                // HTML doctype
527    "# encoding:",              // Ruby encoding
528    "# frozen_string_literal:", // Ruby interpreter instruction
529    "<?php",                    // PHP opening tag
530    "# escape", // Dockerfile directive https://docs.docker.com/engine/reference/builder/#parser-directives
531    "# syntax", // Dockerfile directive https://docs.docker.com/engine/reference/builder/#parser-directives
532];
533
534/// Apply `operation` to each discovered path in `root` that passes `path_predicate`.
535///
536/// Return the paths for which `operation` took action, as indicated by `operation` returning
537/// `true`.
538fn recursive_optional_operation<E>(
539    root: &path::Path,
540    path_predicate: impl Fn(&path::Path) -> bool,
541    operation: impl Fn(&path::Path) -> Result<bool, E>,
542) -> Result<Vec<path::PathBuf>, E>
543where
544    E: From<walkdir::Error>,
545{
546    let (path_tx, path_rx) = crossbeam::channel::unbounded::<path::PathBuf>();
547    find_files(root, path_predicate, path_tx)?;
548    path_rx
549        .into_iter()
550        // keep the paths for which the operation took action, and the errors
551        .filter_map(|p| match operation(&p) {
552            Ok(operation_applied) => {
553                if operation_applied {
554                    Some(Ok(p))
555                } else {
556                    None
557                }
558            }
559            Err(e) => Some(Err(e)),
560        })
561        .collect::<Result<Vec<_>, _>>()
562}