use std::cmp::max;
use std::ffi::OsStr;
use std::path::Path;
use std::sync::LazyLock;
use anstream::ColorChoice;
use futures::{StreamExt, TryStreamExt};
use prek_consts::env_vars::EnvVars;
use rustc_hash::FxHashMap;
use tracing::trace;
use crate::config::PassFilenames;
use crate::hook::Hook;
use crate::warn_user;
pub(crate) static USE_COLOR: LazyLock<bool> =
LazyLock::new(|| match anstream::Stderr::choice(&std::io::stderr()) {
ColorChoice::Always | ColorChoice::AlwaysAnsi => true,
ColorChoice::Never => false,
ColorChoice::Auto => unreachable!(),
});
fn resolve_concurrency(no_concurrency: bool, max_concurrency: Option<&str>, cpu: usize) -> usize {
if no_concurrency {
return 1;
}
if let Some(v) = max_concurrency {
if let Ok(cap) = v.parse::<usize>() {
return cap.max(1);
}
warn_user!(
"Invalid value for {}: {v:?}, using default ({cpu})",
EnvVars::PREK_MAX_CONCURRENCY,
);
}
cpu
}
pub(crate) static CONCURRENCY: LazyLock<usize> = LazyLock::new(|| {
let cpu = std::thread::available_parallelism()
.map(std::num::NonZero::get)
.unwrap_or(1);
resolve_concurrency(
EnvVars::is_set(EnvVars::PREK_NO_CONCURRENCY),
EnvVars::var(EnvVars::PREK_MAX_CONCURRENCY).ok().as_deref(),
cpu,
)
});
fn target_concurrency(serial: bool) -> usize {
if serial { 1 } else { *CONCURRENCY }
}
struct Partitions<'a> {
filenames: &'a [&'a Path],
current_index: usize,
max_per_batch: usize,
remaining_arg_length: usize,
}
const POINTER_SIZE_CONSERVATIVE: usize = 8;
const ARG_HEADROOM: usize = 2048;
fn environment_variable_size<O: AsRef<OsStr>>(key: O, value: O) -> usize {
POINTER_SIZE_CONSERVATIVE + key.as_ref().len() + 1 + value.as_ref().len() + 1 }
fn arg_size<O: AsRef<OsStr>>(arg: O) -> usize {
POINTER_SIZE_CONSERVATIVE + arg.as_ref().len() + 1 }
#[cfg(unix)]
static ARG_MAX: LazyLock<usize> = LazyLock::new(|| {
let arg_max = unsafe { libc::sysconf(libc::_SC_ARG_MAX) };
if arg_max <= 0 {
1 << 12
} else {
usize::try_from(arg_max).expect("SC_ARG_MAX too large")
}
});
#[cfg(unix)]
static PAGE_SIZE: LazyLock<usize> = LazyLock::new(|| {
let page_size = unsafe { libc::sysconf(libc::_SC_PAGE_SIZE) };
if page_size < 4096 {
4096
} else {
usize::try_from(page_size).expect("SC_PAGE_SIZE too large")
}
});
fn platform_max_cli_length() -> usize {
#[cfg(unix)]
{
let mut arg_max = *ARG_MAX;
arg_max -= *PAGE_SIZE;
arg_max -= ARG_HEADROOM;
arg_max.clamp(1 << 12, 1 << 20)
}
#[cfg(windows)]
{
(1 << 15) - ARG_HEADROOM }
#[cfg(not(any(unix, windows)))]
{
1 << 12
}
}
fn env_size(override_envs: &FxHashMap<String, String>) -> usize {
std::env::vars_os()
.map(|(key, value)| {
if key
.to_str()
.map(|key| override_envs.contains_key(key))
.unwrap_or(false)
{
0
} else {
environment_variable_size(&key, &value)
}
})
.sum::<usize>()
+ override_envs
.iter()
.map(|(key, value)| environment_variable_size(key, value))
.sum::<usize>()
}
impl<'a> Partitions<'a> {
fn split(
hook: &'a Hook,
entry: &'a [String],
filenames: &'a [&'a Path],
concurrency: usize,
) -> anyhow::Result<Self> {
let max_per_batch = match hook.pass_filenames {
PassFilenames::Limited(n) => n.get(),
_ => max(4, filenames.len().div_ceil(concurrency)),
};
let mut arg_max = platform_max_cli_length();
let cmd = Path::new(&entry[0]);
if cfg!(windows)
&& cmd.extension().is_some_and(|ext| {
ext.eq_ignore_ascii_case("cmd") || ext.eq_ignore_ascii_case("bat")
})
{
arg_max = 8192 - 1024;
} else if cfg!(unix) {
arg_max -= env_size(&hook.env);
arg_max -= POINTER_SIZE_CONSERVATIVE;
}
let args_size = entry
.iter()
.chain(hook.args.iter())
.map(arg_size)
.sum::<usize>()
+ POINTER_SIZE_CONSERVATIVE;
if args_size >= arg_max {
anyhow::bail!(
"Command line length ({args_size} bytes) exceeds platform limit ({arg_max} bytes).
\nhint: Shorten the hook `entry`/`args` or wrap the command in a script to reduce command-line length.",
);
}
arg_max -= args_size;
Ok(Self {
filenames,
current_index: 0,
max_per_batch,
remaining_arg_length: arg_max,
})
}
}
impl<'a> Iterator for Partitions<'a> {
type Item = &'a [&'a Path];
fn next(&mut self) -> Option<Self::Item> {
if self.filenames.is_empty() && self.current_index == 0 {
self.current_index = 1;
return Some(&[]);
}
if self.current_index >= self.filenames.len() {
return None;
}
let start_index = self.current_index;
let mut remaining_length = self.remaining_arg_length;
while self.current_index < self.filenames.len() {
let filename = self.filenames[self.current_index];
let length = arg_size(filename);
if length > remaining_length || self.current_index - start_index >= self.max_per_batch {
break;
}
remaining_length -= length;
self.current_index += 1;
}
if self.current_index == start_index {
let filename = self.filenames[self.current_index];
let length = arg_size(filename);
panic!(
"Filename `{}` is too long ({length} bytes) to fit in command line (remaining {remaining_length} bytes).",
filename.display(),
);
} else {
Some(&self.filenames[start_index..self.current_index])
}
}
}
pub(crate) async fn run_by_batch<T, F>(
hook: &Hook,
filenames: &[&Path],
entry: &[String],
run: F,
) -> anyhow::Result<Vec<T>>
where
F: for<'a> AsyncFn(&'a [&'a Path]) -> anyhow::Result<T>,
T: Send + 'static,
{
let concurrency = target_concurrency(hook.require_serial);
let partitions = Partitions::split(hook, entry, filenames, concurrency)?;
trace!(
total_files = filenames.len(),
concurrency = concurrency,
"Running {}",
hook.id,
);
#[allow(clippy::redundant_closure)]
let results: Vec<_> = futures::stream::iter(partitions)
.map(|batch| run(batch))
.buffered(concurrency)
.try_collect()
.await?;
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::{Path, PathBuf};
fn create_test_partitions<'a>(
filenames: &'a [&'a Path],
remaining_arg_length: usize,
max_per_batch: usize,
) -> Partitions<'a> {
Partitions {
filenames,
current_index: 0,
remaining_arg_length,
max_per_batch,
}
}
#[test]
fn test_partitions_normal_filenames() {
let file1 = PathBuf::from("file1.txt");
let file2 = PathBuf::from("file2.txt");
let file3 = PathBuf::from("file3.txt");
let filenames: Vec<&Path> = vec![&file1, &file2, &file3];
let partitions = create_test_partitions(&filenames, 4096, 10);
let total_files: usize = partitions.map(<[&Path]>::len).sum();
assert_eq!(total_files, 3);
}
#[test]
fn test_partitions_empty_filenames() {
let filenames: Vec<&Path> = vec![];
let mut partitions = create_test_partitions(&filenames, 4096, 10);
let batch = partitions.next();
assert!(batch.is_some());
assert_eq!(batch.unwrap().len(), 0);
let batch = partitions.next();
assert!(batch.is_none());
}
#[test]
#[should_panic(expected = "is too long")]
fn test_partitions_long_filename_in_middle_panics() {
let file1 = PathBuf::from("file1.txt");
let long_name = "a".repeat(5000);
let long_file = PathBuf::from(&long_name);
let file3 = PathBuf::from("file3.txt");
let filenames: Vec<&Path> = vec![&file1, &long_file, &file3];
let mut partitions = create_test_partitions(&filenames, 1000, 10);
let batch1 = partitions.next();
assert!(batch1.is_some());
partitions.next();
}
#[test]
fn test_partitions_respects_max_per_batch() {
let files: Vec<PathBuf> = (0..100)
.map(|i| PathBuf::from(format!("f{i}.txt")))
.collect();
let file_refs: Vec<&Path> = files.iter().map(PathBuf::as_path).collect();
let partitions = create_test_partitions(&file_refs, 100_000, 25);
let all_batches: Vec<_> = partitions.map(<[&Path]>::len).collect();
assert!(all_batches.len() >= 4);
let total_files: usize = all_batches.iter().sum();
assert_eq!(total_files, 100);
}
#[test]
fn test_resolve_concurrency_defaults_to_cpu() {
assert_eq!(resolve_concurrency(false, None, 16), 16);
}
#[test]
fn test_resolve_concurrency_max_caps_value() {
assert_eq!(resolve_concurrency(false, Some("4"), 16), 4);
}
#[test]
fn test_resolve_concurrency_max_above_cpu() {
assert_eq!(resolve_concurrency(false, Some("32"), 8), 32);
}
#[test]
fn test_resolve_concurrency_max_zero_floors_to_one() {
assert_eq!(resolve_concurrency(false, Some("0"), 16), 1);
}
#[test]
fn test_resolve_concurrency_max_invalid_falls_back() {
assert_eq!(resolve_concurrency(false, Some("abc"), 16), 16);
}
#[test]
fn test_resolve_concurrency_max_empty_falls_back() {
assert_eq!(resolve_concurrency(false, Some(""), 16), 16);
}
#[test]
fn test_resolve_concurrency_no_concurrency() {
assert_eq!(resolve_concurrency(true, None, 16), 1);
}
#[test]
fn test_resolve_concurrency_no_concurrency_overrides_max() {
assert_eq!(resolve_concurrency(true, Some("8"), 16), 1);
}
#[test]
fn test_partitions_respects_cli_length_limit() {
let files: Vec<PathBuf> = (0..10)
.map(|i| PathBuf::from(format!("file{i}.txt")))
.collect();
let file_refs: Vec<&Path> = files.iter().map(PathBuf::as_path).collect();
let partitions = create_test_partitions(&file_refs, 100, 100);
let all_batches: Vec<_> = partitions.map(<[&Path]>::len).collect();
assert!(all_batches.len() > 1);
let total_files: usize = all_batches.iter().sum();
assert_eq!(total_files, 10);
}
}