#![allow(clippy::large_futures)]
#![allow(clippy::struct_excessive_bools)]
use anyhow::{Context, Result};
use clap::Args;
use dialoguer::{Input, Select};
use mqtt5::client::auth_handlers::{JwtAuthHandler, ScramSha256AuthHandler};
use mqtt5::time::Duration;
#[cfg(feature = "codec")]
use mqtt5::{CodecRegistry, DeflateCodec, GzipCodec};
use mqtt5::{ConnectOptions, MqttClient, ProtocolVersion, QoS, WillMessage};
use std::path::PathBuf;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use tokio::signal;
use tokio::sync::Notify;
use tracing::{debug, info};
use super::parsers::{duration_secs_to_u32, parse_duration_secs, parse_stream_strategy};
#[derive(Args)]
pub struct SubCommand {
#[arg(long, short, env = "MQTT5_TOPIC")]
pub topic: Option<String>,
#[arg(long, short = 'U', conflicts_with_all = &["host", "port"], env = "MQTT5_URL")]
pub url: Option<String>,
#[arg(long, short = 'H', default_value = "localhost", env = "MQTT5_HOST")]
pub host: String,
#[arg(long, short, default_value = "1883", env = "MQTT5_PORT")]
pub port: u16,
#[arg(long, short, value_parser = parse_qos, env = "MQTT5_QOS")]
pub qos: Option<QoS>,
#[arg(long, short, env = "MQTT5_USERNAME")]
pub username: Option<String>,
#[arg(long, short = 'P', env = "MQTT5_PASSWORD")]
pub password: Option<String>,
#[arg(long, value_parser = ["password", "scram", "jwt"], env = "MQTT5_AUTH_METHOD")]
pub auth_method: Option<String>,
#[arg(long, env = "MQTT5_JWT_TOKEN")]
pub jwt_token: Option<String>,
#[arg(long, short, env = "MQTT5_CLIENT_ID")]
pub client_id: Option<String>,
#[arg(long, short)]
pub verbose: bool,
#[arg(long, env = "MQTT5_NON_INTERACTIVE")]
pub non_interactive: bool,
#[arg(long, short = 'n', default_value = "0", env = "MQTT5_COUNT")]
pub count: u32,
#[arg(long = "no-clean-start", env = "MQTT5_NO_CLEAN_START")]
pub no_clean_start: bool,
#[arg(long, value_parser = parse_duration_secs, env = "MQTT5_SESSION_EXPIRY")]
pub session_expiry: Option<u64>,
#[arg(long, short = 'k', default_value = "60", value_parser = parse_duration_secs, env = "MQTT5_KEEP_ALIVE")]
pub keep_alive: u64,
#[arg(long, value_parser = parse_protocol_version, env = "MQTT5_PROTOCOL_VERSION")]
pub protocol_version: Option<ProtocolVersion>,
#[arg(long, env = "MQTT5_WILL_TOPIC")]
pub will_topic: Option<String>,
#[arg(long, env = "MQTT5_WILL_MESSAGE")]
pub will_message: Option<String>,
#[arg(long, value_parser = parse_qos, env = "MQTT5_WILL_QOS")]
pub will_qos: Option<QoS>,
#[arg(long, env = "MQTT5_WILL_RETAIN")]
pub will_retain: bool,
#[arg(long, value_parser = parse_duration_secs, env = "MQTT5_WILL_DELAY")]
pub will_delay: Option<u64>,
#[arg(long, env = "MQTT5_CERT")]
pub cert: Option<PathBuf>,
#[arg(long, env = "MQTT5_KEY")]
pub key: Option<PathBuf>,
#[arg(long, env = "MQTT5_CA_CERT")]
pub ca_cert: Option<PathBuf>,
#[arg(long, env = "MQTT5_INSECURE")]
pub insecure: bool,
#[arg(long, env = "MQTT5_AUTO_RECONNECT")]
pub auto_reconnect: bool,
#[arg(long, env = "MQTT5_NO_LOCAL")]
pub no_local: bool,
#[arg(long, env = "MQTT5_SUBSCRIPTION_IDENTIFIER")]
pub subscription_identifier: Option<u32>,
#[arg(long, value_parser = parse_retain_handling, env = "MQTT5_RETAIN_HANDLING")]
pub retain_handling: Option<mqtt5::RetainHandling>,
#[arg(long, env = "MQTT5_RETAIN_AS_PUBLISHED")]
pub retain_as_published: bool,
#[arg(long, value_parser = parse_stream_strategy, env = "MQTT5_QUIC_STREAM_STRATEGY")]
pub quic_stream_strategy: Option<mqtt5::transport::StreamStrategy>,
#[arg(long, env = "MQTT5_QUIC_FLOW_HEADERS")]
pub quic_flow_headers: bool,
#[arg(long, default_value = "300", value_parser = parse_duration_secs, env = "MQTT5_QUIC_FLOW_EXPIRE")]
pub quic_flow_expire: u64,
#[arg(long, env = "MQTT5_QUIC_MAX_STREAMS")]
pub quic_max_streams: Option<usize>,
#[arg(long, env = "MQTT5_QUIC_DATAGRAMS")]
pub quic_datagrams: bool,
#[arg(long, default_value = "30", value_parser = parse_duration_secs, env = "MQTT5_QUIC_CONNECT_TIMEOUT")]
pub quic_connect_timeout: u64,
#[arg(long, env = "MQTT5_QUIC_EARLY_DATA")]
pub quic_early_data: bool,
#[cfg(feature = "opentelemetry")]
#[arg(long, env = "MQTT5_OTEL_ENDPOINT")]
pub otel_endpoint: Option<String>,
#[cfg(feature = "opentelemetry")]
#[arg(long, default_value = "mqttv5-sub", env = "MQTT5_OTEL_SERVICE_NAME")]
pub otel_service_name: String,
#[cfg(feature = "opentelemetry")]
#[arg(long, default_value = "1.0", env = "MQTT5_OTEL_SAMPLING")]
pub otel_sampling: f64,
#[cfg(feature = "codec")]
#[arg(long, value_parser = ["gzip", "deflate", "all"], env = "MQTT5_CODEC")]
pub codec: Option<String>,
}
fn parse_qos(s: &str) -> Result<QoS, String> {
match s {
"0" => Ok(QoS::AtMostOnce),
"1" => Ok(QoS::AtLeastOnce),
"2" => Ok(QoS::ExactlyOnce),
_ => Err(format!("QoS must be 0, 1, or 2, got: {s}")),
}
}
fn parse_protocol_version(s: &str) -> Result<ProtocolVersion, String> {
match s {
"3.1.1" | "311" | "4" => Ok(ProtocolVersion::V311),
"5" | "5.0" => Ok(ProtocolVersion::V5),
_ => Err(format!("Invalid protocol version: {s}. Use '3.1.1' or '5'")),
}
}
fn parse_retain_handling(s: &str) -> Result<mqtt5::RetainHandling, String> {
match s {
"0" => Ok(mqtt5::RetainHandling::SendAtSubscribe),
"1" => Ok(mqtt5::RetainHandling::SendIfNew),
"2" => Ok(mqtt5::RetainHandling::DontSend),
_ => Err(format!("retain_handling must be 0, 1, or 2, got: {s}")),
}
}
fn prompt_topic_and_qos(cmd: &mut SubCommand) -> Result<(String, QoS)> {
if cmd.topic.is_none() && !cmd.non_interactive {
let topic = Input::<String>::new()
.with_prompt("MQTT topic to subscribe to (e.g., sensors/+, home/#)")
.interact()
.context("Failed to get topic input")?;
cmd.topic = Some(topic);
}
let topic = cmd.topic.take().ok_or_else(|| {
anyhow::anyhow!("Topic is required. Use --topic or run without --non-interactive")
})?;
validate_topic_filter(&topic)?;
let qos = if cmd.qos.is_none() && !cmd.non_interactive {
let qos_options = vec!["0 (At most once)", "1 (At least once)", "2 (Exactly once)"];
let selection = Select::new()
.with_prompt("Quality of Service level")
.items(&qos_options)
.default(0)
.interact()
.context("Failed to get QoS selection")?;
match selection {
1 => QoS::AtLeastOnce,
2 => QoS::ExactlyOnce,
_ => QoS::AtMostOnce,
}
} else {
cmd.qos.unwrap_or(QoS::AtMostOnce)
};
Ok((topic, qos))
}
fn build_connect_options(cmd: &SubCommand, client_id: &str) -> ConnectOptions {
let mut options = ConnectOptions::new(client_id.to_owned())
.with_clean_start(!cmd.no_clean_start)
.with_keep_alive(Duration::from_secs(cmd.keep_alive));
if cmd.auto_reconnect {
options = options.with_automatic_reconnect(true);
}
if let Some(version) = cmd.protocol_version {
options = options.with_protocol_version(version);
}
if let Some(expiry) = cmd.session_expiry {
options = options.with_session_expiry_interval(duration_secs_to_u32(expiry));
}
options
}
async fn configure_auth(
client: &MqttClient,
options: &mut ConnectOptions,
cmd: &SubCommand,
) -> Result<()> {
match cmd.auth_method.as_deref() {
Some("scram") => {
let username = cmd.username.clone().ok_or_else(|| {
anyhow::anyhow!("--username is required for SCRAM authentication")
})?;
let password = cmd.password.clone().ok_or_else(|| {
anyhow::anyhow!("--password is required for SCRAM authentication")
})?;
*options = std::mem::take(options)
.with_credentials(username.clone(), Vec::new())
.with_authentication_method("SCRAM-SHA-256");
let handler = ScramSha256AuthHandler::new(username, password);
client.set_auth_handler(handler).await;
debug!("SCRAM-SHA-256 authentication configured");
}
Some("jwt") => {
let token = cmd
.jwt_token
.clone()
.ok_or_else(|| anyhow::anyhow!("--jwt-token is required for JWT authentication"))?;
*options = std::mem::take(options).with_authentication_method("JWT");
let handler = JwtAuthHandler::new(token);
client.set_auth_handler(handler).await;
debug!("JWT authentication configured");
}
Some("password") | None => {
if let (Some(username), Some(password)) = (cmd.username.clone(), cmd.password.clone()) {
*options =
std::mem::take(options).with_credentials(username, password.into_bytes());
} else if let Some(username) = cmd.username.clone() {
*options = std::mem::take(options).with_credentials(username, Vec::new());
}
}
Some(other) => anyhow::bail!("Unknown auth method: {other}"),
}
Ok(())
}
fn configure_will(options: &mut ConnectOptions, cmd: &SubCommand) {
if let Some(topic) = cmd.will_topic.clone() {
let payload = cmd.will_message.clone().unwrap_or_default();
let mut will = WillMessage::new(topic, payload.into_bytes()).with_retain(cmd.will_retain);
if let Some(qos) = cmd.will_qos {
will = will.with_qos(qos);
}
if let Some(delay) = cmd.will_delay {
will.properties.will_delay_interval = Some(duration_secs_to_u32(delay));
}
*options = std::mem::take(options).with_will(will);
}
}
async fn configure_quic_transport(client: &MqttClient, cmd: &SubCommand) {
if let Some(strategy) = cmd.quic_stream_strategy {
client.set_quic_stream_strategy(strategy).await;
debug!("QUIC stream strategy: {:?}", strategy);
}
if cmd.quic_flow_headers {
client.set_quic_flow_headers(true).await;
debug!("QUIC flow headers enabled");
}
client
.set_quic_flow_expire(std::time::Duration::from_secs(cmd.quic_flow_expire))
.await;
if let Some(max) = cmd.quic_max_streams {
client.set_quic_max_streams(Some(max)).await;
debug!("QUIC max streams: {max}");
}
if cmd.quic_datagrams {
client.set_quic_datagrams(true).await;
debug!("QUIC datagrams enabled");
}
client
.set_quic_connect_timeout(Duration::from_secs(cmd.quic_connect_timeout))
.await;
if cmd.quic_early_data {
client.set_quic_early_data(true).await;
debug!("QUIC 0-RTT early data enabled");
}
}
async fn configure_tls_certs(
client: &MqttClient,
broker_url: &str,
cmd: &SubCommand,
) -> Result<()> {
let is_secure = broker_url.starts_with("ssl://")
|| broker_url.starts_with("mqtts://")
|| broker_url.starts_with("quics://");
if !is_secure || (cmd.cert.is_none() && cmd.key.is_none() && cmd.ca_cert.is_none()) {
return Ok(());
}
let cert_pem =
if let Some(cert_path) = &cmd.cert {
Some(std::fs::read(cert_path).with_context(|| {
format!("Failed to read certificate file: {}", cert_path.display())
})?)
} else {
None
};
let key_pem = if let Some(key_path) = &cmd.key {
Some(
std::fs::read(key_path)
.with_context(|| format!("Failed to read key file: {}", key_path.display()))?,
)
} else {
None
};
let ca_pem = if let Some(ca_path) = &cmd.ca_cert {
Some(std::fs::read(ca_path).with_context(|| {
format!("Failed to read CA certificate file: {}", ca_path.display())
})?)
} else {
None
};
client.set_tls_config(cert_pem, key_pem, ca_pem).await;
Ok(())
}
#[cfg(feature = "codec")]
fn configure_codec(options: &mut ConnectOptions, cmd: &SubCommand) -> Result<()> {
if let Some(ref codec_name) = cmd.codec {
let registry = CodecRegistry::new();
match codec_name.as_str() {
"gzip" => {
registry.register(GzipCodec::new());
debug!("Gzip codec decoding enabled");
}
"deflate" => {
registry.register(DeflateCodec::new());
debug!("Deflate codec decoding enabled");
}
"all" => {
registry.register(GzipCodec::new());
registry.register(DeflateCodec::new());
debug!("All codec decoding enabled (gzip, deflate)");
}
_ => anyhow::bail!("Unknown codec: {codec_name}"),
}
*options = std::mem::take(options).with_codec_registry(Arc::new(registry));
}
Ok(())
}
async fn subscribe_and_print(
client: &MqttClient,
topic: &str,
cmd: &SubCommand,
qos: QoS,
) -> Result<Arc<Notify>> {
let target_count = cmd.count;
let verbose = cmd.verbose;
info!("Subscribing to '{}' (QoS {})...", topic, qos as u8);
let message_count_clone = Arc::new(AtomicU32::new(0));
let done_notify = Arc::new(Notify::new());
let done_notify_clone = done_notify.clone();
if let Some(sub_id) = cmd.subscription_identifier {
if sub_id == 0 || sub_id > 268_435_455 {
anyhow::bail!("Subscription identifier must be between 1 and 268435455, got: {sub_id}");
}
}
let subscribe_options = mqtt5::SubscribeOptions {
qos,
no_local: cmd.no_local,
retain_as_published: cmd.retain_as_published,
retain_handling: cmd
.retain_handling
.unwrap_or(mqtt5::RetainHandling::SendAtSubscribe),
subscription_identifier: cmd.subscription_identifier,
};
let (packet_id, granted_qos) = client
.subscribe_with_options(topic, subscribe_options.clone(), move |message| {
let count = message_count_clone.fetch_add(1, Ordering::Relaxed) + 1;
if verbose {
println!(
"{}: {}",
message.topic,
String::from_utf8_lossy(&message.payload)
);
} else {
println!("{}", String::from_utf8_lossy(&message.payload));
}
if target_count > 0 && count >= target_count {
println!("✓ Received {target_count} messages, exiting");
done_notify_clone.notify_one();
}
})
.await?;
debug!(
"Subscription confirmed - packet_id: {}, granted_qos: {:?}",
packet_id, granted_qos
);
println!("✓ Subscribed to '{topic}' (granted QoS {granted_qos:?}) - waiting for messages (Ctrl+C to exit)");
Ok(done_notify)
}
async fn wait_for_completion(
client: &MqttClient,
done_notify: &Notify,
auto_reconnect: bool,
) -> Result<()> {
if auto_reconnect {
tokio::select! {
() = done_notify.notified() => {}
() = async { let _ = signal::ctrl_c().await; } => {
println!("\n✓ Received Ctrl+C, disconnecting...");
}
}
} else {
let mut check_interval = tokio::time::interval(Duration::from_millis(500));
loop {
tokio::select! {
() = done_notify.notified() => {
break;
}
() = async { let _ = signal::ctrl_c().await; } => {
println!("\n✓ Received Ctrl+C, disconnecting...");
break;
}
_ = check_interval.tick() => {
if !client.is_connected().await {
println!("\n✗ Disconnected from broker");
return Err(anyhow::anyhow!("Connection lost"));
}
}
}
}
}
client.disconnect().await?;
Ok(())
}
pub async fn execute(mut cmd: SubCommand, verbose: bool, debug: bool) -> Result<()> {
#[cfg(feature = "opentelemetry")]
let has_otel = cmd.otel_endpoint.is_some();
#[cfg(not(feature = "opentelemetry"))]
let has_otel = false;
if !has_otel {
crate::init_basic_tracing(verbose, debug);
}
let (topic, qos) = prompt_topic_and_qos(&mut cmd)?;
let broker_url = cmd
.url
.take()
.unwrap_or_else(|| format!("mqtt://{}:{}", cmd.host, cmd.port));
debug!("Connecting to broker: {}", broker_url);
let client_id = cmd
.client_id
.take()
.unwrap_or_else(|| format!("mqttv5-sub-{}", rand::rng().random::<u32>()));
let client = MqttClient::new(&client_id);
#[cfg(feature = "opentelemetry")]
if let Some(endpoint) = &cmd.otel_endpoint {
use mqtt5::telemetry::TelemetryConfig;
let telemetry_config = TelemetryConfig::new(&cmd.otel_service_name)
.with_endpoint(endpoint)
.with_sampling_ratio(cmd.otel_sampling);
mqtt5::telemetry::init_tracing_subscriber(&telemetry_config)?;
info!(
"OpenTelemetry enabled: endpoint={}, service={}, sampling={}",
endpoint, cmd.otel_service_name, cmd.otel_sampling
);
}
let mut options = build_connect_options(&cmd, &client_id);
configure_auth(&client, &mut options, &cmd).await?;
configure_will(&mut options, &cmd);
#[cfg(feature = "codec")]
configure_codec(&mut options, &cmd)?;
if cmd.insecure {
client.set_insecure_tls(true).await;
info!("Insecure TLS mode enabled (certificate verification disabled)");
}
configure_quic_transport(&client, &cmd).await;
configure_tls_certs(&client, &broker_url, &cmd).await?;
info!("Connecting to {}...", broker_url);
let result = client
.connect_with_options(&broker_url, options)
.await
.context("Failed to connect to MQTT broker")?;
if result.session_present {
info!("Resumed existing session");
}
let done_notify = subscribe_and_print(&client, &topic, &cmd, qos).await?;
wait_for_completion(&client, &done_notify, cmd.auto_reconnect).await
}
fn validate_topic_filter(topic: &str) -> Result<()> {
if topic.is_empty() {
anyhow::bail!("Topic cannot be empty");
}
if topic.contains("//") {
anyhow::bail!(
"Invalid topic filter '{}' - cannot have empty segments\nDid you mean '{}'?",
topic,
topic.replace("//", "/")
);
}
let segments: Vec<&str> = topic.split('/').collect();
for (i, segment) in segments.iter().enumerate() {
if segment == &"#" && i != segments.len() - 1 {
anyhow::bail!("Invalid topic filter '{topic}' - '#' wildcard must be the last segment");
} else if segment.contains('#') && segment != &"#" {
anyhow::bail!(
"Invalid topic filter '{topic}' - '#' wildcard must be alone in its segment"
);
} else if segment.contains('+') && segment != &"+" {
anyhow::bail!(
"Invalid topic filter '{topic}' - '+' wildcard must be alone in its segment"
);
}
}
Ok(())
}
use rand::Rng;