1pub mod config;
4pub mod error;
5pub mod file_ops;
6pub mod git_ops;
7pub mod regex_ops;
8
9pub use config::Config;
10pub use error::CError;
11use file_ops::read_write_copyright;
12use futures::future::join_all;
13use futures::FutureExt;
14use git_ops::check_for_changes;
15use git_ops::get_added_mod_times_for_file;
16use git_ops::get_files_on_ref;
17use regex_ops::CopyrightCache;
18use regex_ops::{generate_base_regex, generate_copyright_line};
19use serde::Deserialize;
20use std::collections::hash_map::DefaultHasher;
21use std::hash::Hasher;
22use std::path::Path;
23
24#[derive(Debug, Deserialize, Hash, PartialEq)]
25#[serde(untagged)]
26pub enum CommentSign {
27 LeftOnly(String),
28 Enclosing(String, String),
29}
30
31pub async fn check_repo_copyright(
32 repo_path_str: &str,
33 name: &str,
34 fail_on_diff: bool,
35) -> Result<(), CError> {
36 let config = Config::global();
37 let repo_path = Path::new(repo_path_str);
38 let files_to_check = get_files_on_ref(repo_path_str, "HEAD").await?;
39 let files_to_check: Vec<&String> = config
40 .filter_files(files_to_check.iter())
41 .into_iter()
42 .filter(|f| repo_path.join(Path::new(f)).is_file())
43 .collect();
44
45 println!("Checking {} files", files_to_check.len());
46
47 let base_regex = generate_base_regex(name);
48 let regex_cache = CopyrightCache::new(&base_regex);
49
50 let check_and_fix_futures: Vec<_> = files_to_check
51 .iter()
52 .map(|filepath| check_file_copyright(filepath, repo_path_str, name, ®ex_cache))
53 .collect();
54
55 let results = join_all(check_and_fix_futures).await;
56 let failed: Vec<_> = results.iter().filter(|res| res.is_err()).collect();
57 failed.iter().for_each(|res_err| {
58 println!("Error: {}", res_err.as_ref().unwrap_err());
59 });
60
61 if !failed.is_empty() {
62 return Err(CError::FixError);
63 }
64
65 check_for_changes(repo_path_str, fail_on_diff).await?;
66
67 Ok(())
68}
69
70async fn check_file_copyright(
71 filepath: &str,
72 repo_path: &str,
73 name: &str,
74 regex_cache: &CopyrightCache,
75) -> Result<(), CError> {
76 let comment_sign = Config::global().get_comment_sign(filepath)?;
77 let years_fut = get_added_mod_times_for_file(filepath, repo_path).shared();
78 let copyright_line_fut = generate_copyright_line(name, comment_sign, years_fut.clone());
79 let filepath = Path::new(repo_path).join(filepath);
80 let regex = regex_cache.get_regex(comment_sign)?;
81 read_write_copyright(filepath, regex, years_fut, copyright_line_fut).await
82}
83
84pub fn get_hash<T: std::hash::Hash>(obj: &T) -> u64 {
85 let mut hasher = DefaultHasher::new();
86 obj.hash(&mut hasher);
87 hasher.finish()
88}