1use byteorder::{ByteOrder, NetworkEndian};
2use bytes::Bytes;
3use futures::AsyncRead;
4use smol_str::SmolStr;
5use transformable::{
6 BytesTransformError, NumberTransformError, StringTransformError, Transformable,
7};
8
9use super::*;
10
11macro_rules! enum_wrapper {
12 (
13 $(#[$outer:meta])*
14 $vis:vis enum $name:ident $(<$($generic:tt),+>)? {
15 $(
16 $(#[$variant_meta:meta])*
17 $variant:ident($variant_ty: ident $(<$($variant_generic:tt),+>)?) = $variant_tag:literal
18 ), +$(,)?
19 }
20 ) => {
21 $(#[$outer])*
22 $vis enum $name $(< $($generic),+ >)? {
23 $(
24 $(#[$variant_meta])*
25 $variant($variant_ty $(< $($variant_generic),+ >)?),
26 )*
27 }
28
29 impl $(< $($generic),+ >)? $name $(< $($generic),+ >)? {
30 paste::paste! {
31 $(
32 #[doc = concat!("The tag of [`", stringify!($variant_ty), "`] message.")]
33 pub const [< $variant: upper _TAG >]: u8 = $variant_tag;
34 )*
35 }
36
37 #[inline]
39 pub const fn tag(&self) -> u8 {
40 match self {
41 $(
42 Self::$variant(_) => $variant_tag,
43 )*
44 }
45 }
46
47 #[inline]
49 pub const fn kind(&self) -> &'static str {
50 match self {
51 $(
52 Self::$variant(_) => stringify!($variant),
53 )*
54 }
55 }
56
57 $(
58 paste::paste! {
59 #[doc = concat!("Returns the contained [`", stringify!($variant_ty), "`] message, consuming the self value. Panics if the value is not [`", stringify!($variant_ty), "`].")]
60 $vis fn [< unwrap_ $variant:snake>] (self) -> $variant_ty $(< $($variant_generic),+ >)? {
61 if let Self::$variant(val) = self {
62 val
63 } else {
64 panic!(concat!("expect ", stringify!($variant), ", buf got {}"), self.kind())
65 }
66 }
67
68 #[doc = concat!("Returns the contained [`", stringify!($variant_ty), "`] message, consuming the self value. Returns `None` if the value is not [`", stringify!($variant_ty), "`].")]
69 $vis fn [< try_unwrap_ $variant:snake>] (self) -> ::std::option::Option<$variant_ty $(< $($variant_generic),+ >)?> {
70 if let Self::$variant(val) = self {
71 ::std::option::Option::Some(val)
72 } else {
73 ::std::option::Option::None
74 }
75 }
76
77 #[doc = concat!("Construct a [`", stringify!($name), "`] from [`", stringify!($variant_ty), "`].")]
78 pub const fn [< $variant:snake >](val: $variant_ty $(< $($variant_generic),+ >)?) -> Self {
79 Self::$variant(val)
80 }
81 }
82 )*
83 }
84 };
85}
86
87enum_wrapper!(
88 #[derive(Debug, Clone, derive_more::From, PartialEq, Eq, Hash)]
90 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
91 #[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
92 #[non_exhaustive]
93 pub enum Message<I, A> {
94 Ping(Ping<I, A>) = 1,
99 IndirectPing(IndirectPing<I, A>) = 2,
101 Ack(Ack) = 3,
103 Suspect(Suspect<I>) = 4,
105 Alive(Alive<I, A>) = 5,
107 Dead(Dead<I>) = 6,
109 PushPull(PushPull<I, A>) = 7,
111 UserData(Bytes) = 8,
113 Nack(Nack) = 9,
115 ErrorResponse(ErrorResponse) = 10,
117 }
118);
119
120impl<I, A> Message<I, A> {
121 pub const RESERVED_TAG_RANGE: std::ops::RangeInclusive<u8> = (0..=128);
141
142 pub const COMPOUND_TAG: u8 = 0;
144}
145
146#[derive(Debug, thiserror::Error)]
148pub enum MessageTransformError<I: Transformable, A: Transformable> {
149 #[error("encode buffer too small")]
151 BufferTooSmall,
152 #[error("not enough bytes to decode message")]
154 NotEnoughBytes,
155 #[error(transparent)]
157 Ping(#[from] PingTransformError<I, A>),
158 #[error(transparent)]
160 IndirectPing(#[from] IndirectPingTransformError<I, A>),
161 #[error(transparent)]
163 Ack(#[from] AckTransformError),
164 #[error(transparent)]
166 Suspect(#[from] SuspectTransformError<I>),
167 #[error(transparent)]
169 Alive(#[from] AliveTransformError<I, A>),
170 #[error(transparent)]
172 Dead(#[from] DeadTransformError<I>),
173 #[error(transparent)]
175 PushPull(#[from] PushPullTransformError<I, A>),
176 #[error(transparent)]
178 UserData(#[from] BytesTransformError),
179 #[error(transparent)]
181 Nack(#[from] NumberTransformError),
182 #[error(transparent)]
184 ErrorResponse(#[from] StringTransformError),
185}
186
187const USER_DATA_LEN_SIZE: usize = core::mem::size_of::<u32>();
188const INLINED_BYTES_SIZE: usize = 64;
189
190impl<I, A> Transformable for Message<I, A>
191where
192 I: Transformable + core::fmt::Debug,
193 A: Transformable + core::fmt::Debug,
194{
195 type Error = MessageTransformError<I, A>;
196
197 fn encode(&self, mut dst: &mut [u8]) -> Result<usize, Self::Error> {
198 let encoded_len = self.encoded_len();
199 if dst.len() < encoded_len {
200 return Err(Self::Error::BufferTooSmall);
201 }
202
203 dst[0] = self.tag();
204 dst = &mut dst[1..];
205
206 Ok(match self {
207 Self::Ping(msg) => msg.encode(dst).map(|w| w + 1)?,
208 Self::IndirectPing(msg) => msg.encode(dst).map(|w| w + 1)?,
209 Self::Ack(msg) => msg.encode(dst).map(|w| w + 1)?,
210 Self::Suspect(msg) => msg.encode(dst).map(|w| w + 1)?,
211 Self::Alive(msg) => msg.encode(dst).map(|w| w + 1)?,
212 Self::Dead(msg) => msg.encode(dst).map(|w| w + 1)?,
213 Self::PushPull(msg) => msg.encode(dst).map(|w| w + 1)?,
214 Self::UserData(msg) => {
215 let len = msg.len();
216 NetworkEndian::write_u32(dst, len as u32);
217 dst = &mut dst[USER_DATA_LEN_SIZE..];
218 dst[..len].copy_from_slice(msg);
219 1 + USER_DATA_LEN_SIZE + len
220 }
221 Self::Nack(msg) => msg.encode(dst).map(|w| w + 1)?,
222 Self::ErrorResponse(msg) => msg.encode(dst).map(|w| w + 1)?,
223 })
224 }
225
226 fn encoded_len(&self) -> usize {
227 1 + match self {
228 Self::Ping(msg) => msg.encoded_len(),
229 Self::IndirectPing(msg) => msg.encoded_len(),
230 Self::Ack(msg) => msg.encoded_len(),
231 Self::Suspect(msg) => msg.encoded_len(),
232 Self::Alive(msg) => msg.encoded_len(),
233 Self::Dead(msg) => msg.encoded_len(),
234 Self::PushPull(msg) => msg.encoded_len(),
235 Self::UserData(msg) => USER_DATA_LEN_SIZE + msg.len(),
236 Self::Nack(msg) => msg.encoded_len(),
237 Self::ErrorResponse(msg) => msg.encoded_len(),
238 }
239 }
240
241 fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
242 where
243 Self: Sized,
244 {
245 if src.is_empty() {
246 return Err(Self::Error::NotEnoughBytes);
247 }
248
249 let tag = src[0];
250 let src = &src[1..];
251
252 Ok(match tag {
253 Self::PING_TAG => {
254 let (len, msg) = Ping::decode(src)?;
255 (len + 1, Self::Ping(msg))
256 }
257 Self::INDIRECTPING_TAG => {
258 let (len, msg) = IndirectPing::decode(src)?;
259 (len + 1, Self::IndirectPing(msg))
260 }
261 Self::ACK_TAG => {
262 let (len, msg) = Ack::decode(src)?;
263 (len + 1, Self::Ack(msg))
264 }
265 Self::SUSPECT_TAG => {
266 let (len, msg) = Suspect::decode(src)?;
267 (len + 1, Self::Suspect(msg))
268 }
269 Self::ALIVE_TAG => {
270 let (len, msg) = Alive::decode(src)?;
271 (len + 1, Self::Alive(msg))
272 }
273 Self::DEAD_TAG => {
274 let (len, msg) = Dead::decode(src)?;
275 (len + 1, Self::Dead(msg))
276 }
277 Self::PUSHPULL_TAG => {
278 let (len, msg) = PushPull::decode(src)?;
279 (len + 1, Self::PushPull(msg))
280 }
281 Self::USERDATA_TAG => {
282 let len = NetworkEndian::read_u32(src) as usize;
283 let src = &src[USER_DATA_LEN_SIZE..];
284 let msg = Bytes::copy_from_slice(&src[..len]);
285 (len + 1 + USER_DATA_LEN_SIZE, Self::UserData(msg))
286 }
287 Self::NACK_TAG => {
288 let (len, msg) = u32::decode(src)?;
289 (len + 1, Self::Nack(Nack::new(msg)))
290 }
291 Self::ERRORRESPONSE_TAG => {
292 let (len, msg) = <SmolStr as Transformable>::decode(src)?;
293 (len + 1, Self::ErrorResponse(ErrorResponse::new(msg)))
294 }
295 _ => return Err(Self::Error::NotEnoughBytes),
296 })
297 }
298
299 fn decode_from_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<(usize, Self)>
300 where
301 Self: Sized,
302 {
303 let mut tag = [0u8; 1];
304 reader.read_exact(&mut tag)?;
305 let tag = tag[0];
306 let (len, msg) = match tag {
307 Self::PING_TAG => {
308 let (len, msg) = Ping::decode_from_reader(reader)?;
309 (len + 1, Self::Ping(msg))
310 }
311 Self::INDIRECTPING_TAG => {
312 let (len, msg) = IndirectPing::decode_from_reader(reader)?;
313 (len + 1, Self::IndirectPing(msg))
314 }
315 Self::ACK_TAG => {
316 let (len, msg) = Ack::decode_from_reader(reader)?;
317 (len + 1, Self::Ack(msg))
318 }
319 Self::SUSPECT_TAG => {
320 let (len, msg) = Suspect::decode_from_reader(reader)?;
321 (len + 1, Self::Suspect(msg))
322 }
323 Self::ALIVE_TAG => {
324 let (len, msg) = Alive::decode_from_reader(reader)?;
325 (len + 1, Self::Alive(msg))
326 }
327 Self::DEAD_TAG => {
328 let (len, msg) = Dead::decode_from_reader(reader)?;
329 (len + 1, Self::Dead(msg))
330 }
331 Self::PUSHPULL_TAG => {
332 let (len, msg) = PushPull::decode_from_reader(reader)?;
333 (len + 1, Self::PushPull(msg))
334 }
335 Self::USERDATA_TAG => {
336 let mut len_buf = [0u8; USER_DATA_LEN_SIZE];
337 reader.read_exact(&mut len_buf)?;
338 let len = NetworkEndian::read_u32(&len_buf) as usize;
339 if len <= INLINED_BYTES_SIZE {
340 let mut buf = [0u8; INLINED_BYTES_SIZE];
341 reader.read_exact(&mut buf[..len])?;
342 (
343 len + 1 + USER_DATA_LEN_SIZE,
344 Self::UserData(Bytes::copy_from_slice(&buf[..len])),
345 )
346 } else {
347 let mut buf = vec![0u8; len];
348 reader.read_exact(&mut buf)?;
349 (len + 1 + USER_DATA_LEN_SIZE, Self::UserData(buf.into()))
350 }
351 }
352 Self::NACK_TAG => {
353 let (len, msg) = Nack::decode_from_reader(reader)?;
354 (len + 1, Self::Nack(msg))
355 }
356 Self::ERRORRESPONSE_TAG => {
357 let (len, msg) = ErrorResponse::decode_from_reader(reader)?;
358 (len + 1, Self::ErrorResponse(msg))
359 }
360 _ => {
361 return Err(std::io::Error::new(
362 std::io::ErrorKind::InvalidData,
363 "unknown message",
364 ))
365 }
366 };
367 Ok((len, msg))
368 }
369
370 async fn decode_from_async_reader<R: AsyncRead + Send + Unpin>(
371 reader: &mut R,
372 ) -> std::io::Result<(usize, Self)>
373 where
374 Self: Sized,
375 {
376 use futures::io::AsyncReadExt;
377
378 let mut tag = [0u8; 1];
379 reader.read_exact(&mut tag).await?;
380 let tag = tag[0];
381 let (len, msg) = match tag {
382 Self::PING_TAG => {
383 let (len, msg) = Ping::decode_from_async_reader(reader).await?;
384 (len + 1, Self::Ping(msg))
385 }
386 Self::INDIRECTPING_TAG => {
387 let (len, msg) = IndirectPing::decode_from_async_reader(reader).await?;
388 (len + 1, Self::IndirectPing(msg))
389 }
390 Self::ACK_TAG => {
391 let (len, msg) = Ack::decode_from_async_reader(reader).await?;
392 (len + 1, Self::Ack(msg))
393 }
394 Self::SUSPECT_TAG => {
395 let (len, msg) = Suspect::decode_from_async_reader(reader).await?;
396 (len + 1, Self::Suspect(msg))
397 }
398 Self::ALIVE_TAG => {
399 let (len, msg) = Alive::decode_from_async_reader(reader).await?;
400 (len + 1, Self::Alive(msg))
401 }
402 Self::DEAD_TAG => {
403 let (len, msg) = Dead::decode_from_async_reader(reader).await?;
404 (len + 1, Self::Dead(msg))
405 }
406 Self::PUSHPULL_TAG => {
407 let (len, msg) = PushPull::decode_from_async_reader(reader).await?;
408 (len + 1, Self::PushPull(msg))
409 }
410 Self::USERDATA_TAG => {
411 let mut len_buf = [0u8; USER_DATA_LEN_SIZE];
412 reader.read_exact(&mut len_buf).await?;
413 let len = NetworkEndian::read_u32(&len_buf) as usize;
414 if len <= INLINED_BYTES_SIZE {
415 let mut buf = [0u8; INLINED_BYTES_SIZE];
416 reader.read_exact(&mut buf[..len]).await?;
417 (
418 len + 1 + USER_DATA_LEN_SIZE,
419 Self::UserData(Bytes::copy_from_slice(&buf[..len])),
420 )
421 } else {
422 let mut buf = vec![0u8; len];
423 reader.read_exact(&mut buf).await?;
424 (len + 1 + USER_DATA_LEN_SIZE, Self::UserData(buf.into()))
425 }
426 }
427 Self::NACK_TAG => {
428 let (len, msg) = Nack::decode_from_async_reader(reader).await?;
429 (len + 1, Self::Nack(msg))
430 }
431 Self::ERRORRESPONSE_TAG => {
432 let (len, msg) = ErrorResponse::decode_from_async_reader(reader).await?;
433 (len + 1, Self::ErrorResponse(msg))
434 }
435 _ => {
436 return Err(std::io::Error::new(
437 std::io::ErrorKind::InvalidData,
438 "unknown message",
439 ))
440 }
441 };
442 Ok((len, msg))
443 }
444}
445
446#[cfg(test)]
447mod test {
448 use std::net::SocketAddr;
449
450 use super::*;
451
452 #[tokio::test]
453 async fn test_ping_transformable_round_trip() {
454 let msg = Message::Ping(Ping::generate(1));
455 let mut buf = vec![0u8; msg.encoded_len()];
456 msg.encode(&mut buf).unwrap();
457 let (len, decoded) = Message::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
458 assert_eq!(len, buf.len());
459 assert_eq!(decoded, msg);
460
461 let (len, decoded) = Message::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
462 .await
463 .unwrap();
464 assert_eq!(len, buf.len());
465 assert_eq!(decoded, msg);
466 }
467
468 #[tokio::test]
469 async fn test_ack_transformable_round_trip() {
470 let msg = Message::<SmolStr, SocketAddr>::Ack(Ack::random(10));
471 let mut buf = vec![0u8; msg.encoded_len()];
472 msg.encode(&mut buf).unwrap();
473 let (len, decoded) = Message::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
474 assert_eq!(len, buf.len());
475 assert_eq!(decoded, msg);
476
477 let (len, decoded) = Message::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
478 .await
479 .unwrap();
480 assert_eq!(len, buf.len());
481 assert_eq!(decoded, msg);
482 }
483
484 #[tokio::test]
485 async fn test_indirect_ping_transformable_round_trip() {
486 let msg = Message::IndirectPing(IndirectPing::generate(1));
487 let mut buf = vec![0u8; msg.encoded_len()];
488 msg.encode(&mut buf).unwrap();
489 let (len, decoded) = Message::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
490 assert_eq!(len, buf.len());
491 assert_eq!(decoded, msg);
492
493 let (len, decoded) = Message::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
494 .await
495 .unwrap();
496 assert_eq!(len, buf.len());
497 assert_eq!(decoded, msg);
498 }
499
500 #[tokio::test]
501 async fn test_nack_transformable_round_trip() {
502 let msg = Message::<SmolStr, SocketAddr>::Nack(Nack::new(10));
503 let mut buf = vec![0u8; msg.encoded_len()];
504 msg.encode(&mut buf).unwrap();
505 let (len, decoded) = Message::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
506
507 assert_eq!(len, buf.len());
508 assert_eq!(decoded, msg);
509
510 let (len, decoded) = Message::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
511 .await
512 .unwrap();
513 assert_eq!(len, buf.len());
514 assert_eq!(decoded, msg);
515 }
516
517 #[tokio::test]
518 async fn test_suspect_transformable_round_trip() {
519 let msg = Message::<SmolStr, SocketAddr>::Suspect(Suspect::generate(10));
520 let mut buf = vec![0u8; msg.encoded_len()];
521 msg.encode(&mut buf).unwrap();
522 let (len, decoded) = Message::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
523
524 assert_eq!(len, buf.len());
525 assert_eq!(decoded, msg);
526
527 let (len, decoded) = Message::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
528 .await
529 .unwrap();
530 assert_eq!(len, buf.len());
531 assert_eq!(decoded, msg);
532 }
533
534 #[tokio::test]
535 async fn test_dead_transformable_round_trip() {
536 let msg = Message::<SmolStr, SocketAddr>::Dead(Dead::generate(10));
537 let mut buf = vec![0u8; msg.encoded_len()];
538 msg.encode(&mut buf).unwrap();
539 let (len, decoded) = Message::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
540
541 assert_eq!(len, buf.len());
542 assert_eq!(decoded, msg);
543
544 let (len, decoded) = Message::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
545 .await
546 .unwrap();
547 assert_eq!(len, buf.len());
548 assert_eq!(decoded, msg);
549 }
550
551 #[tokio::test]
552 async fn test_alive_transformable_round_trip() {
553 let msg = Message::<SmolStr, SocketAddr>::Alive(Alive::random(128));
554 let mut buf = vec![0u8; msg.encoded_len()];
555 msg.encode(&mut buf).unwrap();
556 let (len, decoded) = Message::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
557
558 assert_eq!(len, buf.len());
559 assert_eq!(decoded, msg);
560
561 let (len, decoded) = Message::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
562 .await
563 .unwrap();
564 assert_eq!(len, buf.len());
565 assert_eq!(decoded, msg);
566 }
567
568 #[tokio::test]
569 async fn test_push_pull_transformable_round_trip() {
570 let msg = Message::<SmolStr, SocketAddr>::PushPull(PushPull::generate(10));
571 let mut buf = vec![0u8; msg.encoded_len()];
572 msg.encode(&mut buf).unwrap();
573 let (len, decoded) = Message::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
574
575 assert_eq!(len, buf.len());
576 assert_eq!(decoded, msg);
577
578 let (len, decoded) = Message::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
579 .await
580 .unwrap();
581 assert_eq!(len, buf.len());
582 assert_eq!(decoded, msg);
583 }
584
585 #[tokio::test]
586 async fn test_user_data_transformable_round_trip() {
587 let msg = Message::<SmolStr, SocketAddr>::UserData(Bytes::from_static(b"hello world"));
588 let mut buf = vec![0u8; msg.encoded_len()];
589 msg.encode(&mut buf).unwrap();
590 let (len, decoded) = Message::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
591
592 assert_eq!(len, buf.len());
593 assert_eq!(decoded, msg);
594
595 let (len, decoded) = Message::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
596 .await
597 .unwrap();
598 assert_eq!(len, buf.len());
599 assert_eq!(decoded, msg);
600 }
601
602 #[tokio::test]
603 async fn test_error_response_transformable_round_trip() {
604 let msg = Message::<SmolStr, SocketAddr>::ErrorResponse(ErrorResponse::new("hello world"));
605 let mut buf = vec![0u8; msg.encoded_len()];
606 msg.encode(&mut buf).unwrap();
607 let (len, decoded) = Message::decode(&buf).unwrap();
608 assert_eq!(len, buf.len());
609 assert_eq!(decoded, msg);
610
611 let (len, decoded) = Message::decode_from_reader(&mut std::io::Cursor::new(&buf)).unwrap();
612
613 assert_eq!(len, buf.len());
614 assert_eq!(decoded, msg);
615
616 let (len, decoded) = Message::decode_from_async_reader(&mut futures::io::Cursor::new(&buf))
617 .await
618 .unwrap();
619 assert_eq!(len, buf.len());
620 assert_eq!(decoded, msg);
621 }
622}