async_nng/
context.rs

1//! Types and traits used for performing asynchronous socket operations that utilize contexts in
2//! conjunction with NNG's AIO.
3
4use super::cancel_guard::{Cancel, CancelGuard};
5use async_channel::{bounded, Receiver};
6use std::time::Duration;
7use tracing::instrument;
8
9/// A wrapper type around [`Context`](nng::Context) to enable async-await send and receive
10/// operations.
11///
12/// This type allows for asynchronous context-based send and receive operations. One trouble with
13/// `AsyncSocket` and `Socket`s in general is that managing concurrent sends and receives can be
14/// troublesome depending on which scalability protocol your socket implements. For example,
15/// consider the following program:
16///
17/// ```ignore
18/// use async_nng::AsyncSocket;
19/// # use std::{io::Write, time::Duration};
20///
21/// struct MyMessageHandler {
22///     // Assume a Rep0 socket
23///     socket: AsyncSocket,
24/// }
25///
26/// impl MyMessageHandler {
27///     async fn some_api_call(&self, msg: nng::Message) -> Result<(), nng::Error> {
28///         self.socket.send(msg, None).await;
29///
30///         // a really long operation
31///         smol::Timer::after(Duration::from_millis(3000)).await;
32///
33///         let reply = self.socket.receive(None).await;
34///         Ok(())
35///     }
36/// }
37/// ```
38///
39/// We have some message handling type, which may or may not break at one of the await points
40/// above. If it breaks ono the middle await point (our really long operation), then we might have
41/// an issue depending on the protocol being used. Above, we're using req-rep, which has the
42/// restriction that generally (unless you're using sockets in raw mode) that you cannot send two
43/// different requests before the reply is received.
44///
45/// Therefore, if one tries to run multiple message handlers or API calls from different tasks or
46/// threads concurrently they are at risk to run into an issue where the req-rep cycle invariants
47/// are violated because of long-awaiting tasks in between the send and receive.
48///
49/// This is exclusive because `AsyncSocket::send` and `AsyncSocket::receive` take `&mut self`
50/// instead of `&self`. The underlying `nng` crate does not actually require this, but is done to
51/// prevent multiple accesses of the same underlying AIO object in Rust. This still runs into an
52/// issue. If users still want to concurrently construct API calls such as the above, they will
53/// then choose to wrap the async socket inside of something like an `Arc<Mutex<_>>`. If that mutex
54/// is asynchronous and releases across await points, then you still have the issue that concurrent
55/// calls to these APIs can fail because e.g. two sends were called before a receive was.
56///
57/// ## Contexts
58///
59/// The way around this is to use contexts. Each context can be constructed locally on the stack,
60/// and then indepedently manage sending and receiving in a safe way, without requiring a lock on
61/// the socket as a whole:
62///
63/// ```no_run
64/// use async_nng::AsyncContext;
65/// use nng::Socket;
66/// # use std::{io::Write, time::Duration};
67///
68/// struct MyMessageHandler {
69///     // Assume a Rep0 socket
70///     socket: Socket,
71/// }
72///
73/// impl MyMessageHandler {
74///     async fn some_api_call(&self, msg: nng::Message) -> Result<(), nng::Error> {
75///         // Notice that we use a regular borrow on `&self.socket`, not `&mut self.socket`.
76///         //
77///         // Contexts defined locally will not conflict with one another, which makes writing
78///         // asynchronous, concurrent programs easier.
79///         let context = AsyncContext::try_from(&self.socket)?;
80///
81///         context.send(msg, None).await;
82///
83///         // a really long operation
84///         smol::Timer::after(Duration::from_millis(3000)).await;
85///
86///         let reply = context.receive(None).await;
87///         Ok(())
88///     }
89/// }
90/// ```
91///
92/// Contexts cannot be used with raw-mode sockets; however, by borrowing the underlying socket they
93/// are able to create an object that can concurrently and independently operate without the
94/// aforementioned bugs that can arise at runtime due to trying to use e.g. a req-rep socket
95/// concurrently.
96///
97/// In almost all cases, one should prefer to use the `AsyncContext` type when possible. It
98/// provides better guarantees about concurrent access, which for asynchronous code is necessarily
99/// a concern.
100#[derive(Debug)]
101pub struct AsyncContext<'a> {
102    /// The underlying context to run async operations on.
103    context: nng::Context,
104
105    /// AIO object for managing async operations on the socket
106    aio: nng::Aio,
107
108    /// Receiving end of a bounded channel for getting results back from the AIO object.
109    receiver: Receiver<nng::AioResult>,
110
111    /// The socket that the context is bound to.
112    ///
113    /// This is pretty much entirely unused but is here to tie the lifetime of the context to the
114    /// socket. This is particularly useful because `nng::Context` doesn't forward this lifetime,
115    /// and as a result you can end up getting `nng::Error::Closed` on operations. By tying this
116    /// lifetime here artificially (because the context is closed if the socket is closed, but not
117    /// vice-versa), we can guarantee that contexts do not outlive their sockets using the borrow
118    /// checker instead of fishing for `nng::Error::Closed`.
119    ///
120    /// See <https://doc.rust-lang.org/nomicon/phantom-data.html#table-of-phantomdata-patterns> for
121    /// more info on variance of the phantom. The reference to the socket was intentionally not
122    /// chosen to be held here because contexts do not generally appreciate if one touches the
123    /// socket (e.g. call `nng::Socket::close`) while borrowed, and this is an easy way to enforce
124    /// the lifetime bound artificially without providing a footgun that is a reference to the
125    /// actual socket.
126    _phantom: std::marker::PhantomData<&'a nng::Socket>,
127}
128
129impl<'a> TryFrom<&'a nng::Socket> for AsyncContext<'a> {
130    type Error = nng::Error;
131
132    #[instrument]
133    fn try_from(socket: &'a nng::Socket) -> Result<Self, Self::Error> {
134        let context = nng::Context::new(socket)?;
135        let (sender, receiver) = bounded(1);
136        let aio = nng::Aio::new(move |aio, result| {
137            if let Err(err) = sender.try_send(result) {
138                tracing::warn!(?err, "Failed to forward IO event from nng::Aio.");
139                aio.cancel();
140            }
141        })?;
142
143        Ok(Self {
144            context,
145            aio,
146            receiver,
147            _phantom: std::marker::PhantomData,
148        })
149    }
150}
151
152impl<'a> Cancel for AsyncContext<'a> {
153    fn cancel(&mut self) {
154        self.aio.cancel();
155        // Blocking call which isn't great but because we've cancelled above this will either:
156        //
157        // 1. Return NNG_ECANCELED very quickly
158        // 2. Return the actual IO event
159        //
160        // Either way, the result isn't useful to us because it has been canceled. One trade-off of
161        // this approach is that we have to accept that we are "blocking" but realistically we
162        // aren't expecting to be blocking long, because the call to cancel the AIO object should
163        // not take excessively long aside from potential context switching from the OS.
164        let _ = self.receiver.recv_blocking();
165    }
166}
167
168impl<'a> AsyncContext<'a> {
169    /// Sends a [`Message`](nng::Message) to the socket asynchronously.
170    ///
171    /// # Errors
172    ///
173    /// - `IncorrectState` if the internal `Aio` is already running an operation, or the socket
174    /// cannot send messages in its current state.
175    /// - `MessageTooLarge`: The message is too large.
176    /// - `NotSupported`: The protocol does not support sending messages.
177    /// - `OutOfMemory`: Insufficient memory available.
178    /// - `TimedOut`: The operation timed out.
179    /// - `Closed`: The context or socket has been closed and future operations will not work.
180    #[instrument(skip(msg))]
181    pub async fn send<M>(
182        &mut self,
183        msg: M,
184        timeout: Option<Duration>,
185    ) -> Result<(), (nng::Message, nng::Error)>
186    where
187        M: Into<nng::Message>,
188    {
189        if let Err(err) = self.aio.set_timeout(timeout) {
190            return Err((msg.into(), err));
191        }
192        self.context.send(&self.aio, msg.into())?;
193        let receiver = self.receiver.clone();
194
195        let guard = CancelGuard::new(self);
196        // Unwrap is safe here because we know that the sender cannot be dropped unless `self.aio`
197        // is dropped.
198        let res = match receiver.recv().await.unwrap() {
199            nng::AioResult::Send(res) => res,
200            _ => {
201                tracing::warn!("Reached invalid state in AsyncContext::send.");
202                unreachable!();
203            }
204        };
205        std::mem::forget(guard);
206        res
207    }
208
209    /// Receives a [`Message`](nng::Message) from the socket asynchronously.
210    ///
211    /// # Errors
212    ///
213    /// - `IncorrectState` if the internal `Aio` is already running an operation, or the socket
214    /// cannot receive messages in its current state.
215    /// - `MessageTooLarge`: The message is too large.
216    /// - `NotSupported`: The protocol does not support sending messages.
217    /// - `OutOfMemory`: Insufficient memory available.
218    /// - `TimedOut`: The operation timed out.
219    /// - `Closed`: The context or socket has been closed and future operations will not work.
220    #[instrument]
221    pub async fn receive(&mut self, timeout: Option<Duration>) -> Result<nng::Message, nng::Error> {
222        self.aio.set_timeout(timeout)?;
223        self.context.recv(&self.aio)?;
224
225        let receiver = self.receiver.clone();
226        let guard = CancelGuard::new(self);
227        // Unwrap is safe here because we know that the sender cannot be dropped unless `self.aio`
228        // is dropped.
229        let res = match receiver.recv().await.unwrap() {
230            nng::AioResult::Recv(res) => res,
231            _ => {
232                tracing::warn!("Reached invalid state in AsyncContext::receive.");
233                unreachable!();
234            }
235        };
236        std::mem::forget(guard);
237        res
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use macro_rules_attribute::apply;
245    use nng::options::Options;
246    use std::{io::Write, time::Duration};
247    use tracing_test::traced_test;
248
249    /// Bytes to send on the request
250    const PING_BYTES: &[u8] = b"ping";
251
252    /// Bytes to send on the reply
253    const PONG_BYTES: &[u8] = b"pong";
254
255    /// Helper function for getting a configured pair of Req/Rep socket pairs.
256    #[instrument]
257    fn make_req_rep() -> (nng::Socket, nng::Socket) {
258        let addr = nng::SocketAddr::InProc(random_string::generate(128, "abcdef1234567890"));
259        tracing::info!(addr = addr.to_string(), "Local address");
260
261        let rep = nng::Socket::new(nng::Protocol::Rep0).expect("Construct Rep0");
262        let _listener = {
263            let builder = nng::ListenerBuilder::new(&rep, addr.to_string().as_str())
264                .expect("Listener create succeeds.");
265            // 1MiB RECVMAXSZ
266            builder
267                .set_opt::<nng::options::RecvMaxSize>(1024 * 1024)
268                .expect("Set opt on listener builder.");
269            builder.start().expect("Start listener.")
270        };
271
272        let req = nng::Socket::new(nng::Protocol::Req0).expect("Construct Req0");
273        let _dialer = {
274            let builder = nng::DialerBuilder::new(&req, addr.to_string().as_str())
275                .expect("Dial create succeeds.");
276            // 1MiB RECVMAXSZ
277            builder
278                .set_opt::<nng::options::RecvMaxSize>(1024 * 1024)
279                .expect("Set opt on dialer builder.");
280            builder.start(false).expect("Start dialer.")
281        };
282
283        (req, rep)
284    }
285
286    #[traced_test]
287    #[apply(smol_macros::test)]
288    async fn ping_pong_with_req_rep() {
289        let (req, rep) = make_req_rep();
290
291        let mut req_ctx = AsyncContext::try_from(&req).expect("Make async context for req.");
292        let mut rep_ctx = AsyncContext::try_from(&rep).expect("Make async context for rep.");
293
294        let mut ping = nng::Message::new();
295        ping.write_all(PING_BYTES).expect("All bytes written.");
296        req_ctx
297            .send(ping, None)
298            .await
299            .expect("Request should send successfully");
300
301        let request = rep_ctx
302            .receive(None)
303            .await
304            .expect("Request should be received successfully.");
305        assert_eq!(PING_BYTES, request.as_slice());
306
307        let mut pong = nng::Message::new();
308        pong.write_all(PONG_BYTES).expect("All bytes written.");
309
310        rep_ctx
311            .send(pong, None)
312            .await
313            .expect("Reply should send successfully.");
314
315        let response = req_ctx
316            .receive(None)
317            .await
318            .expect("Response should be received successfully.");
319        assert_eq!(PONG_BYTES, response.as_slice());
320    }
321
322    #[traced_test]
323    #[apply(smol_macros::test)]
324    async fn drop_rep_after_request_sent_fails() {
325        let (req, rep) = make_req_rep();
326
327        let mut req_ctx = AsyncContext::try_from(&req).expect("Make async context for req.");
328        let rep_ctx = AsyncContext::try_from(&rep).expect("Make async context for rep.");
329
330        let mut ping = nng::Message::new();
331        ping.write_all(PING_BYTES).expect("All bytes written.");
332        req_ctx
333            .send(ping, None)
334            .await
335            .expect("Request should send successfully");
336
337        std::mem::drop(rep_ctx);
338
339        let err = req_ctx
340            .receive(Some(Duration::from_millis(20)))
341            .await
342            .expect_err("Request should timeout because the response will never come.");
343
344        assert_eq!(err, nng::Error::TimedOut);
345    }
346
347    #[traced_test]
348    #[apply(smol_macros::test)]
349    async fn context_send_is_cancel_safe() {
350        let (req, rep) = make_req_rep();
351
352        let mut req_ctx = AsyncContext::try_from(&req).expect("Make async context for req.");
353        let mut rep_ctx = AsyncContext::try_from(&rep).expect("Make async context for rep.");
354
355        let mut ping = nng::Message::new();
356        ping.write_all(PING_BYTES).expect("All bytes written.");
357
358        smol::future::or(smol::future::ready(Ok(())), req_ctx.send(ping, None))
359            .await
360            .expect("Immediate future should return Ok.");
361
362        let err = rep_ctx
363            .receive(Some(Duration::from_millis(20)))
364            .await
365            .expect_err("Request should timeout because nothing was sent.");
366
367        assert_eq!(err, nng::Error::TimedOut);
368
369        // Now send it for real. The state of the req socket should be good to send again.
370        let mut ping = nng::Message::new();
371        ping.write_all(PING_BYTES).expect("All bytes written.");
372
373        req_ctx
374            .send(ping, None)
375            .await
376            .expect("Request should send successfully");
377
378        let request = rep_ctx
379            .receive(None)
380            .await
381            .expect("Request should be received successfully.");
382        assert_eq!(PING_BYTES, request.as_slice());
383    }
384}