1use core::pin::Pin;
2use core::task::Poll;
3
4use anyhow::Context as _;
5use async_nats::subject::ToSubject as _;
6use bytes::{BufMut, Bytes};
7use futures::{pin_mut, Future, Stream, StreamExt};
8use tokio::io::{AsyncRead, AsyncWrite};
9
10pub struct Connection {
11 nats: async_nats::Client,
12 tx: async_nats::Subject,
13 rx: async_nats::Subscriber,
14 rx_buffer: Option<Bytes>,
15}
16
17impl Connection {
18 pub fn new(
19 nats: async_nats::Client,
20 tx: async_nats::Subject,
21 rx: async_nats::Subscriber,
22 ) -> Self {
23 Self {
24 nats,
25 tx,
26 rx,
27 rx_buffer: None,
28 }
29 }
30}
31
32fn process_message(
33 msg: async_nats::Message,
34) -> std::io::Result<(Bytes, Option<async_nats::Subject>)> {
35 match msg {
36 async_nats::Message {
37 reply,
38 payload,
39 status: None | Some(async_nats::StatusCode::OK),
40 ..
41 } => Ok((payload, reply)),
42 async_nats::Message {
43 status: Some(async_nats::StatusCode::NO_RESPONDERS),
44 ..
45 } => Err(std::io::ErrorKind::NotConnected.into()),
46 async_nats::Message {
47 status: Some(async_nats::StatusCode::TIMEOUT),
48 ..
49 } => Err(std::io::ErrorKind::TimedOut.into()),
50 async_nats::Message {
51 status: Some(async_nats::StatusCode::REQUEST_TERMINATED),
52 ..
53 } => Err(std::io::ErrorKind::UnexpectedEof.into()),
54 async_nats::Message {
55 status: Some(code),
56 description,
57 ..
58 } => Err(std::io::Error::new(
59 std::io::ErrorKind::Other,
60 if let Some(description) = description {
61 format!("received a response with code `{code}` ({description})")
62 } else {
63 format!("received a response with code `{code}`")
64 },
65 )),
66 }
67}
68
69pub async fn connect(
71 nats: async_nats::Client,
72 subject: impl async_nats::subject::ToSubject,
73 payload: Bytes,
74) -> anyhow::Result<(Connection, Bytes)> {
75 let reply = nats.new_inbox().to_subject();
76 let mut rx = nats
77 .subscribe(reply.clone())
78 .await
79 .context("failed to subscribe to inbox")?;
80 nats.publish_with_reply(subject, reply, payload)
81 .await
82 .context("failed to connect to peer")?;
83 let msg = rx
84 .next()
85 .await
86 .context("failed to receive outbound subject from peer")?;
87 let (payload, tx) = process_message(msg)?;
88 let tx = tx.context("peer did not specify reply subject")?;
89 Ok((Connection::new(nats, tx, rx), payload))
90}
91
92pub async fn accept(
94 nats: async_nats::Client,
95 sub: &mut async_nats::Subscriber,
96 handle: impl FnOnce(Bytes) -> std::io::Result<Bytes>,
97) -> anyhow::Result<Connection> {
98 let msg = sub.next().await.context("failed to accept connection")?;
99 let (payload, tx) = process_message(msg)?;
100 let tx = tx.context("peer did not specify reply subject")?;
101 let payload = handle(payload).context("failed to process handshake data")?;
102 let reply = nats.new_inbox().to_subject();
103 let rx = nats
104 .subscribe(reply.clone())
105 .await
106 .context("failed to subscribe to inbox")?;
107 nats.publish_with_reply(tx.clone(), reply, payload)
108 .await
109 .context("failed to connect to peer")?;
110 Ok(Connection::new(nats, tx, rx))
111}
112
113impl AsyncWrite for Connection {
114 fn poll_write(
115 self: Pin<&mut Self>,
116 cx: &mut core::task::Context<'_>,
117 buf: &[u8],
118 ) -> Poll<std::io::Result<usize>> {
119 let async_nats::ServerInfo { max_payload, .. } = self.nats.server_info();
120 let n = buf.len().min(max_payload);
121 let (buf, _) = buf.split_at(n);
122 let fut = self
123 .nats
124 .publish(self.tx.clone(), Bytes::copy_from_slice(buf));
125 pin_mut!(fut);
126 match fut.poll(cx) {
127 Poll::Pending => Poll::Pending,
128 Poll::Ready(Err(err)) => Poll::Ready(Err(std::io::Error::new(
129 std::io::ErrorKind::Other,
130 err.to_string(),
131 ))),
132 Poll::Ready(Ok(())) => Poll::Ready(Ok(n)),
133 }
134 }
135
136 fn poll_flush(
137 self: Pin<&mut Self>,
138 _cx: &mut core::task::Context<'_>,
139 ) -> Poll<std::io::Result<()>> {
140 Poll::Ready(Ok(()))
141 }
142
143 fn poll_shutdown(
144 self: Pin<&mut Self>,
145 _cx: &mut core::task::Context<'_>,
146 ) -> Poll<std::io::Result<()>> {
147 Poll::Ready(Ok(()))
148 }
149}
150
151impl AsyncRead for Connection {
152 fn poll_read(
153 mut self: Pin<&mut Self>,
154 cx: &mut core::task::Context<'_>,
155 buf: &mut tokio::io::ReadBuf<'_>,
156 ) -> Poll<std::io::Result<()>> {
157 let mut payload = if let Some(buffer) = self.rx_buffer.take() {
158 buffer
159 } else {
160 match Pin::new(&mut self.rx).poll_next(cx) {
161 Poll::Pending => return Poll::Pending,
162 Poll::Ready(None) => return Poll::Ready(Ok(())),
163 Poll::Ready(Some(msg)) => {
164 let (payload, _) = process_message(msg)?;
165 payload
166 }
167 }
168 };
169 let cap = buf.capacity();
170 if payload.len() > cap {
171 self.rx_buffer = Some(payload.split_off(cap));
172 }
173 buf.put(payload);
174 Poll::Ready(Ok(()))
175 }
176}