use core::future::Future;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{bail, Context as _};
use bytes::Bytes;
use futures::{SinkExt as _, StreamExt as _};
use protocol::{ClientOp, ClientOpEncoder, ConnectOptions, InfoOptions, ServerOp, ServerOpDecoder};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::ToSocketAddrs;
use tokio::select;
use tokio::sync::{mpsc, oneshot, watch};
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::{debug, warn};
pub mod protocol;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Message {
pub subject: Bytes,
pub reply: Bytes,
pub headers: Bytes,
pub payload: Bytes,
}
pub enum Command {
Connect(ConnectOptions),
Publish {
subject: Bytes,
reply: Bytes,
headers: Bytes,
payload: Bytes,
},
Subscribe {
subject: Bytes,
group: Bytes,
sid: Bytes,
tx: mpsc::Sender<Message>,
},
ServerInfo(oneshot::Sender<watch::Receiver<Arc<InfoOptions>>>),
Batch(Box<[Command]>),
}
pub struct Conn<I, O> {
rx: FramedRead<I, ServerOpDecoder>,
tx: FramedWrite<O, ClientOpEncoder>,
info: watch::Sender<Arc<InfoOptions>>,
}
impl Conn<OwnedReadHalf, OwnedWriteHalf> {
pub async fn connect_tcp(addr: impl ToSocketAddrs) -> anyhow::Result<(Self, Arc<InfoOptions>)> {
let (tx, mut rx) = protocol::connect_tcp(addr).await?;
let op = rx
.next()
.await
.context("server stream unexpectedly finished")?;
let info = match op.context("failed to receive server info")? {
ServerOp::Info(opts) => opts,
_ => bail!("server did not send INFO"),
};
let info = Arc::new(info);
let (info_tx, _) = watch::channel(Arc::clone(&info));
Ok((
Conn {
rx,
tx,
info: info_tx,
},
info,
))
}
}
impl<I, O> Conn<I, O>
where
I: AsyncRead + Unpin,
O: AsyncWrite + Unpin,
{
async fn handle_op(
&mut self,
subs: &mut HashMap<Bytes, mpsc::Sender<Message>>,
op: ServerOp,
) -> anyhow::Result<bool> {
match op {
ServerOp::Info(opts) => {
self.info.send_replace(Arc::new(opts));
Ok(false)
}
ServerOp::Msg {
subject,
sid,
reply,
payload,
} => {
let tx = subs
.get(&sid)
.with_context(|| format!("received MSG for unknown sid: {sid:?}"))?;
if let Err(_) = tx
.send(Message {
subject,
reply,
headers: Bytes::default(),
payload,
})
.await
{
debug!(?sid, "remove unused subscription");
subs.remove(&sid);
}
Ok(false)
}
ServerOp::Hmsg {
subject,
sid,
reply,
headers,
payload,
} => {
let tx = subs
.get(&sid)
.with_context(|| format!("received HMSG for unknown sid: {sid:?}"))?;
if let Err(_) = tx
.send(Message {
subject,
reply,
headers,
payload,
})
.await
{
debug!(?sid, "remove unused subscription");
subs.remove(&sid);
}
Ok(false)
}
ServerOp::Ping => {
self.tx.feed(ClientOp::Pong).await?;
Ok(true)
}
ServerOp::Pong => Ok(false),
ServerOp::Ok => Ok(false),
ServerOp::Err => {
bail!("received an error")
}
}
}
async fn handle_command(
&mut self,
subs: &mut HashMap<Bytes, mpsc::Sender<Message>>,
cmd: Command,
) -> anyhow::Result<bool> {
match cmd {
Command::ServerInfo(tx) => {
if let Err(_) = tx.send(self.info.subscribe()) {
warn!("server info receiver dropped");
}
Ok(false)
}
Command::Connect(opts) => {
self.tx.feed(ClientOp::Connect(opts)).await?;
Ok(true)
}
Command::Subscribe {
subject,
group,
sid,
tx,
} => {
self.tx
.feed(ClientOp::Sub {
subject,
group,
sid: sid.clone(),
})
.await?;
subs.insert(sid, tx);
Ok(true)
}
Command::Publish {
subject,
reply,
headers,
payload,
} => {
let op = if headers.is_empty() {
ClientOp::Pub {
subject,
reply,
payload,
}
} else {
ClientOp::Hpub {
subject,
reply,
headers,
payload,
}
};
self.tx.feed(op).await?;
Ok(true)
}
Command::Batch(cmds) => {
let mut has_data = false;
for cmd in cmds {
if Box::pin(self.handle_command(subs, cmd)).await? {
has_data = true;
}
}
Ok(has_data)
}
}
}
async fn run(&mut self, cmds: &mut mpsc::Receiver<Command>) -> anyhow::Result<()> {
let mut subs = HashMap::default();
let mut has_data = false;
loop {
select! {
Some(op) = self.rx.next() => {
let op = op?;
if self.handle_op(&mut subs, op).await? {
has_data = true;
}
},
Some(cmd) = cmds.recv() => {
if self.handle_command(&mut subs, cmd).await? {
has_data = true;
}
},
res = self.tx.flush(), if has_data => {
res?;
has_data = false;
}
};
}
}
}
pub trait Connector {
type Error: core::error::Error + Send + Sync + 'static;
fn connect(
self,
opts: &InfoOptions,
) -> impl Future<Output = Result<ConnectOptions, Self::Error>>;
}
#[derive(Clone, Debug)]
pub struct Client {
commands: mpsc::Sender<Command>,
}
impl Client {
pub async fn new(
mut conn: Conn<impl AsyncRead + Unpin, impl AsyncWrite + Unpin>,
opts: ConnectOptions,
) -> anyhow::Result<(Self, impl Future<Output = anyhow::Result<()>>)> {
let (cmd_tx, mut cmd_rx) = mpsc::channel(8192);
conn.tx
.feed(ClientOp::Connect(opts))
.await
.context("failed to send connect operation")?;
Ok((Self { commands: cmd_tx }, async move {
conn.run(&mut cmd_rx).await
}))
}
#[must_use]
pub fn commands(&self) -> &mpsc::Sender<Command> {
&self.commands
}
#[must_use]
pub async fn server_info(&self) -> Option<watch::Receiver<Arc<InfoOptions>>> {
let (tx, rx) = oneshot::channel();
self.commands.send(Command::ServerInfo(tx)).await.ok()?;
rx.await.ok()
}
pub async fn subscribe(
&self,
subject: impl Into<Bytes>,
group: impl Into<Bytes>,
sid: impl Into<Bytes>,
tx: mpsc::Sender<Message>,
) -> Result<(), mpsc::error::SendError<Command>> {
self.commands
.send(Command::Subscribe {
subject: subject.into(),
group: group.into(),
sid: sid.into(),
tx,
})
.await
}
pub async fn publish(
&self,
subject: impl Into<Bytes>,
reply: impl Into<Bytes>,
headers: impl Into<Bytes>,
payload: impl Into<Bytes>,
) -> Result<(), mpsc::error::SendError<Command>> {
self.commands
.send(Command::Publish {
subject: subject.into(),
reply: reply.into(),
headers: headers.into(),
payload: payload.into(),
})
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test_log::test(tokio::test)]
async fn connect() -> anyhow::Result<()> {
let (conn, _) = Conn::connect_tcp("localhost:4222").await?;
let (clt, io) = Client::new(conn, ConnectOptions::default()).await?;
let io = tokio::spawn(io);
let info = clt
.server_info()
.await
.context("failed to get server info")?;
let info = Arc::clone(&info.borrow());
eprintln!("info: {info:?}");
let (tx, mut rx) = mpsc::channel(128);
let subject = Bytes::from("test");
clt.subscribe(subject.clone(), "", "0", tx)
.await
.context("failed to subscribe")?;
let cmds = clt.commands();
clt.publish(subject.clone(), "", "", "hello")
.await
.context("failed to publish `hello`")?;
cmds.send(Command::Publish {
subject: subject.clone(),
headers: Bytes::default(),
reply: Bytes::default(),
payload: Bytes::from("bye\r\n\r\n"),
})
.await
.context("failed to publish `bye`")?;
let msg = rx.recv().await.context("failed to receive `hello`")?;
assert_eq!(
msg,
Message {
subject: subject.clone(),
reply: Bytes::default(),
headers: Bytes::default(),
payload: Bytes::from("hello"),
}
);
let msg = rx.recv().await.context("failed to receive `bye`")?;
assert_eq!(
msg,
Message {
subject: subject.clone(),
reply: Bytes::default(),
headers: Bytes::default(),
payload: Bytes::from("bye\r\n\r\n"),
}
);
io.abort();
Ok(())
}
}