nats_connect/
lib.rs

1use core::pin::Pin;
2use core::task::Poll;
3
4use anyhow::Context as _;
5use async_nats::subject::ToSubject as _;
6use bytes::{BufMut, Bytes};
7use futures::{pin_mut, Future, Stream, StreamExt};
8use tokio::io::{AsyncRead, AsyncWrite};
9
10pub struct Connection {
11    nats: async_nats::Client,
12    tx: async_nats::Subject,
13    rx: async_nats::Subscriber,
14    rx_buffer: Option<Bytes>,
15}
16
17impl Connection {
18    pub fn new(
19        nats: async_nats::Client,
20        tx: async_nats::Subject,
21        rx: async_nats::Subscriber,
22    ) -> Self {
23        Self {
24            nats,
25            tx,
26            rx,
27            rx_buffer: None,
28        }
29    }
30}
31
32fn process_message(
33    msg: async_nats::Message,
34) -> std::io::Result<(Bytes, Option<async_nats::Subject>)> {
35    match msg {
36        async_nats::Message {
37            reply,
38            payload,
39            status: None | Some(async_nats::StatusCode::OK),
40            ..
41        } => Ok((payload, reply)),
42        async_nats::Message {
43            status: Some(async_nats::StatusCode::NO_RESPONDERS),
44            ..
45        } => Err(std::io::ErrorKind::NotConnected.into()),
46        async_nats::Message {
47            status: Some(async_nats::StatusCode::TIMEOUT),
48            ..
49        } => Err(std::io::ErrorKind::TimedOut.into()),
50        async_nats::Message {
51            status: Some(async_nats::StatusCode::REQUEST_TERMINATED),
52            ..
53        } => Err(std::io::ErrorKind::UnexpectedEof.into()),
54        async_nats::Message {
55            status: Some(code),
56            description,
57            ..
58        } => Err(std::io::Error::new(
59            std::io::ErrorKind::Other,
60            if let Some(description) = description {
61                format!("received a response with code `{code}` ({description})")
62            } else {
63                format!("received a response with code `{code}`")
64            },
65        )),
66    }
67}
68
69// TODO: Use proper error types
70pub async fn connect(
71    nats: async_nats::Client,
72    subject: impl async_nats::subject::ToSubject,
73    payload: Bytes,
74) -> anyhow::Result<(Connection, Bytes)> {
75    let reply = nats.new_inbox().to_subject();
76    let mut rx = nats
77        .subscribe(reply.clone())
78        .await
79        .context("failed to subscribe to inbox")?;
80    nats.publish_with_reply(subject, reply, payload)
81        .await
82        .context("failed to connect to peer")?;
83    let msg = rx
84        .next()
85        .await
86        .context("failed to receive outbound subject from peer")?;
87    let (payload, tx) = process_message(msg)?;
88    let tx = tx.context("peer did not specify reply subject")?;
89    Ok((Connection::new(nats, tx, rx), payload))
90}
91
92// TODO: Use proper error types
93pub async fn accept(
94    nats: async_nats::Client,
95    sub: &mut async_nats::Subscriber,
96    handle: impl FnOnce(Bytes) -> std::io::Result<Bytes>,
97) -> anyhow::Result<Connection> {
98    let msg = sub.next().await.context("failed to accept connection")?;
99    let (payload, tx) = process_message(msg)?;
100    let tx = tx.context("peer did not specify reply subject")?;
101    let payload = handle(payload).context("failed to process handshake data")?;
102    let reply = nats.new_inbox().to_subject();
103    let rx = nats
104        .subscribe(reply.clone())
105        .await
106        .context("failed to subscribe to inbox")?;
107    nats.publish_with_reply(tx.clone(), reply, payload)
108        .await
109        .context("failed to connect to peer")?;
110    Ok(Connection::new(nats, tx, rx))
111}
112
113impl AsyncWrite for Connection {
114    fn poll_write(
115        self: Pin<&mut Self>,
116        cx: &mut core::task::Context<'_>,
117        buf: &[u8],
118    ) -> Poll<std::io::Result<usize>> {
119        let async_nats::ServerInfo { max_payload, .. } = self.nats.server_info();
120        let n = buf.len().min(max_payload);
121        let (buf, _) = buf.split_at(n);
122        let fut = self
123            .nats
124            .publish(self.tx.clone(), Bytes::copy_from_slice(buf));
125        pin_mut!(fut);
126        match fut.poll(cx) {
127            Poll::Pending => Poll::Pending,
128            Poll::Ready(Err(err)) => Poll::Ready(Err(std::io::Error::new(
129                std::io::ErrorKind::Other,
130                err.to_string(),
131            ))),
132            Poll::Ready(Ok(())) => Poll::Ready(Ok(n)),
133        }
134    }
135
136    fn poll_flush(
137        self: Pin<&mut Self>,
138        _cx: &mut core::task::Context<'_>,
139    ) -> Poll<std::io::Result<()>> {
140        Poll::Ready(Ok(()))
141    }
142
143    fn poll_shutdown(
144        self: Pin<&mut Self>,
145        _cx: &mut core::task::Context<'_>,
146    ) -> Poll<std::io::Result<()>> {
147        Poll::Ready(Ok(()))
148    }
149}
150
151impl AsyncRead for Connection {
152    fn poll_read(
153        mut self: Pin<&mut Self>,
154        cx: &mut core::task::Context<'_>,
155        buf: &mut tokio::io::ReadBuf<'_>,
156    ) -> Poll<std::io::Result<()>> {
157        let mut payload = if let Some(buffer) = self.rx_buffer.take() {
158            buffer
159        } else {
160            match Pin::new(&mut self.rx).poll_next(cx) {
161                Poll::Pending => return Poll::Pending,
162                Poll::Ready(None) => return Poll::Ready(Ok(())),
163                Poll::Ready(Some(msg)) => {
164                    let (payload, _) = process_message(msg)?;
165                    payload
166                }
167            }
168        };
169        let cap = buf.capacity();
170        if payload.len() > cap {
171            self.rx_buffer = Some(payload.split_off(cap));
172        }
173        buf.put(payload);
174        Poll::Ready(Ok(()))
175    }
176}