ants 0.1.0-alpha.2

Low-level async NATS.io client
Documentation
use core::future::Future;

use std::collections::HashMap;
use std::sync::Arc;

use anyhow::{bail, Context as _};
use bytes::Bytes;
use futures::{SinkExt as _, StreamExt as _};
use protocol::{ClientOp, ClientOpEncoder, ConnectOptions, InfoOptions, ServerOp, ServerOpDecoder};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::ToSocketAddrs;
use tokio::select;
use tokio::sync::{mpsc, oneshot, watch};
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::{debug, warn};

pub mod protocol;

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Message {
    pub subject: Bytes,
    pub reply: Bytes,
    pub headers: Bytes,
    pub payload: Bytes,
}

pub enum Command {
    Connect(ConnectOptions),
    Publish {
        subject: Bytes,
        reply: Bytes,
        headers: Bytes,
        payload: Bytes,
    },
    Subscribe {
        subject: Bytes,
        group: Bytes,
        sid: Bytes,
        tx: mpsc::Sender<Message>,
    },
    ServerInfo(oneshot::Sender<watch::Receiver<Arc<InfoOptions>>>),
    Batch(Box<[Command]>),
}

pub struct Conn<I, O> {
    rx: FramedRead<I, ServerOpDecoder>,
    tx: FramedWrite<O, ClientOpEncoder>,
    info: watch::Sender<Arc<InfoOptions>>,
}

impl Conn<OwnedReadHalf, OwnedWriteHalf> {
    pub async fn connect_tcp(addr: impl ToSocketAddrs) -> anyhow::Result<(Self, Arc<InfoOptions>)> {
        let (tx, mut rx) = protocol::connect_tcp(addr).await?;
        let op = rx
            .next()
            .await
            .context("server stream unexpectedly finished")?;
        let info = match op.context("failed to receive server info")? {
            ServerOp::Info(opts) => opts,
            _ => bail!("server did not send INFO"),
        };
        let info = Arc::new(info);
        let (info_tx, _) = watch::channel(Arc::clone(&info));
        Ok((
            Conn {
                rx,
                tx,
                info: info_tx,
            },
            info,
        ))
    }
}

impl<I, O> Conn<I, O>
where
    I: AsyncRead + Unpin,
    O: AsyncWrite + Unpin,
{
    async fn handle_op(
        &mut self,
        subs: &mut HashMap<Bytes, mpsc::Sender<Message>>,
        op: ServerOp,
    ) -> anyhow::Result<bool> {
        match op {
            ServerOp::Info(opts) => {
                self.info.send_replace(Arc::new(opts));
                Ok(false)
            }
            ServerOp::Msg {
                subject,
                sid,
                reply,
                payload,
            } => {
                let tx = subs
                    .get(&sid)
                    .with_context(|| format!("received MSG for unknown sid: {sid:?}"))?;
                if let Err(_) = tx
                    .send(Message {
                        subject,
                        reply,
                        headers: Bytes::default(),
                        payload,
                    })
                    .await
                {
                    debug!(?sid, "remove unused subscription");
                    subs.remove(&sid);
                }
                Ok(false)
            }
            ServerOp::Hmsg {
                subject,
                sid,
                reply,
                headers,
                payload,
            } => {
                let tx = subs
                    .get(&sid)
                    .with_context(|| format!("received HMSG for unknown sid: {sid:?}"))?;
                if let Err(_) = tx
                    .send(Message {
                        subject,
                        reply,
                        headers,
                        payload,
                    })
                    .await
                {
                    debug!(?sid, "remove unused subscription");
                    subs.remove(&sid);
                }
                Ok(false)
            }
            ServerOp::Ping => {
                self.tx.feed(ClientOp::Pong).await?;
                Ok(true)
            }
            ServerOp::Pong => Ok(false),
            ServerOp::Ok => Ok(false),
            ServerOp::Err => {
                // TODO: Implement
                bail!("received an error")
            }
        }
    }

    async fn handle_command(
        &mut self,
        subs: &mut HashMap<Bytes, mpsc::Sender<Message>>,
        cmd: Command,
    ) -> anyhow::Result<bool> {
        match cmd {
            Command::ServerInfo(tx) => {
                if let Err(_) = tx.send(self.info.subscribe()) {
                    warn!("server info receiver dropped");
                }
                Ok(false)
            }
            Command::Connect(opts) => {
                self.tx.feed(ClientOp::Connect(opts)).await?;
                Ok(true)
            }
            Command::Subscribe {
                subject,
                group,
                sid,
                tx,
            } => {
                self.tx
                    .feed(ClientOp::Sub {
                        subject,
                        group,
                        sid: sid.clone(),
                    })
                    .await?;
                subs.insert(sid, tx);
                Ok(true)
            }
            Command::Publish {
                subject,
                reply,
                headers,
                payload,
            } => {
                let op = if headers.is_empty() {
                    ClientOp::Pub {
                        subject,
                        reply,
                        payload,
                    }
                } else {
                    ClientOp::Hpub {
                        subject,
                        reply,
                        headers,
                        payload,
                    }
                };
                self.tx.feed(op).await?;
                Ok(true)
            }
            Command::Batch(cmds) => {
                let mut has_data = false;
                for cmd in cmds {
                    if Box::pin(self.handle_command(subs, cmd)).await? {
                        has_data = true;
                    }
                }
                Ok(has_data)
            }
        }
    }

    async fn run(&mut self, cmds: &mut mpsc::Receiver<Command>) -> anyhow::Result<()> {
        let mut subs = HashMap::default();
        let mut has_data = false;
        loop {
            select! {
                Some(op) = self.rx.next() => {
                    let op = op?;
                    if self.handle_op(&mut subs, op).await? {
                        has_data = true;
                    }
                },
                Some(cmd) = cmds.recv() => {
                    if self.handle_command(&mut subs, cmd).await? {
                        has_data = true;
                    }
                },
                res = self.tx.flush(), if has_data => {
                    res?;
                    has_data = false;
                }
            };
        }
    }
}

pub trait Connector {
    type Error: core::error::Error + Send + Sync + 'static;

    fn connect(
        self,
        opts: &InfoOptions,
    ) -> impl Future<Output = Result<ConnectOptions, Self::Error>>;
}

#[derive(Clone, Debug)]
pub struct Client {
    commands: mpsc::Sender<Command>,
}

impl Client {
    pub async fn new(
        mut conn: Conn<impl AsyncRead + Unpin, impl AsyncWrite + Unpin>,
        opts: ConnectOptions,
    ) -> anyhow::Result<(Self, impl Future<Output = anyhow::Result<()>>)> {
        let (cmd_tx, mut cmd_rx) = mpsc::channel(8192);
        conn.tx
            .feed(ClientOp::Connect(opts))
            .await
            .context("failed to send connect operation")?;
        Ok((Self { commands: cmd_tx }, async move {
            conn.run(&mut cmd_rx).await
        }))
    }

    #[must_use]
    pub fn commands(&self) -> &mpsc::Sender<Command> {
        &self.commands
    }

    #[must_use]
    pub async fn server_info(&self) -> Option<watch::Receiver<Arc<InfoOptions>>> {
        let (tx, rx) = oneshot::channel();
        self.commands.send(Command::ServerInfo(tx)).await.ok()?;
        rx.await.ok()
    }

    pub async fn subscribe(
        &self,
        subject: impl Into<Bytes>,
        group: impl Into<Bytes>,
        sid: impl Into<Bytes>,
        tx: mpsc::Sender<Message>,
    ) -> Result<(), mpsc::error::SendError<Command>> {
        self.commands
            .send(Command::Subscribe {
                subject: subject.into(),
                group: group.into(),
                sid: sid.into(),
                tx,
            })
            .await
    }

    pub async fn publish(
        &self,
        subject: impl Into<Bytes>,
        reply: impl Into<Bytes>,
        headers: impl Into<Bytes>,
        payload: impl Into<Bytes>,
    ) -> Result<(), mpsc::error::SendError<Command>> {
        self.commands
            .send(Command::Publish {
                subject: subject.into(),
                reply: reply.into(),
                headers: headers.into(),
                payload: payload.into(),
            })
            .await
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test_log::test(tokio::test)]
    async fn connect() -> anyhow::Result<()> {
        let (conn, _) = Conn::connect_tcp("localhost:4222").await?;
        let (clt, io) = Client::new(conn, ConnectOptions::default()).await?;
        let io = tokio::spawn(io);
        let info = clt
            .server_info()
            .await
            .context("failed to get server info")?;
        let info = Arc::clone(&info.borrow());
        eprintln!("info: {info:?}");
        let (tx, mut rx) = mpsc::channel(128);
        let subject = Bytes::from("test");
        clt.subscribe(subject.clone(), "", "0", tx)
            .await
            .context("failed to subscribe")?;
        let cmds = clt.commands();
        clt.publish(subject.clone(), "", "", "hello")
            .await
            .context("failed to publish `hello`")?;
        cmds.send(Command::Publish {
            subject: subject.clone(),
            headers: Bytes::default(),
            reply: Bytes::default(),
            payload: Bytes::from("bye\r\n\r\n"),
        })
        .await
        .context("failed to publish `bye`")?;
        let msg = rx.recv().await.context("failed to receive `hello`")?;
        assert_eq!(
            msg,
            Message {
                subject: subject.clone(),
                reply: Bytes::default(),
                headers: Bytes::default(),
                payload: Bytes::from("hello"),
            }
        );
        let msg = rx.recv().await.context("failed to receive `bye`")?;
        assert_eq!(
            msg,
            Message {
                subject: subject.clone(),
                reply: Bytes::default(),
                headers: Bytes::default(),
                payload: Bytes::from("bye\r\n\r\n"),
            }
        );
        io.abort();
        Ok(())
    }
}