async_io_typed/
write.rs

1use crate::{
2    ChecksumEnabled, Error, CHECKSUM_DISABLED, CHECKSUM_ENABLED, PROTOCOL_VERSION, U16_MARKER,
3    U32_MARKER, U64_MARKER, ZST_MARKER,
4};
5use bincode::Options;
6use futures_io::AsyncWrite;
7use futures_util::{Sink, SinkExt};
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use siphasher::sip::SipHasher;
11use std::collections::VecDeque;
12use std::hash::Hasher;
13use std::mem::size_of;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17/// Provides the ability to write `serde` compatible types to any type that implements `futures::io::AsyncWrite`.
18#[derive(Debug)]
19pub struct AsyncWriteTyped<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> {
20    raw: Option<W>,
21    write_buffer: Vec<u8>,
22    state: AsyncWriteState,
23    primed_values: VecDeque<T>,
24    message_features: MessageFeatures,
25}
26
27#[derive(Debug)]
28pub(crate) enum AsyncWriteState {
29    WritingVersion { version: [u8; 8], len_sent: usize },
30    WritingChecksumEnabled,
31    Idle,
32    WritingValue { bytes_sent: usize },
33    Closing,
34    Closed,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub(crate) struct MessageFeatures {
39    pub size_limit: u64,
40    pub checksum_enabled: bool,
41}
42
43impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> Sink<T>
44    for AsyncWriteTyped<W, T>
45{
46    type Error = Error;
47
48    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
49        Poll::Ready(Ok(()))
50    }
51
52    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
53        self.primed_values.push_front(item);
54        Ok(())
55    }
56
57    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58        let Self {
59            ref mut raw,
60            ref mut write_buffer,
61            ref mut state,
62            ref mut primed_values,
63            ref message_features,
64        } = *self.as_mut();
65        match futures_core::ready!(Self::maybe_send(
66            raw.as_mut().expect("infallible"),
67            state,
68            write_buffer,
69            primed_values,
70            *message_features,
71            cx,
72            false,
73        ))? {
74            Some(()) => {
75                // Send successful, poll_flush now
76                Pin::new(raw.as_mut().expect("infallible"))
77                    .poll_flush(cx)
78                    .map(|r| r.map_err(Error::Io))
79            }
80            None => Poll::Ready(Ok(())),
81        }
82    }
83
84    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
85        let Self {
86            ref mut raw,
87            ref mut state,
88            ref mut write_buffer,
89            ref mut primed_values,
90            ref message_features,
91        } = *self.as_mut();
92        match futures_core::ready!(Self::maybe_send(
93            raw.as_mut().expect("infallible"),
94            state,
95            write_buffer,
96            primed_values,
97            *message_features,
98            cx,
99            true,
100        ))? {
101            Some(()) => {
102                // Send successful, poll_close now
103                Pin::new(raw.as_mut().expect("infallible"))
104                    .poll_close(cx)
105                    .map(|r| r.map_err(Error::Io))
106            }
107            None => Poll::Ready(Ok(())),
108        }
109    }
110}
111
112impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteTyped<W, T> {
113    pub(crate) fn maybe_send(
114        raw: &mut W,
115        state: &mut AsyncWriteState,
116        write_buffer: &mut Vec<u8>,
117        primed_values: &mut VecDeque<T>,
118        message_features: MessageFeatures,
119        cx: &mut Context<'_>,
120        closing: bool,
121    ) -> Poll<Result<Option<()>, Error>> {
122        let MessageFeatures {
123            checksum_enabled,
124            size_limit,
125        } = message_features;
126        loop {
127            return match state {
128                AsyncWriteState::WritingVersion { version, len_sent } => {
129                    while *len_sent < size_of::<u64>() {
130                        let len = futures_core::ready!(
131                            Pin::new(&mut *raw).poll_write(cx, &version[(*len_sent)..])
132                        )?;
133                        *len_sent += len;
134                    }
135                    *state = AsyncWriteState::WritingChecksumEnabled;
136                    continue;
137                }
138                AsyncWriteState::WritingChecksumEnabled => {
139                    let to_send = if checksum_enabled {
140                        CHECKSUM_ENABLED
141                    } else {
142                        CHECKSUM_DISABLED
143                    };
144                    if futures_core::ready!(Pin::new(&mut *raw).poll_write(cx, &[to_send]))? == 1 {
145                        *state = AsyncWriteState::Idle;
146                    }
147                    continue;
148                }
149                AsyncWriteState::Idle => {
150                    if let Some(item) = primed_values.pop_back() {
151                        write_buffer.clear();
152                        let length = crate::bincode_options(size_limit)
153                            .serialized_size(&item)
154                            .map_err(Error::Bincode)?;
155                        if length > size_limit {
156                            return Poll::Ready(Err(Error::SentMessageTooLarge));
157                        }
158                        if length == 0 {
159                            write_buffer.push(ZST_MARKER);
160                        } else if length < U16_MARKER as u64 {
161                            write_buffer.extend((length as u8).to_le_bytes());
162                        } else if length < 2_u64.pow(16) {
163                            write_buffer.push(U16_MARKER);
164                            write_buffer.extend((length as u16).to_le_bytes());
165                        } else if length < 2_u64.pow(32) {
166                            write_buffer.push(U32_MARKER);
167                            write_buffer.extend((length as u32).to_le_bytes());
168                        } else {
169                            write_buffer.push(U64_MARKER);
170                            write_buffer.extend(length.to_le_bytes());
171                        }
172                        // Save the length... of the length value.
173                        let length_length = write_buffer.len();
174                        crate::bincode_options(size_limit)
175                            .serialize_into(&mut *write_buffer, &item)
176                            .map_err(Error::Bincode)?;
177                        if checksum_enabled {
178                            let mut hasher = SipHasher::new();
179                            hasher.write(&write_buffer[length_length..]);
180                            let checksum = hasher.finish();
181                            write_buffer.extend(checksum.to_le_bytes());
182                        }
183                        *state = AsyncWriteState::WritingValue { bytes_sent: 0 };
184                        continue;
185                    } else if closing {
186                        *state = AsyncWriteState::Closing;
187                        continue;
188                    } else {
189                        Poll::Ready(Ok(Some(())))
190                    }
191                }
192                AsyncWriteState::WritingValue { bytes_sent } => {
193                    while *bytes_sent < write_buffer.len() {
194                        let len = futures_core::ready!(
195                            Pin::new(&mut *raw).poll_write(cx, &write_buffer[*bytes_sent..])
196                        )?;
197                        *bytes_sent += len;
198                    }
199                    *state = AsyncWriteState::Idle;
200                    if primed_values.is_empty() {
201                        return Poll::Ready(Ok(Some(())));
202                    }
203                    continue;
204                }
205                AsyncWriteState::Closing => {
206                    let len = futures_core::ready!(Pin::new(&mut *raw).poll_write(cx, &[0]))?;
207                    if len == 1 {
208                        *state = AsyncWriteState::Closed;
209                        Poll::Ready(Ok(Some(())))
210                    } else {
211                        continue;
212                    }
213                }
214                AsyncWriteState::Closed => Poll::Ready(Ok(None)),
215            };
216        }
217    }
218}
219
220impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteTyped<W, T> {
221    /// Creates a typed writer, initializing it with the given size limit specified in bytes.
222    /// Checksums are used to validate that messages arrived without corruption. **The checksum will only be used
223    /// if both the reader and the writer enable it. If either one disables it, then no checking is performed.**
224    ///
225    /// Be careful, large size limits might create a vulnerability to a Denial of Service attack.
226    pub fn new_with_limit(raw: W, size_limit: u64, checksum_enabled: ChecksumEnabled) -> Self {
227        Self {
228            raw: Some(raw),
229            write_buffer: Vec::new(),
230            state: AsyncWriteState::WritingVersion {
231                version: PROTOCOL_VERSION.to_le_bytes(),
232                len_sent: 0,
233            },
234            message_features: MessageFeatures {
235                size_limit,
236                checksum_enabled: checksum_enabled.into(),
237            },
238            primed_values: VecDeque::new(),
239        }
240    }
241
242    /// Creates a typed writer, initializing it with a default size limit of 1 MB per message.
243    /// Checksums are used to validate that messages arrived without corruption. **The checksum will only be used
244    /// if both the reader and the writer enable it. If either one disables it, then no checking is performed.**
245    pub fn new(raw: W, checksum_enabled: ChecksumEnabled) -> Self {
246        Self::new_with_limit(raw, 1024u64.pow(2), checksum_enabled)
247    }
248
249    /// Returns a reference to the raw I/O primitive that this type is using.
250    pub fn inner(&self) -> &W {
251        self.raw.as_ref().expect("infallible")
252    }
253
254    /// Consumes this `AsyncWriteTyped` and returns the raw I/O primitive that was being used.
255    pub fn into_inner(mut self) -> W {
256        self.raw.take().expect("infallible")
257    }
258
259    /// `AsyncWriteTyped` keeps a memory buffer for sending values which is the same size as the largest
260    /// message that's been sent. If the message size varies a lot, you might find yourself wasting
261    /// memory space. This function will reduce the memory usage as much as is possible without impeding
262    /// functioning. Overuse of this function may cause excessive memory allocations when the buffer
263    /// needs to grow.
264    pub fn optimize_memory_usage(&mut self) {
265        match self.state {
266            AsyncWriteState::WritingValue { .. } => self.write_buffer.shrink_to_fit(),
267            _ => {
268                self.write_buffer = Vec::new();
269            }
270        }
271    }
272
273    /// Reports the size of the memory buffer used for sending values. You can shrink this buffer as much as
274    /// possible with [`Self::optimize_memory_usage`].
275    pub fn current_memory_usage(&self) -> usize {
276        self.write_buffer.capacity()
277    }
278
279    /// Returns true if checksums are enabled for this channel. This does not guarantee that the reader is
280    /// actually using those checksum values, it only reflects whether checksums are being sent.
281    pub fn checksum_enabled(&self) -> bool {
282        self.message_features.checksum_enabled
283    }
284}
285
286impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> Drop
287    for AsyncWriteTyped<W, T>
288{
289    fn drop(&mut self) {
290        // This will panic if raw was already taken.
291        if self.raw.is_some() {
292            let _ = futures_executor::block_on(SinkExt::close(self));
293        }
294    }
295}