use std::convert::TryFrom;
use std::{
io::{self, Error, ErrorKind},
str::{self, FromStr},
};
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::{connect::ConnectInfo, header::HeaderMap, inject_io_failure, ServerInfo};
#[derive(Debug)]
pub(crate) enum ServerOp {
Info(ServerInfo),
Msg {
subject: String,
sid: u64,
reply_to: Option<String>,
payload: Vec<u8>,
},
Hmsg {
subject: String,
headers: HeaderMap,
sid: u64,
reply_to: Option<String>,
payload: Vec<u8>,
},
Ping,
Pong,
Err(String),
Unknown(String),
}
async fn read_line<R: AsyncBufReadExt + ?Sized + std::marker::Unpin>(
r: &mut R,
buf: &mut [u8],
) -> io::Result<usize> {
let mut read = 0;
loop {
let available = match r.fill_buf().await {
Ok(n) => n,
Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
};
let (done, len) = {
if let Some(i) = memchr::memchr(b'\n', available) {
(true, i + 1)
} else {
buf[read..read + available.len()].copy_from_slice(available);
(false, available.len())
}
};
if len + read > buf.len() {
return Err(Error::new(
ErrorKind::InvalidInput,
"received command exceeded fixed command buffer",
));
}
buf[read..read + len].copy_from_slice(&available[..len]);
r.consume(len);
read += len;
if done || len == 0 {
return Ok(read);
}
}
}
pub(crate) async fn decode(
mut stream: impl AsyncBufRead + std::marker::Unpin,
) -> io::Result<Option<ServerOp>> {
inject_io_failure()?;
#[allow(unsafe_code)]
#[allow(clippy::uninit_assumed_init)]
let mut command_buf: [u8; 4096] = unsafe { std::mem::MaybeUninit::uninit().assume_init() };
let command_len = read_line(&mut stream, &mut command_buf).await?;
if command_len == 0 {
return Ok(None);
}
let line = str::from_utf8(&command_buf[..command_len])
.map_err(|err| Error::new(ErrorKind::InvalidInput, err))?;
let op = line
.split_ascii_whitespace()
.next()
.unwrap_or("")
.to_ascii_uppercase();
if op == "PING" {
return Ok(Some(ServerOp::Ping));
}
if op == "PONG" {
return Ok(Some(ServerOp::Pong));
}
if op == "INFO" {
let server_info = ServerInfo::parse(&line["INFO".len()..])
.ok_or_else(|| Error::new(ErrorKind::InvalidInput, "cannot parse server info"))?;
return Ok(Some(ServerOp::Info(server_info)));
}
if op == "MSG" {
let args = line["MSG".len()..]
.split_whitespace()
.filter(|s| !s.is_empty());
let args = args.collect::<Vec<_>>();
let (subject, sid, reply_to, num_bytes) = match args[..] {
[subject, sid, num_bytes] => (subject, sid, None, num_bytes),
[subject, sid, reply_to, num_bytes] => (subject, sid, Some(reply_to), num_bytes),
_ => {
return Err(Error::new(
ErrorKind::InvalidInput,
"invalid number of arguments after MSG",
));
}
};
let subject = subject.to_string();
let sid = u64::from_str(sid).map_err(|_| {
Error::new(
ErrorKind::InvalidInput,
"cannot parse sid argument after MSG",
)
})?;
let reply_to = reply_to.map(ToString::to_string);
let num_bytes = u32::from_str(num_bytes).map_err(|_| {
Error::new(
ErrorKind::InvalidInput,
"cannot parse the number of bytes argument after MSG",
)
})?;
let mut payload = Vec::new();
payload.resize(num_bytes as usize, 0_u8);
stream.read_exact(&mut payload[..]).await?;
stream.read_exact(&mut [0_u8; 2]).await?;
return Ok(Some(ServerOp::Msg {
subject,
sid,
reply_to,
payload,
}));
}
if op == "HMSG" {
let args = line["HMSG".len()..]
.split_whitespace()
.filter(|s| !s.is_empty());
let args = args.collect::<Vec<_>>();
let (subject, sid, reply_to, num_header_bytes, num_bytes) = match args[..] {
[subject, sid, num_header_bytes, num_bytes] => {
(subject, sid, None, num_header_bytes, num_bytes)
}
[subject, sid, reply_to, num_header_bytes, num_bytes] => {
(subject, sid, Some(reply_to), num_header_bytes, num_bytes)
}
_ => {
return Err(Error::new(
ErrorKind::InvalidInput,
"invalid number of arguments after HMSG",
));
}
};
let subject = subject.to_string();
let sid = u64::from_str(sid).map_err(|_| {
Error::new(
ErrorKind::InvalidInput,
"cannot parse sid argument after HMSG",
)
})?;
let reply_to = reply_to.map(ToString::to_string);
let num_header_bytes = u32::from_str(num_header_bytes).map_err(|_| {
Error::new(
ErrorKind::InvalidInput,
"cannot parse the number of header bytes argument after \
HMSG",
)
})?;
let num_bytes = u32::from_str(num_bytes).map_err(|_| {
Error::new(
ErrorKind::InvalidInput,
"cannot parse the number of bytes argument after HMSG",
)
})?;
if num_bytes < num_header_bytes {
return Err(Error::new(
ErrorKind::InvalidInput,
"number of header bytes was greater than or equal to the \
total number of bytes after HMSG",
));
}
let num_payload_bytes = num_bytes - num_header_bytes;
let mut header_payload = Vec::new();
header_payload.resize(num_header_bytes as usize, 0_u8);
stream.read_exact(&mut header_payload[..]).await?;
let headers = HeaderMap::try_from(&*header_payload)?;
let mut payload = Vec::new();
payload.resize(num_payload_bytes as usize, 0_u8);
stream.read_exact(&mut payload[..]).await?;
stream.read_exact(&mut [0_u8; 2]).await?;
return Ok(Some(ServerOp::Hmsg {
subject,
headers,
sid,
reply_to,
payload,
}));
}
if op == "-ERR" {
let msg = line["-ERR".len()..].trim().trim_matches('\'').to_string();
return Ok(Some(ServerOp::Err(msg)));
}
Ok(Some(ServerOp::Unknown(line.to_owned())))
}
#[derive(Clone, Copy, Debug)]
pub(crate) enum ClientOp<'a> {
Connect(&'a ConnectInfo),
Pub {
subject: &'a str,
reply_to: Option<&'a str>,
payload: &'a [u8],
},
Hpub {
subject: &'a str,
reply_to: Option<&'a str>,
headers: &'a HeaderMap,
payload: &'a [u8],
},
Sub {
subject: &'a str,
queue_group: Option<&'a str>,
sid: u64,
},
Unsub { sid: u64, max_msgs: Option<u64> },
Ping,
Pong,
}
pub(crate) async fn encode(
mut stream: impl AsyncWrite + std::marker::Unpin,
op: ClientOp<'_>,
) -> io::Result<()> {
match &op {
ClientOp::Connect(connect_info) => {
let op = format!(
"CONNECT {}\r\n",
connect_info.dump().ok_or_else(|| Error::new(
ErrorKind::InvalidData,
"cannot serialize connect info"
))?
);
stream.write_all(op.as_bytes()).await?;
}
ClientOp::Pub {
subject,
reply_to,
payload,
} => {
stream.write_all(b"PUB ").await?;
stream.write_all(subject.as_bytes()).await?;
stream.write_all(b" ").await?;
if let Some(reply_to) = reply_to {
stream.write_all(reply_to.as_bytes()).await?;
stream.write_all(b" ").await?;
}
let mut buf = itoa::Buffer::new();
stream
.write_all(buf.format(payload.len()).as_bytes())
.await?;
stream.write_all(b"\r\n").await?;
stream.write_all(payload).await?;
stream.write_all(b"\r\n").await?;
}
ClientOp::Hpub {
subject,
reply_to,
headers,
payload,
} => {
stream.write_all(b"HPUB ").await?;
stream.write_all(subject.as_bytes()).await?;
stream.write_all(b" ").await?;
if let Some(reply_to) = reply_to {
stream.write_all(reply_to.as_bytes()).await?;
stream.write_all(b" ").await?;
}
let header_bytes = headers.to_bytes();
let header_len = header_bytes.len();
let total_len = header_len + payload.len();
let mut hlen_buf = itoa::Buffer::new();
stream
.write_all(hlen_buf.format(header_len).as_bytes())
.await?;
stream.write_all(b" ").await?;
let mut tlen_buf = itoa::Buffer::new();
stream
.write_all(tlen_buf.format(total_len).as_bytes())
.await?;
stream.write_all(b"\r\n").await?;
stream.write_all(&header_bytes).await?;
stream.write_all(payload).await?;
stream.write_all(b"\r\n").await?;
}
ClientOp::Sub {
subject,
queue_group,
sid,
} => {
let op = if let Some(queue_group) = queue_group {
format!("SUB {} {} {}\r\n", subject, queue_group, sid)
} else {
format!("SUB {} {}\r\n", subject, sid)
};
stream.write_all(op.as_bytes()).await?;
}
ClientOp::Unsub { sid, max_msgs } => {
let op = if let Some(max_msgs) = max_msgs {
format!("UNSUB {} {}\r\n", sid, max_msgs)
} else {
format!("UNSUB {}\r\n", sid)
};
stream.write_all(op.as_bytes()).await?;
}
ClientOp::Ping => {
stream.write_all(b"PING\r\n").await?;
}
ClientOp::Pong => {
stream.write_all(b"PONG\r\n").await?;
}
}
Ok(())
}