use std::ffi::OsString;
use std::io::{self, stderr, stdout, Write};
use std::path::PathBuf;
use std::process::{Command, exit, Stdio};
use std::time::Duration;
use anyhow::{Context, Result};
use clap::error::{ContextKind, ContextValue, ErrorKind};
use clap::Parser;
use bkt::{CommandDesc, Bkt};
struct DisregardBrokenPipe(Box<dyn Write+Send>);
impl Write for DisregardBrokenPipe {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.0.write(buf) {
Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => Ok(0),
r => r,
}
}
fn write_all(&mut self, mut buf: &[u8]) -> io::Result<()> {
while !buf.is_empty() {
match self.write(buf) {
Ok(0) => return Ok(()),
Ok(n) => buf = &buf[n..],
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {},
Err(e) => return Err(e),
}
}
Ok(())
}
fn flush(&mut self) -> io::Result<()> {
match self.0.flush() {
Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => Ok(()),
r => r,
}
}
}
fn force_update_async() -> Result<()> {
let mut args = std::env::args_os();
let arg0 = args.next().expect("Must always be a 0th argument");
let mut command = match std::env::current_exe() {
Ok(path) => Command::new(path),
Err(_) => Command::new(arg0),
};
command.arg("--force").args(args.filter(|a| a != "--warm"))
.stdout(Stdio::null()).stderr(Stdio::null())
.spawn().context("Failed to start background process")?;
Ok(())
}
fn run(cli: Cli) -> Result<i32> {
let ttl: Duration = cli.ttl.into();
let stale: Option<Duration> = cli.stale.map(Into::into);
assert!(!ttl.is_zero(), "--ttl cannot be zero");
if let Some(stale) = stale {
assert!(!stale.is_zero(), "--stale cannot be zero");
assert!(stale < ttl, "--stale must be less than --ttl");
}
let mut bkt = match cli.cache_dir {
Some(cache_dir) => Bkt::create(cache_dir)?,
None => Bkt::in_tmp()?,
};
if let Some(scope) = cli.scope {
bkt = bkt.scoped(scope);
}
let mut command = CommandDesc::new(cli.command);
if cli.cwd {
command = command.with_cwd();
}
let envs = cli.env;
if !envs.is_empty() {
command = command.with_envs(&envs);
}
let files = cli.modtime;
if !files.is_empty() {
command = command.with_modtimes(&files);
}
if cli.discard_failures {
command = command.with_discard_failures(true);
}
if cli.warm && !cli.force {
force_update_async()?;
return Ok(0);
}
let invocation = if cli.force {
bkt.refresh_streaming(&command, ttl, DisregardBrokenPipe(
Box::new(stdout())), DisregardBrokenPipe(Box::new(stderr())))?.0
} else {
let (invocation, status) = bkt.retrieve_streaming(
&command, ttl, DisregardBrokenPipe(Box::new(stdout())), DisregardBrokenPipe(Box::new(stderr())))?;
if let Some(stale) = stale {
if let bkt::CacheStatus::Hit(cached_at) = status {
if cached_at.elapsed().unwrap_or(Duration::MAX) > stale {
force_update_async()?;
}
}
}
invocation
};
Ok(invocation.exit_code())
}
#[derive(Debug, Parser)]
#[command(about, version)]
struct Cli {
#[arg(required = true, last = true)]
command: Vec<OsString>,
#[arg(long, value_name = "DURATION", visible_alias = "time-to-live", env = "BKT_TTL")]
ttl: humantime::Duration,
#[arg(long, value_name = "DURATION", conflicts_with = "warm")]
stale: Option<humantime::Duration>,
#[arg(long)]
warm: bool,
#[arg(long, conflicts_with = "warm")]
force: bool,
#[arg(long, visible_alias = "use-working-dir")]
cwd: bool,
#[arg(long, value_name = "NAME", visible_alias = "use-environment")]
env: Vec<OsString>,
#[arg(long, value_name = "FILE", visible_alias = "use-file-modtime")]
modtime: Vec<OsString>,
#[arg(long)]
discard_failures: bool,
#[arg(long, value_name = "NAME", env = "BKT_SCOPE")]
scope: Option<String>,
#[arg(long, value_name = "DIR", env = "BKT_CACHE_DIR")]
cache_dir: Option<PathBuf>,
}
fn main() {
let mut cli = Cli::try_parse();
if let Err(err) = cli.as_mut() {
if matches!(err.kind(), ErrorKind::MissingRequiredArgument) {
err.insert(ContextKind::Suggested, ContextValue::StyledStrs(vec![[
"Prior to 0.8.0 --ttl was optional, and defaulted to 60 seconds.",
"To preserve this behavior pass `--ttl=1m` or set `BKT_TTL=1m` in your environment."
].join(" ").into()]));
}
err.exit();
}
let cli = cli.expect("Not Err");
match run(cli) {
Ok(code) => exit(code),
Err(msg) => {
eprintln!("bkt: {:#}", msg);
exit(127);
}
}
}