1use crate::error::MessageError;
2
3use std::fmt::Debug;
4use std::hash::Hash;
5
6use stream::packet::{
7 self, Flags, Packet,
8 PacketHeader, BodyBytes, BodyBytesMut, PacketError
9};
10pub use stream::packet::PacketBytes;
11
12use bytes::{Bytes, BytesMut, BytesRead, BytesWrite};
13
14
15pub trait Action: Debug + Copy + Eq + Hash {
18 fn from_u16(num: u16) -> Option<Self>;
22
23 fn as_u16(&self) -> u16;
24
25 fn max_body_size(_header: &Header<Self>) -> Option<u32> {
26 None
27 }
28}
29
30pub trait IntoMessage<A, B> {
31 fn into_message(self) -> Result<Message<A, B>, MessageError>;
32}
33
34pub trait FromMessage<A, B>: Sized {
35 fn from_message(msg: Message<A, B>) -> Result<Self, MessageError>;
36}
37
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct Message<A, B> {
41 header: Header<A>,
42 bytes: B
43}
44
45impl<A, B> Message<A, B>
46where
47 A: Action,
48 B: PacketBytes
49{
50 pub fn new() -> Self {
51 Self::empty()
52 }
53}
54
55impl<A, B> Message<A, B>
56where B: PacketBytes {
57 pub fn set_success(&mut self, success: bool) {
58 self.header.msg_flags.set_success(success);
59 }
60
61 pub fn is_success(&self) -> bool {
62 self.header.msg_flags.is_success()
63 }
64
65 pub fn body(&self) -> BodyBytes<'_> {
66 self.bytes.body()
67 }
68
69 pub fn body_mut(&mut self) -> BodyBytesMut<'_> {
70 self.bytes.body_mut()
71 }
72}
73
74impl<A, B> Message<A, B> {
75 pub fn action(&self) -> Option<&A> {
76 match &self.header.action {
77 MaybeAction::Action(a) => Some(a),
78 _ => None
79 }
80 }
81}
82
83impl<A, B> IntoMessage<A, B> for Message<A, B> {
84 fn into_message(self) -> Result<Self, MessageError> {
85 Ok(self)
86 }
87}
88
89impl<A, B> FromMessage<A, B> for Message<A, B> {
90 fn from_message(me: Self) -> Result<Self, MessageError> {
91 Ok(me)
92 }
93}
94
95impl<A, B> Packet<B> for Message<A, B>
96where
97 A: Action,
98 B: PacketBytes
99{
100 type Header = Header<A>;
101
102 fn header(&self) -> &Self::Header {
103 &self.header
104 }
105
106 fn header_mut(&mut self) -> &mut Self::Header {
107 &mut self.header
108 }
109
110 fn empty() -> Self {
111 Self {
112 header: Self::Header::empty(),
113 bytes: B::new(Self::Header::LEN as usize)
114 }
115 }
116
117 fn from_bytes_and_header(
118 bytes: B,
119 header: Self::Header
120 ) -> packet::Result<Self> {
121 Ok(Self { header, bytes })
123 }
124
125 fn into_bytes(mut self) -> B {
126 let body_len = self.bytes.body().len();
127 self.header.body_len = body_len as u32;
128 self.header.to_bytes(self.bytes.header_mut());
129 self.bytes
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub struct Header<A> {
135 body_len: u32,
136 flags: Flags,
137 msg_flags: MessageFlags,
138 id: u32,
139 action: MaybeAction<A>
140}
141
142impl<A> Header<A>
143where A: Action {
144 pub fn empty() -> Self {
145 Self {
146 body_len: 0,
147 flags: Flags::empty(),
148 msg_flags: MessageFlags::new(true),
149 id: 0,
150 action: MaybeAction::None
151 }
152 }
153
154 pub fn to_bytes(&self, mut bytes: BytesMut) {
155 bytes.write_u32(self.body_len);
156 bytes.write_u8(self.flags.as_u8());
157 bytes.write_u8(self.msg_flags.as_u8());
158 bytes.write_u32(self.id);
159 bytes.write_u16(self.action.as_u16());
160 }
161
162 pub fn set_action(&mut self, action: A) {
163 self.action = MaybeAction::Action(action);
164 }
165}
166
167impl<A> PacketHeader for Header<A>
168where A: Action {
169 const LEN: u32 = 4 + 1 + 1 + 4 + 2;
170
171 fn from_bytes(mut bytes: Bytes) -> packet::Result<Self> {
172 let me = Self {
173 body_len: bytes.read_u32(),
174 flags: Flags::from_u8(bytes.read_u8())?,
175 msg_flags: MessageFlags::from_u8(bytes.read_u8()),
176 id: bytes.read_u32(),
177 action: {
178 let action_num = bytes.read_u16();
179 if action_num == 0 {
180 MaybeAction::None
181 } else if let Some(action) = A::from_u16(action_num) {
182 MaybeAction::Action(action)
188 } else {
189 MaybeAction::Unknown(action_num)
190 }
191 }
192 };
193
194 if let Some(max) = A::max_body_size(&me) {
195 if me.body_len > max {
196 return Err(PacketError::BodyLimitReached(max))
197 }
198 }
199
200 Ok(me)
201 }
202
203 fn body_len(&self) -> u32 {
204 self.body_len
205 }
206
207 fn flags(&self) -> &Flags {
208 &self.flags
209 }
210
211 fn set_flags(&mut self, flags: Flags) {
212 self.flags = flags;
213 }
214
215 fn id(&self) -> u32 {
216 self.id
217 }
218
219 fn set_id(&mut self, id: u32) {
220 self.id = id;
221 }
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
231struct MessageFlags {
232 inner: u8
233}
234
235impl MessageFlags {
236 const SUCCESS_BIT: u8 = 0b0000_0001;
237
238 pub fn new(success: bool) -> Self {
239 let mut me = Self { inner: 0 };
240 me.set_success(success);
241
242 me
243 }
244
245 pub fn from_u8(inner: u8) -> Self {
246 Self { inner }
247 }
248
249 pub fn is_success(&self) -> bool {
250 self.inner & Self::SUCCESS_BIT != 0
251 }
252
253 pub fn set_success(&mut self, success: bool) {
254 if success {
255 self.inner |= Self::SUCCESS_BIT;
256 } else {
257 self.inner &= !Self::SUCCESS_BIT;
258 }
259 }
260
261 pub fn as_u8(&self) -> u8 {
262 self.inner
263 }
264}
265
266#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
267enum MaybeAction<A> {
268 Action(A),
269 Unknown(u16),
271 None
273}
274
275impl<A> MaybeAction<A>
276where A: Action {
277 pub fn as_u16(&self) -> u16 {
278 match self {
279 Self::Action(a) => a.as_u16(),
280 Self::Unknown(u) => *u,
281 Self::None => 0
282 }
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 use stream::packet::PlainBytes;
291
292 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
293 enum SomeAction {
294 One,
295 Two
296 }
297
298 impl Action for SomeAction {
299 fn from_u16(num: u16) -> Option<Self> {
300 match num {
301 1 => Some(Self::One),
302 2 => Some(Self::Two),
303 _ => None
304 }
305 }
306
307 fn as_u16(&self) -> u16 {
308 match self {
309 Self::One => 1,
310 Self::Two => 2
311 }
312 }
313 }
314
315 #[test]
316 fn msg_from_to_bytes() {
317 let mut msg = Message::<_, PlainBytes>::new();
318 msg.header_mut().set_action(SomeAction::Two);
319 msg.body_mut().write_u16(u16::MAX);
320
321 let bytes = msg.clone().into_bytes();
322 let header = Header::from_bytes(bytes.header()).unwrap();
323 let n_msg = Message::from_bytes_and_header(bytes, header).unwrap();
324 assert_eq!(msg.action(), n_msg.action());
325 assert_eq!(msg.header().flags(), n_msg.header().flags());
326
327 assert_eq!(n_msg.body().read_u16(), u16::MAX);
328 }
329}