use clap::{Parser, ValueEnum};
use std::path::PathBuf;
#[derive(Parser, Debug)]
#[command(name = "pfind")]
#[command(author, version, about, long_about = None)]
pub struct Cli {
#[arg(default_value = ".")]
pub root: PathBuf,
#[arg(short = 't', long, value_enum)]
pub file_type: Option<FileType>,
#[arg(short, long, value_delimiter = ',')]
pub ext: Vec<String>,
#[arg(long)]
pub min_size: Option<String>,
#[arg(long)]
pub max_size: Option<String>,
#[arg(long, value_delimiter = ',')]
pub skip: Vec<String>,
#[arg(long)]
pub no_skip_defaults: bool,
#[arg(short, long)]
pub count: bool,
#[arg(short = 's', long)]
pub sum_size: bool,
#[arg(short = 'l', long)]
pub long: bool,
#[arg(short = 'j', long)]
pub threads: Option<usize>,
#[arg(short = 'd', long)]
pub dirs: bool,
#[arg(short = 'n', long)]
pub name: Option<String>,
}
#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)]
pub enum FileType {
Models,
Checkpoints,
Data,
Images,
Audio,
Video,
Configs,
Python,
Text,
Logs,
}
impl FileType {
pub fn extensions(&self) -> &'static [&'static str] {
match self {
FileType::Models => &[
"pt",
"pth",
"ckpt",
"safetensors",
"onnx",
"h5",
"hdf5",
"pb",
"tflite",
"bin",
"model",
"weights",
],
FileType::Checkpoints => &["ckpt", "pt", "pth", "safetensors"],
FileType::Data => &[
"csv", "parquet", "arrow", "tfrecord", "jsonl", "json", "tsv", "npy", "npz", "pkl",
"pickle", "feather", "hdf5", "h5",
],
FileType::Images => &[
"jpg", "jpeg", "png", "webp", "bmp", "tiff", "tif", "gif", "ico",
],
FileType::Audio => &["wav", "mp3", "flac", "ogg", "m4a", "aac", "wma"],
FileType::Video => &["mp4", "avi", "mkv", "mov", "webm", "flv", "wmv", "m4v"],
FileType::Configs => &["yaml", "yml", "json", "toml", "cfg", "ini", "conf"],
FileType::Python => &["py", "pyi", "ipynb", "pyx", "pxd"],
FileType::Text => &["txt", "md", "rst", "tex", "rtf"],
FileType::Logs => &["log", "out", "err"],
}
}
}
#[derive(Debug, Clone, Default)]
pub struct FilterConfig {
pub extensions: Vec<String>,
pub skip_dirs: Vec<String>,
pub min_size: Option<u64>,
pub max_size: Option<u64>,
pub name_pattern: Option<String>,
pub match_dirs: bool,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct OutputConfig {
pub count_only: bool,
pub sum_size: bool,
pub long_format: bool,
}
impl Cli {
pub fn filter_config(&self) -> FilterConfig {
FilterConfig {
extensions: self.all_extensions(),
skip_dirs: self.skip_dirs().into_iter().map(String::from).collect(),
min_size: self.min_size_bytes(),
max_size: self.max_size_bytes(),
name_pattern: self.name.clone(),
match_dirs: self.dirs,
}
}
pub fn output_config(&self) -> OutputConfig {
OutputConfig {
count_only: self.count,
sum_size: self.sum_size,
long_format: self.long,
}
}
pub fn all_extensions(&self) -> Vec<String> {
let mut exts: Vec<String> = self
.ext
.iter()
.map(|e| e.trim_start_matches('.').to_lowercase())
.collect();
if let Some(ft) = &self.file_type {
for ext in ft.extensions() {
let e = ext.to_string();
if !exts.contains(&e) {
exts.push(e);
}
}
}
exts
}
pub fn skip_dirs(&self) -> Vec<&str> {
let mut dirs: Vec<&str> = self.skip.iter().map(|s| s.as_str()).collect();
if !self.no_skip_defaults {
const DEFAULTS: &[&str] = &[
".git",
".hg",
".svn",
"__pycache__",
".pytest_cache",
".mypy_cache",
".ruff_cache",
"node_modules",
".venv",
"venv",
".env",
"env",
".tox",
".nox",
".eggs",
"*.egg-info",
"build",
"dist",
".ipynb_checkpoints",
"wandb", "mlruns", "lightning_logs", "outputs", ".cache",
"__MACOSX",
".DS_Store",
"Thumbs.db",
];
for d in DEFAULTS {
if !dirs.contains(d) {
dirs.push(d);
}
}
}
dirs
}
pub fn parse_size(s: &str) -> Result<u64, String> {
let s = s.trim().to_uppercase();
let (num, mult): (&str, u64) = if s.ends_with("G") || s.ends_with("GB") {
(
s.trim_end_matches("GB").trim_end_matches("G"),
1024 * 1024 * 1024,
)
} else if s.ends_with("M") || s.ends_with("MB") {
(s.trim_end_matches("MB").trim_end_matches("M"), 1024 * 1024)
} else if s.ends_with("K") || s.ends_with("KB") {
(s.trim_end_matches("KB").trim_end_matches("K"), 1024)
} else if s.ends_with("B") {
(s.trim_end_matches("B"), 1)
} else {
(&s, 1)
};
num.trim()
.parse::<f64>()
.map(|n| (n * mult as f64) as u64)
.map_err(|_| format!("Invalid size: {}", s))
}
pub fn min_size_bytes(&self) -> Option<u64> {
self.min_size
.as_ref()
.and_then(|s| Self::parse_size(s).ok())
}
pub fn max_size_bytes(&self) -> Option<u64> {
self.max_size
.as_ref()
.and_then(|s| Self::parse_size(s).ok())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_size() {
assert_eq!(Cli::parse_size("100").unwrap(), 100);
assert_eq!(Cli::parse_size("100B").unwrap(), 100);
assert_eq!(Cli::parse_size("1K").unwrap(), 1024);
assert_eq!(Cli::parse_size("1KB").unwrap(), 1024);
assert_eq!(Cli::parse_size("1M").unwrap(), 1024 * 1024);
assert_eq!(
Cli::parse_size("1.5G").unwrap(),
(1.5 * 1024.0 * 1024.0 * 1024.0) as u64
);
}
#[test]
fn test_file_type_extensions() {
let exts = FileType::Models.extensions();
assert!(exts.contains(&"pt"));
assert!(exts.contains(&"safetensors"));
assert!(exts.contains(&"onnx"));
}
}