use crate::channel::mpsc;
use crate::cx::Cx;
use crate::io::{AsyncRead, AsyncWriteExt, ReadBuf};
use crate::net::TcpStream;
use crate::tracing_compat::warn;
use crate::types::Time;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::fmt;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::Poll;
use std::time::Duration;
const REQUEST_TIMEOUT_MESSAGE: &str = "request timeout";
fn timeout_now(cx: &Cx) -> Time {
cx.timer_driver()
.map_or_else(crate::time::wall_now, |driver| driver.now())
}
fn request_timeout_error() -> NatsError {
NatsError::Io(io::Error::new(
io::ErrorKind::TimedOut,
REQUEST_TIMEOUT_MESSAGE,
))
}
#[derive(Debug)]
pub enum NatsError {
Io(io::Error),
Protocol(String),
Server(String),
InvalidUrl(String),
Cancelled,
Closed,
SubscriptionNotFound(u64),
NotConnected,
}
impl fmt::Display for NatsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Io(e) => write!(f, "NATS I/O error: {e}"),
Self::Protocol(msg) => write!(f, "NATS protocol error: {msg}"),
Self::Server(msg) => write!(f, "NATS server error: {msg}"),
Self::InvalidUrl(url) => write!(f, "Invalid NATS URL: {url}"),
Self::Cancelled => write!(f, "NATS operation cancelled"),
Self::Closed => write!(f, "NATS connection closed"),
Self::SubscriptionNotFound(sid) => write!(f, "NATS subscription not found: {sid}"),
Self::NotConnected => write!(f, "NATS not connected"),
}
}
}
impl std::error::Error for NatsError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
_ => None,
}
}
}
impl From<io::Error> for NatsError {
fn from(err: io::Error) -> Self {
Self::Io(err)
}
}
impl NatsError {
#[must_use]
pub fn is_transient(&self) -> bool {
matches!(self, Self::Io(_) | Self::Closed | Self::NotConnected)
}
#[must_use]
pub fn is_connection_error(&self) -> bool {
matches!(self, Self::Io(_) | Self::Closed | Self::NotConnected)
}
#[must_use]
pub fn is_capacity_error(&self) -> bool {
false
}
#[must_use]
pub fn is_timeout(&self) -> bool {
matches!(self, Self::Io(e) if e.kind() == io::ErrorKind::TimedOut)
}
#[must_use]
pub fn is_retryable(&self) -> bool {
self.is_transient()
}
}
#[derive(Debug, Clone)]
pub struct NatsConfig {
pub host: String,
pub port: u16,
pub user: Option<String>,
pub password: Option<String>,
pub token: Option<String>,
pub name: Option<String>,
pub verbose: bool,
pub pedantic: bool,
pub request_timeout: Duration,
pub max_payload: usize,
pub max_read_buffer: usize,
}
impl Default for NatsConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 4222,
user: None,
password: None,
token: None,
name: None,
verbose: false,
pedantic: false,
request_timeout: Duration::from_secs(10),
max_payload: 1_048_576, max_read_buffer: DEFAULT_MAX_READ_BUFFER,
}
}
}
impl NatsConfig {
pub fn from_url(url: &str) -> Result<Self, NatsError> {
let url = url
.strip_prefix("nats://")
.ok_or_else(|| NatsError::InvalidUrl(url.to_string()))?;
let mut config = Self::default();
let url = if let Some((creds, rest)) = url.rsplit_once('@') {
if let Some((user, pass)) = creds.split_once(':') {
config.user = Some(user.to_string());
config.password = Some(pass.to_string());
} else {
config.token = Some(creds.to_string());
}
rest
} else {
url
};
if let Some(rest) = url.strip_prefix('[') {
let (host_body, after_host) = rest
.split_once(']')
.ok_or_else(|| NatsError::InvalidUrl("invalid IPv6 host".to_string()))?;
config.host = format!("[{host_body}]");
if let Some(port) = after_host.strip_prefix(':') {
config.port = port
.parse()
.map_err(|_| NatsError::InvalidUrl(format!("invalid port: {port}")))?;
} else if !after_host.is_empty() {
return Err(NatsError::InvalidUrl(format!("invalid host/port: {url}")));
}
} else if url.matches(':').count() <= 1 {
if let Some((host, port)) = url.rsplit_once(':') {
config.host = host.to_string();
config.port = port
.parse()
.map_err(|_| NatsError::InvalidUrl(format!("invalid port: {port}")))?;
} else if !url.is_empty() {
config.host = url.to_string();
}
} else if !url.is_empty() {
config.host = url.to_string();
}
if config.host.is_empty() {
return Err(NatsError::InvalidUrl("host must not be empty".to_string()));
}
Ok(config)
}
}
#[derive(Debug, Clone)]
pub struct Message {
pub subject: String,
pub sid: u64,
pub reply_to: Option<String>,
pub payload: Vec<u8>,
}
#[derive(Debug, Clone, Default)]
pub struct ServerInfo {
pub server_id: String,
pub server_name: String,
pub version: String,
pub proto: i32,
pub max_payload: usize,
pub tls_required: bool,
pub tls_available: bool,
pub connect_urls: Vec<String>,
}
impl ServerInfo {
fn parse(json: &str) -> Self {
let mut info = Self::default();
if let Some(v) = extract_json_string(json, "server_id") {
info.server_id = v;
}
if let Some(v) = extract_json_string(json, "server_name") {
info.server_name = v;
}
if let Some(v) = extract_json_string(json, "version") {
info.version = v;
}
if let Some(v) = extract_json_i64(json, "proto") {
info.proto = v as i32;
}
if let Some(v) = extract_json_i64(json, "max_payload") {
info.max_payload = usize::try_from(v).unwrap_or(0);
}
if let Some(v) = extract_json_bool(json, "tls_required") {
info.tls_required = v;
}
if let Some(v) = extract_json_bool(json, "tls_available") {
info.tls_available = v;
}
info
}
}
fn extract_json_string(json: &str, key: &str) -> Option<String> {
let pattern = format!("\"{key}\":\"");
let start = json.find(&pattern)? + pattern.len();
let slice = &json[start..];
let mut out = String::with_capacity(slice.len());
let mut chars = slice.chars();
loop {
match chars.next()? {
'"' => return Some(out),
'\\' => {
let next = chars.next()?;
match next {
'b' => out.push('\x08'),
'f' => out.push('\x0C'),
'n' => out.push('\n'),
'r' => out.push('\r'),
't' => out.push('\t'),
'u' => {
let mut hex = String::with_capacity(4);
for _ in 0..4 {
hex.push(chars.next()?);
}
if let Ok(val) = u32::from_str_radix(&hex, 16) {
if let Some(c) = char::from_u32(val) {
out.push(c);
}
}
}
other => out.push(other),
}
}
c => out.push(c),
}
}
}
fn nats_json_escape(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for ch in s.chars() {
match ch {
'"' => out.push_str("\\\""),
'\\' => out.push_str("\\\\"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
c if c.is_control() => {
use std::fmt::Write;
write!(&mut out, "\\u{:04x}", c as u32).expect("write to String");
}
c => out.push(c),
}
}
out
}
fn extract_json_i64(json: &str, key: &str) -> Option<i64> {
let pattern = format!("\"{key}\":");
let start = json.find(&pattern)? + pattern.len();
let rest = json[start..].trim_start();
let end = rest
.find(|c: char| !c.is_ascii_digit() && c != '-')
.unwrap_or(rest.len());
rest[..end].parse().ok()
}
fn extract_json_bool(json: &str, key: &str) -> Option<bool> {
let pattern = format!("\"{key}\":");
let start = json.find(&pattern)? + pattern.len();
let rest = json[start..].trim_start();
if rest.starts_with("true") {
Some(true)
} else if rest.starts_with("false") {
Some(false)
} else {
None
}
}
fn validate_nats_token(value: &str, field: &str) -> Result<(), NatsError> {
if value.is_empty() {
return Err(NatsError::Protocol(format!("{field} must not be empty")));
}
if value
.chars()
.any(|ch| ch.is_ascii_control() || ch.is_whitespace())
{
return Err(NatsError::Protocol(format!(
"{field} contains illegal whitespace/control characters"
)));
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SubscriptionPatternToken<'a> {
Literal(&'a str),
SingleWildcard,
TailWildcard,
}
fn is_valid_nats_segment(token: &str) -> bool {
!token.is_empty()
&& !token
.chars()
.any(|ch| ch.is_ascii_control() || ch.is_whitespace())
}
fn parse_subscription_pattern(pattern: &str) -> Option<Vec<SubscriptionPatternToken<'_>>> {
if pattern.is_empty() {
return None;
}
let raw_tokens: Vec<_> = pattern.split('.').collect();
let raw_len = raw_tokens.len();
if raw_tokens.iter().any(|token| !is_valid_nats_segment(token)) {
return None;
}
let mut parsed = Vec::with_capacity(raw_tokens.len());
for (index, token) in raw_tokens.into_iter().enumerate() {
match token {
"*" => parsed.push(SubscriptionPatternToken::SingleWildcard),
">" if index + 1 == raw_len => {
parsed.push(SubscriptionPatternToken::TailWildcard);
}
">" => return None,
_ if token.contains('*') || token.contains('>') => return None,
_ => parsed.push(SubscriptionPatternToken::Literal(token)),
}
}
Some(parsed)
}
fn parse_publish_subject(subject: &str) -> Option<Vec<&str>> {
if subject.is_empty() {
return None;
}
let tokens: Vec<_> = subject.split('.').collect();
if tokens
.iter()
.any(|token| !is_valid_nats_segment(token) || token.contains('*') || token.contains('>'))
{
return None;
}
Some(tokens)
}
fn subscription_matches_subject_impl(pattern: &str, subject: &str) -> bool {
let Some(pattern_tokens) = parse_subscription_pattern(pattern) else {
return false;
};
let Some(subject_tokens) = parse_publish_subject(subject) else {
return false;
};
let mut subject_index = 0usize;
for token in pattern_tokens {
match token {
SubscriptionPatternToken::Literal(literal) => {
if subject_tokens.get(subject_index) != Some(&literal) {
return false;
}
subject_index += 1;
}
SubscriptionPatternToken::SingleWildcard => {
if subject_tokens.get(subject_index).is_none() {
return false;
}
subject_index += 1;
}
SubscriptionPatternToken::TailWildcard => {
return subject_index < subject_tokens.len();
}
}
}
subject_index == subject_tokens.len()
}
#[cfg(any(test, feature = "test-internals"))]
#[doc(hidden)]
pub fn subscription_matches_subject(pattern: &str, subject: &str) -> bool {
subscription_matches_subject_impl(pattern, subject)
}
fn random_suffix(cx: &Cx) -> String {
let hi = cx.random_u64();
let lo = cx.random_u64();
format!("{:016x}", hi ^ lo)
}
const DEFAULT_MAX_READ_BUFFER: usize = 8 * 1024 * 1024;
#[derive(Debug)]
struct NatsReadBuffer {
buf: Vec<u8>,
pos: usize,
max_size: usize,
}
impl NatsReadBuffer {
#[cfg(test)]
fn new() -> Self {
Self::with_limit(DEFAULT_MAX_READ_BUFFER)
}
fn with_limit(max_size: usize) -> Self {
Self {
buf: Vec::new(),
pos: 0,
max_size,
}
}
fn available(&self) -> &[u8] {
&self.buf[self.pos..]
}
fn extend(&mut self, bytes: &[u8]) -> Result<(), NatsError> {
if self.buf.len() + bytes.len() - self.pos > self.max_size {
return Err(NatsError::Protocol(format!(
"read buffer exceeds maximum size ({} bytes)",
self.max_size
)));
}
self.buf.extend_from_slice(bytes);
Ok(())
}
fn consume(&mut self, n: usize) {
self.pos = self.pos.saturating_add(n).min(self.buf.len());
if self.pos > 0 && (self.pos > 4096 && self.pos > (self.buf.len() / 2)) {
self.buf.drain(..self.pos);
self.pos = 0;
}
}
fn find_crlf(&self) -> Option<usize> {
let buf = self.available();
(0..buf.len().saturating_sub(1)).find(|&i| buf[i] == b'\r' && buf[i + 1] == b'\n')
}
}
#[derive(Debug)]
enum NatsMessage {
Info(ServerInfo),
Msg(Message),
Ok,
Err(String),
Ping,
Pong,
}
struct SubscriptionState {
#[allow(dead_code)] subject: String,
sender: mpsc::Sender<Message>,
}
struct SharedState {
subscriptions: Mutex<HashMap<u64, SubscriptionState>>,
server_info: Mutex<Option<ServerInfo>>,
closed: std::sync::atomic::AtomicBool,
}
impl SharedState {
fn new() -> Self {
Self {
subscriptions: Mutex::new(HashMap::new()),
server_info: Mutex::new(None),
closed: std::sync::atomic::AtomicBool::new(false),
}
}
}
struct SubscribeGuard<'a> {
subs: &'a Mutex<HashMap<u64, SubscriptionState>>,
sid: u64,
defused: bool,
}
impl Drop for SubscribeGuard<'_> {
fn drop(&mut self) {
if !self.defused {
self.subs.lock().remove(&self.sid);
}
}
}
pub struct NatsClient {
config: NatsConfig,
stream: TcpStream,
read_buf: NatsReadBuffer,
state: Arc<SharedState>,
next_sid: AtomicU64,
connected: bool,
}
impl fmt::Debug for NatsClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NatsClient")
.field("host", &self.config.host)
.field("port", &self.config.port)
.field("connected", &self.connected)
.finish_non_exhaustive()
}
}
impl NatsClient {
pub async fn connect(cx: &Cx, url: &str) -> Result<Self, NatsError> {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
let config = NatsConfig::from_url(url)?;
Self::connect_with_config(cx, config).await
}
pub async fn connect_with_config(cx: &Cx, config: NatsConfig) -> Result<Self, NatsError> {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
cx.trace(&format!(
"nats: connecting to {}:{}",
config.host, config.port
));
let addr = format!("{}:{}", config.host, config.port);
let stream = TcpStream::connect(addr).await?;
let read_buf_limit = config.max_read_buffer;
let mut client = Self {
config,
stream,
read_buf: NatsReadBuffer::with_limit(read_buf_limit),
state: Arc::new(SharedState::new()),
next_sid: AtomicU64::new(1),
connected: false,
};
let info = client.read_info(cx).await?;
if info.max_payload > 0 && info.max_payload < client.config.max_payload {
client.config.max_payload = info.max_payload;
}
*client.state.server_info.lock() = Some(info.clone());
client.send_connect(cx).await?;
client.connected = true;
cx.trace("nats: connection established");
Ok(client)
}
async fn read_info(&mut self, cx: &Cx) -> Result<ServerInfo, NatsError> {
loop {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
if let Some(msg) = self.try_parse_message()? {
match msg {
NatsMessage::Info(info) => return Ok(info),
NatsMessage::Err(e) => return Err(NatsError::Server(e)),
_ => {
return Err(NatsError::Protocol(
"expected INFO message from server".to_string(),
));
}
}
}
self.read_more().await?;
}
}
async fn send_connect(&mut self, cx: &Cx) -> Result<(), NatsError> {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
let mut connect = String::from("{");
connect.push_str("\"verbose\":");
connect.push_str(if self.config.verbose { "true" } else { "false" });
connect.push_str(",\"pedantic\":");
connect.push_str(if self.config.pedantic {
"true"
} else {
"false"
});
connect.push_str(",\"lang\":\"rust\"");
connect.push_str(",\"version\":\"0.1.0\"");
connect.push_str(",\"protocol\":1");
if let Some(ref name) = self.config.name {
connect.push_str(",\"name\":\"");
connect.push_str(&nats_json_escape(name));
connect.push('"');
}
if let Some(ref user) = self.config.user {
connect.push_str(",\"user\":\"");
connect.push_str(&nats_json_escape(user));
connect.push('"');
}
if let Some(ref pass) = self.config.password {
connect.push_str(",\"pass\":\"");
connect.push_str(&nats_json_escape(pass));
connect.push('"');
}
if let Some(ref token) = self.config.token {
connect.push_str(",\"auth_token\":\"");
connect.push_str(&nats_json_escape(token));
connect.push('"');
}
connect.push('}');
let cmd = format!("CONNECT {connect}\r\n");
self.stream.write_all(cmd.as_bytes()).await?;
self.stream.flush().await?;
if self.config.verbose {
self.expect_ok(cx).await?;
}
Ok(())
}
async fn expect_ok(&mut self, cx: &Cx) -> Result<(), NatsError> {
loop {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
if let Some(msg) = self.try_parse_message()? {
match msg {
NatsMessage::Ok => return Ok(()),
NatsMessage::Err(e) => return Err(NatsError::Server(e)),
NatsMessage::Ping => {
self.send_server_pong().await?;
}
_ => {} }
} else {
self.read_more().await?;
}
}
}
async fn read_more(&mut self) -> Result<(), NatsError> {
let mut tmp = [0u8; 4096];
let n = std::future::poll_fn(|task_cx| {
if crate::cx::Cx::current().is_some_and(|c| c.checkpoint().is_err()) {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Interrupted,
"cancelled",
)));
}
let mut read_buf = ReadBuf::new(&mut tmp);
match Pin::new(&mut self.stream).poll_read(task_cx, &mut read_buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len())),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
}
})
.await?;
if n == 0 {
return Err(NatsError::Closed);
}
self.read_buf.extend(&tmp[..n])?;
Ok(())
}
async fn read_more_until(&mut self, cx: &Cx, deadline: Time) -> Result<(), NatsError> {
let now = timeout_now(cx);
let remaining = Duration::from_nanos(deadline.duration_since(now));
crate::time::timeout(now, remaining, self.read_more())
.await
.unwrap_or_else(|_| Err(request_timeout_error()))
}
async fn cleanup_request_subscription(&mut self, cx: &Cx, sid: u64, _reason: &str) {
if let Err(_err) = self.unsubscribe(cx, sid).await {
}
}
async fn send_server_pong(&mut self) -> Result<(), NatsError> {
let restore_connected = self.connected;
if restore_connected {
self.connected = false;
}
self.stream.write_all(b"PONG\r\n").await?;
self.stream.flush().await?;
if restore_connected {
self.connected = true;
}
Ok(())
}
fn remove_local_subscription(&self, sid: u64) {
let mut subs = self.state.subscriptions.lock();
subs.remove(&sid);
}
fn try_parse_message(&mut self) -> Result<Option<NatsMessage>, NatsError> {
let buf = self.read_buf.available();
if buf.is_empty() {
return Ok(None);
}
if buf.starts_with(b"INFO ") {
return self.parse_info();
} else if buf.starts_with(b"MSG ") {
return self.parse_msg();
} else if buf.starts_with(b"+OK") {
if buf.len() >= 5 && buf[3] == b'\r' && buf[4] == b'\n' {
self.read_buf.consume(5);
return Ok(Some(NatsMessage::Ok));
} else if buf.len() < 5 {
return Ok(None); }
return Err(NatsError::Protocol("malformed +OK frame".to_string()));
} else if buf.starts_with(b"-ERR ") {
return self.parse_err();
} else if buf.starts_with(b"PING") {
if buf.len() >= 6 && buf[4] == b'\r' && buf[5] == b'\n' {
self.read_buf.consume(6);
return Ok(Some(NatsMessage::Ping));
} else if buf.len() < 6 {
return Ok(None);
}
return Err(NatsError::Protocol("malformed PING frame".to_string()));
} else if buf.starts_with(b"PONG") {
if buf.len() >= 6 && buf[4] == b'\r' && buf[5] == b'\n' {
self.read_buf.consume(6);
return Ok(Some(NatsMessage::Pong));
} else if buf.len() < 6 {
return Ok(None);
}
return Err(NatsError::Protocol("malformed PONG frame".to_string()));
}
let Some(line_end) = self.read_buf.find_crlf() else {
return Ok(None);
};
let line = String::from_utf8_lossy(&self.read_buf.available()[..line_end]);
Err(NatsError::Protocol(format!("unknown message: {line}")))
}
fn parse_info(&mut self) -> Result<Option<NatsMessage>, NatsError> {
let buf = self.read_buf.available();
let Some(end) = self.read_buf.find_crlf() else {
return Ok(None);
};
let line = std::str::from_utf8(&buf[..end])
.map_err(|_| NatsError::Protocol("invalid UTF-8 in INFO".to_string()))?;
let json = line
.strip_prefix("INFO ")
.ok_or_else(|| NatsError::Protocol("malformed INFO".to_string()))?;
let info = ServerInfo::parse(json);
self.read_buf.consume(end + 2);
Ok(Some(NatsMessage::Info(info)))
}
fn parse_msg(&mut self) -> Result<Option<NatsMessage>, NatsError> {
let buf = self.read_buf.available();
let Some(header_end) = self.read_buf.find_crlf() else {
return Ok(None);
};
let header = std::str::from_utf8(&buf[..header_end])
.map_err(|_| NatsError::Protocol("invalid UTF-8 in MSG header".to_string()))?;
let mut parts = header.split_whitespace();
let _msg = parts.next(); let subject_str = parts
.next()
.ok_or_else(|| NatsError::Protocol(format!("malformed MSG header: {header}")))?;
let sid_str = parts
.next()
.ok_or_else(|| NatsError::Protocol(format!("malformed MSG header: {header}")))?;
let third = parts
.next()
.ok_or_else(|| NatsError::Protocol(format!("malformed MSG header: {header}")))?;
let fourth = parts.next();
if parts.next().is_some() {
return Err(NatsError::Protocol(format!(
"malformed MSG header (too many fields): {header}"
)));
}
let subject = subject_str.to_string();
let sid: u64 = sid_str
.parse()
.map_err(|_| NatsError::Protocol(format!("invalid SID: {sid_str}")))?;
let (reply_to, payload_len) = if let Some(len_str) = fourth {
(
Some(third.to_string()),
len_str.parse::<usize>().map_err(|_| {
NatsError::Protocol(format!("invalid payload length: {len_str}"))
})?,
)
} else {
(
None,
third
.parse::<usize>()
.map_err(|_| NatsError::Protocol(format!("invalid payload length: {third}")))?,
)
};
let max_buf = self.config.max_read_buffer;
if payload_len > max_buf {
return Err(NatsError::Protocol(format!(
"MSG payload length {payload_len} exceeds maximum ({max_buf} bytes)"
)));
}
let payload_start = header_end + 2;
let payload_end = payload_start + payload_len;
let total_len = payload_end + 2;
if buf.len() < total_len {
return Ok(None); }
if buf[payload_end] != b'\r' || buf[payload_end + 1] != b'\n' {
return Err(NatsError::Protocol(
"malformed MSG payload terminator".to_string(),
));
}
let payload = buf[payload_start..payload_end].to_vec();
self.read_buf.consume(total_len);
Ok(Some(NatsMessage::Msg(Message {
subject,
sid,
reply_to,
payload,
})))
}
fn parse_err(&mut self) -> Result<Option<NatsMessage>, NatsError> {
let buf = self.read_buf.available();
let Some(end) = self.read_buf.find_crlf() else {
return Ok(None);
};
let line = std::str::from_utf8(&buf[..end])
.map_err(|_| NatsError::Protocol("invalid UTF-8 in -ERR".to_string()))?;
let msg = line
.strip_prefix("-ERR ")
.unwrap_or(line)
.trim_matches('\'')
.to_string();
self.read_buf.consume(end + 2);
Ok(Some(NatsMessage::Err(msg)))
}
pub async fn publish(
&mut self,
cx: &Cx,
subject: &str,
payload: &[u8],
) -> Result<(), NatsError> {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
if !self.connected {
return Err(NatsError::NotConnected);
}
validate_nats_token(subject, "subject")?;
if payload.len() > self.config.max_payload {
return Err(NatsError::Protocol(format!(
"payload too large: {} > {}",
payload.len(),
self.config.max_payload
)));
}
self.connected = false;
let cmd = format!("PUB {subject} {}\r\n", payload.len());
self.stream.write_all(cmd.as_bytes()).await?;
self.stream.write_all(payload).await?;
self.stream.write_all(b"\r\n").await?;
self.stream.flush().await?;
self.connected = true;
self.handle_pending_messages(cx).await?;
Ok(())
}
pub async fn publish_request(
&mut self,
cx: &Cx,
subject: &str,
reply_to: &str,
payload: &[u8],
) -> Result<(), NatsError> {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
if !self.connected {
return Err(NatsError::NotConnected);
}
validate_nats_token(subject, "subject")?;
validate_nats_token(reply_to, "reply-to subject")?;
if payload.len() > self.config.max_payload {
return Err(NatsError::Protocol(format!(
"payload too large: {} > {}",
payload.len(),
self.config.max_payload
)));
}
self.connected = false;
let cmd = format!("PUB {subject} {reply_to} {}\r\n", payload.len());
self.stream.write_all(cmd.as_bytes()).await?;
self.stream.write_all(payload).await?;
self.stream.write_all(b"\r\n").await?;
self.stream.flush().await?;
self.connected = true;
Ok(())
}
pub async fn request(
&mut self,
cx: &Cx,
subject: &str,
payload: &[u8],
) -> Result<Message, NatsError> {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
if !self.connected {
return Err(NatsError::NotConnected);
}
validate_nats_token(subject, "subject")?;
let inbox = format!(
"_INBOX.{}.{}",
self.next_sid.load(Ordering::Relaxed),
random_suffix(cx)
);
let mut sub = self.subscribe(cx, &inbox).await?;
if let Err(err) = self.publish_request(cx, subject, &inbox, payload).await {
self.cleanup_request_subscription(cx, sub.sid(), "publish_request_failed")
.await;
return Err(err);
}
let deadline = timeout_now(cx) + self.config.request_timeout;
loop {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
let mut processed_any = false;
loop {
let message = match self.try_parse_message() {
Ok(message) => message,
Err(err) => {
self.cleanup_request_subscription(cx, sub.sid(), "parse_failed")
.await;
return Err(err);
}
};
match message {
Some(NatsMessage::Ping) => {
if let Err(err) = self.send_server_pong().await {
self.cleanup_request_subscription(
cx,
sub.sid(),
"server_ping_write_failed",
)
.await;
return Err(err);
}
processed_any = true;
}
Some(NatsMessage::Msg(m)) => {
if m.sid == sub.sid() {
self.unsubscribe(cx, sub.sid()).await?;
return Ok(m);
}
self.dispatch_message(m);
processed_any = true;
}
Some(NatsMessage::Err(e)) => {
self.cleanup_request_subscription(cx, sub.sid(), "server_error")
.await;
return Err(NatsError::Server(e));
}
Some(_) => {
processed_any = true;
}
None => {
if processed_any {
break;
}
if let Err(err) = self.read_more_until(cx, deadline).await {
self.cleanup_request_subscription(
cx,
sub.sid(),
REQUEST_TIMEOUT_MESSAGE,
)
.await;
return Err(err);
}
processed_any = true;
}
}
}
if let Some(msg) = sub.try_next() {
self.unsubscribe(cx, sub.sid()).await?;
return Ok(msg);
}
}
}
pub async fn subscribe(&mut self, cx: &Cx, subject: &str) -> Result<Subscription, NatsError> {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
if !self.connected {
return Err(NatsError::NotConnected);
}
validate_nats_token(subject, "subject")?;
let sid = self.next_sid.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = mpsc::channel(256);
{
let mut subs = self.state.subscriptions.lock();
subs.insert(
sid,
SubscriptionState {
subject: subject.to_string(),
sender: tx,
},
);
}
let mut guard = SubscribeGuard {
subs: &self.state.subscriptions,
sid,
defused: false,
};
self.connected = false;
let cmd = format!("SUB {subject} {sid}\r\n");
self.stream.write_all(cmd.as_bytes()).await?;
self.stream.flush().await?;
self.connected = true;
guard.defused = true;
cx.trace(&format!("nats: subscribed to {subject} (sid={sid})"));
Ok(Subscription {
sid,
subject: subject.to_string(),
rx,
state: Arc::clone(&self.state),
})
}
pub async fn queue_subscribe(
&mut self,
cx: &Cx,
subject: &str,
queue_group: &str,
) -> Result<Subscription, NatsError> {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
if !self.connected {
return Err(NatsError::NotConnected);
}
validate_nats_token(subject, "subject")?;
validate_nats_token(queue_group, "queue group")?;
let sid = self.next_sid.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = mpsc::channel(256);
{
let mut subs = self.state.subscriptions.lock();
subs.insert(
sid,
SubscriptionState {
subject: subject.to_string(),
sender: tx,
},
);
}
let mut guard = SubscribeGuard {
subs: &self.state.subscriptions,
sid,
defused: false,
};
self.connected = false;
let cmd = format!("SUB {subject} {queue_group} {sid}\r\n");
self.stream.write_all(cmd.as_bytes()).await?;
self.stream.flush().await?;
self.connected = true;
guard.defused = true;
Ok(Subscription {
sid,
subject: subject.to_string(),
rx,
state: Arc::clone(&self.state),
})
}
pub async fn unsubscribe(&mut self, cx: &Cx, sid: u64) -> Result<(), NatsError> {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
self.remove_local_subscription(sid);
if !self.connected {
return Err(NatsError::NotConnected);
}
self.connected = false;
let cmd = format!("UNSUB {sid}\r\n");
self.stream.write_all(cmd.as_bytes()).await?;
self.stream.flush().await?;
self.connected = true;
Ok(())
}
pub async fn ping(&mut self, cx: &Cx) -> Result<(), NatsError> {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
if !self.connected {
return Err(NatsError::NotConnected);
}
self.connected = false;
self.stream.write_all(b"PING\r\n").await?;
self.stream.flush().await?;
loop {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
if let Some(msg) = self.try_parse_message()? {
match msg {
NatsMessage::Pong => {
self.connected = true;
return Ok(());
}
NatsMessage::Err(e) => return Err(NatsError::Server(e)),
NatsMessage::Ping => {
self.send_server_pong().await?;
}
NatsMessage::Msg(m) => {
self.dispatch_message(m);
}
_ => {}
}
} else {
self.read_more().await?;
}
}
}
async fn handle_pending_messages(&mut self, _cx: &Cx) -> Result<(), NatsError> {
loop {
match self.try_parse_message()? {
Some(NatsMessage::Ping) => {
self.send_server_pong().await?;
}
Some(NatsMessage::Msg(m)) => {
self.dispatch_message(m);
}
Some(NatsMessage::Err(e)) => {
return Err(NatsError::Server(e));
}
Some(_) => {}
None => break,
}
}
Ok(())
}
fn dispatch_message(&self, msg: Message) {
let subs = self.state.subscriptions.lock();
if let Some(sub) = subs.get(&msg.sid) {
if sub.sender.try_send(msg).is_err() {
warn!(
subject = %sub.subject,
"NATS message dropped due to backpressure - consumer too slow"
);
}
}
}
pub async fn process(&mut self, cx: &Cx) -> Result<(), NatsError> {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
let mut processed_any = false;
loop {
match self.try_parse_message()? {
Some(NatsMessage::Ping) => {
self.send_server_pong().await?;
processed_any = true;
}
Some(NatsMessage::Msg(m)) => {
self.dispatch_message(m);
processed_any = true;
}
Some(NatsMessage::Err(e)) => {
return Err(NatsError::Server(e));
}
Some(_) => {
processed_any = true;
}
None => {
if processed_any {
break;
}
self.read_more().await?;
processed_any = true;
}
}
}
Ok(())
}
pub async fn close(&mut self, cx: &Cx) -> Result<(), NatsError> {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
self.state.closed.store(true, Ordering::Release);
{
let mut subs = self.state.subscriptions.lock();
subs.clear();
}
if self.connected {
let _ = self.stream.flush().await;
let _ = self.stream.shutdown(std::net::Shutdown::Both);
}
self.connected = false;
Ok(())
}
pub fn server_info(&self) -> Option<ServerInfo> {
self.state.server_info.lock().clone()
}
}
pub struct Subscription {
sid: u64,
subject: String,
rx: mpsc::Receiver<Message>,
state: Arc<SharedState>,
}
impl fmt::Debug for Subscription {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Subscription")
.field("sid", &self.sid)
.field("subject", &self.subject)
.finish_non_exhaustive()
}
}
impl Subscription {
#[must_use]
pub fn sid(&self) -> u64 {
self.sid
}
#[must_use]
pub fn subject(&self) -> &str {
&self.subject
}
pub async fn next(&mut self, cx: &Cx) -> Result<Option<Message>, NatsError> {
cx.checkpoint().map_err(|_| NatsError::Cancelled)?;
if let Ok(msg) = self.rx.try_recv() {
return Ok(Some(msg));
}
if self.state.closed.load(Ordering::Acquire) {
return Ok(None);
}
match self.rx.recv(cx).await {
Ok(msg) => Ok(Some(msg)),
Err(mpsc::RecvError::Disconnected | mpsc::RecvError::Empty) => Ok(None),
Err(mpsc::RecvError::Cancelled) => Err(NatsError::Cancelled),
}
}
pub fn try_next(&mut self) -> Option<Message> {
self.rx.try_recv().ok()
}
}
impl Drop for Subscription {
fn drop(&mut self) {
let mut subs = self.state.subscriptions.lock();
subs.remove(&self.sid);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{assert_completes_within, run_test_with_cx};
use serde_json::json;
use socket2::SockRef;
use std::io::{BufRead, BufReader, Read, Write};
use std::net::TcpListener;
use std::sync::mpsc as std_mpsc;
use std::thread;
fn scrub_reply_subject(reply_to: Option<&str>) -> Option<&str> {
let value = reply_to?;
Some(if value.starts_with("_INBOX.") {
"_INBOX.[SCRUBBED]"
} else {
value
})
}
fn message_event_snapshot(message: &Message) -> serde_json::Value {
json!({
"subject": message.subject,
"sid": message.sid,
"reply_to": scrub_reply_subject(message.reply_to.as_deref()),
"payload_utf8": String::from_utf8_lossy(&message.payload),
"payload_len": message.payload.len(),
})
}
fn read_protocol_line(reader: &mut BufReader<std::net::TcpStream>) -> String {
let mut line = String::new();
let bytes = reader.read_line(&mut line).expect("read protocol line");
assert!(bytes > 0, "peer closed before sending a full protocol line");
line
}
fn parse_pub_payload_len(header: &str) -> usize {
let parts: Vec<_> = header.split_whitespace().collect();
assert_eq!(parts.first().copied(), Some("PUB"));
assert_eq!(parts.len(), 4, "request publish must include reply-to");
parts[3].parse().expect("parse PUB payload length")
}
fn read_optional_protocol_line(reader: &mut BufReader<std::net::TcpStream>) -> Option<String> {
let mut line = String::new();
match reader.read_line(&mut line) {
Ok(0) => None,
Ok(_) => Some(line),
Err(err)
if matches!(
err.kind(),
io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut
) =>
{
None
}
Err(err) => panic!("read protocol line: {err}"),
}
}
#[test]
fn test_config_from_url_simple() {
let config = NatsConfig::from_url("nats://localhost:4222").unwrap();
assert_eq!(config.host, "localhost");
assert_eq!(config.port, 4222);
assert!(config.user.is_none());
assert!(config.password.is_none());
}
#[test]
fn test_config_from_url_with_auth() {
let config = NatsConfig::from_url("nats://user:pass@localhost:4222").unwrap();
assert_eq!(config.host, "localhost");
assert_eq!(config.port, 4222);
assert_eq!(config.user, Some("user".to_string()));
assert_eq!(config.password, Some("pass".to_string()));
}
#[test]
fn test_config_from_url_with_token() {
let config = NatsConfig::from_url("nats://mytoken@localhost:4222").unwrap();
assert_eq!(config.host, "localhost");
assert_eq!(config.port, 4222);
assert_eq!(config.token, Some("mytoken".to_string()));
}
#[test]
fn test_config_from_url_default_port() {
let config = NatsConfig::from_url("nats://localhost").unwrap();
assert_eq!(config.host, "localhost");
assert_eq!(config.port, 4222); }
#[test]
fn test_config_from_url_ipv6() {
let config = NatsConfig::from_url("nats://[::1]:4333").unwrap();
assert_eq!(config.host, "[::1]");
assert_eq!(config.port, 4333);
}
#[test]
fn test_config_from_url_password_with_at_sign() {
let config = NatsConfig::from_url("nats://user:pa@ss@localhost:4222").unwrap();
assert_eq!(config.user.as_deref(), Some("user"));
assert_eq!(config.password.as_deref(), Some("pa@ss"));
assert_eq!(config.host, "localhost");
assert_eq!(config.port, 4222);
}
#[test]
fn test_server_info_parse() {
let json = r#"{"server_id":"id123","server_name":"test","version":"2.9.0","proto":1,"max_payload":1048576,"tls_required":false}"#;
let info = ServerInfo::parse(json);
assert_eq!(info.server_id, "id123");
assert_eq!(info.server_name, "test");
assert_eq!(info.version, "2.9.0");
assert_eq!(info.proto, 1);
assert_eq!(info.max_payload, 1_048_576);
assert!(!info.tls_required);
}
#[test]
fn test_extract_json_string() {
let json = r#"{"name":"value","other":123}"#;
assert_eq!(extract_json_string(json, "name"), Some("value".to_string()));
assert_eq!(extract_json_string(json, "missing"), None);
}
#[test]
fn test_extract_json_i64() {
let json = r#"{"count":42,"neg":-5}"#;
assert_eq!(extract_json_i64(json, "count"), Some(42));
assert_eq!(extract_json_i64(json, "neg"), Some(-5));
assert_eq!(extract_json_i64(json, "missing"), None);
}
#[test]
fn test_extract_json_bool() {
let json = r#"{"enabled":true,"disabled":false}"#;
assert_eq!(extract_json_bool(json, "enabled"), Some(true));
assert_eq!(extract_json_bool(json, "disabled"), Some(false));
assert_eq!(extract_json_bool(json, "missing"), None);
}
#[test]
fn test_config_invalid_url() {
let result = NatsConfig::from_url("http://localhost:4222");
assert!(matches!(result, Err(NatsError::InvalidUrl(_))));
}
#[test]
fn test_config_invalid_port() {
let result = NatsConfig::from_url("nats://localhost:notaport");
assert!(matches!(result, Err(NatsError::InvalidUrl(_))));
}
#[test]
fn test_config_invalid_empty_host() {
let result = NatsConfig::from_url("nats://:4222");
assert!(matches!(result, Err(NatsError::InvalidUrl(_))));
}
#[test]
fn test_nats_error_display() {
assert_eq!(
format!("{}", NatsError::Cancelled),
"NATS operation cancelled"
);
assert_eq!(format!("{}", NatsError::Closed), "NATS connection closed");
assert_eq!(format!("{}", NatsError::NotConnected), "NATS not connected");
assert_eq!(
format!("{}", NatsError::SubscriptionNotFound(42)),
"NATS subscription not found: 42"
);
assert_eq!(
format!("{}", NatsError::Server("auth error".to_string())),
"NATS server error: auth error"
);
assert_eq!(
format!("{}", NatsError::Protocol("parse error".to_string())),
"NATS protocol error: parse error"
);
assert_eq!(
format!("{}", NatsError::InvalidUrl("bad".to_string())),
"Invalid NATS URL: bad"
);
}
#[test]
fn test_validate_nats_token_rejects_whitespace_and_controls() {
assert!(validate_nats_token("foo.bar", "subject").is_ok());
assert!(validate_nats_token("", "subject").is_err());
assert!(validate_nats_token("foo bar", "subject").is_err());
assert!(validate_nats_token("foo\r\nPUB x 1\r\nx", "subject").is_err());
assert!(validate_nats_token("queue\tgroup", "queue group").is_err());
}
#[test]
fn test_subscription_matches_subject_exact_and_single_wildcard() {
assert!(subscription_matches_subject("time.us.east", "time.us.east"));
assert!(subscription_matches_subject("time.*.east", "time.us.east"));
assert!(!subscription_matches_subject(
"time.*.east",
"time.us.east.atlanta"
));
assert!(!subscription_matches_subject("time.*.east", "time.east"));
}
#[test]
fn test_subscription_matches_subject_tail_wildcard_requires_trailing_tokens() {
assert!(subscription_matches_subject("time.>", "time.us"));
assert!(subscription_matches_subject(
"time.>",
"time.us.east.atlanta"
));
assert!(!subscription_matches_subject("time.>", "time"));
assert!(subscription_matches_subject(">", "time.us"));
}
#[test]
fn test_subscription_matches_subject_rejects_invalid_wildcard_placements() {
assert!(!subscription_matches_subject("time>.east", "time.us.east"));
assert!(!subscription_matches_subject("time.>.east", "time.us.east"));
assert!(!subscription_matches_subject("time.*east", "time.us.east"));
assert!(!subscription_matches_subject("time.east", "time.*"));
}
#[test]
fn test_subscription_matches_subject_rejects_empty_tokens() {
assert!(!subscription_matches_subject("time..east", "time.us.east"));
assert!(!subscription_matches_subject(".time.east", "time.us.east"));
assert!(!subscription_matches_subject("time.east", "time..east"));
assert!(!subscription_matches_subject("time.east", "time.east."));
}
#[test]
fn test_random_suffix_format() {
let cx: Cx = Cx::for_testing();
let s1 = random_suffix(&cx);
let s2 = random_suffix(&cx);
assert_eq!(s1.len(), 16);
assert!(s1.chars().all(|c| c.is_ascii_hexdigit()));
assert_eq!(s2.len(), 16);
assert!(s2.chars().all(|c| c.is_ascii_hexdigit()));
assert_ne!(s1, s2);
}
#[test]
fn test_server_info_parse_minimal() {
let json = "{}";
let info = ServerInfo::parse(json);
assert_eq!(info.server_id, "");
assert_eq!(info.max_payload, 0);
assert!(!info.tls_required);
}
#[test]
fn test_server_info_parse_with_tls() {
let json = r#"{"tls_required":true,"tls_available":true}"#;
let info = ServerInfo::parse(json);
assert!(info.tls_required);
assert!(info.tls_available);
}
#[test]
fn test_nats_config_default() {
let config = NatsConfig::default();
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 4222);
assert!(config.user.is_none());
assert!(config.password.is_none());
assert!(config.token.is_none());
assert!(!config.verbose);
assert!(!config.pedantic);
assert_eq!(config.max_payload, 1_048_576);
assert_eq!(config.request_timeout, Duration::from_secs(10));
}
#[test]
fn test_read_buffer_operations() {
let mut buf = NatsReadBuffer::new();
assert!(buf.available().is_empty());
buf.extend(b"hello\r\n").unwrap();
assert_eq!(buf.available(), b"hello\r\n");
assert_eq!(buf.find_crlf(), Some(5));
buf.consume(7);
assert!(buf.available().is_empty());
}
#[test]
fn test_read_buffer_partial_crlf() {
let mut buf = NatsReadBuffer::new();
buf.extend(b"hello\r").unwrap();
assert_eq!(buf.find_crlf(), None);
buf.extend(b"\n").unwrap();
assert_eq!(buf.find_crlf(), Some(5));
}
#[test]
fn test_nats_json_escape_c1_control() {
let input = "\u{0080}";
let escaped = nats_json_escape(input);
assert_eq!(escaped, "\\u0080");
}
#[test]
fn test_nats_json_escape_c0_control() {
let escaped = nats_json_escape("\u{0001}");
assert_eq!(escaped, "\\u0001");
}
#[test]
fn test_nats_json_escape_common_chars() {
assert_eq!(nats_json_escape(r#"hello"world"#), r#"hello\"world"#);
assert_eq!(nats_json_escape("back\\slash"), "back\\\\slash");
assert_eq!(nats_json_escape("new\nline"), "new\\nline");
assert_eq!(nats_json_escape("plain"), "plain");
}
#[test]
fn nats_error_display_all_variants() {
assert!(
NatsError::Io(io::Error::other("e"))
.to_string()
.contains("I/O error")
);
assert!(
NatsError::Protocol("p".into())
.to_string()
.contains("protocol error")
);
assert!(
NatsError::Server("s".into())
.to_string()
.contains("server error")
);
assert!(
NatsError::InvalidUrl("bad://".into())
.to_string()
.contains("bad://")
);
assert!(NatsError::Cancelled.to_string().contains("cancelled"));
assert!(NatsError::Closed.to_string().contains("closed"));
assert!(
NatsError::SubscriptionNotFound(42)
.to_string()
.contains("42")
);
assert!(
NatsError::NotConnected
.to_string()
.contains("not connected")
);
}
#[test]
fn nats_error_debug() {
let err = NatsError::Closed;
let dbg = format!("{err:?}");
assert!(dbg.contains("Closed"));
}
#[test]
fn nats_error_source_io() {
let err = NatsError::Io(io::Error::other("disk"));
assert!(std::error::Error::source(&err).is_some());
}
#[test]
fn nats_request_timeout_error_is_classified_as_timeout() {
assert!(request_timeout_error().is_timeout());
}
#[test]
fn request_enforces_timeout_while_socket_reads_are_pending() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind test listener");
let addr = listener.local_addr().expect("listener addr");
let server = thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("accept client");
stream
.set_read_timeout(Some(Duration::from_secs(2)))
.expect("set read timeout");
stream
.write_all(
b"INFO {\"server_id\":\"id\",\"server_name\":\"test\",\"version\":\"2.10.0\",\"proto\":1,\"max_payload\":1048576}\r\n",
)
.expect("write INFO");
stream.flush().expect("flush INFO");
let mut reader = BufReader::new(stream);
let connect = read_protocol_line(&mut reader);
assert!(
connect.starts_with("CONNECT "),
"unexpected CONNECT: {connect:?}"
);
let subscribe = read_protocol_line(&mut reader);
assert!(
subscribe.starts_with("SUB _INBOX."),
"unexpected SUB: {subscribe:?}"
);
let publish = read_protocol_line(&mut reader);
assert!(
publish.starts_with("PUB svc.echo _INBOX."),
"unexpected PUB: {publish:?}"
);
let payload_len = parse_pub_payload_len(&publish);
let mut payload = vec![0_u8; payload_len + 2];
reader
.read_exact(&mut payload)
.expect("read request payload");
assert_eq!(&payload[..payload_len], b"ping");
assert_eq!(&payload[payload_len..], b"\r\n");
read_protocol_line(&mut reader)
});
run_test_with_cx(|cx| async move {
let config = NatsConfig {
host: addr.ip().to_string(),
port: addr.port(),
request_timeout: Duration::from_millis(100),
..Default::default()
};
assert_completes_within(
Duration::from_secs(2),
"nats request timeout enforcement",
move || {
let config = config.clone();
Box::pin(async move {
let mut client = NatsClient::connect_with_config(&cx, config)
.await
.expect("connect to test server");
let err = client
.request(&cx, "svc.echo", b"ping")
.await
.expect_err("request must time out");
assert!(
matches!(err, NatsError::Io(ref io_err) if io_err.kind() == io::ErrorKind::TimedOut),
"expected timed out I/O error, got {err:?}"
);
assert!(err.is_timeout(), "expected timeout classification");
})
},
)
.await;
});
let unsubscribe = server.join().expect("server join");
assert!(
unsubscribe.starts_with("UNSUB "),
"timeout cleanup must unsubscribe, got {unsubscribe:?}"
);
}
#[test]
fn unsubscribe_on_disconnected_client_skips_wire_write() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind test listener");
let addr = listener.local_addr().expect("listener addr");
let server = thread::spawn(move || {
let (stream, _) = listener.accept().expect("accept client");
stream
.set_read_timeout(Some(Duration::from_millis(250)))
.expect("set read timeout");
let mut reader = BufReader::new(stream);
read_optional_protocol_line(&mut reader)
});
run_test_with_cx(|cx| async move {
let stream = TcpStream::connect(format!("{addr}"))
.await
.expect("connect client");
let state = Arc::new(SharedState::new());
let sid = 41;
let (tx, _rx) = mpsc::channel(8);
state.subscriptions.lock().insert(
sid,
SubscriptionState {
subject: "svc.echo".to_string(),
sender: tx,
},
);
let mut client = NatsClient {
config: NatsConfig::default(),
stream,
read_buf: NatsReadBuffer::new(),
state: Arc::clone(&state),
next_sid: AtomicU64::new(1),
connected: false,
};
let err = client
.unsubscribe(&cx, sid)
.await
.expect_err("disconnected unsubscribe must fail closed");
assert!(matches!(err, NatsError::NotConnected));
assert!(
!state.subscriptions.lock().contains_key(&sid),
"local subscription must still be removed"
);
});
let line = server.join().expect("server join");
assert!(
line.is_none(),
"disconnected unsubscribe must not emit UNSUB, got {line:?}"
);
}
#[test]
fn ping_on_disconnected_client_skips_wire_write() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind test listener");
let addr = listener.local_addr().expect("listener addr");
let server = thread::spawn(move || {
let (stream, _) = listener.accept().expect("accept client");
stream
.set_read_timeout(Some(Duration::from_millis(250)))
.expect("set read timeout");
let mut reader = BufReader::new(stream);
read_optional_protocol_line(&mut reader)
});
run_test_with_cx(|cx| async move {
let stream = TcpStream::connect(format!("{addr}"))
.await
.expect("connect client");
let mut client = NatsClient {
config: NatsConfig::default(),
stream,
read_buf: NatsReadBuffer::new(),
state: Arc::new(SharedState::new()),
next_sid: AtomicU64::new(1),
connected: false,
};
let err = client
.ping(&cx)
.await
.expect_err("disconnected ping must fail closed");
assert!(matches!(err, NatsError::NotConnected));
});
let line = server.join().expect("server join");
assert!(
line.is_none(),
"disconnected ping must not emit wire bytes, got {line:?}"
);
}
#[test]
fn process_ping_write_failure_marks_client_disconnected() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind test listener");
let addr = listener.local_addr().expect("listener addr");
let (close_tx, close_rx) = std_mpsc::channel();
let (closed_tx, closed_rx) = std_mpsc::channel();
let server = thread::spawn(move || {
let (stream, _) = listener.accept().expect("accept client");
close_rx.recv().expect("close signal");
SockRef::from(&stream)
.set_linger(Some(Duration::ZERO))
.expect("force reset on close");
drop(stream);
closed_tx.send(()).expect("closed ack");
});
run_test_with_cx(|cx| async move {
let stream = TcpStream::connect(format!("{addr}"))
.await
.expect("connect client");
let mut client = NatsClient {
config: NatsConfig::default(),
stream,
read_buf: NatsReadBuffer::new(),
state: Arc::new(SharedState::new()),
next_sid: AtomicU64::new(1),
connected: true,
};
client.read_buf.extend(b"PING\r\n").expect("buffer ping");
close_tx.send(()).expect("signal close");
closed_rx.recv().expect("server closed");
thread::sleep(Duration::from_millis(20));
let err = client
.process(&cx)
.await
.expect_err("PONG write must fail against reset peer");
assert!(
matches!(err, NatsError::Io(_)),
"expected I/O error, got {err:?}"
);
assert!(
!client.connected,
"client must remain disconnected after failed PONG write"
);
let follow_up = client
.publish(&cx, "svc.echo", b"ping")
.await
.expect_err("fail-closed client must reject follow-up publish");
assert!(matches!(follow_up, NatsError::NotConnected));
});
server.join().expect("server join");
}
#[test]
fn request_ping_write_failure_cleans_up_temporary_subscription() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind test listener");
let addr = listener.local_addr().expect("listener addr");
let server = thread::spawn(move || {
let (stream, _) = listener.accept().expect("accept client");
let mut reader = BufReader::new(stream);
let subscribe = read_protocol_line(&mut reader);
assert!(
subscribe.starts_with("SUB _INBOX."),
"unexpected SUB: {subscribe:?}"
);
let publish = read_protocol_line(&mut reader);
assert!(
publish.starts_with("PUB svc.echo _INBOX."),
"unexpected PUB: {publish:?}"
);
let payload_len = parse_pub_payload_len(&publish);
let mut payload = vec![0_u8; payload_len + 2];
reader
.read_exact(&mut payload)
.expect("read request payload");
assert_eq!(&payload[..payload_len], b"ping");
assert_eq!(&payload[payload_len..], b"\r\n");
let stream = reader.into_inner();
SockRef::from(&stream)
.set_linger(Some(Duration::ZERO))
.expect("force reset on close");
drop(stream);
});
run_test_with_cx(|cx| async move {
let stream = TcpStream::connect(format!("{addr}"))
.await
.expect("connect client");
let state = Arc::new(SharedState::new());
let mut client = NatsClient {
config: NatsConfig::default(),
stream,
read_buf: NatsReadBuffer::new(),
state: Arc::clone(&state),
next_sid: AtomicU64::new(1),
connected: true,
};
client.read_buf.extend(b"PING\r\n").expect("buffer ping");
let err = client
.request(&cx, "svc.echo", b"ping")
.await
.expect_err("request must fail when PONG write fails");
assert!(
matches!(err, NatsError::Io(_)),
"expected I/O error, got {err:?}"
);
assert!(
state.subscriptions.lock().is_empty(),
"temporary request inbox subscription must be cleaned up after PONG write failure"
);
assert!(
!client.connected,
"client must remain disconnected after failed PONG write"
);
});
server.join().expect("server join");
}
#[test]
fn nats_error_source_none_for_others() {
assert!(std::error::Error::source(&NatsError::Cancelled).is_none());
assert!(std::error::Error::source(&NatsError::Closed).is_none());
assert!(std::error::Error::source(&NatsError::NotConnected).is_none());
}
#[test]
fn nats_error_from_io() {
let io_err = io::Error::other("net");
let err: NatsError = NatsError::from(io_err);
assert!(matches!(err, NatsError::Io(_)));
}
#[test]
fn nats_config_debug_clone() {
let cfg = NatsConfig::default();
let dbg = format!("{cfg:?}");
assert!(dbg.contains("NatsConfig"));
let cloned = cfg;
assert_eq!(cloned.host, "127.0.0.1");
assert_eq!(cloned.port, 4222);
}
#[test]
fn nats_config_from_url_with_creds() {
let cfg = NatsConfig::from_url("nats://user:pass@myhost:4223").unwrap();
assert_eq!(cfg.host, "myhost");
assert_eq!(cfg.port, 4223);
assert_eq!(cfg.user, Some("user".into()));
assert_eq!(cfg.password, Some("pass".into()));
}
#[test]
fn nats_config_from_url_with_token() {
let cfg = NatsConfig::from_url("nats://mytoken@server:4222").unwrap();
assert_eq!(cfg.token, Some("mytoken".into()));
assert!(cfg.user.is_none());
}
#[test]
fn nats_config_from_url_host_only() {
let cfg = NatsConfig::from_url("nats://myhost").unwrap();
assert_eq!(cfg.host, "myhost");
assert_eq!(cfg.port, 4222); }
#[test]
fn nats_config_from_url_invalid_scheme() {
assert!(NatsConfig::from_url("http://localhost").is_err());
}
#[test]
fn message_debug_clone() {
let msg = Message {
subject: "foo.bar".into(),
sid: 1,
reply_to: Some("_INBOX.123".into()),
payload: b"hello".to_vec(),
};
let dbg = format!("{msg:?}");
assert!(dbg.contains("foo.bar"));
assert!(dbg.contains("_INBOX"));
let cloned = msg;
assert_eq!(cloned.subject, "foo.bar");
assert_eq!(cloned.sid, 1);
assert_eq!(cloned.payload, b"hello");
}
#[test]
fn message_no_reply() {
let msg = Message {
subject: "test".into(),
sid: 0,
reply_to: None,
payload: vec![],
};
assert!(msg.reply_to.is_none());
assert!(msg.payload.is_empty());
}
#[test]
fn nats_pubsub_event_snapshot_scrubbed() {
let msg = Message {
subject: "svc.echo".into(),
sid: 7,
reply_to: Some("_INBOX.42.reply".into()),
payload: b"{\"event\":\"published\",\"seq\":12}".to_vec(),
};
insta::assert_json_snapshot!("nats_pubsub_event_scrubbed", message_event_snapshot(&msg));
}
#[test]
fn server_info_default() {
let info = ServerInfo::default();
assert!(info.server_id.is_empty());
assert!(info.server_name.is_empty());
assert!(info.version.is_empty());
assert_eq!(info.proto, 0);
assert_eq!(info.max_payload, 0);
assert!(!info.tls_required);
assert!(!info.tls_available);
assert!(info.connect_urls.is_empty());
}
#[test]
fn server_info_debug_clone() {
let info = ServerInfo {
server_id: "test-id".into(),
..Default::default()
};
let dbg = format!("{info:?}");
assert!(dbg.contains("ServerInfo"));
let cloned = info;
assert_eq!(cloned.server_id, "test-id");
}
#[test]
fn server_info_parse_full() {
let json = r#"{"server_id":"abc","server_name":"srv","version":"2.10","proto":1,"max_payload":1048576}"#;
let info = ServerInfo::parse(json);
assert_eq!(info.server_id, "abc");
assert_eq!(info.server_name, "srv");
assert_eq!(info.version, "2.10");
assert_eq!(info.proto, 1);
assert_eq!(info.max_payload, 1_048_576);
}
#[test]
fn server_info_parse_empty() {
let info = ServerInfo::parse("{}");
assert!(info.server_id.is_empty());
assert_eq!(info.proto, 0);
}
#[test]
fn nats_config_debug_clone_default() {
let cfg = NatsConfig::default();
let cloned = cfg.clone();
assert_eq!(cloned.host, "127.0.0.1");
assert_eq!(cloned.port, 4222);
assert!(!cloned.verbose);
assert!(!cloned.pedantic);
let dbg = format!("{cfg:?}");
assert!(dbg.contains("NatsConfig"));
}
#[test]
fn server_info_debug_clone_default() {
let info = ServerInfo::default();
assert!(info.server_id.is_empty());
assert_eq!(info.proto, 0);
assert!(!info.tls_required);
let cloned = info.clone();
assert_eq!(cloned.max_payload, 0);
let dbg = format!("{info:?}");
assert!(dbg.contains("ServerInfo"));
}
#[test]
fn test_max_read_buffer_constant() {
assert_eq!(DEFAULT_MAX_READ_BUFFER, 8 * 1024 * 1024);
}
#[test]
fn test_read_buffer_rejects_oversized() {
let mut buf = NatsReadBuffer::new();
let big = vec![0u8; DEFAULT_MAX_READ_BUFFER + 1];
let err = buf.extend(&big).unwrap_err();
assert!(matches!(err, NatsError::Protocol(_)));
}
#[test]
fn test_read_buffer_accepts_max() {
let mut buf = NatsReadBuffer::new();
let data = vec![0u8; DEFAULT_MAX_READ_BUFFER];
buf.extend(&data).unwrap();
assert_eq!(buf.available().len(), DEFAULT_MAX_READ_BUFFER);
}
#[test]
fn test_read_buffer_consumed_data_not_counted() {
let mut buf = NatsReadBuffer::new();
let data = vec![0u8; DEFAULT_MAX_READ_BUFFER - 100];
buf.extend(&data).unwrap();
buf.consume(DEFAULT_MAX_READ_BUFFER - 200);
let more = vec![0u8; 200];
buf.extend(&more).unwrap();
}
#[test]
fn test_read_buffer_consume_clamps_when_over_consumed() {
let mut buf = NatsReadBuffer::new();
buf.extend(b"abc").unwrap();
buf.consume(usize::MAX);
assert!(buf.available().is_empty());
buf.extend(b"xy").unwrap();
assert_eq!(buf.available(), b"xy");
}
#[test]
fn test_config_max_payload_default() {
let config = NatsConfig::default();
assert_eq!(config.max_payload, 1_048_576);
}
#[test]
fn test_server_info_parse_max_payload() {
let json = r#"{"max_payload":524288}"#;
let info = ServerInfo::parse(json);
assert_eq!(info.max_payload, 524_288);
}
#[test]
fn test_validate_nats_token_accepts_valid() {
assert!(validate_nats_token("foo.bar.>", "subject").is_ok());
assert!(validate_nats_token("*", "subject").is_ok());
assert!(validate_nats_token("_INBOX.123.abc", "subject").is_ok());
}
#[test]
fn test_validate_nats_token_rejects_empty() {
assert!(validate_nats_token("", "subject").is_err());
}
#[test]
fn test_validate_nats_token_rejects_newline_injection() {
assert!(validate_nats_token("foo\r\nPUB evil 0\r\n", "subject").is_err());
}
#[test]
fn test_validate_nats_token_rejects_tab() {
assert!(validate_nats_token("foo\tbar", "queue").is_err());
}
#[test]
fn test_nats_json_escape_empty() {
assert_eq!(nats_json_escape(""), "");
}
#[test]
fn test_nats_json_escape_tab_and_cr() {
assert_eq!(nats_json_escape("\t"), "\\t");
assert_eq!(nats_json_escape("\r"), "\\r");
}
#[test]
fn test_extract_json_string_with_escape() {
let json = r#"{"key":"val\"ue"}"#;
assert_eq!(
extract_json_string(json, "key"),
Some("val\"ue".to_string())
);
}
#[test]
fn test_extract_json_i64_negative() {
let json = r#"{"val":-42}"#;
assert_eq!(extract_json_i64(json, "val"), Some(-42));
}
#[test]
fn test_extract_json_bool_missing() {
let json = r#"{"other":42}"#;
assert_eq!(extract_json_bool(json, "missing"), None);
}
#[test]
fn test_config_from_url_ipv6_default_port() {
let config = NatsConfig::from_url("nats://[::1]").unwrap();
assert_eq!(config.host, "[::1]");
assert_eq!(config.port, 4222);
}
#[test]
fn test_config_from_url_ipv6_invalid() {
let result = NatsConfig::from_url("nats://[::1");
assert!(matches!(result, Err(NatsError::InvalidUrl(_))));
}
#[test]
fn handle_pending_messages_propagates_server_error() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind test listener");
let addr = listener.local_addr().expect("listener addr");
let server = thread::spawn(move || {
let (stream, _) = listener.accept().expect("accept client");
drop(stream);
});
run_test_with_cx(|cx| async move {
let stream = TcpStream::connect(format!("{addr}"))
.await
.expect("connect client");
let mut client = NatsClient {
config: NatsConfig::default(),
stream,
read_buf: NatsReadBuffer::new(),
state: Arc::new(SharedState::new()),
next_sid: AtomicU64::new(1),
connected: true,
};
client
.read_buf
.extend(b"-ERR 'Permissions Violation'\r\n")
.expect("buffer server error");
let err = client
.handle_pending_messages(&cx)
.await
.expect_err("server -ERR must propagate as error");
assert!(
matches!(&err, NatsError::Server(msg) if msg.contains("Permissions Violation")),
"expected server error with permissions message, got {err:?}"
);
});
server.join().expect("server join");
}
}