async_io_typed/
duplex.rs

1use crate::read::{AsyncReadState, AsyncReadTyped, ChecksumReadState};
2use crate::write::{AsyncWriteState, AsyncWriteTyped, MessageFeatures};
3use crate::{ChecksumEnabled, Error, PROTOCOL_VERSION};
4use futures_core::Stream;
5use futures_io::{AsyncRead, AsyncWrite};
6use futures_util::{Sink, SinkExt};
7use serde::de::DeserializeOwned;
8use serde::Serialize;
9use std::collections::VecDeque;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13/// A duplex async connection for sending and receiving messages of a particular type.
14#[derive(Debug)]
15pub struct DuplexStreamTyped<
16    RW: AsyncRead + AsyncWrite + Unpin,
17    T: Serialize + DeserializeOwned + Unpin,
18> {
19    rw: Option<RW>,
20    read_state: AsyncReadState,
21    read_buffer: Vec<u8>,
22    write_state: AsyncWriteState,
23    write_buffer: Vec<u8>,
24    primed_values: VecDeque<T>,
25    checksum_read_state: ChecksumReadState,
26    message_features: MessageFeatures,
27}
28
29impl<RW: AsyncRead + AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin>
30    DuplexStreamTyped<RW, T>
31{
32    /// Creates a duplex typed reader and writer, initializing it with the given size limit specified in bytes.
33    /// Checksums are used to validate that messages arrived without corruption. **The checksum will only be used
34    /// if both the reader and the writer enable it. If either one disables it, then no checking is performed.**
35    ///
36    /// Be careful, large size limits might create a vulnerability to a Denial of Service attack.
37    pub fn new_with_limit(rw: RW, size_limit: u64, checksum_enabled: ChecksumEnabled) -> Self {
38        Self {
39            rw: Some(rw),
40            read_state: AsyncReadState::ReadingVersion {
41                version_in_progress: [0; 8],
42                version_in_progress_assigned: 0,
43            },
44            read_buffer: Vec::new(),
45            write_state: AsyncWriteState::WritingVersion {
46                version: PROTOCOL_VERSION.to_le_bytes(),
47                len_sent: 0,
48            },
49            write_buffer: Vec::new(),
50            primed_values: VecDeque::new(),
51            checksum_read_state: checksum_enabled.into(),
52            message_features: MessageFeatures {
53                size_limit,
54                checksum_enabled: checksum_enabled.into(),
55            },
56        }
57    }
58
59    /// Creates a duplex typed reader and writer, initializing it with a default size limit of 1 MB per message.
60    /// Checksums are used to validate that messages arrived without corruption. **The checksum will only be used
61    /// if both the reader and the writer enable it. If either one disables it, then no checking is performed.**
62    pub fn new(rw: RW, checksum_enabled: ChecksumEnabled) -> Self {
63        Self::new_with_limit(rw, 1024_u64.pow(2), checksum_enabled)
64    }
65
66    /// Returns a reference to the raw I/O primitive that this type is using.
67    pub fn inner(&self) -> &RW {
68        self.rw.as_ref().expect("infallible")
69    }
70
71    /// Consumes this `DuplexStreamTyped` and returns the raw I/O primitive that was being used.
72    pub fn into_inner(mut self) -> RW {
73        self.rw.take().expect("infallible")
74    }
75
76    /// `DuplexStreamTyped` keeps memory buffers for sending and receiving values which are the same size as the largest
77    /// message that's been sent or received. If the message size varies a lot, you might find yourself wasting
78    /// memory space. This function will reduce the memory usage as much as is possible without impeding
79    /// functioning. Overuse of this function may cause excessive memory allocations when the buffer
80    /// needs to grow.
81    pub fn optimize_memory_usage(&mut self) {
82        match self.read_state {
83            AsyncReadState::ReadingItem { .. } => self.read_buffer.shrink_to_fit(),
84            _ => {
85                self.read_buffer = Vec::new();
86            }
87        }
88        match self.write_state {
89            AsyncWriteState::WritingValue { .. } => self.write_buffer.shrink_to_fit(),
90            _ => {
91                self.write_buffer = Vec::new();
92            }
93        }
94    }
95
96    /// Reports the size of the memory buffers used for sending and receiving values. You can shrink these buffers as much as
97    /// possible with [`Self::optimize_memory_usage`].
98    pub fn current_memory_usage(&self) -> usize {
99        self.write_buffer.capacity() + self.read_buffer.capacity()
100    }
101
102    /// Returns true if checksums are enabled for this channel. This does not guarantee that the reader is
103    /// actually using those checksum values, it only reflects whether checksums are being sent.
104    pub fn checksum_send_enabled(&self) -> bool {
105        self.message_features.checksum_enabled
106    }
107
108    /// Returns true if checksums are enabled for this channel. This may become false after receiving the first value.
109    /// If that happens, the writer may have disabled checksums, so there is no checksum for the reader to check.
110    pub fn checksum_receive_enabled(&self) -> bool {
111        self.checksum_read_state == ChecksumReadState::Yes
112    }
113}
114
115impl<RW: AsyncRead + AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> Stream
116    for DuplexStreamTyped<RW, T>
117{
118    type Item = Result<T, Error>;
119
120    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
121        let Self {
122            ref mut rw,
123            ref mut read_state,
124            ref mut read_buffer,
125            ref message_features,
126            ref mut checksum_read_state,
127            ..
128        } = *self.as_mut();
129        AsyncReadTyped::poll_next_impl(
130            read_state,
131            rw.as_mut().expect("infallible"),
132            read_buffer,
133            message_features.size_limit,
134            checksum_read_state,
135            cx,
136        )
137    }
138}
139
140impl<RW: AsyncRead + AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> Sink<T>
141    for DuplexStreamTyped<RW, T>
142{
143    type Error = Error;
144
145    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
146        Poll::Ready(Ok(()))
147    }
148
149    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
150        self.primed_values.push_front(item);
151        Ok(())
152    }
153
154    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
155        let Self {
156            ref mut rw,
157            ref mut write_state,
158            ref mut write_buffer,
159            ref mut primed_values,
160            ref message_features,
161            ..
162        } = *self.as_mut();
163        let rw = rw.as_mut().expect("infallible");
164        match futures_core::ready!(AsyncWriteTyped::maybe_send(
165            rw,
166            write_state,
167            write_buffer,
168            primed_values,
169            *message_features,
170            cx,
171            false,
172        ))? {
173            Some(()) => {
174                // Send successful, poll_flush now
175                Pin::new(rw).poll_flush(cx).map(|r| r.map_err(Error::Io))
176            }
177            None => Poll::Ready(Ok(())),
178        }
179    }
180
181    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
182        let Self {
183            ref mut rw,
184            ref mut write_state,
185            ref mut write_buffer,
186            ref mut primed_values,
187            ref message_features,
188            ..
189        } = *self.as_mut();
190        let rw = rw.as_mut().expect("infallible");
191        match futures_core::ready!(AsyncWriteTyped::maybe_send(
192            rw,
193            write_state,
194            write_buffer,
195            primed_values,
196            *message_features,
197            cx,
198            true,
199        ))? {
200            Some(()) => {
201                // Send successful, poll_close now
202                Pin::new(rw).poll_close(cx).map(|r| r.map_err(Error::Io))
203            }
204            None => Poll::Ready(Ok(())),
205        }
206    }
207}
208
209impl<RW: AsyncRead + AsyncWrite + Unpin, T: Serialize + Unpin + DeserializeOwned> Drop
210    for DuplexStreamTyped<RW, T>
211{
212    fn drop(&mut self) {
213        if self.rw.is_some() {
214            let _ = futures_executor::block_on(SinkExt::close(self));
215        }
216    }
217}