use std::path::{Path, PathBuf};
use anyhow::{Context, Result};
use itertools::{Either, Itertools};
use path_clean::PathClean;
use prek_consts::env_vars::EnvVars;
use prek_identify::{TagSet, tags_from_path};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use rustc_hash::FxHashSet;
use tracing::{debug, error, instrument};
use crate::config::{FilePattern, Stage};
use crate::git::GIT_ROOT;
use crate::hook::Hook;
use crate::workspace::Project;
use crate::{fs, git, warn_user};
pub(crate) struct FilenameFilter<'a> {
include: Option<&'a FilePattern>,
exclude: Option<&'a FilePattern>,
}
impl<'a> FilenameFilter<'a> {
pub(crate) fn new(include: Option<&'a FilePattern>, exclude: Option<&'a FilePattern>) -> Self {
Self { include, exclude }
}
pub(crate) fn filter(&self, filename: &Path) -> bool {
let Some(filename) = filename.to_str() else {
return false;
};
if let Some(pattern) = &self.include {
if !pattern.is_match(filename) {
return false;
}
}
if let Some(pattern) = &self.exclude {
if pattern.is_match(filename) {
return false;
}
}
true
}
}
pub(crate) struct FileTagFilter<'a> {
all: Option<&'a TagSet>,
any: Option<&'a TagSet>,
exclude: Option<&'a TagSet>,
}
impl<'a> FileTagFilter<'a> {
fn new(
types: Option<&'a TagSet>,
types_or: Option<&'a TagSet>,
exclude_types: Option<&'a TagSet>,
) -> Self {
Self {
all: types,
any: types_or,
exclude: exclude_types,
}
}
pub(crate) fn filter(&self, file_types: &TagSet) -> bool {
if self.all.is_some_and(|s| !s.is_subset(file_types)) {
return false;
}
if self
.any
.is_some_and(|s| !s.is_empty() && s.is_disjoint(file_types))
{
return false;
}
if self.exclude.is_some_and(|s| !s.is_disjoint(file_types)) {
return false;
}
true
}
}
pub(crate) struct FileFilter<'a> {
filenames: Vec<&'a Path>,
filename_prefix: PathBuf,
}
impl<'a> FileFilter<'a> {
#[instrument(level = "trace", skip_all, fields(project = %project))]
pub(crate) fn for_project<I>(
filenames: I,
project: &Project,
mut consumed_files: Option<&mut FxHashSet<&'a Path>>,
) -> Self
where
I: Iterator<Item = &'a PathBuf> + Send,
{
let filter = FilenameFilter::new(
project.config().files.as_ref(),
project.config().exclude.as_ref(),
);
let relative_path = project.relative_path();
let orphan = project.config().orphan.unwrap_or(false);
let filenames = filenames
.map(PathBuf::as_path)
.filter(|filename| filename.starts_with(relative_path))
.filter(|filename| {
if let Some(consumed_files) = consumed_files.as_mut() {
if orphan {
return consumed_files.insert(filename);
}
!consumed_files.contains(filename)
} else {
true
}
})
.filter(|filename| {
let relative = filename
.strip_prefix(relative_path)
.expect("Filename should start with project relative path");
filter.filter(relative)
})
.collect::<Vec<_>>();
Self {
filenames,
filename_prefix: relative_path.to_path_buf(),
}
}
pub(crate) fn len(&self) -> usize {
self.filenames.len()
}
pub(crate) fn by_type(
&self,
types: Option<&TagSet>,
types_or: Option<&TagSet>,
exclude_types: Option<&TagSet>,
) -> Vec<&Path> {
let filter = FileTagFilter::new(types, types_or, exclude_types);
let filenames: Vec<_> = self
.filenames
.par_iter()
.filter(|filename| match tags_from_path(filename) {
Ok(tags) => filter.filter(&tags),
Err(err) => {
error!(filename = ?filename.display(), error = %err, "Failed to get tags");
false
}
})
.copied()
.collect();
filenames
}
#[instrument(level = "trace", skip_all, fields(hook = ?hook.id))]
pub(crate) fn for_hook(&self, hook: &Hook) -> Vec<&Path> {
let filter = FilenameFilter::new(hook.files.as_ref(), hook.exclude.as_ref());
let filenames = self.filenames.par_iter().filter(|filename| {
if let Ok(relative) = filename.strip_prefix(&self.filename_prefix) {
filter.filter(relative)
} else {
false
}
});
let filter = FileTagFilter::new(
Some(&hook.types),
Some(&hook.types_or),
Some(&hook.exclude_types),
);
let filenames = filenames.filter(|filename| match tags_from_path(filename) {
Ok(tags) => filter.filter(&tags),
Err(err) => {
error!(filename = ?filename.display(), error = %err, "Failed to get tags");
false
}
});
let filenames: Vec<_> = filenames
.map(|p| {
p.strip_prefix(&self.filename_prefix)
.expect("Filename should start with project relative path")
})
.collect();
filenames
}
}
#[derive(Default)]
pub(crate) struct CollectOptions {
pub(crate) hook_stage: Stage,
pub(crate) from_ref: Option<String>,
pub(crate) to_ref: Option<String>,
pub(crate) all_files: bool,
pub(crate) files: Vec<String>,
pub(crate) directories: Vec<String>,
pub(crate) commit_msg_filename: Option<String>,
}
impl CollectOptions {
pub(crate) fn all_files() -> Self {
Self {
all_files: true,
..Default::default()
}
}
}
#[instrument(level = "trace", skip_all)]
pub(crate) async fn collect_files(root: &Path, opts: CollectOptions) -> Result<Vec<PathBuf>> {
let CollectOptions {
hook_stage,
from_ref,
to_ref,
all_files,
files,
directories,
commit_msg_filename,
} = opts;
let git_root = GIT_ROOT.as_ref()?;
let relative_root = root.strip_prefix(git_root).with_context(|| {
format!(
"Workspace root `{}` is not under git root `{}`",
root.display(),
git_root.display()
)
})?;
let filenames = collect_files_from_args(
git_root,
root,
hook_stage,
from_ref,
to_ref,
all_files,
files,
directories,
commit_msg_filename,
)
.await?;
let mut filenames = filenames
.into_iter()
.filter_map(|filename| {
filename
.strip_prefix(relative_root)
.map(|p| fs::normalize_path(p.to_path_buf()))
.ok()
})
.collect::<Vec<_>>();
if EnvVars::is_set(EnvVars::PREK_INTERNAL__SORT_FILENAMES) {
filenames.sort_unstable();
}
Ok(filenames)
}
fn adjust_relative_path(path: &str, new_cwd: &Path) -> Result<PathBuf, std::io::Error> {
let absolute = std::path::absolute(path)?.clean();
fs::relative_to(absolute, new_cwd)
}
#[allow(clippy::too_many_arguments)]
async fn collect_files_from_args(
git_root: &Path,
workspace_root: &Path,
hook_stage: Stage,
from_ref: Option<String>,
to_ref: Option<String>,
all_files: bool,
files: Vec<String>,
directories: Vec<String>,
commit_msg_filename: Option<String>,
) -> Result<Vec<PathBuf>> {
if !hook_stage.operate_on_files() {
return Ok(vec![]);
}
if hook_stage == Stage::PrepareCommitMsg || hook_stage == Stage::CommitMsg {
let path = commit_msg_filename.expect("commit_msg_filename should be set");
let path = adjust_relative_path(&path, git_root)?;
return Ok(vec![path]);
}
if let (Some(from_ref), Some(to_ref)) = (from_ref, to_ref) {
let files = git::get_changed_files(&from_ref, &to_ref, workspace_root).await?;
debug!(
"Files changed between {} and {}: {}",
from_ref,
to_ref,
files.len()
);
return Ok(files);
}
if !files.is_empty() || !directories.is_empty() {
let (exists, non_exists): (FxHashSet<_>, Vec<_>) =
files.into_iter().partition_map(|filename| {
if std::fs::exists(&filename).unwrap_or(false) {
Either::Left(filename)
} else {
Either::Right(filename)
}
});
if !non_exists.is_empty() {
if non_exists.len() == 1 {
warn_user!(
"This file does not exist and will be ignored: `{}`",
non_exists[0]
);
} else {
warn_user!(
"These files do not exist and will be ignored: `{}`",
non_exists.join(", ")
);
}
}
let mut exists = exists
.into_iter()
.map(|filename| adjust_relative_path(&filename, git_root).map(fs::normalize_path))
.collect::<Result<FxHashSet<_>, _>>()?;
for dir in directories {
let dir = adjust_relative_path(&dir, git_root)?;
let dir_files = git::ls_files(git_root, &dir).await?;
for file in dir_files {
let file = fs::normalize_path(file);
exists.insert(file);
}
}
debug!("Files passed as arguments: {}", exists.len());
return Ok(exists.into_iter().collect());
}
if all_files {
let files = git::ls_files(git_root, workspace_root).await?;
debug!("All files in the workspace: {}", files.len());
return Ok(files);
}
if git::is_in_merge_conflict().await? {
let files = git::get_conflicted_files(workspace_root).await?;
debug!("Conflicted files: {}", files.len());
return Ok(files);
}
let files = git::get_staged_files(workspace_root).await?;
debug!("Staged files: {}", files.len());
Ok(files)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::GlobPatterns;
fn glob_pattern(pattern: &str) -> FilePattern {
FilePattern::Glob(GlobPatterns::new(vec![pattern.to_string()]).unwrap())
}
#[test]
fn filename_filter_supports_glob_include_and_exclude() {
let include = glob_pattern("src/**/*.rs");
let exclude = glob_pattern("src/**/ignored.rs");
let filter = FilenameFilter::new(Some(&include), Some(&exclude));
assert!(filter.filter(Path::new("src/lib/main.rs")));
assert!(!filter.filter(Path::new("src/lib/ignored.rs")));
assert!(!filter.filter(Path::new("tests/main.rs")));
}
}