async_nng/
socket.rs

1//! Types and traits used for performing asynchronous socket operations based on NNG's AIO.
2
3use super::cancel_guard::{Cancel, CancelGuard};
4use async_channel::{bounded, Receiver};
5use std::time::Duration;
6use tracing::instrument;
7
8/// A wrapper type around [`Socket`](nng::Socket) to enable async-await send and receive
9/// operations.
10///
11/// This type allows for getting a single future (send or receive) from a socket at a time. This is
12/// useful for raw sockets and scalability protocols which cannot leverage
13/// [`Context`](nng::Context), and thus need to perform send and receive operations on an owned
14/// socket directly.
15///
16/// If you have set up your socket such that it can be used with contexts, it is instead
17/// recommended to use [`AsyncContext`](super::context::AsyncContext), which only needs to borrow
18/// the underlying socket. Contexts are more useful for concurrent operations where
19/// of futures for each independent operation on a socket.
20#[derive(Debug)]
21pub struct AsyncSocket {
22    /// The underlying socket
23    socket: nng::Socket,
24
25    /// AIO object for managing async operations on the socket
26    aio: nng::Aio,
27
28    /// Receiving end of a bounded channel for getting results back from the AIO object.
29    receiver: Receiver<nng::AioResult>,
30}
31
32impl TryFrom<nng::Socket> for AsyncSocket {
33    type Error = nng::Error;
34
35    #[instrument]
36    fn try_from(socket: nng::Socket) -> Result<Self, Self::Error> {
37        let (sender, receiver) = bounded(1);
38        let aio = nng::Aio::new(move |aio, result| {
39            if let Err(err) = sender.try_send(result) {
40                tracing::warn!(?err, "Failed to forward IO event from nng::Aio.");
41                aio.cancel();
42            }
43        })?;
44        Ok(Self {
45            socket,
46            aio,
47            receiver,
48        })
49    }
50}
51
52impl Cancel for AsyncSocket {
53    fn cancel(&mut self) {
54        self.aio.cancel();
55        // Blocking call which isn't great but because we've cancelled above this will either:
56        //
57        // 1. Return NNG_ECANCELED very quickly
58        // 2. Return the actual IO event
59        //
60        // Either way, the result isn't useful to us because it has been canceled. One trade-off of
61        // this approach is that we have to accept that we are "blocking" but realistically we
62        // aren't expecting to be blocking long, because the call to cancel the AIO object should
63        // not take excessively long aside from potential context switching from the OS.
64        let _ = self.receiver.recv_blocking();
65    }
66}
67
68impl AsyncSocket {
69    /// Sends a [`Message`](nng::Message) to the socket asynchronously.
70    ///
71    /// # Errors
72    ///
73    /// - `IncorrectState` if the internal `Aio` is already running an operation, or the socket
74    /// cannot send messages in its current state.
75    /// - `MessageTooLarge`: The message is too large.
76    /// - `NotSupported`: The protocol does not support sending messages.
77    /// - `OutOfMemory`: Insufficient memory available.
78    /// - `TimedOut`: The operation timed out.
79    #[instrument(skip(msg))]
80    pub async fn send<M>(
81        &mut self,
82        msg: M,
83        timeout: Option<Duration>,
84    ) -> Result<(), (nng::Message, nng::Error)>
85    where
86        M: Into<nng::Message>,
87    {
88        if let Err(err) = self.aio.set_timeout(timeout) {
89            return Err((msg.into(), err));
90        }
91
92        self.socket.send_async(&self.aio, msg.into())?;
93        let receiver = self.receiver.clone();
94
95        let guard = CancelGuard::new(self);
96        // Unwrap is safe here because we know that the sender cannot be dropped unless `self.aio`
97        // is dropped.
98        let res = match receiver.recv().await.unwrap() {
99            nng::AioResult::Send(res) => res,
100            _ => {
101                tracing::warn!("Reached invalid state in AsyncSocket::send.");
102                unreachable!();
103            }
104        };
105        std::mem::forget(guard);
106        res
107    }
108
109    /// Receives a [`Message`](nng::Message) from the socket asynchronously.
110    ///
111    /// # Errors
112    ///
113    /// - `IncorrectState` if the internal `Aio` is already running an operation, or the socket
114    /// cannot send messages in its current state.
115    /// - `MessageTooLarge`: The message is too large.
116    /// - `NotSupported`: The protocol does not support sending messages.
117    /// - `OutOfMemory`: Insufficient memory available.
118    /// - `TimedOut`: The operation timed out.
119    #[instrument]
120    pub async fn receive(&mut self, timeout: Option<Duration>) -> Result<nng::Message, nng::Error> {
121        self.aio.set_timeout(timeout)?;
122        self.socket.recv_async(&self.aio)?;
123        let receiver = self.receiver.clone();
124
125        let guard = CancelGuard::new(self);
126        // Unwrap is safe here because we know that the sender cannot be dropped unless `self.aio`
127        // is dropped.
128        let res = match receiver.recv().await.unwrap() {
129            nng::AioResult::Recv(res) => res,
130            _ => {
131                tracing::warn!("Reached invalid state in AsyncSocket::receive.");
132                unreachable!();
133            }
134        };
135        std::mem::forget(guard);
136        res
137    }
138
139    /// Grabs the inner [`Socket`](nng::Socket).
140    #[instrument]
141    pub fn into_inner(self) -> nng::Socket {
142        self.socket
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use macro_rules_attribute::apply;
150    use nng::options::Options;
151    use std::{io::Write, time::Duration};
152    use tracing_test::traced_test;
153
154    /// Bytes to send on the request
155    const PING_BYTES: &[u8] = b"ping";
156
157    /// Bytes to send on the reply
158    const PONG_BYTES: &[u8] = b"pong";
159
160    /// Helper function for getting a configured pair of Req/Rep socket pairs.
161    #[instrument]
162    fn make_req_rep() -> (AsyncSocket, AsyncSocket) {
163        let addr = nng::SocketAddr::InProc(random_string::generate(128, "abcdef1234567890"));
164        tracing::info!(addr = addr.to_string(), "Local address");
165
166        let rep = nng::Socket::new(nng::Protocol::Rep0).expect("Construct Rep0");
167        let _listener = {
168            let builder = nng::ListenerBuilder::new(&rep, addr.to_string().as_str())
169                .expect("Listener create succeeds.");
170            // 1MiB RECVMAXSZ
171            builder
172                .set_opt::<nng::options::RecvMaxSize>(1024 * 1024)
173                .expect("Set opt on listener builder.");
174            builder.start().expect("Start listener.")
175        };
176
177        let req = nng::Socket::new(nng::Protocol::Req0).expect("Construct Req0");
178        let _dialer = {
179            let builder = nng::DialerBuilder::new(&req, addr.to_string().as_str())
180                .expect("Dial create succeeds.");
181            // 1MiB RECVMAXSZ
182            builder
183                .set_opt::<nng::options::RecvMaxSize>(1024 * 1024)
184                .expect("Set opt on dialer builder.");
185            builder.start(false).expect("Start dialer.")
186        };
187
188        let req_async = AsyncSocket::try_from(req).expect("Async req.");
189        let rep_async = AsyncSocket::try_from(rep).expect("Async rep.");
190
191        (req_async, rep_async)
192    }
193
194    #[traced_test]
195    #[apply(smol_macros::test)]
196    async fn ping_pong_with_req_rep() {
197        let (mut req, mut rep) = make_req_rep();
198
199        let mut ping = nng::Message::new();
200        ping.write_all(PING_BYTES).expect("All bytes written.");
201        req.send(ping, None)
202            .await
203            .expect("Request should send successfully");
204
205        let request = rep
206            .receive(None)
207            .await
208            .expect("Request should be received successfully.");
209        assert_eq!(PING_BYTES, request.as_slice());
210
211        let mut pong = nng::Message::new();
212        pong.write_all(PONG_BYTES).expect("All bytes written.");
213
214        rep.send(pong, None)
215            .await
216            .expect("Reply should send successfully.");
217
218        let response = req
219            .receive(None)
220            .await
221            .expect("Response should be received successfully.");
222        assert_eq!(PONG_BYTES, response.as_slice());
223    }
224
225    #[traced_test]
226    #[apply(smol_macros::test)]
227    async fn drop_rep_after_request_sent_fails() {
228        let (mut req, rep) = make_req_rep();
229
230        let mut ping = nng::Message::new();
231        ping.write_all(PING_BYTES).expect("All bytes written.");
232        req.send(ping, None)
233            .await
234            .expect("Request should send successfully");
235
236        std::mem::drop(rep);
237
238        let err = req
239            .receive(Some(Duration::from_millis(20)))
240            .await
241            .expect_err("Request should timeout because the response will never come.");
242
243        assert_eq!(err, nng::Error::TimedOut);
244    }
245
246    #[traced_test]
247    #[apply(smol_macros::test)]
248    async fn send_is_cancel_safe() {
249        let (mut req, mut rep) = make_req_rep();
250
251        let mut ping = nng::Message::new();
252        ping.write_all(PING_BYTES).expect("All bytes written.");
253
254        smol::future::or(smol::future::ready(Ok(())), req.send(ping, None))
255            .await
256            .expect("Immediate future should return Ok.");
257
258        let err = rep
259            .receive(Some(Duration::from_millis(20)))
260            .await
261            .expect_err("Request should timeout because nothing was sent.");
262
263        assert_eq!(err, nng::Error::TimedOut);
264
265        // Now send it for real. The state of the req socket should be good to send again.
266        let mut ping = nng::Message::new();
267        ping.write_all(PING_BYTES).expect("All bytes written.");
268
269        req.send(ping, None)
270            .await
271            .expect("Request should send successfully");
272
273        let request = rep
274            .receive(None)
275            .await
276            .expect("Request should be received successfully.");
277        assert_eq!(PING_BYTES, request.as_slice());
278    }
279}