ants 0.1.0-alpha.1

Low-level async NATS.io client
Documentation
use core::future::Future;

use std::collections::HashMap;
use std::sync::Arc;

use anyhow::{bail, Context as _};
use bytes::Bytes;
use futures::{SinkExt as _, StreamExt as _};
use protocol::{ClientOp, ClientOpEncoder, ConnectOptions, InfoOptions, ServerOp, ServerOpDecoder};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::ToSocketAddrs;
use tokio::select;
use tokio::sync::{mpsc, oneshot, watch};
use tokio_util::codec::{FramedRead, FramedWrite};
use tracing::{debug, warn};

pub mod protocol;

#[derive(Debug, Eq, PartialEq)]
pub struct Message {
    subject: Bytes,
    respond: Bytes,
    headers: Bytes,
    payload: Bytes,
}

pub enum Command {
    Connect(ConnectOptions),
    Publish {
        subject: Bytes,
        respond: Bytes,
        headers: Bytes,
        payload: Bytes,
    },
    Subscribe {
        subject: Bytes,
        group: Bytes,
        sid: Bytes,
        tx: mpsc::Sender<Message>,
    },
    ServerInfo(oneshot::Sender<watch::Receiver<Arc<InfoOptions>>>),
    Batch(Box<[Command]>),
}

pub struct Conn {
    tx: FramedWrite<OwnedWriteHalf, ClientOpEncoder>,
    rx: FramedRead<OwnedReadHalf, ServerOpDecoder>,
    info: watch::Sender<Arc<InfoOptions>>,
}

impl Conn {
    pub async fn connect(addr: impl ToSocketAddrs) -> anyhow::Result<Self> {
        let (tx, mut rx) = protocol::connect_tcp(addr).await?;
        let op = rx
            .next()
            .await
            .context("server stream unexpectedly finished")?;
        let info = match op.context("failed to receive server info")? {
            ServerOp::Info(opts) => opts,
            _ => bail!("server did not send INFO"),
        };
        let (info, _) = watch::channel(Arc::new(info));
        Ok(Self { tx, rx, info })
    }

    async fn handle_op(
        &mut self,
        subs: &mut HashMap<Bytes, mpsc::Sender<Message>>,
        op: ServerOp,
    ) -> anyhow::Result<bool> {
        match op {
            ServerOp::Info(opts) => {
                self.info.send_replace(Arc::new(opts));
                Ok(false)
            }
            ServerOp::Msg {
                subject,
                sid,
                respond,
                payload,
            } => {
                let tx = subs
                    .get(&sid)
                    .with_context(|| format!("received MSG for unknown sid: {sid:?}"))?;
                if let Err(_) = tx
                    .send(Message {
                        subject,
                        respond,
                        headers: Bytes::default(),
                        payload,
                    })
                    .await
                {
                    debug!(?sid, "remove unused subscription");
                    subs.remove(&sid);
                }
                Ok(false)
            }
            ServerOp::Hmsg {
                subject,
                sid,
                respond,
                headers,
                payload,
            } => {
                let tx = subs
                    .get(&sid)
                    .with_context(|| format!("received HMSG for unknown sid: {sid:?}"))?;
                if let Err(_) = tx
                    .send(Message {
                        subject,
                        respond,
                        headers,
                        payload,
                    })
                    .await
                {
                    debug!(?sid, "remove unused subscription");
                    subs.remove(&sid);
                }
                Ok(false)
            }
            ServerOp::Ping => {
                self.tx.feed(ClientOp::Pong).await?;
                Ok(true)
            }
            ServerOp::Pong => Ok(false),
            ServerOp::Ok => Ok(false),
            ServerOp::Err => {
                // TODO: Implement
                bail!("received an error")
            }
        }
    }

    async fn handle_command(
        &mut self,
        subs: &mut HashMap<Bytes, mpsc::Sender<Message>>,
        cmd: Command,
    ) -> anyhow::Result<bool> {
        match cmd {
            Command::ServerInfo(tx) => {
                if let Err(_) = tx.send(self.info.subscribe()) {
                    warn!("server info receiver dropped");
                }
                Ok(false)
            }
            Command::Connect(opts) => {
                self.tx.feed(ClientOp::Connect(opts)).await?;
                Ok(true)
            }
            Command::Subscribe {
                subject,
                group,
                sid,
                tx,
            } => {
                self.tx
                    .feed(ClientOp::Sub {
                        subject,
                        group,
                        sid: sid.clone(),
                    })
                    .await?;
                subs.insert(sid, tx);
                Ok(true)
            }
            Command::Publish {
                subject,
                respond,
                headers,
                payload,
            } => {
                let op = if headers.is_empty() {
                    ClientOp::Pub {
                        subject,
                        respond,
                        payload,
                    }
                } else {
                    ClientOp::Hpub {
                        subject,
                        respond,
                        headers,
                        payload,
                    }
                };
                self.tx.feed(op).await?;
                Ok(true)
            }
            Command::Batch(cmds) => {
                let mut has_data = false;
                for cmd in cmds {
                    if Box::pin(self.handle_command(subs, cmd)).await? {
                        has_data = true
                    }
                }
                Ok(has_data)
            }
        }
    }

    pub async fn run(&mut self, cmds: &mut mpsc::Receiver<Command>) -> anyhow::Result<()> {
        let mut subs = HashMap::default();
        let mut has_data = false;
        loop {
            select! {
                Some(op) = self.rx.next() => {
                    let op = op?;
                    if self.handle_op(&mut subs, op).await? {
                        has_data = true;
                    }
                },
                Some(cmd) = cmds.recv() => {
                    if self.handle_command(&mut subs, cmd).await? {
                        has_data = true;
                    }
                },
                res = self.tx.flush(), if has_data => {
                    res?;
                    has_data = false;
                }
            };
        }
    }
}

pub trait Connector {
    fn connect(&self, opts: &InfoOptions) -> impl Future<Output = Option<ConnectOptions>>;
}

pub struct DefaultConnector;

impl Connector for DefaultConnector {
    async fn connect(&self, _: &InfoOptions) -> Option<ConnectOptions> {
        Some(ConnectOptions::default())
    }
}

impl Connector for ConnectOptions {
    async fn connect(&self, _: &InfoOptions) -> Option<ConnectOptions> {
        Some(self.clone())
    }
}

pub struct Client {
    commands: mpsc::Sender<Command>,
}

impl Client {
    pub async fn new(
        addr: impl ToSocketAddrs + Clone + Send + 'static,
        connector: impl Connector,
    ) -> anyhow::Result<(Self, impl Future<Output = anyhow::Result<()>>)> {
        let (cmd_tx, mut cmd_rx) = mpsc::channel(8192);
        let mut conn = Conn::connect(addr)
            .await
            .context("failed to connect to server")?;
        let info = Arc::clone(&conn.info.subscribe().borrow());
        let opts = connector
            .connect(&info)
            .await
            .context("failed to construct connect options")?;
        conn.tx
            .feed(ClientOp::Connect(opts))
            .await
            .context("failed to send connect operation")?;
        Ok((Self { commands: cmd_tx }, async move {
            conn.run(&mut cmd_rx).await
        }))
    }

    pub fn commands(&self) -> &mpsc::Sender<Command> {
        &self.commands
    }

    pub async fn server_info(&self) -> Option<watch::Receiver<Arc<InfoOptions>>> {
        let (tx, rx) = oneshot::channel();
        self.commands.send(Command::ServerInfo(tx)).await.ok()?;
        rx.await.ok()
    }

    pub async fn connect(
        &self,
        opts: ConnectOptions,
    ) -> Result<(), mpsc::error::SendError<Command>> {
        self.commands.send(Command::Connect(opts)).await
    }

    pub async fn subscribe(
        &self,
        subject: Bytes,
        group: Bytes,
        sid: Bytes,
        tx: mpsc::Sender<Message>,
    ) -> Result<(), mpsc::error::SendError<Command>> {
        self.commands
            .send(Command::Subscribe {
                subject,
                group,
                sid,
                tx,
            })
            .await
    }

    pub async fn publish(
        &self,
        subject: Bytes,
        respond: Bytes,
        headers: Bytes,
        payload: Bytes,
    ) -> Result<(), mpsc::error::SendError<Command>> {
        self.commands
            .send(Command::Publish {
                subject,
                respond,
                headers,
                payload,
            })
            .await
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test_log::test(tokio::test)]
    async fn connect() -> anyhow::Result<()> {
        let (clt, io) = Client::new("localhost:4222", ConnectOptions::default()).await?;
        let io = tokio::spawn(io);
        let info = clt
            .server_info()
            .await
            .context("failed to get server info")?;
        let info = Arc::clone(&info.borrow());
        eprintln!("info: {info:?}");
        clt.connect(ConnectOptions::default())
            .await
            .context("failed to connect")?;
        let (tx, mut rx) = mpsc::channel(128);
        let subject = Bytes::from("test");
        clt.subscribe(subject.clone(), Bytes::default(), Bytes::from("0"), tx)
            .await
            .context("failed to subscribe")?;
        let cmds = clt.commands();
        clt.publish(
            subject.clone(),
            Bytes::default(),
            Bytes::default(),
            Bytes::from("hello"),
        )
        .await
        .context("failed to publish `hello`")?;
        cmds.send(Command::Publish {
            subject: subject.clone(),
            headers: Bytes::default(),
            respond: Bytes::default(),
            payload: Bytes::from("bye"),
        })
        .await
        .context("failed to publish `bye`")?;
        let msg = rx.recv().await.context("failed to receive `hello`")?;
        assert_eq!(
            msg,
            Message {
                subject: subject.clone(),
                respond: Bytes::default(),
                headers: Bytes::default(),
                payload: Bytes::from("hello"),
            }
        );
        let msg = rx.recv().await.context("failed to receive `bye`")?;
        assert_eq!(
            msg,
            Message {
                subject: subject.clone(),
                respond: Bytes::default(),
                headers: Bytes::default(),
                payload: Bytes::from("bye"),
            }
        );
        io.abort();
        Ok(())
    }
}