use std::{
io::{self, BufRead, IsTerminal, Read, Write},
time::Duration,
};
use eyre::{ContextCompat, Result, WrapErr};
use hpx::ws::message::Message;
use crate::{
cli::Cli,
output::{format_json_pretty, is_terminal, write_body},
};
const BACKOFF_INITIAL: Duration = Duration::from_secs(1);
const BACKOFF_MAX: Duration = Duration::from_secs(30);
pub(crate) fn backoff_duration(attempt: u32) -> Duration {
let delay_ms = BACKOFF_INITIAL.as_millis() as u64 * (1u64 << attempt.min(30));
Duration::from_millis(delay_ms.min(BACKOFF_MAX.as_millis() as u64))
}
pub(crate) fn is_reconnectable_error(err: &eyre::Report) -> bool {
let msg = format!("{err:#}");
!msg.contains("EOF") && !msg.contains("Connection closed") && !msg.contains("closed by peer")
}
pub(crate) async fn execute(cli: &Cli, url: &str) -> Result<()> {
let max_retries = if cli.reconnect { cli.reconnect_max } else { 0 };
let mut attempt = 0;
loop {
match try_execute(cli, url).await {
Ok(()) => return Ok(()),
Err(e) => {
if attempt >= max_retries {
return Err(e);
}
let is_reconnectable = is_reconnectable_error(&e);
if !is_reconnectable {
return Err(e);
}
let delay = backoff_duration(attempt);
if !cli.silent {
eprintln!(
"Connection lost. Reconnecting in {delay:?}... (attempt {}/{max_retries})",
attempt + 1
);
}
tokio::time::sleep(delay).await;
attempt += 1;
}
}
}
}
async fn try_execute(cli: &Cli, url: &str) -> Result<()> {
let mut builder = hpx::websocket(url);
for (name, value) in &cli.parsed_headers() {
builder = builder.header(name.as_str(), value.as_str());
}
if let Some(ref user) = cli.bearer {
builder = builder.bearer_auth(user);
} else if let Some(ref basic) = cli.basic {
let (user, pass) = basic
.split_once(':')
.wrap_err("basic auth must be in USER:PASS format")?;
builder = builder.basic_auth(user, Some(pass));
}
let resp = builder
.send()
.await
.wrap_err("WebSocket handshake failed")?;
let mut ws = resp
.into_websocket()
.await
.wrap_err("Failed to upgrade to WebSocket")?;
if !cli.silent {
eprintln!("Connected to {url}");
}
if let Some(ref data) = cli.data {
let msg = load_data_payload(data)?;
ws.send(msg)
.await
.wrap_err("Failed to send initial message")?;
} else if let Some(ref json_data) = cli.json {
let msg = load_json_payload(json_data)?;
ws.send(msg)
.await
.wrap_err("Failed to send initial message")?;
}
let stdin = io::stdin();
let is_interactive = stdin.is_terminal() && cli.data.is_none() && cli.json.is_none();
if is_interactive {
run_interactive(&mut ws, cli).await?;
} else {
run_stdin_pump(&mut ws, cli).await?;
}
ws.close(hpx::ws::message::CloseCode::NORMAL, "")
.await
.wrap_err("Failed to close WebSocket")?;
if !cli.silent {
eprintln!("Connection closed");
}
Ok(())
}
async fn run_interactive(ws: &mut hpx::ws::WebSocket, cli: &Cli) -> Result<()> {
if !cli.silent {
eprintln!("Interactive mode. Type messages and press Enter. Ctrl+D to exit.");
}
let stdin = io::stdin();
let mut reader = stdin.lock();
loop {
if !cli.silent {
eprint!("> ");
io::stderr().flush()?;
}
let mut line = String::new();
match reader.read_line(&mut line) {
Ok(0) => break, Ok(_) => {
let trimmed = line.trim_end_matches('\n').trim_end_matches('\r');
if trimmed.is_empty() {
continue;
}
let msg = Message::text(trimmed.to_string());
if let Err(e) = ws.send(msg).await {
eprintln!("Send error: {e}");
break;
}
}
Err(e) => {
eprintln!("Read error: {e}");
break;
}
}
drain_pending(ws, cli).await;
}
Ok(())
}
async fn run_stdin_pump(ws: &mut hpx::ws::WebSocket, cli: &Cli) -> Result<()> {
let mut buf = Vec::new();
io::stdin()
.read_to_end(&mut buf)
.wrap_err("Failed to read stdin")?;
if !buf.is_empty() {
let msg = Message::binary(buf);
ws.send(msg).await.wrap_err("Failed to send stdin data")?;
}
receive_all(ws, cli).await
}
async fn drain_pending(ws: &mut hpx::ws::WebSocket, cli: &Cli) {
loop {
let msg = tokio::time::timeout(std::time::Duration::from_millis(10), ws.recv()).await;
match msg {
Ok(Some(Ok(message))) => print_message(message, cli),
Ok(Some(Err(e))) => {
eprintln!("Receive error: {e}");
break;
}
Ok(None) => {
eprintln!("Connection closed by server");
break;
}
Err(_) => break, }
}
}
async fn receive_all(ws: &mut hpx::ws::WebSocket, cli: &Cli) -> Result<()> {
loop {
match ws.recv().await {
Some(Ok(message)) => print_message(message, cli),
Some(Err(e)) => {
eprintln!("Receive error: {e}");
break;
}
None => break,
}
}
Ok(())
}
fn print_message(message: Message, cli: &Cli) {
match message {
Message::Text(text) => {
let s = text.as_str();
if cli.format == crate::cli::OutputFormat::Auto
&& is_json(s)
&& let Ok(pretty) = format_json_pretty(s.as_bytes())
{
println!("{pretty}");
return;
}
println!("{s}");
}
Message::Binary(data) => {
if let Some(ref path) = cli.output {
if let Err(e) = write_body(&data, Some(path)) {
eprintln!("Failed to write binary to file: {e}");
} else if !cli.silent {
eprintln!("Wrote {} bytes to {path}", data.len());
}
} else if is_terminal() {
eprintln!(
"[binary message, {} bytes. Use --output to save to file.]",
data.len()
);
} else {
let _ = io::stdout().write_all(&data);
}
}
Message::Ping(data) => {
if cli.verbose > 0 {
eprintln!("[ping, {} bytes]", data.len());
}
}
Message::Pong(data) => {
if cli.verbose > 0 {
eprintln!("[pong, {} bytes]", data.len());
}
}
Message::Close(close) => {
if let Some(frame) = close {
eprintln!(
"[close, code={:?}, reason={}]",
frame.code,
frame.reason.as_str()
);
} else {
eprintln!("[close]");
}
}
}
}
fn is_json(s: &str) -> bool {
let trimmed = s.trim_start();
trimmed.starts_with('{') || trimmed.starts_with('[')
}
fn load_data_payload(data: &str) -> Result<Message> {
if let Some(path) = data.strip_prefix('@') {
let bytes =
std::fs::read(path).wrap_err_with(|| format!("Failed to read data file: {path}"))?;
Ok(Message::binary(bytes))
} else {
Ok(Message::text(data.to_string()))
}
}
fn load_json_payload(json: &str) -> Result<Message> {
if let Some(path) = json.strip_prefix('@') {
let bytes =
std::fs::read(path).wrap_err_with(|| format!("Failed to read JSON file: {path}"))?;
let _value: serde_json::Value =
serde_json::from_slice(&bytes).wrap_err("Invalid JSON in file")?;
Ok(Message::binary(bytes))
} else {
let _value: serde_json::Value =
serde_json::from_str(json).wrap_err("Invalid JSON string")?;
Ok(Message::text(json.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backoff_first_attempt() {
assert_eq!(backoff_duration(0), Duration::from_secs(1));
}
#[test]
fn backoff_second_attempt() {
assert_eq!(backoff_duration(1), Duration::from_secs(2));
}
#[test]
fn backoff_third_attempt() {
assert_eq!(backoff_duration(2), Duration::from_secs(4));
}
#[test]
fn backoff_caps_at_max() {
assert_eq!(backoff_duration(30), BACKOFF_MAX);
assert_eq!(backoff_duration(100), BACKOFF_MAX);
}
#[test]
fn backoff_grows_exponentially() {
let d0 = backoff_duration(0);
let d1 = backoff_duration(1);
let d2 = backoff_duration(2);
assert!(d1 > d0);
assert!(d2 > d1);
assert_eq!(d1.as_secs(), d0.as_secs() * 2);
assert_eq!(d2.as_secs(), d1.as_secs() * 2);
}
#[test]
fn reconnectable_error_network() {
let err = eyre::eyre!("connection reset by peer");
assert!(is_reconnectable_error(&err));
}
#[test]
fn reconnectable_error_eof() {
let err = eyre::eyre!("unexpected EOF");
assert!(!is_reconnectable_error(&err));
}
#[test]
fn reconnectable_error_clean_close() {
let err = eyre::eyre!("Connection closed");
assert!(!is_reconnectable_error(&err));
}
#[test]
fn reconnectable_error_closed_by_peer() {
let err = eyre::eyre!("closed by peer");
assert!(!is_reconnectable_error(&err));
}
#[test]
fn reconnectable_error_timeout() {
let err = eyre::eyre!("timed out");
assert!(is_reconnectable_error(&err));
}
}