1use crate::{
2 ChecksumEnabled, Error, CHECKSUM_DISABLED, CHECKSUM_ENABLED, PROTOCOL_VERSION, U16_MARKER,
3 U32_MARKER, U64_MARKER, ZST_MARKER,
4};
5use bincode::Options;
6use futures_io::AsyncWrite;
7use futures_util::{Sink, SinkExt};
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use siphasher::sip::SipHasher;
11use std::collections::VecDeque;
12use std::hash::Hasher;
13use std::mem::size_of;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17#[derive(Debug)]
19pub struct AsyncWriteTyped<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> {
20 raw: Option<W>,
21 write_buffer: Vec<u8>,
22 state: AsyncWriteState,
23 primed_values: VecDeque<T>,
24 message_features: MessageFeatures,
25}
26
27#[derive(Debug)]
28pub(crate) enum AsyncWriteState {
29 WritingVersion { version: [u8; 8], len_sent: usize },
30 WritingChecksumEnabled,
31 Idle,
32 WritingValue { bytes_sent: usize },
33 Closing,
34 Closed,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub(crate) struct MessageFeatures {
39 pub size_limit: u64,
40 pub checksum_enabled: bool,
41}
42
43impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> Sink<T>
44 for AsyncWriteTyped<W, T>
45{
46 type Error = Error;
47
48 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
49 Poll::Ready(Ok(()))
50 }
51
52 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
53 self.primed_values.push_front(item);
54 Ok(())
55 }
56
57 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
58 let Self {
59 ref mut raw,
60 ref mut write_buffer,
61 ref mut state,
62 ref mut primed_values,
63 ref message_features,
64 } = *self.as_mut();
65 match futures_core::ready!(Self::maybe_send(
66 raw.as_mut().expect("infallible"),
67 state,
68 write_buffer,
69 primed_values,
70 *message_features,
71 cx,
72 false,
73 ))? {
74 Some(()) => {
75 Pin::new(raw.as_mut().expect("infallible"))
77 .poll_flush(cx)
78 .map(|r| r.map_err(Error::Io))
79 }
80 None => Poll::Ready(Ok(())),
81 }
82 }
83
84 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
85 let Self {
86 ref mut raw,
87 ref mut state,
88 ref mut write_buffer,
89 ref mut primed_values,
90 ref message_features,
91 } = *self.as_mut();
92 match futures_core::ready!(Self::maybe_send(
93 raw.as_mut().expect("infallible"),
94 state,
95 write_buffer,
96 primed_values,
97 *message_features,
98 cx,
99 true,
100 ))? {
101 Some(()) => {
102 Pin::new(raw.as_mut().expect("infallible"))
104 .poll_close(cx)
105 .map(|r| r.map_err(Error::Io))
106 }
107 None => Poll::Ready(Ok(())),
108 }
109 }
110}
111
112impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteTyped<W, T> {
113 pub(crate) fn maybe_send(
114 raw: &mut W,
115 state: &mut AsyncWriteState,
116 write_buffer: &mut Vec<u8>,
117 primed_values: &mut VecDeque<T>,
118 message_features: MessageFeatures,
119 cx: &mut Context<'_>,
120 closing: bool,
121 ) -> Poll<Result<Option<()>, Error>> {
122 let MessageFeatures {
123 checksum_enabled,
124 size_limit,
125 } = message_features;
126 loop {
127 return match state {
128 AsyncWriteState::WritingVersion { version, len_sent } => {
129 while *len_sent < size_of::<u64>() {
130 let len = futures_core::ready!(
131 Pin::new(&mut *raw).poll_write(cx, &version[(*len_sent)..])
132 )?;
133 *len_sent += len;
134 }
135 *state = AsyncWriteState::WritingChecksumEnabled;
136 continue;
137 }
138 AsyncWriteState::WritingChecksumEnabled => {
139 let to_send = if checksum_enabled {
140 CHECKSUM_ENABLED
141 } else {
142 CHECKSUM_DISABLED
143 };
144 if futures_core::ready!(Pin::new(&mut *raw).poll_write(cx, &[to_send]))? == 1 {
145 *state = AsyncWriteState::Idle;
146 }
147 continue;
148 }
149 AsyncWriteState::Idle => {
150 if let Some(item) = primed_values.pop_back() {
151 write_buffer.clear();
152 let length = crate::bincode_options(size_limit)
153 .serialized_size(&item)
154 .map_err(Error::Bincode)?;
155 if length > size_limit {
156 return Poll::Ready(Err(Error::SentMessageTooLarge));
157 }
158 if length == 0 {
159 write_buffer.push(ZST_MARKER);
160 } else if length < U16_MARKER as u64 {
161 write_buffer.extend((length as u8).to_le_bytes());
162 } else if length < 2_u64.pow(16) {
163 write_buffer.push(U16_MARKER);
164 write_buffer.extend((length as u16).to_le_bytes());
165 } else if length < 2_u64.pow(32) {
166 write_buffer.push(U32_MARKER);
167 write_buffer.extend((length as u32).to_le_bytes());
168 } else {
169 write_buffer.push(U64_MARKER);
170 write_buffer.extend(length.to_le_bytes());
171 }
172 let length_length = write_buffer.len();
174 crate::bincode_options(size_limit)
175 .serialize_into(&mut *write_buffer, &item)
176 .map_err(Error::Bincode)?;
177 if checksum_enabled {
178 let mut hasher = SipHasher::new();
179 hasher.write(&write_buffer[length_length..]);
180 let checksum = hasher.finish();
181 write_buffer.extend(checksum.to_le_bytes());
182 }
183 *state = AsyncWriteState::WritingValue { bytes_sent: 0 };
184 continue;
185 } else if closing {
186 *state = AsyncWriteState::Closing;
187 continue;
188 } else {
189 Poll::Ready(Ok(Some(())))
190 }
191 }
192 AsyncWriteState::WritingValue { bytes_sent } => {
193 while *bytes_sent < write_buffer.len() {
194 let len = futures_core::ready!(
195 Pin::new(&mut *raw).poll_write(cx, &write_buffer[*bytes_sent..])
196 )?;
197 *bytes_sent += len;
198 }
199 *state = AsyncWriteState::Idle;
200 if primed_values.is_empty() {
201 return Poll::Ready(Ok(Some(())));
202 }
203 continue;
204 }
205 AsyncWriteState::Closing => {
206 let len = futures_core::ready!(Pin::new(&mut *raw).poll_write(cx, &[0]))?;
207 if len == 1 {
208 *state = AsyncWriteState::Closed;
209 Poll::Ready(Ok(Some(())))
210 } else {
211 continue;
212 }
213 }
214 AsyncWriteState::Closed => Poll::Ready(Ok(None)),
215 };
216 }
217 }
218}
219
220impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncWriteTyped<W, T> {
221 pub fn new_with_limit(raw: W, size_limit: u64, checksum_enabled: ChecksumEnabled) -> Self {
227 Self {
228 raw: Some(raw),
229 write_buffer: Vec::new(),
230 state: AsyncWriteState::WritingVersion {
231 version: PROTOCOL_VERSION.to_le_bytes(),
232 len_sent: 0,
233 },
234 message_features: MessageFeatures {
235 size_limit,
236 checksum_enabled: checksum_enabled.into(),
237 },
238 primed_values: VecDeque::new(),
239 }
240 }
241
242 pub fn new(raw: W, checksum_enabled: ChecksumEnabled) -> Self {
246 Self::new_with_limit(raw, 1024u64.pow(2), checksum_enabled)
247 }
248
249 pub fn inner(&self) -> &W {
251 self.raw.as_ref().expect("infallible")
252 }
253
254 pub fn into_inner(mut self) -> W {
256 self.raw.take().expect("infallible")
257 }
258
259 pub fn optimize_memory_usage(&mut self) {
265 match self.state {
266 AsyncWriteState::WritingValue { .. } => self.write_buffer.shrink_to_fit(),
267 _ => {
268 self.write_buffer = Vec::new();
269 }
270 }
271 }
272
273 pub fn current_memory_usage(&self) -> usize {
276 self.write_buffer.capacity()
277 }
278
279 pub fn checksum_enabled(&self) -> bool {
282 self.message_features.checksum_enabled
283 }
284}
285
286impl<W: AsyncWrite + Unpin, T: Serialize + DeserializeOwned + Unpin> Drop
287 for AsyncWriteTyped<W, T>
288{
289 fn drop(&mut self) {
290 if self.raw.is_some() {
292 let _ = futures_executor::block_on(SinkExt::close(self));
293 }
294 }
295}