pub mod buffer;
pub mod error;
pub use error::Error;
use std::path::PathBuf;
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum Target {
Stdout,
File(PathBuf),
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CompatibilityMode {
#[default]
Default,
Strict,
}
#[non_exhaustive]
#[derive(Debug)]
pub struct Sponge {
target: Target,
append: bool,
spill_threshold: usize,
#[allow(dead_code)]
compat: CompatibilityMode,
}
pub const DEFAULT_SPILL_THRESHOLD: usize = 128 * 1024 * 1024;
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct SpongeBuilder {
target: Target,
append: bool,
spill_threshold: usize,
compat: CompatibilityMode,
}
impl Default for SpongeBuilder {
fn default() -> Self {
Self::new()
}
}
impl SpongeBuilder {
#[must_use]
pub fn new() -> Self {
Self {
target: Target::Stdout,
append: false,
spill_threshold: DEFAULT_SPILL_THRESHOLD,
compat: CompatibilityMode::Default,
}
}
#[must_use]
pub fn target(mut self, target: Target) -> Self {
self.target = target;
self
}
#[must_use]
pub fn append(mut self, append: bool) -> Self {
self.append = append;
self
}
#[must_use]
pub fn spill_threshold(mut self, bytes: usize) -> Self {
self.spill_threshold = bytes;
self
}
#[must_use]
pub fn compat(mut self, compat: CompatibilityMode) -> Self {
self.compat = compat;
self
}
pub fn build(self) -> Result<Sponge, Error> {
if self.append && matches!(self.target, Target::Stdout) {
return Err(Error::InvalidBuilderConfiguration(
"append requires a file target",
));
}
if self.compat == CompatibilityMode::Strict
&& self.spill_threshold != DEFAULT_SPILL_THRESHOLD
{
return Err(Error::CompatibilityViolation(
"explicit spill threshold not honored in Strict mode",
));
}
Ok(Sponge {
target: self.target,
append: self.append,
spill_threshold: self.spill_threshold,
compat: self.compat,
})
}
}
impl Sponge {
pub fn run<R: std::io::Read>(&mut self, reader: R) -> Result<(), Error> {
match &self.target {
Target::Stdout => {
let mut buf = buffer::Buffer::new();
let spill_dir = std::env::temp_dir();
buf.drain_reader(reader, self.spill_threshold, &spill_dir)?;
let stdout = std::io::stdout();
let mut locked = stdout.lock();
buf.write_to(&mut locked)?;
Ok(())
}
Target::File(path) => {
validate_target_path(path)?;
let spill_dir = path
.parent()
.filter(|p| !p.as_os_str().is_empty())
.map(std::path::PathBuf::from)
.unwrap_or_else(|| std::path::PathBuf::from("."));
let mut buf = buffer::Buffer::new();
buf.drain_reader(reader, self.spill_threshold, &spill_dir)?;
if writethrough::requires_write_through(path) {
writethrough::write_through(buf, path, self.append)?;
} else {
atomic::write_atomic(buf, path, self.append)?;
}
Ok(())
}
}
}
}
pub mod atomic;
pub mod writethrough;
fn validate_target_path(target: &std::path::Path) -> Result<(), Error> {
if let Ok(meta) = std::fs::symlink_metadata(target) {
if meta.is_dir() {
return Err(Error::TargetIsDirectory(target.to_path_buf()));
}
}
Ok(())
}
#[cfg(feature = "cli")]
pub mod cli;
#[cfg(feature = "cli")]
pub mod mode;
#[cfg(feature = "cli")]
pub mod signal;
#[cfg(feature = "cli")]
pub mod strict;
#[cfg(feature = "cli")]
pub fn run() -> std::process::ExitCode {
use clap::Parser;
use std::process::ExitCode;
if let Err(e) = signal::install_handlers() {
eprintln!("warning: could not install signal handlers: {e}");
}
let raw_argv: Vec<std::ffi::OsString> = std::env::args_os().collect();
let pre_strict = strict::pre_scan_strict_flag(&raw_argv);
let env_strict = std::env::var_os("RUSTY_SPONGE_STRICT");
let argv0 = raw_argv.first().cloned();
let early_mode = mode::resolve(pre_strict, env_strict.as_deref(), argv0.as_deref());
if early_mode == CompatibilityMode::Strict {
return strict::run(&raw_argv);
}
let cli_args = match cli::Cli::try_parse() {
Ok(args) => args,
Err(e) => {
e.print().ok();
return match e.kind() {
clap::error::ErrorKind::DisplayHelp | clap::error::ErrorKind::DisplayVersion => {
ExitCode::SUCCESS
}
_ => ExitCode::from(2),
};
}
};
if let Some(cli::Subcommand::Completions { shell }) = cli_args.command {
use clap::CommandFactory;
let mut cmd = cli::Cli::command();
let name = cmd.get_name().to_string();
clap_complete::generate(shell, &mut cmd, name, &mut std::io::stdout());
return ExitCode::SUCCESS;
}
let argv0 = std::env::args_os().next();
let env_strict = std::env::var_os("RUSTY_SPONGE_STRICT");
let compat = mode::resolve(
cli::strict_flag(&cli_args),
env_strict.as_deref(),
argv0.as_deref(),
);
let spill_threshold = resolve_spill_threshold(&cli_args, compat);
let target = match cli_args.target {
Some(path) => Target::File(path),
None => Target::Stdout,
};
let result = SpongeBuilder::new()
.target(target)
.append(cli_args.append)
.spill_threshold(spill_threshold)
.compat(compat)
.build();
let mut sponge = match result {
Ok(s) => s,
Err(e) => {
eprintln!("rusty-sponge: {e}");
return ExitCode::from(1);
}
};
let stdin = std::io::stdin();
let locked = stdin.lock();
match sponge.run(locked) {
Ok(()) => ExitCode::SUCCESS,
Err(Error::Io(io_err)) if io_err.kind() == std::io::ErrorKind::Interrupted => {
eprintln!("rusty-sponge: cancelled");
ExitCode::from(130)
}
Err(e) => {
eprintln!("rusty-sponge: {e}");
ExitCode::from(1)
}
}
}
#[cfg(feature = "cli")]
fn resolve_spill_threshold(cli_args: &cli::Cli, compat: CompatibilityMode) -> usize {
if compat == CompatibilityMode::Strict {
return DEFAULT_SPILL_THRESHOLD;
}
let Some(raw) = cli_args.spill_mb.as_deref() else {
return DEFAULT_SPILL_THRESHOLD;
};
match raw.trim().parse::<usize>() {
Ok(0) | Err(_) => {
eprintln!(
"warning: invalid RUSTY_SPONGE_SPILL_MB value '{raw}'; using default 128 MiB"
);
DEFAULT_SPILL_THRESHOLD
}
Ok(mb) => mb.saturating_mul(1024 * 1024),
}
}