use std::path::PathBuf;
use clap::{Args, Parser, Subcommand, ValueEnum};
use clap_complete::Shell;
const SUBCOMMAND_NAMES: &[&str] = &[
"stream",
"serve",
"convert",
"info",
"analyze",
"index",
"verify",
"completions",
"man",
"help",
];
#[derive(Parser, Debug)]
#[command(
name = "shuflr",
version,
about = "Stream large JSONL in shuffled order, without loading it into memory.",
long_about = "shuflr streams records from JSONL files (optionally compressed) in \
shuffled order without loading the file into memory. Works as a CLI \
pipe; builds with --features serve also expose HTTP and shuflr-wire/1 \
listeners. See the subcommands for specific workflows; \
`shuflr file.jsonl` is shorthand for `shuflr stream file.jsonl`.",
disable_help_subcommand = true,
arg_required_else_help = true
)]
pub struct Cli {
#[command(subcommand)]
pub command: Command,
}
#[derive(Subcommand, Debug)]
pub enum Command {
Stream(StreamArgs),
#[cfg(feature = "serve")]
Serve(ServeArgs),
#[cfg(feature = "zstd")]
Convert(ConvertArgs),
#[cfg(feature = "zstd")]
Info(InfoArgs),
Analyze(AnalyzeArgs),
Index(IndexArgs),
Verify(VerifyArgs),
Completions(CompletionsArgs),
Man(ManArgs),
}
impl Command {
pub fn log_level(&self) -> &str {
match self {
Self::Stream(a) => &a.log_level,
#[cfg(feature = "serve")]
Self::Serve(a) => &a.log_level,
#[cfg(feature = "zstd")]
Self::Convert(a) => &a.log_level,
#[cfg(feature = "zstd")]
Self::Info(_) => "info",
Self::Analyze(_) => "info",
Self::Index(_) => "info",
Self::Verify(_) => "info",
Self::Completions(_) => "warn",
Self::Man(_) => "warn",
}
}
}
#[derive(Args, Debug, Clone)]
pub struct InputArgs {
#[arg(value_name = "INPUT", default_value = "-")]
pub inputs: Vec<PathBuf>,
}
#[derive(Args, Debug)]
pub struct StreamArgs {
#[command(flatten)]
pub input: InputArgs,
#[arg(
short = 's',
long,
value_enum,
default_value_t = ShuffleMode::ChunkShuffled,
value_name = "MODE",
)]
pub shuffle: ShuffleMode,
#[arg(long, env = "SHUFLR_SEED", value_name = "U64")]
pub seed: Option<u64>,
#[arg(short = 'n', long, value_name = "N")]
pub sample: Option<u64>,
#[arg(short = 'e', long, default_value_t = 1, value_name = "N")]
pub epochs: u64,
#[arg(long, requires = "world_size", value_name = "R")]
pub rank: Option<u32>,
#[arg(long, requires = "rank", value_name = "W")]
pub world_size: Option<u32>,
#[arg(long, value_enum, default_value_t = OnErrorPolicy::Skip, value_name = "POLICY")]
pub on_error: OnErrorPolicy,
#[arg(long, default_value = "16MiB", value_parser = parse_bytes, value_name = "BYTES")]
pub max_line: u64,
#[arg(long, default_value_t = 100_000, value_name = "K")]
pub buffer_size: u64,
#[arg(long, default_value_t = 0, value_name = "N")]
pub build_threads: usize,
#[arg(long, default_value_t = 1, value_name = "N")]
pub emit_threads: usize,
#[arg(long, default_value_t = 32, value_name = "K")]
pub emit_prefetch: usize,
#[arg(long, default_value_t = 10_000, value_name = "K")]
pub reservoir_size: u64,
#[arg(long, value_enum, default_value_t = When::Auto, value_name = "WHEN")]
pub progress: When,
#[arg(long, env = "SHUFLR_LOG", default_value = "info", value_name = "LEVEL")]
pub log_level: String,
}
#[cfg(feature = "serve")]
#[derive(Args, Debug)]
pub struct ServeArgs {
#[arg(long, value_name = "ADDR")]
pub http: Option<String>,
#[arg(long, value_name = "ADDR")]
pub wire: Option<String>,
#[cfg(feature = "grpc")]
#[arg(long, value_name = "ADDR")]
pub grpc: Option<String>,
#[arg(long = "dataset", value_name = "ID=PATH")]
pub datasets: Vec<String>,
#[arg(long)]
pub bind_public: bool,
#[arg(long)]
pub insecure_public: bool,
#[arg(long, value_name = "PATH")]
pub tls_cert: Option<std::path::PathBuf>,
#[arg(long, value_name = "PATH")]
pub tls_key: Option<std::path::PathBuf>,
#[arg(long, value_name = "PATH")]
pub tls_client_ca: Option<std::path::PathBuf>,
#[arg(long, value_enum, default_value_t = AuthKind::None, value_name = "KIND")]
pub auth: AuthKind,
#[arg(long, value_name = "PATH")]
pub auth_tokens: Option<std::path::PathBuf>,
#[arg(long, env = "SHUFLR_LOG", default_value = "info", value_name = "LEVEL")]
pub log_level: String,
}
#[cfg(feature = "serve")]
#[derive(Copy, Clone, Debug, ValueEnum, PartialEq, Eq)]
pub enum AuthKind {
None,
Bearer,
Mtls,
}
#[cfg(feature = "zstd")]
#[derive(Args, Debug)]
pub struct ConvertArgs {
#[command(flatten)]
pub input: InputArgs,
#[arg(short = 'o', long, value_name = "PATH")]
pub output: PathBuf,
#[arg(short = 'l', long, default_value_t = 3, value_parser = clap::value_parser!(u32).range(1..=22))]
pub level: u32,
#[arg(short = 'f', long, default_value = "2MiB", value_parser = parse_bytes, value_name = "BYTES")]
pub frame_size: u64,
#[arg(short = 'T', long, default_value_t = 0, value_name = "N")]
pub threads: u32,
#[arg(long, value_enum, default_value_t = InputFormat::Auto)]
pub input_format: InputFormat,
#[arg(long)]
pub no_checksum: bool,
#[arg(long)]
pub no_record_align: bool,
#[arg(long)]
pub verify: bool,
#[arg(short = 'n', long, value_name = "N")]
pub limit: Option<u64>,
#[arg(long, value_name = "P", value_parser = parse_probability)]
pub sample_rate: Option<f64>,
#[arg(long, value_name = "BITS", value_parser = parse_entropy_bits)]
pub min_entropy: Option<f64>,
#[arg(long, value_name = "BITS", value_parser = parse_entropy_bits)]
pub max_entropy: Option<f64>,
#[arg(long, env = "SHUFLR_SEED", value_name = "U64")]
pub seed: Option<u64>,
#[cfg(feature = "parquet")]
#[arg(long, value_name = "COL1,COL2,...", value_delimiter = ',')]
pub parquet_project: Option<Vec<String>>,
#[arg(long, value_enum, default_value_t = When::Auto, value_name = "WHEN")]
pub progress: When,
#[arg(long, env = "SHUFLR_LOG", default_value = "info", value_name = "LEVEL")]
pub log_level: String,
}
#[cfg(feature = "zstd")]
#[derive(Args, Debug)]
pub struct InfoArgs {
#[arg(value_name = "FILE")]
pub input: PathBuf,
#[arg(long)]
pub json: bool,
}
#[derive(Args, Debug)]
pub struct AnalyzeArgs {
#[command(flatten)]
pub input: InputArgs,
#[arg(long, default_value_t = 32, value_name = "N")]
pub sample_chunks: u32,
#[arg(long)]
pub strict: bool,
#[arg(long)]
pub json: bool,
}
#[derive(Args, Debug)]
pub struct IndexArgs {
#[command(flatten)]
pub input: InputArgs,
#[arg(short = 'o', long, value_name = "PATH")]
pub output: Option<PathBuf>,
#[arg(long, default_value_t = 0, value_name = "N")]
pub threads: usize,
}
#[derive(Args, Debug)]
pub struct VerifyArgs {
#[command(flatten)]
pub input: InputArgs,
#[arg(long)]
pub deep: bool,
}
#[derive(Args, Debug)]
pub struct CompletionsArgs {
#[arg(value_enum)]
pub shell: Shell,
}
#[derive(Args, Debug)]
pub struct ManArgs {
#[arg(value_name = "SUBCOMMAND")]
pub subcommand: Option<String>,
}
#[derive(Copy, Clone, Debug, ValueEnum, PartialEq, Eq)]
pub enum ShuffleMode {
None,
ChunkRr,
ChunkShuffled,
IndexPerm,
Buffer,
Reservoir,
}
#[derive(Copy, Clone, Debug, ValueEnum, PartialEq, Eq)]
pub enum OnErrorPolicy {
Skip,
Fail,
Passthrough,
}
impl From<OnErrorPolicy> for shuflr::OnError {
fn from(p: OnErrorPolicy) -> Self {
match p {
OnErrorPolicy::Skip => shuflr::OnError::Skip,
OnErrorPolicy::Fail => shuflr::OnError::Fail,
OnErrorPolicy::Passthrough => shuflr::OnError::Passthrough,
}
}
}
#[derive(Copy, Clone, Debug, ValueEnum, PartialEq, Eq)]
pub enum When {
Never,
Auto,
Always,
}
#[derive(Copy, Clone, Debug, ValueEnum, PartialEq, Eq)]
pub enum InputFormat {
Auto,
Plain,
Gzip,
Zstd,
#[cfg(feature = "bzip2")]
Bz2,
#[cfg(feature = "xz")]
Xz,
}
fn parse_entropy_bits(raw: &str) -> std::result::Result<f64, String> {
let h: f64 = raw
.parse()
.map_err(|e| format!("invalid entropy '{raw}': {e}"))?;
if !(0.0..=8.0).contains(&h) {
return Err(format!(
"entropy {h} bits is outside [0, 8] (max entropy for a byte is 8 bits)"
));
}
Ok(h)
}
fn parse_probability(raw: &str) -> std::result::Result<f64, String> {
let p: f64 = raw
.parse()
.map_err(|e| format!("invalid probability '{raw}': {e}"))?;
if !(0.0..=1.0).contains(&p) {
return Err(format!(
"probability {p} is outside [0.0, 1.0] (pass e.g. 0.01 for 1%)"
));
}
Ok(p)
}
pub(crate) fn parse_bytes(raw: &str) -> std::result::Result<u64, String> {
let raw = raw.trim();
let (num, suffix) = raw
.find(|c: char| !c.is_ascii_digit() && c != '.')
.map(|i| raw.split_at(i))
.unwrap_or((raw, ""));
let num: f64 = num
.parse()
.map_err(|e| format!("invalid number '{num}': {e}"))?;
let mult: u64 = match suffix.trim().to_ascii_lowercase().as_str() {
"" | "b" => 1,
"k" | "kb" => 1_000,
"ki" | "kib" => 1 << 10,
"m" | "mb" => 1_000_000,
"mi" | "mib" => 1 << 20,
"g" | "gb" => 1_000_000_000,
"gi" | "gib" => 1 << 30,
"t" | "tb" => 1_000_000_000_000,
"ti" | "tib" => 1 << 40,
other => return Err(format!("unknown byte suffix '{other}' (try KiB, MiB, GiB)")),
};
Ok((num * mult as f64) as u64)
}
pub fn parse() -> Cli {
let raw: Vec<std::ffi::OsString> = std::env::args_os().collect();
Cli::parse_from(rewrite_implicit_stream(raw))
}
fn rewrite_implicit_stream(mut argv: Vec<std::ffi::OsString>) -> Vec<std::ffi::OsString> {
if argv.len() < 2 {
return argv;
}
let explicit_subcommand = argv[1..].iter().any(|a| {
let s = a.to_string_lossy();
SUBCOMMAND_NAMES.contains(&s.as_ref())
});
if explicit_subcommand {
return argv;
}
let only_top_level_pass = argv[1..].iter().all(|a| {
matches!(
a.to_string_lossy().as_ref(),
"--help" | "-h" | "--help-full" | "--version" | "-V"
)
});
if only_top_level_pass {
return argv;
}
argv.insert(1, std::ffi::OsString::from("stream"));
argv
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn implicit_stream_bare_file() {
let in_ = vec!["shuflr".into(), "data.jsonl".into()];
let out = rewrite_implicit_stream(in_);
assert_eq!(
out,
vec!["shuflr", "stream", "data.jsonl"]
.into_iter()
.map(std::ffi::OsString::from)
.collect::<Vec<_>>()
);
}
#[test]
fn explicit_subcommand_preserved() {
let in_ = vec!["shuflr".into(), "verify".into(), "x.jsonl".into()];
let out = rewrite_implicit_stream(in_.clone());
assert_eq!(out, in_);
}
#[test]
fn top_level_flag_preserved() {
let in_ = vec!["shuflr".into(), "--version".into()];
let out = rewrite_implicit_stream(in_.clone());
assert_eq!(out, in_);
}
#[test]
fn implicit_stream_with_leading_flag() {
let in_ = vec![
"shuflr".into(),
"--shuffle".into(),
"none".into(),
"data.jsonl".into(),
];
let out = rewrite_implicit_stream(in_);
assert_eq!(
out,
vec!["shuflr", "stream", "--shuffle", "none", "data.jsonl"]
.into_iter()
.map(std::ffi::OsString::from)
.collect::<Vec<_>>()
);
}
#[test]
fn parse_bytes_variants() {
assert_eq!(parse_bytes("1024").unwrap(), 1024);
assert_eq!(parse_bytes("16KiB").unwrap(), 16 << 10);
assert_eq!(parse_bytes("2MiB").unwrap(), 2 << 20);
assert_eq!(parse_bytes("1GB").unwrap(), 1_000_000_000);
assert_eq!(
parse_bytes("1.5MiB").unwrap(),
(1.5 * (1 << 20) as f64) as u64
);
assert!(parse_bytes("2zz").is_err());
}
}