async_dup/
lib.rs

1//! Duplicate an async I/O handle.
2//!
3//! This crate provides two tools, [`Arc`] and [`Mutex`]:
4//!
5//! * [`Arc`] implements [`AsyncRead`], [`AsyncWrite`], and [`AsyncSeek`] if a reference to the
6//!   inner type does.
7//! * A reference to [`Mutex`] implements [`AsyncRead`], [`AsyncWrite`], and [`AsyncSeek`] if the
8//!   inner type does.
9//!
10//! Wrap an async I/O handle in [`Arc`] or [`Mutex`] to clone it or share among tasks.
11//!
12//! # Examples
13//!
14//! Clone an async I/O handle:
15//!
16//! ```no_run
17//! use async_dup::Arc;
18//! use futures::io;
19//! use smol::Async;
20//! use std::net::TcpStream;
21//!
22//! # fn main() -> std::io::Result<()> { smol::block_on(async {
23//! // A client that echoes messages back to the server.
24//! let stream = Async::<TcpStream>::connect(([127, 0, 0, 1], 8000)).await?;
25//!
26//! // Create two handles to the stream.
27//! let reader = Arc::new(stream);
28//! let mut writer = reader.clone();
29//!
30//! // Echo data received from the reader back into the writer.
31//! io::copy(reader, &mut writer).await?;
32//! # Ok(()) }) }
33//! ```
34//!
35//! Share an async I/O handle:
36//!
37//! ```
38//! use async_dup::Mutex;
39//! use futures::io;
40//! use futures::prelude::*;
41//!
42//! // Reads data from a stream and echoes it back.
43//! async fn echo(stream: impl AsyncRead + AsyncWrite + Unpin) -> io::Result<u64> {
44//!     let stream = Mutex::new(stream);
45//!     io::copy(&stream, &mut &stream).await
46//! }
47//! ```
48
49#![forbid(unsafe_code)]
50#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
51#![doc(
52    html_favicon_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
53)]
54#![doc(
55    html_logo_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
56)]
57
58use std::fmt;
59use std::hash::{Hash, Hasher};
60use std::io::{self, IoSlice, IoSliceMut, SeekFrom};
61use std::ops::{Deref, DerefMut};
62use std::pin::Pin;
63use std::task::{Context, Poll};
64
65use futures_io::{AsyncRead, AsyncSeek, AsyncWrite};
66
67/// A reference-counted pointer that implements async I/O traits.
68///
69/// This is just a wrapper around [`std::sync::Arc`] that adds the following impls:
70///
71/// - `impl<T> AsyncRead for Arc<T> where &T: AsyncRead {}`
72/// - `impl<T> AsyncWrite for Arc<T> where &T: AsyncWrite {}`
73/// - `impl<T> AsyncSeek for Arc<T> where &T: AsyncSeek {}`
74pub struct Arc<T>(pub std::sync::Arc<T>);
75
76impl<T> Unpin for Arc<T> {}
77
78impl<T> Arc<T> {
79    /// Constructs a new `Arc<T>`.
80    ///
81    /// # Examples
82    ///
83    /// ```
84    /// use async_dup::Arc;
85    ///
86    /// let a = Arc::new(7);
87    /// ```
88    pub fn new(data: T) -> Arc<T> {
89        Arc(std::sync::Arc::new(data))
90    }
91}
92
93impl<T> Clone for Arc<T> {
94    fn clone(&self) -> Arc<T> {
95        Arc(self.0.clone())
96    }
97}
98
99impl<T> Deref for Arc<T> {
100    type Target = T;
101
102    #[inline]
103    fn deref(&self) -> &Self::Target {
104        &self.0
105    }
106}
107
108impl<T: fmt::Debug> fmt::Debug for Arc<T> {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        fmt::Debug::fmt(&**self, f)
111    }
112}
113
114impl<T: fmt::Display> fmt::Display for Arc<T> {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        fmt::Display::fmt(&**self, f)
117    }
118}
119
120impl<T: Hash> Hash for Arc<T> {
121    fn hash<H: Hasher>(&self, state: &mut H) {
122        (**self).hash(state)
123    }
124}
125
126impl<T> fmt::Pointer for Arc<T> {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        fmt::Pointer::fmt(&(&**self as *const T), f)
129    }
130}
131
132impl<T: Default> Default for Arc<T> {
133    fn default() -> Arc<T> {
134        Arc::new(Default::default())
135    }
136}
137
138impl<T> From<T> for Arc<T> {
139    fn from(t: T) -> Arc<T> {
140        Arc::new(t)
141    }
142}
143
144// NOTE(stjepang): It would also make sense to have the following impls:
145//
146// - `impl<T> AsyncRead for &Arc<T> where &T: AsyncRead {}`
147// - `impl<T> AsyncWrite for &Arc<T> where &T: AsyncWrite {}`
148// - `impl<T> AsyncSeek for &Arc<T> where &T: AsyncSeek {}`
149//
150// However, those impls sometimes make Rust's type inference try too hard when types cannot be
151// inferred. In the end, instead of complaining with a nice error message, the Rust compiler ends
152// up overflowing and dumping a very long error message spanning multiple screens.
153//
154// Since those impls are not essential, I decided to err on the safe side and not include them.
155
156impl<T> AsyncRead for Arc<T>
157where
158    for<'a> &'a T: AsyncRead,
159{
160    fn poll_read(
161        self: Pin<&mut Self>,
162        cx: &mut Context<'_>,
163        buf: &mut [u8],
164    ) -> Poll<io::Result<usize>> {
165        Pin::new(&mut &*self.0).poll_read(cx, buf)
166    }
167
168    fn poll_read_vectored(
169        self: Pin<&mut Self>,
170        cx: &mut Context<'_>,
171        bufs: &mut [IoSliceMut<'_>],
172    ) -> Poll<io::Result<usize>> {
173        Pin::new(&mut &*self.0).poll_read_vectored(cx, bufs)
174    }
175}
176
177impl<T> AsyncWrite for Arc<T>
178where
179    for<'a> &'a T: AsyncWrite,
180{
181    fn poll_write(
182        self: Pin<&mut Self>,
183        cx: &mut Context<'_>,
184        buf: &[u8],
185    ) -> Poll<io::Result<usize>> {
186        Pin::new(&mut &*self.0).poll_write(cx, buf)
187    }
188
189    fn poll_write_vectored(
190        self: Pin<&mut Self>,
191        cx: &mut Context<'_>,
192        bufs: &[IoSlice<'_>],
193    ) -> Poll<io::Result<usize>> {
194        Pin::new(&mut &*self.0).poll_write_vectored(cx, bufs)
195    }
196
197    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198        Pin::new(&mut &*self.0).poll_flush(cx)
199    }
200
201    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
202        Pin::new(&mut &*self.0).poll_close(cx)
203    }
204}
205
206impl<T> AsyncSeek for Arc<T>
207where
208    for<'a> &'a T: AsyncSeek,
209{
210    fn poll_seek(
211        self: Pin<&mut Self>,
212        cx: &mut Context<'_>,
213        pos: SeekFrom,
214    ) -> Poll<io::Result<u64>> {
215        Pin::new(&mut &*self.0).poll_seek(cx, pos)
216    }
217}
218
219/// A mutex that implements async I/O traits.
220///
221/// This is a blocking mutex that adds the following impls:
222///
223/// - `impl<T> AsyncRead for Mutex<T> where T: AsyncRead + Unpin {}`
224/// - `impl<T> AsyncRead for &Mutex<T> where T: AsyncRead + Unpin {}`
225/// - `impl<T> AsyncWrite for Mutex<T> where T: AsyncWrite + Unpin {}`
226/// - `impl<T> AsyncWrite for &Mutex<T> where T: AsyncWrite + Unpin {}`
227/// - `impl<T> AsyncSeek for Mutex<T> where T: AsyncSeek + Unpin {}`
228/// - `impl<T> AsyncSeek for &Mutex<T> where T: AsyncSeek + Unpin {}`
229pub struct Mutex<T>(async_lock::Mutex<T>);
230
231impl<T> Mutex<T> {
232    /// Creates a new mutex.
233    ///
234    /// # Examples
235    ///
236    /// ```
237    /// use async_dup::Mutex;
238    ///
239    /// let mutex = Mutex::new(10);
240    /// ```
241    pub fn new(data: T) -> Mutex<T> {
242        Mutex(data.into())
243    }
244
245    /// Acquires the mutex, blocking the current thread until it is able to do so.
246    ///
247    /// Returns a guard that releases the mutex when dropped.
248    ///
249    /// # Examples
250    ///
251    /// ```
252    /// use async_dup::Mutex;
253    ///
254    /// let mutex = Mutex::new(10);
255    /// let guard = mutex.lock();
256    /// assert_eq!(*guard, 10);
257    /// ```
258    pub fn lock(&self) -> MutexGuard<'_, T> {
259        MutexGuard(self.0.lock_blocking())
260    }
261
262    /// Attempts to acquire the mutex.
263    ///
264    /// If the mutex could not be acquired at this time, then [`None`] is returned. Otherwise, a
265    /// guard is returned that releases the mutex when dropped.
266    ///
267    /// [`None`]: https://doc.rust-lang.org/std/option/enum.Option.html#variant.None
268    ///
269    /// # Examples
270    ///
271    /// ```
272    /// use async_dup::Mutex;
273    ///
274    /// let mutex = Mutex::new(10);
275    /// if let Some(guard) = mutex.try_lock() {
276    ///     assert_eq!(*guard, 10);
277    /// }
278    /// # ;
279    /// ```
280    pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
281        self.0.try_lock().map(MutexGuard)
282    }
283
284    /// Consumes the mutex, returning the underlying data.
285    ///
286    /// # Examples
287    ///
288    /// ```
289    /// use async_dup::Mutex;
290    ///
291    /// let mutex = Mutex::new(10);
292    /// assert_eq!(mutex.into_inner(), 10);
293    /// ```
294    pub fn into_inner(self) -> T {
295        self.0.into_inner()
296    }
297
298    /// Returns a mutable reference to the underlying data.
299    ///
300    /// Since this call borrows the mutex mutably, no actual locking takes place -- the mutable
301    /// borrow statically guarantees the mutex is not already acquired.
302    ///
303    /// # Examples
304    ///
305    /// ```
306    /// use async_dup::Mutex;
307    ///
308    /// let mut mutex = Mutex::new(0);
309    /// *mutex.get_mut() = 10;
310    /// assert_eq!(*mutex.lock(), 10);
311    /// ```
312    pub fn get_mut(&mut self) -> &mut T {
313        self.0.get_mut()
314    }
315}
316
317impl<T: fmt::Debug> fmt::Debug for Mutex<T> {
318    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319        struct Locked;
320        impl fmt::Debug for Locked {
321            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322                f.write_str("<locked>")
323            }
324        }
325
326        match self.try_lock() {
327            None => f.debug_struct("Mutex").field("data", &Locked).finish(),
328            Some(guard) => f.debug_struct("Mutex").field("data", &&*guard).finish(),
329        }
330    }
331}
332
333impl<T> From<T> for Mutex<T> {
334    fn from(val: T) -> Mutex<T> {
335        Mutex::new(val)
336    }
337}
338
339impl<T: Default> Default for Mutex<T> {
340    fn default() -> Mutex<T> {
341        Mutex::new(Default::default())
342    }
343}
344
345impl<T: AsyncRead + Unpin> AsyncRead for Mutex<T> {
346    fn poll_read(
347        self: Pin<&mut Self>,
348        cx: &mut Context<'_>,
349        buf: &mut [u8],
350    ) -> Poll<io::Result<usize>> {
351        Pin::new(&mut *self.lock()).poll_read(cx, buf)
352    }
353
354    fn poll_read_vectored(
355        self: Pin<&mut Self>,
356        cx: &mut Context<'_>,
357        bufs: &mut [IoSliceMut<'_>],
358    ) -> Poll<io::Result<usize>> {
359        Pin::new(&mut *self.lock()).poll_read_vectored(cx, bufs)
360    }
361}
362
363impl<T: AsyncRead + Unpin> AsyncRead for &Mutex<T> {
364    fn poll_read(
365        self: Pin<&mut Self>,
366        cx: &mut Context<'_>,
367        buf: &mut [u8],
368    ) -> Poll<io::Result<usize>> {
369        Pin::new(&mut *self.lock()).poll_read(cx, buf)
370    }
371
372    fn poll_read_vectored(
373        self: Pin<&mut Self>,
374        cx: &mut Context<'_>,
375        bufs: &mut [IoSliceMut<'_>],
376    ) -> Poll<io::Result<usize>> {
377        Pin::new(&mut *self.lock()).poll_read_vectored(cx, bufs)
378    }
379}
380
381impl<T: AsyncWrite + Unpin> AsyncWrite for Mutex<T> {
382    fn poll_write(
383        self: Pin<&mut Self>,
384        cx: &mut Context<'_>,
385        buf: &[u8],
386    ) -> Poll<io::Result<usize>> {
387        Pin::new(&mut *self.lock()).poll_write(cx, buf)
388    }
389
390    fn poll_write_vectored(
391        self: Pin<&mut Self>,
392        cx: &mut Context<'_>,
393        bufs: &[IoSlice<'_>],
394    ) -> Poll<io::Result<usize>> {
395        Pin::new(&mut *self.lock()).poll_write_vectored(cx, bufs)
396    }
397
398    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
399        Pin::new(&mut *self.lock()).poll_flush(cx)
400    }
401
402    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
403        Pin::new(&mut *self.lock()).poll_close(cx)
404    }
405}
406
407impl<T: AsyncWrite + Unpin> AsyncWrite for &Mutex<T> {
408    fn poll_write(
409        self: Pin<&mut Self>,
410        cx: &mut Context<'_>,
411        buf: &[u8],
412    ) -> Poll<io::Result<usize>> {
413        Pin::new(&mut *self.lock()).poll_write(cx, buf)
414    }
415
416    fn poll_write_vectored(
417        self: Pin<&mut Self>,
418        cx: &mut Context<'_>,
419        bufs: &[IoSlice<'_>],
420    ) -> Poll<io::Result<usize>> {
421        Pin::new(&mut *self.lock()).poll_write_vectored(cx, bufs)
422    }
423
424    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
425        Pin::new(&mut *self.lock()).poll_flush(cx)
426    }
427
428    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
429        Pin::new(&mut *self.lock()).poll_close(cx)
430    }
431}
432
433impl<T: AsyncSeek + Unpin> AsyncSeek for Mutex<T> {
434    fn poll_seek(
435        self: Pin<&mut Self>,
436        cx: &mut Context<'_>,
437        pos: SeekFrom,
438    ) -> Poll<io::Result<u64>> {
439        Pin::new(&mut *self.lock()).poll_seek(cx, pos)
440    }
441}
442
443impl<T: AsyncSeek + Unpin> AsyncSeek for &Mutex<T> {
444    fn poll_seek(
445        self: Pin<&mut Self>,
446        cx: &mut Context<'_>,
447        pos: SeekFrom,
448    ) -> Poll<io::Result<u64>> {
449        Pin::new(&mut *self.lock()).poll_seek(cx, pos)
450    }
451}
452
453/// A guard that releases the mutex when dropped.
454pub struct MutexGuard<'a, T>(async_lock::MutexGuard<'a, T>);
455
456impl<T: fmt::Debug> fmt::Debug for MutexGuard<'_, T> {
457    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
458        fmt::Debug::fmt(&**self, f)
459    }
460}
461
462impl<T: fmt::Display> fmt::Display for MutexGuard<'_, T> {
463    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
464        (**self).fmt(f)
465    }
466}
467
468impl<T> Deref for MutexGuard<'_, T> {
469    type Target = T;
470
471    fn deref(&self) -> &T {
472        &self.0
473    }
474}
475
476impl<T> DerefMut for MutexGuard<'_, T> {
477    fn deref_mut(&mut self) -> &mut T {
478        &mut self.0
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    fn is_send<T: Send>(_: &T) {}
487    fn is_sync<T: Sync>(_: &T) {}
488
489    #[test]
490    fn is_send_sync() {
491        let arc = Arc::new(());
492        let mutex = Mutex::new(());
493
494        is_send(&arc);
495        is_sync(&arc);
496
497        is_send(&mutex);
498        is_sync(&mutex);
499
500        let guard = mutex.lock();
501        is_send(&guard);
502        is_sync(&guard);
503    }
504}