use core::future::Future;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{bail, Context as _};
use bytes::Bytes;
use futures::{SinkExt as _, StreamExt as _};
use protocol::{ClientOp, ClientOpEncoder, ConnectOptions, InfoOptions, ServerOp, ServerOpDecoder};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::ToSocketAddrs;
use tokio::select;
use tokio::sync::{mpsc, oneshot, watch};
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::{debug, warn};
pub mod protocol;
#[derive(Debug, Eq, PartialEq)]
pub struct Message {
subject: Bytes,
respond: Bytes,
headers: Bytes,
payload: Bytes,
}
pub enum Command {
Connect(ConnectOptions),
Publish {
subject: Bytes,
respond: Bytes,
headers: Bytes,
payload: Bytes,
},
Subscribe {
subject: Bytes,
group: Bytes,
sid: Bytes,
tx: mpsc::Sender<Message>,
},
ServerInfo(oneshot::Sender<watch::Receiver<Arc<InfoOptions>>>),
Batch(Box<[Command]>),
}
pub struct Conn {
tx: FramedWrite<OwnedWriteHalf, ClientOpEncoder>,
rx: FramedRead<OwnedReadHalf, ServerOpDecoder>,
info: watch::Sender<Arc<InfoOptions>>,
}
impl Conn {
pub async fn connect(addr: impl ToSocketAddrs) -> anyhow::Result<Self> {
let (tx, mut rx) = protocol::connect_tcp(addr).await?;
let op = rx
.next()
.await
.context("server stream unexpectedly finished")?;
let info = match op.context("failed to receive server info")? {
ServerOp::Info(opts) => opts,
_ => bail!("server did not send INFO"),
};
let (info, _) = watch::channel(Arc::new(info));
Ok(Self { tx, rx, info })
}
async fn handle_op(
&mut self,
subs: &mut HashMap<Bytes, mpsc::Sender<Message>>,
op: ServerOp,
) -> anyhow::Result<bool> {
match op {
ServerOp::Info(opts) => {
self.info.send_replace(Arc::new(opts));
Ok(false)
}
ServerOp::Msg {
subject,
sid,
respond,
payload,
} => {
let tx = subs
.get(&sid)
.with_context(|| format!("received MSG for unknown sid: {sid:?}"))?;
if let Err(_) = tx
.send(Message {
subject,
respond,
headers: Bytes::default(),
payload,
})
.await
{
debug!(?sid, "remove unused subscription");
subs.remove(&sid);
}
Ok(false)
}
ServerOp::Hmsg {
subject,
sid,
respond,
headers,
payload,
} => {
let tx = subs
.get(&sid)
.with_context(|| format!("received HMSG for unknown sid: {sid:?}"))?;
if let Err(_) = tx
.send(Message {
subject,
respond,
headers,
payload,
})
.await
{
debug!(?sid, "remove unused subscription");
subs.remove(&sid);
}
Ok(false)
}
ServerOp::Ping => {
self.tx.feed(ClientOp::Pong).await?;
Ok(true)
}
ServerOp::Pong => Ok(false),
ServerOp::Ok => Ok(false),
ServerOp::Err => {
bail!("received an error")
}
}
}
async fn handle_command(
&mut self,
subs: &mut HashMap<Bytes, mpsc::Sender<Message>>,
cmd: Command,
) -> anyhow::Result<bool> {
match cmd {
Command::ServerInfo(tx) => {
if let Err(_) = tx.send(self.info.subscribe()) {
warn!("server info receiver dropped");
}
Ok(false)
}
Command::Connect(opts) => {
self.tx.feed(ClientOp::Connect(opts)).await?;
Ok(true)
}
Command::Subscribe {
subject,
group,
sid,
tx,
} => {
self.tx
.feed(ClientOp::Sub {
subject,
group,
sid: sid.clone(),
})
.await?;
subs.insert(sid, tx);
Ok(true)
}
Command::Publish {
subject,
respond,
headers,
payload,
} => {
let op = if headers.is_empty() {
ClientOp::Pub {
subject,
respond,
payload,
}
} else {
ClientOp::Hpub {
subject,
respond,
headers,
payload,
}
};
self.tx.feed(op).await?;
Ok(true)
}
Command::Batch(cmds) => {
let mut has_data = false;
for cmd in cmds {
if Box::pin(self.handle_command(subs, cmd)).await? {
has_data = true
}
}
Ok(has_data)
}
}
}
pub async fn run(&mut self, cmds: &mut mpsc::Receiver<Command>) -> anyhow::Result<()> {
let mut subs = HashMap::default();
let mut has_data = false;
loop {
select! {
Some(op) = self.rx.next() => {
let op = op?;
if self.handle_op(&mut subs, op).await? {
has_data = true;
}
},
Some(cmd) = cmds.recv() => {
if self.handle_command(&mut subs, cmd).await? {
has_data = true;
}
},
res = self.tx.flush(), if has_data => {
res?;
has_data = false;
}
};
}
}
}
pub trait Connector {
fn connect(&self, opts: &InfoOptions) -> impl Future<Output = Option<ConnectOptions>>;
}
pub struct DefaultConnector;
impl Connector for DefaultConnector {
async fn connect(&self, _: &InfoOptions) -> Option<ConnectOptions> {
Some(ConnectOptions::default())
}
}
impl Connector for ConnectOptions {
async fn connect(&self, _: &InfoOptions) -> Option<ConnectOptions> {
Some(self.clone())
}
}
pub struct Client {
commands: mpsc::Sender<Command>,
}
impl Client {
pub async fn new(
addr: impl ToSocketAddrs + Clone + Send + 'static,
connector: impl Connector,
) -> anyhow::Result<(Self, impl Future<Output = anyhow::Result<()>>)> {
let (cmd_tx, mut cmd_rx) = mpsc::channel(8192);
let mut conn = Conn::connect(addr)
.await
.context("failed to connect to server")?;
let info = Arc::clone(&conn.info.subscribe().borrow());
let opts = connector
.connect(&info)
.await
.context("failed to construct connect options")?;
conn.tx
.feed(ClientOp::Connect(opts))
.await
.context("failed to send connect operation")?;
Ok((Self { commands: cmd_tx }, async move {
conn.run(&mut cmd_rx).await
}))
}
pub fn commands(&self) -> &mpsc::Sender<Command> {
&self.commands
}
pub async fn server_info(&self) -> Option<watch::Receiver<Arc<InfoOptions>>> {
let (tx, rx) = oneshot::channel();
self.commands.send(Command::ServerInfo(tx)).await.ok()?;
rx.await.ok()
}
pub async fn connect(
&self,
opts: ConnectOptions,
) -> Result<(), mpsc::error::SendError<Command>> {
self.commands.send(Command::Connect(opts)).await
}
pub async fn subscribe(
&self,
subject: Bytes,
group: Bytes,
sid: Bytes,
tx: mpsc::Sender<Message>,
) -> Result<(), mpsc::error::SendError<Command>> {
self.commands
.send(Command::Subscribe {
subject,
group,
sid,
tx,
})
.await
}
pub async fn publish(
&self,
subject: Bytes,
respond: Bytes,
headers: Bytes,
payload: Bytes,
) -> Result<(), mpsc::error::SendError<Command>> {
self.commands
.send(Command::Publish {
subject,
respond,
headers,
payload,
})
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test_log::test(tokio::test)]
async fn connect() -> anyhow::Result<()> {
let (clt, io) = Client::new("localhost:4222", ConnectOptions::default()).await?;
let io = tokio::spawn(io);
let info = clt
.server_info()
.await
.context("failed to get server info")?;
let info = Arc::clone(&info.borrow());
eprintln!("info: {info:?}");
clt.connect(ConnectOptions::default())
.await
.context("failed to connect")?;
let (tx, mut rx) = mpsc::channel(128);
let subject = Bytes::from("test");
clt.subscribe(subject.clone(), Bytes::default(), Bytes::from("0"), tx)
.await
.context("failed to subscribe")?;
let cmds = clt.commands();
clt.publish(
subject.clone(),
Bytes::default(),
Bytes::default(),
Bytes::from("hello"),
)
.await
.context("failed to publish `hello`")?;
cmds.send(Command::Publish {
subject: subject.clone(),
headers: Bytes::default(),
respond: Bytes::default(),
payload: Bytes::from("bye"),
})
.await
.context("failed to publish `bye`")?;
let msg = rx.recv().await.context("failed to receive `hello`")?;
assert_eq!(
msg,
Message {
subject: subject.clone(),
respond: Bytes::default(),
headers: Bytes::default(),
payload: Bytes::from("hello"),
}
);
let msg = rx.recv().await.context("failed to receive `bye`")?;
assert_eq!(
msg,
Message {
subject: subject.clone(),
respond: Bytes::default(),
headers: Bytes::default(),
payload: Bytes::from("bye"),
}
);
io.abort();
Ok(())
}
}