1use byteorder::{ByteOrder, NetworkEndian};
2use bytes::Bytes;
3use transformable::{utils::*, Transformable};
4
5const MAX_INLINED_BYTES: usize = 64;
6
7#[viewit::viewit(getters(vis_all = "pub"), setters(vis_all = "pub", prefix = "with"))]
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
11#[cfg_attr(
12 feature = "rkyv",
13 derive(::rkyv::Serialize, ::rkyv::Deserialize, ::rkyv::Archive)
14)]
15#[cfg_attr(
16 feature = "rkyv",
17 rkyv(compare(PartialEq), derive(Debug, PartialEq, Eq, Hash),)
18)]
19pub struct Ack {
20 #[viewit(
22 getter(const, attrs(doc = "Returns the sequence number of the ack")),
23 setter(
24 const,
25 attrs(doc = "Sets the sequence number of the ack (Builder pattern)")
26 )
27 )]
28 sequence_number: u32,
29 #[viewit(
31 getter(const, style = "ref", attrs(doc = "Returns the payload of the ack")),
32 setter(attrs(doc = "Sets the payload of the ack (Builder pattern)"))
33 )]
34 payload: Bytes,
35}
36
37impl Ack {
38 #[inline]
40 pub const fn new(sequence_number: u32) -> Self {
41 Self {
42 sequence_number,
43 payload: Bytes::new(),
44 }
45 }
46
47 #[inline]
49 pub fn set_sequence_number(&mut self, sequence_number: u32) -> &mut Self {
50 self.sequence_number = sequence_number;
51 self
52 }
53
54 #[inline]
56 pub fn set_payload(&mut self, payload: Bytes) -> &mut Self {
57 self.payload = payload;
58 self
59 }
60
61 #[inline]
63 pub fn into_components(self) -> (u32, Bytes) {
64 (self.sequence_number, self.payload)
65 }
66}
67
68#[derive(Debug, thiserror::Error)]
70pub enum AckTransformError {
71 #[error("encode buffer too small")]
73 InsufficientBuffer(#[from] InsufficientBuffer),
74 #[error("the buffer did not contain enough bytes to decode Ack")]
76 NotEnoughBytes,
77 #[error("fail to decode sequence number: {0}")]
79 DecodeVarint(#[from] DecodeVarintError),
80}
81
82impl Transformable for Ack {
83 type Error = AckTransformError;
84
85 fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
86 let encoded_len = self.encoded_len();
87
88 if encoded_len > dst.len() {
89 return Err(Self::Error::InsufficientBuffer(
90 InsufficientBuffer::with_information(encoded_len as u64, dst.len() as u64),
91 ));
92 }
93
94 let mut offset = 0;
95 NetworkEndian::write_u32(dst, encoded_len as u32);
96 offset += core::mem::size_of::<u32>();
97 NetworkEndian::write_u32(&mut dst[offset..], self.sequence_number);
98 offset += core::mem::size_of::<u32>();
99
100 let payload_size = self.payload.len();
101 if !self.payload.is_empty() {
102 dst[offset..offset + payload_size].copy_from_slice(&self.payload);
103 offset += payload_size;
104 }
105
106 debug_assert_eq!(
107 offset, encoded_len,
108 "expect bytes written ({encoded_len}) not match actual bytes writtend ({offset})"
109 );
110 Ok(offset)
111 }
112
113 fn encoded_len(&self) -> usize {
114 core::mem::size_of::<u32>() + core::mem::size_of::<u32>() + self.payload.len()
115 }
116
117 fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
118 where
119 Self: Sized,
120 {
121 let mut offset = 0;
122 if core::mem::size_of::<u32>() > src.len() {
123 return Err(Self::Error::NotEnoughBytes);
124 }
125
126 let total_len = NetworkEndian::read_u32(&src[offset..]);
127 offset += core::mem::size_of::<u32>();
128 let sequence_number = NetworkEndian::read_u32(&src[offset..]);
129 offset += core::mem::size_of::<u32>();
130
131 if total_len as usize == 2 * core::mem::size_of::<u32>() {
132 return Ok((
133 offset,
134 Self {
135 sequence_number,
136 payload: Bytes::new(),
137 },
138 ));
139 }
140
141 if total_len as usize - core::mem::size_of::<u32>() > src.len() {
142 return Err(Self::Error::NotEnoughBytes);
143 }
144
145 let payload = Bytes::copy_from_slice(&src[offset..total_len as usize]);
146 Ok((
147 total_len as usize,
148 Self {
149 sequence_number,
150 payload,
151 },
152 ))
153 }
154
155 fn decode_from_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<(usize, Self)>
156 where
157 Self: Sized,
158 {
159 let mut buf = [0; 8];
160 reader.read_exact(&mut buf)?;
161 let total_len = NetworkEndian::read_u32(&buf) as usize;
162 let sequence_number = NetworkEndian::read_u32(&buf[core::mem::size_of::<u32>()..]);
163
164 if total_len == 2 * core::mem::size_of::<u32>() {
165 return Ok((
166 total_len,
167 Self {
168 sequence_number,
169 payload: Bytes::new(),
170 },
171 ));
172 }
173
174 let payload_len = total_len - core::mem::size_of::<u32>() * 2;
175 if payload_len <= MAX_INLINED_BYTES {
176 let mut buf = [0; MAX_INLINED_BYTES];
177 reader.read_exact(&mut buf[..payload_len])?;
178 let payload = Bytes::copy_from_slice(&buf[..payload_len]);
179 Ok((
180 total_len,
181 Self {
182 sequence_number,
183 payload,
184 },
185 ))
186 } else {
187 let mut payload = vec![0; payload_len];
188 reader.read_exact(&mut payload)?;
189 Ok((
190 total_len,
191 Self {
192 sequence_number,
193 payload: payload.into(),
194 },
195 ))
196 }
197 }
198
199 async fn decode_from_async_reader<R: futures::AsyncRead + Send + Unpin>(
200 reader: &mut R,
201 ) -> std::io::Result<(usize, Self)>
202 where
203 Self: Sized,
204 {
205 use futures::AsyncReadExt;
206
207 let mut buf = [0; 8];
208 reader.read_exact(&mut buf).await?;
209
210 let total_len = NetworkEndian::read_u32(&buf) as usize;
211 let sequence_number = NetworkEndian::read_u32(&buf[core::mem::size_of::<u32>()..]);
212
213 if total_len == 2 * core::mem::size_of::<u32>() {
214 return Ok((
215 total_len,
216 Self {
217 sequence_number,
218 payload: Bytes::new(),
219 },
220 ));
221 }
222
223 let payload_len = total_len - core::mem::size_of::<u32>() * 2;
224 if payload_len <= MAX_INLINED_BYTES {
225 let mut buf = [0; MAX_INLINED_BYTES];
226 reader.read_exact(&mut buf[..payload_len]).await?;
227 let payload = Bytes::copy_from_slice(&buf[..payload_len]);
228 Ok((
229 total_len,
230 Self {
231 sequence_number,
232 payload,
233 },
234 ))
235 } else {
236 let mut payload = vec![0; payload_len];
237 reader.read_exact(&mut payload).await?;
238 Ok((
239 total_len,
240 Self {
241 sequence_number,
242 payload: payload.into(),
243 },
244 ))
245 }
246 }
247}
248
249#[viewit::viewit(
253 vis_all = "pub(crate)",
254 getters(vis_all = "pub"),
255 setters(vis_all = "pub", prefix = "with")
256)]
257#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
258#[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
259#[cfg_attr(feature = "serde", serde(transparent))]
260#[cfg_attr(
261 feature = "rkyv",
262 derive(::rkyv::Serialize, ::rkyv::Deserialize, ::rkyv::Archive)
263)]
264#[cfg_attr(
265 feature = "rkyv",
266 rkyv(derive(Debug, Clone, PartialEq, Eq, Hash), compare(PartialEq))
267)]
268#[repr(transparent)]
269pub struct Nack {
270 #[viewit(
271 getter(const, attrs(doc = "Returns the sequence number of the nack")),
272 setter(
273 const,
274 attrs(doc = "Sets the sequence number of the nack (Builder pattern)")
275 )
276 )]
277 sequence_number: u32,
278}
279
280impl Nack {
281 #[inline]
283 pub const fn new(sequence_number: u32) -> Self {
284 Self { sequence_number }
285 }
286
287 #[inline]
289 pub fn set_sequence_number(&mut self, sequence_number: u32) -> &mut Self {
290 self.sequence_number = sequence_number;
291 self
292 }
293}
294
295impl Transformable for Nack {
296 type Error = <u32 as Transformable>::Error;
297
298 fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
299 <u32 as Transformable>::encode(&self.sequence_number, dst)
300 }
301
302 fn encoded_len(&self) -> usize {
303 <u32 as Transformable>::encoded_len(&self.sequence_number)
304 }
305
306 fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
307 where
308 Self: Sized,
309 {
310 let (n, sequence_number) = <u32 as Transformable>::decode(src)?;
311 Ok((n, Self { sequence_number }))
312 }
313
314 async fn encode_to_async_writer<W: futures::io::AsyncWrite + Send + Unpin>(
315 &self,
316 writer: &mut W,
317 ) -> std::io::Result<usize> {
318 <u32 as Transformable>::encode_to_async_writer(&self.sequence_number, writer).await
319 }
320
321 fn encode_to_writer<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<usize> {
322 <u32 as Transformable>::encode_to_writer(&self.sequence_number, writer)
323 }
324
325 fn decode_from_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<(usize, Self)>
326 where
327 Self: Sized,
328 {
329 <u32 as Transformable>::decode_from_reader(reader)
330 .map(|(n, sequence_number)| (n, Self { sequence_number }))
331 }
332
333 async fn decode_from_async_reader<R: futures::io::AsyncRead + Send + Unpin>(
334 reader: &mut R,
335 ) -> std::io::Result<(usize, Self)>
336 where
337 Self: Sized,
338 {
339 <u32 as Transformable>::decode_from_async_reader(reader)
340 .await
341 .map(|(n, sequence_number)| (n, Self { sequence_number }))
342 }
343}
344
345#[cfg(test)]
346const _: () = {
347 use rand::random;
348
349 impl Ack {
350 #[inline]
352 pub fn random(payload_size: usize) -> Self {
353 let sequence_number = random();
354 let payload = (0..payload_size)
355 .map(|_| random())
356 .collect::<Vec<_>>()
357 .into();
358 Self {
359 sequence_number,
360 payload,
361 }
362 }
363 }
364
365 impl Nack {
366 #[inline]
368 pub fn random() -> Self {
369 Self {
370 sequence_number: random(),
371 }
372 }
373 }
374};
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use futures::io::Cursor as FCursor;
380 use std::io::Cursor;
381
382 #[tokio::test]
383 async fn test_ack_response_encode_decode() {
384 for i in 0..100 {
385 let ack_response = Ack::random(i);
387 let mut buf = vec![0; ack_response.encoded_len()];
388 let encoded = ack_response.encode(&mut buf).unwrap();
389 assert_eq!(encoded, buf.len());
390 let (read, decoded) = Ack::decode(&buf).unwrap();
391 assert_eq!(read, buf.len());
392 assert_eq!(ack_response.sequence_number, decoded.sequence_number);
393 assert_eq!(ack_response.payload, decoded.payload);
394 let mut cur = Cursor::new(&buf);
395 let (_, decoded) = Ack::decode_from_reader(&mut cur).unwrap();
396 assert_eq!(ack_response.sequence_number, decoded.sequence_number);
397 assert_eq!(ack_response.payload, decoded.payload);
398 let mut cur = FCursor::new(&buf);
399 let (_, decoded) = Ack::decode_from_async_reader(&mut cur).await.unwrap();
400 assert_eq!(ack_response.sequence_number, decoded.sequence_number);
401 assert_eq!(ack_response.payload, decoded.payload);
402
403 let mut buf = Vec::new();
405 ack_response.encode_to_writer(&mut buf).unwrap();
406 let mut buf = Cursor::new(buf);
407 let (_, decoded) = Ack::decode_from_reader(&mut buf).unwrap();
408 assert_eq!(ack_response.sequence_number, decoded.sequence_number);
409 assert_eq!(ack_response.payload, decoded.payload);
410
411 let mut buf = Vec::new();
413 ack_response.encode_to_async_writer(&mut buf).await.unwrap();
414 let mut buf = FCursor::new(buf);
415 let (_, decoded) = Ack::decode_from_async_reader(&mut buf).await.unwrap();
416 assert_eq!(ack_response.sequence_number, decoded.sequence_number);
417 assert_eq!(ack_response.payload, decoded.payload);
418 }
419 }
420
421 #[tokio::test]
422 async fn test_nack_response_encode_decode() {
423 for _ in 0..100 {
424 let nack_response = Nack::random();
426 let mut buf = vec![0; nack_response.encoded_len()];
427 let encoded = nack_response.encode(&mut buf).unwrap();
428 assert_eq!(encoded, buf.len());
429 let (read, decoded) = Nack::decode(&buf).unwrap();
430 assert_eq!(read, buf.len());
431 assert_eq!(nack_response.sequence_number, decoded.sequence_number);
432 let mut cur = Cursor::new(&buf);
433 let (_, decoded) = Nack::decode_from_reader(&mut cur).unwrap();
434 assert_eq!(nack_response.sequence_number, decoded.sequence_number);
435 let mut cur = FCursor::new(&buf);
436 let (_, decoded) = Nack::decode_from_async_reader(&mut cur).await.unwrap();
437 assert_eq!(nack_response.sequence_number, decoded.sequence_number);
438
439 let mut buf = Vec::new();
441 nack_response.encode_to_writer(&mut buf).unwrap();
442 let mut buf = Cursor::new(buf);
443 let (_, decoded) = Nack::decode_from_reader(&mut buf).unwrap();
444 assert_eq!(nack_response.sequence_number, decoded.sequence_number);
445
446 let mut buf = Vec::new();
448 nack_response
449 .encode_to_async_writer(&mut buf)
450 .await
451 .unwrap();
452 let mut buf = FCursor::new(buf);
453 let (_, decoded) = Nack::decode_from_async_reader(&mut buf).await.unwrap();
454 assert_eq!(nack_response.sequence_number, decoded.sequence_number);
455 }
456 }
457
458 #[test]
459 fn test_access() {
460 let mut ack = Ack::random(100);
461 ack.set_payload(Bytes::from_static(b"hello world"));
462 ack.set_sequence_number(100);
463 assert_eq!(ack.sequence_number(), 100);
464 assert_eq!(ack.payload(), &Bytes::from_static(b"hello world"));
465
466 let mut nack = Nack::random();
467 nack.set_sequence_number(100);
468 assert_eq!(nack.sequence_number(), 100);
469 }
470}