1use crate::read::{AsyncReadState, AsyncReadTyped, ChecksumReadState};
2use crate::write::{AsyncWriteState, AsyncWriteTyped, MessageFeatures};
3use crate::{ChecksumEnabled, Error, PROTOCOL_VERSION};
4use futures_core::Stream;
5use futures_io::{AsyncRead, AsyncWrite};
6use futures_util::{Sink, SinkExt};
7use serde::de::DeserializeOwned;
8use serde::Serialize;
9use std::collections::VecDeque;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13#[derive(Debug)]
15pub struct DuplexStreamTyped<
16 RW: AsyncRead + AsyncWrite + Unpin,
17 T: Serialize + DeserializeOwned + Unpin,
18> {
19 rw: Option<RW>,
20 read_state: AsyncReadState,
21 read_buffer: Vec<u8>,
22 write_state: AsyncWriteState,
23 write_buffer: Vec<u8>,
24 primed_values: VecDeque<T>,
25 checksum_read_state: ChecksumReadState,
26 message_features: MessageFeatures,
27}
28
29impl<RW: AsyncRead + AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin>
30 DuplexStreamTyped<RW, T>
31{
32 pub fn new_with_limit(rw: RW, size_limit: u64, checksum_enabled: ChecksumEnabled) -> Self {
38 Self {
39 rw: Some(rw),
40 read_state: AsyncReadState::ReadingVersion {
41 version_in_progress: [0; 8],
42 version_in_progress_assigned: 0,
43 },
44 read_buffer: Vec::new(),
45 write_state: AsyncWriteState::WritingVersion {
46 version: PROTOCOL_VERSION.to_le_bytes(),
47 len_sent: 0,
48 },
49 write_buffer: Vec::new(),
50 primed_values: VecDeque::new(),
51 checksum_read_state: checksum_enabled.into(),
52 message_features: MessageFeatures {
53 size_limit,
54 checksum_enabled: checksum_enabled.into(),
55 },
56 }
57 }
58
59 pub fn new(rw: RW, checksum_enabled: ChecksumEnabled) -> Self {
63 Self::new_with_limit(rw, 1024_u64.pow(2), checksum_enabled)
64 }
65
66 pub fn inner(&self) -> &RW {
68 self.rw.as_ref().expect("infallible")
69 }
70
71 pub fn into_inner(mut self) -> RW {
73 self.rw.take().expect("infallible")
74 }
75
76 pub fn optimize_memory_usage(&mut self) {
82 match self.read_state {
83 AsyncReadState::ReadingItem { .. } => self.read_buffer.shrink_to_fit(),
84 _ => {
85 self.read_buffer = Vec::new();
86 }
87 }
88 match self.write_state {
89 AsyncWriteState::WritingValue { .. } => self.write_buffer.shrink_to_fit(),
90 _ => {
91 self.write_buffer = Vec::new();
92 }
93 }
94 }
95
96 pub fn current_memory_usage(&self) -> usize {
99 self.write_buffer.capacity() + self.read_buffer.capacity()
100 }
101
102 pub fn checksum_send_enabled(&self) -> bool {
105 self.message_features.checksum_enabled
106 }
107
108 pub fn checksum_receive_enabled(&self) -> bool {
111 self.checksum_read_state == ChecksumReadState::Yes
112 }
113}
114
115impl<RW: AsyncRead + AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> Stream
116 for DuplexStreamTyped<RW, T>
117{
118 type Item = Result<T, Error>;
119
120 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
121 let Self {
122 ref mut rw,
123 ref mut read_state,
124 ref mut read_buffer,
125 ref message_features,
126 ref mut checksum_read_state,
127 ..
128 } = *self.as_mut();
129 AsyncReadTyped::poll_next_impl(
130 read_state,
131 rw.as_mut().expect("infallible"),
132 read_buffer,
133 message_features.size_limit,
134 checksum_read_state,
135 cx,
136 )
137 }
138}
139
140impl<RW: AsyncRead + AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> Sink<T>
141 for DuplexStreamTyped<RW, T>
142{
143 type Error = Error;
144
145 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
146 Poll::Ready(Ok(()))
147 }
148
149 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
150 self.primed_values.push_front(item);
151 Ok(())
152 }
153
154 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
155 let Self {
156 ref mut rw,
157 ref mut write_state,
158 ref mut write_buffer,
159 ref mut primed_values,
160 ref message_features,
161 ..
162 } = *self.as_mut();
163 let rw = rw.as_mut().expect("infallible");
164 match futures_core::ready!(AsyncWriteTyped::maybe_send(
165 rw,
166 write_state,
167 write_buffer,
168 primed_values,
169 *message_features,
170 cx,
171 false,
172 ))? {
173 Some(()) => {
174 Pin::new(rw).poll_flush(cx).map(|r| r.map_err(Error::Io))
176 }
177 None => Poll::Ready(Ok(())),
178 }
179 }
180
181 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
182 let Self {
183 ref mut rw,
184 ref mut write_state,
185 ref mut write_buffer,
186 ref mut primed_values,
187 ref message_features,
188 ..
189 } = *self.as_mut();
190 let rw = rw.as_mut().expect("infallible");
191 match futures_core::ready!(AsyncWriteTyped::maybe_send(
192 rw,
193 write_state,
194 write_buffer,
195 primed_values,
196 *message_features,
197 cx,
198 true,
199 ))? {
200 Some(()) => {
201 Pin::new(rw).poll_close(cx).map(|r| r.map_err(Error::Io))
203 }
204 None => Poll::Ready(Ok(())),
205 }
206 }
207}
208
209impl<RW: AsyncRead + AsyncWrite + Unpin, T: Serialize + Unpin + DeserializeOwned> Drop
210 for DuplexStreamTyped<RW, T>
211{
212 fn drop(&mut self) {
213 if self.rw.is_some() {
214 let _ = futures_executor::block_on(SinkExt::close(self));
215 }
216 }
217}