async_nng/context.rs
1//! Types and traits used for performing asynchronous socket operations that utilize contexts in
2//! conjunction with NNG's AIO.
3
4use super::cancel_guard::{Cancel, CancelGuard};
5use async_channel::{bounded, Receiver};
6use std::time::Duration;
7use tracing::instrument;
8
9/// A wrapper type around [`Context`](nng::Context) to enable async-await send and receive
10/// operations.
11///
12/// This type allows for asynchronous context-based send and receive operations. One trouble with
13/// `AsyncSocket` and `Socket`s in general is that managing concurrent sends and receives can be
14/// troublesome depending on which scalability protocol your socket implements. For example,
15/// consider the following program:
16///
17/// ```ignore
18/// use async_nng::AsyncSocket;
19/// # use std::{io::Write, time::Duration};
20///
21/// struct MyMessageHandler {
22/// // Assume a Rep0 socket
23/// socket: AsyncSocket,
24/// }
25///
26/// impl MyMessageHandler {
27/// async fn some_api_call(&self, msg: nng::Message) -> Result<(), nng::Error> {
28/// self.socket.send(msg, None).await;
29///
30/// // a really long operation
31/// smol::Timer::after(Duration::from_millis(3000)).await;
32///
33/// let reply = self.socket.receive(None).await;
34/// Ok(())
35/// }
36/// }
37/// ```
38///
39/// We have some message handling type, which may or may not break at one of the await points
40/// above. If it breaks ono the middle await point (our really long operation), then we might have
41/// an issue depending on the protocol being used. Above, we're using req-rep, which has the
42/// restriction that generally (unless you're using sockets in raw mode) that you cannot send two
43/// different requests before the reply is received.
44///
45/// Therefore, if one tries to run multiple message handlers or API calls from different tasks or
46/// threads concurrently they are at risk to run into an issue where the req-rep cycle invariants
47/// are violated because of long-awaiting tasks in between the send and receive.
48///
49/// This is exclusive because `AsyncSocket::send` and `AsyncSocket::receive` take `&mut self`
50/// instead of `&self`. The underlying `nng` crate does not actually require this, but is done to
51/// prevent multiple accesses of the same underlying AIO object in Rust. This still runs into an
52/// issue. If users still want to concurrently construct API calls such as the above, they will
53/// then choose to wrap the async socket inside of something like an `Arc<Mutex<_>>`. If that mutex
54/// is asynchronous and releases across await points, then you still have the issue that concurrent
55/// calls to these APIs can fail because e.g. two sends were called before a receive was.
56///
57/// ## Contexts
58///
59/// The way around this is to use contexts. Each context can be constructed locally on the stack,
60/// and then indepedently manage sending and receiving in a safe way, without requiring a lock on
61/// the socket as a whole:
62///
63/// ```no_run
64/// use async_nng::AsyncContext;
65/// use nng::Socket;
66/// # use std::{io::Write, time::Duration};
67///
68/// struct MyMessageHandler {
69/// // Assume a Rep0 socket
70/// socket: Socket,
71/// }
72///
73/// impl MyMessageHandler {
74/// async fn some_api_call(&self, msg: nng::Message) -> Result<(), nng::Error> {
75/// // Notice that we use a regular borrow on `&self.socket`, not `&mut self.socket`.
76/// //
77/// // Contexts defined locally will not conflict with one another, which makes writing
78/// // asynchronous, concurrent programs easier.
79/// let context = AsyncContext::try_from(&self.socket)?;
80///
81/// context.send(msg, None).await;
82///
83/// // a really long operation
84/// smol::Timer::after(Duration::from_millis(3000)).await;
85///
86/// let reply = context.receive(None).await;
87/// Ok(())
88/// }
89/// }
90/// ```
91///
92/// Contexts cannot be used with raw-mode sockets; however, by borrowing the underlying socket they
93/// are able to create an object that can concurrently and independently operate without the
94/// aforementioned bugs that can arise at runtime due to trying to use e.g. a req-rep socket
95/// concurrently.
96///
97/// In almost all cases, one should prefer to use the `AsyncContext` type when possible. It
98/// provides better guarantees about concurrent access, which for asynchronous code is necessarily
99/// a concern.
100#[derive(Debug)]
101pub struct AsyncContext<'a> {
102 /// The underlying context to run async operations on.
103 context: nng::Context,
104
105 /// AIO object for managing async operations on the socket
106 aio: nng::Aio,
107
108 /// Receiving end of a bounded channel for getting results back from the AIO object.
109 receiver: Receiver<nng::AioResult>,
110
111 /// The socket that the context is bound to.
112 ///
113 /// This is pretty much entirely unused but is here to tie the lifetime of the context to the
114 /// socket. This is particularly useful because `nng::Context` doesn't forward this lifetime,
115 /// and as a result you can end up getting `nng::Error::Closed` on operations. By tying this
116 /// lifetime here artificially (because the context is closed if the socket is closed, but not
117 /// vice-versa), we can guarantee that contexts do not outlive their sockets using the borrow
118 /// checker instead of fishing for `nng::Error::Closed`.
119 ///
120 /// See <https://doc.rust-lang.org/nomicon/phantom-data.html#table-of-phantomdata-patterns> for
121 /// more info on variance of the phantom. The reference to the socket was intentionally not
122 /// chosen to be held here because contexts do not generally appreciate if one touches the
123 /// socket (e.g. call `nng::Socket::close`) while borrowed, and this is an easy way to enforce
124 /// the lifetime bound artificially without providing a footgun that is a reference to the
125 /// actual socket.
126 _phantom: std::marker::PhantomData<&'a nng::Socket>,
127}
128
129impl<'a> TryFrom<&'a nng::Socket> for AsyncContext<'a> {
130 type Error = nng::Error;
131
132 #[instrument]
133 fn try_from(socket: &'a nng::Socket) -> Result<Self, Self::Error> {
134 let context = nng::Context::new(socket)?;
135 let (sender, receiver) = bounded(1);
136 let aio = nng::Aio::new(move |aio, result| {
137 if let Err(err) = sender.try_send(result) {
138 tracing::warn!(?err, "Failed to forward IO event from nng::Aio.");
139 aio.cancel();
140 }
141 })?;
142
143 Ok(Self {
144 context,
145 aio,
146 receiver,
147 _phantom: std::marker::PhantomData,
148 })
149 }
150}
151
152impl<'a> Cancel for AsyncContext<'a> {
153 fn cancel(&mut self) {
154 self.aio.cancel();
155 // Blocking call which isn't great but because we've cancelled above this will either:
156 //
157 // 1. Return NNG_ECANCELED very quickly
158 // 2. Return the actual IO event
159 //
160 // Either way, the result isn't useful to us because it has been canceled. One trade-off of
161 // this approach is that we have to accept that we are "blocking" but realistically we
162 // aren't expecting to be blocking long, because the call to cancel the AIO object should
163 // not take excessively long aside from potential context switching from the OS.
164 let _ = self.receiver.recv_blocking();
165 }
166}
167
168impl<'a> AsyncContext<'a> {
169 /// Sends a [`Message`](nng::Message) to the socket asynchronously.
170 ///
171 /// # Errors
172 ///
173 /// - `IncorrectState` if the internal `Aio` is already running an operation, or the socket
174 /// cannot send messages in its current state.
175 /// - `MessageTooLarge`: The message is too large.
176 /// - `NotSupported`: The protocol does not support sending messages.
177 /// - `OutOfMemory`: Insufficient memory available.
178 /// - `TimedOut`: The operation timed out.
179 /// - `Closed`: The context or socket has been closed and future operations will not work.
180 #[instrument(skip(msg))]
181 pub async fn send<M>(
182 &mut self,
183 msg: M,
184 timeout: Option<Duration>,
185 ) -> Result<(), (nng::Message, nng::Error)>
186 where
187 M: Into<nng::Message>,
188 {
189 if let Err(err) = self.aio.set_timeout(timeout) {
190 return Err((msg.into(), err));
191 }
192 self.context.send(&self.aio, msg.into())?;
193 let receiver = self.receiver.clone();
194
195 let guard = CancelGuard::new(self);
196 // Unwrap is safe here because we know that the sender cannot be dropped unless `self.aio`
197 // is dropped.
198 let res = match receiver.recv().await.unwrap() {
199 nng::AioResult::Send(res) => res,
200 _ => {
201 tracing::warn!("Reached invalid state in AsyncContext::send.");
202 unreachable!();
203 }
204 };
205 std::mem::forget(guard);
206 res
207 }
208
209 /// Receives a [`Message`](nng::Message) from the socket asynchronously.
210 ///
211 /// # Errors
212 ///
213 /// - `IncorrectState` if the internal `Aio` is already running an operation, or the socket
214 /// cannot receive messages in its current state.
215 /// - `MessageTooLarge`: The message is too large.
216 /// - `NotSupported`: The protocol does not support sending messages.
217 /// - `OutOfMemory`: Insufficient memory available.
218 /// - `TimedOut`: The operation timed out.
219 /// - `Closed`: The context or socket has been closed and future operations will not work.
220 #[instrument]
221 pub async fn receive(&mut self, timeout: Option<Duration>) -> Result<nng::Message, nng::Error> {
222 self.aio.set_timeout(timeout)?;
223 self.context.recv(&self.aio)?;
224
225 let receiver = self.receiver.clone();
226 let guard = CancelGuard::new(self);
227 // Unwrap is safe here because we know that the sender cannot be dropped unless `self.aio`
228 // is dropped.
229 let res = match receiver.recv().await.unwrap() {
230 nng::AioResult::Recv(res) => res,
231 _ => {
232 tracing::warn!("Reached invalid state in AsyncContext::receive.");
233 unreachable!();
234 }
235 };
236 std::mem::forget(guard);
237 res
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use macro_rules_attribute::apply;
245 use nng::options::Options;
246 use std::{io::Write, time::Duration};
247 use tracing_test::traced_test;
248
249 /// Bytes to send on the request
250 const PING_BYTES: &[u8] = b"ping";
251
252 /// Bytes to send on the reply
253 const PONG_BYTES: &[u8] = b"pong";
254
255 /// Helper function for getting a configured pair of Req/Rep socket pairs.
256 #[instrument]
257 fn make_req_rep() -> (nng::Socket, nng::Socket) {
258 let addr = nng::SocketAddr::InProc(random_string::generate(128, "abcdef1234567890"));
259 tracing::info!(addr = addr.to_string(), "Local address");
260
261 let rep = nng::Socket::new(nng::Protocol::Rep0).expect("Construct Rep0");
262 let _listener = {
263 let builder = nng::ListenerBuilder::new(&rep, addr.to_string().as_str())
264 .expect("Listener create succeeds.");
265 // 1MiB RECVMAXSZ
266 builder
267 .set_opt::<nng::options::RecvMaxSize>(1024 * 1024)
268 .expect("Set opt on listener builder.");
269 builder.start().expect("Start listener.")
270 };
271
272 let req = nng::Socket::new(nng::Protocol::Req0).expect("Construct Req0");
273 let _dialer = {
274 let builder = nng::DialerBuilder::new(&req, addr.to_string().as_str())
275 .expect("Dial create succeeds.");
276 // 1MiB RECVMAXSZ
277 builder
278 .set_opt::<nng::options::RecvMaxSize>(1024 * 1024)
279 .expect("Set opt on dialer builder.");
280 builder.start(false).expect("Start dialer.")
281 };
282
283 (req, rep)
284 }
285
286 #[traced_test]
287 #[apply(smol_macros::test)]
288 async fn ping_pong_with_req_rep() {
289 let (req, rep) = make_req_rep();
290
291 let mut req_ctx = AsyncContext::try_from(&req).expect("Make async context for req.");
292 let mut rep_ctx = AsyncContext::try_from(&rep).expect("Make async context for rep.");
293
294 let mut ping = nng::Message::new();
295 ping.write_all(PING_BYTES).expect("All bytes written.");
296 req_ctx
297 .send(ping, None)
298 .await
299 .expect("Request should send successfully");
300
301 let request = rep_ctx
302 .receive(None)
303 .await
304 .expect("Request should be received successfully.");
305 assert_eq!(PING_BYTES, request.as_slice());
306
307 let mut pong = nng::Message::new();
308 pong.write_all(PONG_BYTES).expect("All bytes written.");
309
310 rep_ctx
311 .send(pong, None)
312 .await
313 .expect("Reply should send successfully.");
314
315 let response = req_ctx
316 .receive(None)
317 .await
318 .expect("Response should be received successfully.");
319 assert_eq!(PONG_BYTES, response.as_slice());
320 }
321
322 #[traced_test]
323 #[apply(smol_macros::test)]
324 async fn drop_rep_after_request_sent_fails() {
325 let (req, rep) = make_req_rep();
326
327 let mut req_ctx = AsyncContext::try_from(&req).expect("Make async context for req.");
328 let rep_ctx = AsyncContext::try_from(&rep).expect("Make async context for rep.");
329
330 let mut ping = nng::Message::new();
331 ping.write_all(PING_BYTES).expect("All bytes written.");
332 req_ctx
333 .send(ping, None)
334 .await
335 .expect("Request should send successfully");
336
337 std::mem::drop(rep_ctx);
338
339 let err = req_ctx
340 .receive(Some(Duration::from_millis(20)))
341 .await
342 .expect_err("Request should timeout because the response will never come.");
343
344 assert_eq!(err, nng::Error::TimedOut);
345 }
346
347 #[traced_test]
348 #[apply(smol_macros::test)]
349 async fn context_send_is_cancel_safe() {
350 let (req, rep) = make_req_rep();
351
352 let mut req_ctx = AsyncContext::try_from(&req).expect("Make async context for req.");
353 let mut rep_ctx = AsyncContext::try_from(&rep).expect("Make async context for rep.");
354
355 let mut ping = nng::Message::new();
356 ping.write_all(PING_BYTES).expect("All bytes written.");
357
358 smol::future::or(smol::future::ready(Ok(())), req_ctx.send(ping, None))
359 .await
360 .expect("Immediate future should return Ok.");
361
362 let err = rep_ctx
363 .receive(Some(Duration::from_millis(20)))
364 .await
365 .expect_err("Request should timeout because nothing was sent.");
366
367 assert_eq!(err, nng::Error::TimedOut);
368
369 // Now send it for real. The state of the req socket should be good to send again.
370 let mut ping = nng::Message::new();
371 ping.write_all(PING_BYTES).expect("All bytes written.");
372
373 req_ctx
374 .send(ping, None)
375 .await
376 .expect("Request should send successfully");
377
378 let request = rep_ctx
379 .receive(None)
380 .await
381 .expect("Request should be received successfully.");
382 assert_eq!(PING_BYTES, request.as_slice());
383 }
384}