1use super::cancel_guard::{Cancel, CancelGuard};
4use async_channel::{bounded, Receiver};
5use std::time::Duration;
6use tracing::instrument;
7
8#[derive(Debug)]
21pub struct AsyncSocket {
22 socket: nng::Socket,
24
25 aio: nng::Aio,
27
28 receiver: Receiver<nng::AioResult>,
30}
31
32impl TryFrom<nng::Socket> for AsyncSocket {
33 type Error = nng::Error;
34
35 #[instrument]
36 fn try_from(socket: nng::Socket) -> Result<Self, Self::Error> {
37 let (sender, receiver) = bounded(1);
38 let aio = nng::Aio::new(move |aio, result| {
39 if let Err(err) = sender.try_send(result) {
40 tracing::warn!(?err, "Failed to forward IO event from nng::Aio.");
41 aio.cancel();
42 }
43 })?;
44 Ok(Self {
45 socket,
46 aio,
47 receiver,
48 })
49 }
50}
51
52impl Cancel for AsyncSocket {
53 fn cancel(&mut self) {
54 self.aio.cancel();
55 let _ = self.receiver.recv_blocking();
65 }
66}
67
68impl AsyncSocket {
69 #[instrument(skip(msg))]
80 pub async fn send<M>(
81 &mut self,
82 msg: M,
83 timeout: Option<Duration>,
84 ) -> Result<(), (nng::Message, nng::Error)>
85 where
86 M: Into<nng::Message>,
87 {
88 if let Err(err) = self.aio.set_timeout(timeout) {
89 return Err((msg.into(), err));
90 }
91
92 self.socket.send_async(&self.aio, msg.into())?;
93 let receiver = self.receiver.clone();
94
95 let guard = CancelGuard::new(self);
96 let res = match receiver.recv().await.unwrap() {
99 nng::AioResult::Send(res) => res,
100 _ => {
101 tracing::warn!("Reached invalid state in AsyncSocket::send.");
102 unreachable!();
103 }
104 };
105 std::mem::forget(guard);
106 res
107 }
108
109 #[instrument]
120 pub async fn receive(&mut self, timeout: Option<Duration>) -> Result<nng::Message, nng::Error> {
121 self.aio.set_timeout(timeout)?;
122 self.socket.recv_async(&self.aio)?;
123 let receiver = self.receiver.clone();
124
125 let guard = CancelGuard::new(self);
126 let res = match receiver.recv().await.unwrap() {
129 nng::AioResult::Recv(res) => res,
130 _ => {
131 tracing::warn!("Reached invalid state in AsyncSocket::receive.");
132 unreachable!();
133 }
134 };
135 std::mem::forget(guard);
136 res
137 }
138
139 #[instrument]
141 pub fn into_inner(self) -> nng::Socket {
142 self.socket
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use macro_rules_attribute::apply;
150 use nng::options::Options;
151 use std::{io::Write, time::Duration};
152 use tracing_test::traced_test;
153
154 const PING_BYTES: &[u8] = b"ping";
156
157 const PONG_BYTES: &[u8] = b"pong";
159
160 #[instrument]
162 fn make_req_rep() -> (AsyncSocket, AsyncSocket) {
163 let addr = nng::SocketAddr::InProc(random_string::generate(128, "abcdef1234567890"));
164 tracing::info!(addr = addr.to_string(), "Local address");
165
166 let rep = nng::Socket::new(nng::Protocol::Rep0).expect("Construct Rep0");
167 let _listener = {
168 let builder = nng::ListenerBuilder::new(&rep, addr.to_string().as_str())
169 .expect("Listener create succeeds.");
170 builder
172 .set_opt::<nng::options::RecvMaxSize>(1024 * 1024)
173 .expect("Set opt on listener builder.");
174 builder.start().expect("Start listener.")
175 };
176
177 let req = nng::Socket::new(nng::Protocol::Req0).expect("Construct Req0");
178 let _dialer = {
179 let builder = nng::DialerBuilder::new(&req, addr.to_string().as_str())
180 .expect("Dial create succeeds.");
181 builder
183 .set_opt::<nng::options::RecvMaxSize>(1024 * 1024)
184 .expect("Set opt on dialer builder.");
185 builder.start(false).expect("Start dialer.")
186 };
187
188 let req_async = AsyncSocket::try_from(req).expect("Async req.");
189 let rep_async = AsyncSocket::try_from(rep).expect("Async rep.");
190
191 (req_async, rep_async)
192 }
193
194 #[traced_test]
195 #[apply(smol_macros::test)]
196 async fn ping_pong_with_req_rep() {
197 let (mut req, mut rep) = make_req_rep();
198
199 let mut ping = nng::Message::new();
200 ping.write_all(PING_BYTES).expect("All bytes written.");
201 req.send(ping, None)
202 .await
203 .expect("Request should send successfully");
204
205 let request = rep
206 .receive(None)
207 .await
208 .expect("Request should be received successfully.");
209 assert_eq!(PING_BYTES, request.as_slice());
210
211 let mut pong = nng::Message::new();
212 pong.write_all(PONG_BYTES).expect("All bytes written.");
213
214 rep.send(pong, None)
215 .await
216 .expect("Reply should send successfully.");
217
218 let response = req
219 .receive(None)
220 .await
221 .expect("Response should be received successfully.");
222 assert_eq!(PONG_BYTES, response.as_slice());
223 }
224
225 #[traced_test]
226 #[apply(smol_macros::test)]
227 async fn drop_rep_after_request_sent_fails() {
228 let (mut req, rep) = make_req_rep();
229
230 let mut ping = nng::Message::new();
231 ping.write_all(PING_BYTES).expect("All bytes written.");
232 req.send(ping, None)
233 .await
234 .expect("Request should send successfully");
235
236 std::mem::drop(rep);
237
238 let err = req
239 .receive(Some(Duration::from_millis(20)))
240 .await
241 .expect_err("Request should timeout because the response will never come.");
242
243 assert_eq!(err, nng::Error::TimedOut);
244 }
245
246 #[traced_test]
247 #[apply(smol_macros::test)]
248 async fn send_is_cancel_safe() {
249 let (mut req, mut rep) = make_req_rep();
250
251 let mut ping = nng::Message::new();
252 ping.write_all(PING_BYTES).expect("All bytes written.");
253
254 smol::future::or(smol::future::ready(Ok(())), req.send(ping, None))
255 .await
256 .expect("Immediate future should return Ok.");
257
258 let err = rep
259 .receive(Some(Duration::from_millis(20)))
260 .await
261 .expect_err("Request should timeout because nothing was sent.");
262
263 assert_eq!(err, nng::Error::TimedOut);
264
265 let mut ping = nng::Message::new();
267 ping.write_all(PING_BYTES).expect("All bytes written.");
268
269 req.send(ping, None)
270 .await
271 .expect("Request should send successfully");
272
273 let request = rep
274 .receive(None)
275 .await
276 .expect("Request should be received successfully.");
277 assert_eq!(PING_BYTES, request.as_slice());
278 }
279}