async_hal/io/
async_write.rs

1use core::{
2    ops::DerefMut,
3    pin::Pin,
4    task::{Context, Poll},
5};
6use void::Void;
7
8use super::{write_all::write_all, WriteAll};
9
10pub trait AsyncWrite {
11    type Error;
12
13    fn poll_write(
14        self: Pin<&mut Self>,
15        cx: &mut Context,
16        buf: &[u8],
17    ) -> Poll<Result<usize, Self::Error>>;
18
19    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>>;
20
21    /// Attempts to write an entire buffer into this writer.
22    ///
23    /// Equivalent to:
24    ///
25    /// ```ignore
26    /// async fn write_all(&mut self, buf: &[u8]) -> io::Result<()>;
27    /// ```
28    ///
29    /// This method will continuously call [`write`] until there is no more data
30    /// to be written. This method will not return until the entire buffer
31    /// has been successfully written or such an error occurs. The first
32    /// error generated from this method will be returned.
33    ///
34    /// # Cancel safety
35    ///
36    /// This method is not cancellation safe. If it is used as the event
37    /// in a [`tokio::select!`](crate::select) statement and some other
38    /// branch completes first, then the provided buffer may have been
39    /// partially written, but future calls to `write_all` will start over
40    /// from the beginning of the buffer.
41    ///
42    /// # Errors
43    ///
44    /// This function will return the first error that [`write`] returns.
45    /// [`write`]: AsyncWrite::write
46    fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> WriteAll<'a, Self>
47    where
48        Self: Unpin,
49    {
50        write_all(self, buf)
51    }
52}
53
54impl<T: ?Sized + AsyncWrite + Unpin> AsyncWrite for &mut T {
55    type Error = T::Error;
56
57    fn poll_write(
58        mut self: Pin<&mut Self>,
59        cx: &mut Context,
60        buf: &[u8],
61    ) -> Poll<Result<usize, Self::Error>> {
62        Pin::new(&mut **self).poll_write(cx, buf)
63    }
64
65    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
66        Pin::new(&mut **self).poll_flush(cx)
67    }
68}
69
70impl<P> AsyncWrite for Pin<P>
71where
72    P: DerefMut + Unpin,
73    P::Target: AsyncWrite,
74{
75    type Error = <P::Target as AsyncWrite>::Error;
76
77    fn poll_write(
78        self: Pin<&mut Self>,
79        cx: &mut Context<'_>,
80        buf: &[u8],
81    ) -> Poll<Result<usize, Self::Error>> {
82        self.get_mut().as_mut().poll_write(cx, buf)
83    }
84
85    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86        self.get_mut().as_mut().poll_flush(cx)
87    }
88}
89
90impl AsyncWrite for &'_ mut [u8] {
91    type Error = Void;
92
93    fn poll_write(
94        mut self: Pin<&mut Self>,
95        _cx: &mut Context,
96        buf: &[u8],
97    ) -> Poll<Result<usize, Self::Error>> {
98        let amt = core::cmp::min(buf.len(), self.len());
99        let (a, b) = core::mem::replace(&mut *self, &mut []).split_at_mut(amt);
100        a.copy_from_slice(&buf[..amt]);
101        *self = b;
102        Poll::Ready(Ok(amt))
103    }
104
105    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
106        Poll::Ready(Ok(()))
107    }
108}