1use super::message_ops::Sealed;
2use super::{
3 Message, MessageDeserializeError, MessageKind, MessageOps, MessageSerializeError,
4 MessageSerializer, MessageWithValueDeserializer,
5};
6use crate::{SerializedValue, SerializedValueSlice};
7use bytes::BytesMut;
8use num_enum::{IntoPrimitive, TryFromPrimitive};
9
10#[derive(Debug, Copy, Clone, PartialEq, Eq, IntoPrimitive, TryFromPrimitive)]
11#[repr(u8)]
12enum ConnectReplyKind {
13 Ok = 0,
14 IncompatibleVersion = 1,
15 Rejected = 2,
16}
17
18#[derive(Debug, Clone, PartialEq, Eq)]
19#[cfg_attr(feature = "fuzzing", derive(arbitrary::Arbitrary))]
20pub enum ConnectReply {
21 Ok(SerializedValue),
22 IncompatibleVersion(u32),
23 Rejected(SerializedValue),
24}
25
26impl MessageOps for ConnectReply {
27 fn kind(&self) -> MessageKind {
28 MessageKind::ConnectReply
29 }
30
31 fn serialize_message(self) -> Result<BytesMut, MessageSerializeError> {
32 match self {
33 Self::Ok(value) => {
34 let mut serializer =
35 MessageSerializer::with_value(value, MessageKind::ConnectReply)?;
36 serializer.put_discriminant_u8(ConnectReplyKind::Ok);
37 serializer.finish()
38 }
39
40 Self::IncompatibleVersion(version) => {
41 let mut serializer = MessageSerializer::with_none_value(MessageKind::ConnectReply);
42 serializer.put_discriminant_u8(ConnectReplyKind::IncompatibleVersion);
43 serializer.put_varint_u32_le(version);
44 serializer.finish()
45 }
46
47 Self::Rejected(value) => {
48 let mut serializer =
49 MessageSerializer::with_value(value, MessageKind::ConnectReply)?;
50 serializer.put_discriminant_u8(ConnectReplyKind::Rejected);
51 serializer.finish()
52 }
53 }
54 }
55
56 fn deserialize_message(buf: BytesMut) -> Result<Self, MessageDeserializeError> {
57 let mut deserializer = MessageWithValueDeserializer::new(buf, MessageKind::ConnectReply)?;
58
59 match deserializer.try_get_discriminant_u8()? {
60 ConnectReplyKind::Ok => deserializer.finish().map(Self::Ok),
61
62 ConnectReplyKind::IncompatibleVersion => {
63 let version = deserializer.try_get_varint_u32_le()?;
64 deserializer.finish_discard_value()?;
65 Ok(Self::IncompatibleVersion(version))
66 }
67
68 ConnectReplyKind::Rejected => deserializer.finish().map(Self::Rejected),
69 }
70 }
71
72 fn value(&self) -> Option<&SerializedValueSlice> {
73 match self {
74 Self::Ok(value) | Self::Rejected(value) => Some(value),
75 Self::IncompatibleVersion(_) => None,
76 }
77 }
78
79 fn value_mut(&mut self) -> Option<&mut SerializedValue> {
80 match self {
81 Self::Ok(value) | Self::Rejected(value) => Some(value),
82 Self::IncompatibleVersion(_) => None,
83 }
84 }
85}
86
87impl Sealed for ConnectReply {}
88
89impl From<ConnectReply> for Message {
90 fn from(msg: ConnectReply) -> Self {
91 Self::ConnectReply(msg)
92 }
93}
94
95#[cfg(test)]
96mod test {
97 use super::super::test::{
98 assert_deserialize_eq, assert_deserialize_eq_with_value, assert_serialize_eq,
99 };
100 use super::super::Message;
101 use super::ConnectReply;
102 use crate::{tags, SerializedValue};
103
104 #[test]
105 fn ok() {
106 let serialized = [12, 0, 0, 0, 1, 2, 0, 0, 0, 3, 4, 0];
107 let value = 4u8;
108
109 let msg = ConnectReply::Ok(SerializedValue::serialize(value).unwrap());
110 assert_serialize_eq(&msg, serialized);
111 assert_deserialize_eq_with_value::<_, _, tags::U8, _>(&msg, serialized, &value);
112
113 let msg = Message::ConnectReply(msg);
114 assert_serialize_eq(&msg, serialized);
115 assert_deserialize_eq_with_value::<_, _, tags::U8, _>(&msg, serialized, &value);
116 }
117
118 #[test]
119 fn version_mismatch() {
120 let serialized = [12, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 2];
121
122 let msg = ConnectReply::IncompatibleVersion(2);
123 assert_serialize_eq(&msg, serialized);
124 assert_deserialize_eq(&msg, serialized);
125
126 let msg = Message::ConnectReply(msg);
127 assert_serialize_eq(&msg, serialized);
128 assert_deserialize_eq(&msg, serialized);
129 }
130
131 #[test]
132 fn rejected() {
133 let serialized = [12, 0, 0, 0, 1, 2, 0, 0, 0, 3, 4, 2];
134 let value = 4u8;
135
136 let msg = ConnectReply::Rejected(SerializedValue::serialize(value).unwrap());
137 assert_serialize_eq(&msg, serialized);
138 assert_deserialize_eq_with_value::<_, _, tags::U8, _>(&msg, serialized, &value);
139
140 let msg = Message::ConnectReply(msg);
141 assert_serialize_eq(&msg, serialized);
142 assert_deserialize_eq_with_value::<_, _, tags::U8, _>(&msg, serialized, &value);
143 }
144}