use clap::Parser;
use std::io::Write;
use std::time::{Duration, Instant};
#[derive(Parser, Debug)]
#[command(
name = "whirl",
about = "whirl — a tiny curl-like HTTP client powered by wrest"
)]
struct Args {
url: String,
#[arg(short = 'X', long = "request", default_value = "GET")]
method: String,
#[arg(short = 'H', long = "header")]
headers: Vec<String>,
#[arg(short = 'd', long = "data")]
data: Option<String>,
#[arg(long = "data-binary")]
data_binary: Option<String>,
#[arg(short = 'o', long = "output")]
output: Option<String>,
#[arg(short = 'L', long = "location")]
follow_redirects: bool,
#[arg(short = 'v', long = "verbose")]
verbose: bool,
#[arg(short = 's', long = "silent")]
silent: bool,
#[arg(short = 'I', long = "head")]
head_only: bool,
#[arg(short = 'm', long = "max-time")]
max_time: Option<u64>,
#[arg(short = 'A', long = "user-agent")]
user_agent: Option<String>,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
if args.follow_redirects {
eprintln!("whirl: -L is a no-op (wrest follows redirects by default)");
}
let method_str = if args.head_only { "HEAD" } else { &args.method };
let method: http::Method = method_str
.parse()
.map_err(|_| format!("Unknown HTTP method: {method_str}"))?;
let mut builder = wrest::Client::builder();
if let Some(secs) = args.max_time {
builder = builder.timeout(Duration::from_secs(secs));
}
if let Some(ref ua) = args.user_agent {
builder = builder.user_agent(ua);
}
let client = builder.build()?;
let mut req = client.request(method, &args.url);
for h in &args.headers {
let (key, value) = h
.split_once(':')
.ok_or_else(|| format!("Invalid header (expected 'Key: Value'): {h}"))?;
req = req.header(key.trim(), value.trim());
}
if let Some(ref data) = args.data {
req = req.body(data.clone());
} else if let Some(ref path) = args.data_binary {
let path = path.strip_prefix('@').unwrap_or(path);
let body = std::fs::read(path)?;
req = req.body(body);
}
let mut resp = req.send().await?;
let status = resp.status();
let headers = resp.headers().clone();
let version = resp.version();
if args.head_only {
println!("{version:?} {status}");
for (name, value) in &headers {
println!("{name}: {}", value.to_str().unwrap_or("<binary>"));
}
return Ok(());
}
if args.verbose {
let stderr = std::io::stderr();
let mut err = stderr.lock();
writeln!(err, "{version:?} {status}")?;
for (name, value) in &headers {
writeln!(err, "{name}: {}", value.to_str().unwrap_or("<binary>"))?;
}
writeln!(err)?;
err.flush()?;
}
if let Some(ref path) = args.output {
let mut file = std::fs::File::create(path)?;
let total = resp.content_length();
let mut downloaded: u64 = 0;
let start = Instant::now();
let show_progress = !args.silent;
while let Some(chunk) = resp.chunk().await? {
file.write_all(&chunk)?;
downloaded += chunk.len() as u64;
if show_progress {
print_progress(downloaded, total, start.elapsed());
}
}
file.flush()?;
if show_progress {
let elapsed = start.elapsed();
eprintln!(
"\r{:>70}\r Downloaded {downloaded} bytes to {path} in {elapsed:.1?}",
"" );
}
} else {
let body = resp.bytes().await?;
let text = String::from_utf8_lossy(&body);
print!("{text}");
}
if !status.is_success() && !status.is_informational() && !status.is_redirection() {
if !args.silent {
eprintln!("Request failed with status: {status}");
}
std::process::exit(22); }
Ok(())
}
fn print_progress(downloaded: u64, total: Option<u64>, elapsed: Duration) {
let secs = elapsed.as_secs_f64();
let speed = if secs > 0.0 {
downloaded as f64 / secs
} else {
0.0
};
let speed_str = format_bytes(speed);
let line = if let Some(total) = total {
let pct = if total > 0 {
(downloaded as f64 / total as f64) * 100.0
} else {
100.0
};
format!(
" {:>5.1}% {} / {} {}/s",
pct,
format_bytes(downloaded as f64),
format_bytes(total as f64),
speed_str,
)
} else {
format!(" {} {}/s", format_bytes(downloaded as f64), speed_str,)
};
eprint!("\r{line:<70}");
}
fn format_bytes(bytes: f64) -> String {
const KIB: f64 = 1024.0;
const MIB: f64 = 1024.0 * 1024.0;
const GIB: f64 = 1024.0 * 1024.0 * 1024.0;
if bytes >= GIB {
format!("{:.1} GiB", bytes / GIB)
} else if bytes >= MIB {
format!("{:.1} MiB", bytes / MIB)
} else if bytes >= KIB {
format!("{:.1} KiB", bytes / KIB)
} else {
format!("{bytes:.0} B")
}
}