1use crate::error::MessageError;
2
3use std::fmt::Debug;
4use std::hash::Hash;
5
6pub use lafere::packet::PacketBytes;
7use lafere::packet::{
8 self, BodyBytes, BodyBytesMut, Flags, Packet, PacketError, PacketHeader,
9};
10
11use bytes::{Bytes, BytesMut, BytesRead, BytesWrite};
12
13pub trait Action: Debug + Copy + Eq + Hash {
16 fn from_u16(num: u16) -> Option<Self>;
20
21 fn as_u16(&self) -> u16;
22
23 fn max_body_size(_header: &Header<Self>) -> Option<u32> {
24 None
25 }
26}
27
28pub trait IntoMessage<A, B> {
29 fn into_message(self) -> Result<Message<A, B>, MessageError>;
30}
31
32pub trait FromMessage<A, B>: Sized {
33 fn from_message(msg: Message<A, B>) -> Result<Self, MessageError>;
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct Message<A, B> {
38 header: Header<A>,
39 bytes: B,
40}
41
42impl<A, B> Message<A, B>
43where
44 A: Action,
45 B: PacketBytes,
46{
47 pub fn new() -> Self {
48 Self::empty()
49 }
50}
51
52impl<A, B> Message<A, B>
53where
54 B: PacketBytes,
55{
56 pub fn set_success(&mut self, success: bool) {
57 self.header.msg_flags.set_success(success);
58 }
59
60 pub fn is_success(&self) -> bool {
61 self.header.msg_flags.is_success()
62 }
63
64 pub fn body(&self) -> BodyBytes<'_> {
65 self.bytes.body()
66 }
67
68 pub fn body_mut(&mut self) -> BodyBytesMut<'_> {
69 self.bytes.body_mut()
70 }
71}
72
73impl<A, B> Message<A, B> {
74 pub fn action(&self) -> Option<&A> {
75 match &self.header.action {
76 MaybeAction::Action(a) => Some(a),
77 _ => None,
78 }
79 }
80}
81
82impl<A, B> IntoMessage<A, B> for Message<A, B> {
83 fn into_message(self) -> Result<Self, MessageError> {
84 Ok(self)
85 }
86}
87
88impl<A, B> FromMessage<A, B> for Message<A, B> {
89 fn from_message(me: Self) -> Result<Self, MessageError> {
90 Ok(me)
91 }
92}
93
94impl<A, B> Packet<B> for Message<A, B>
95where
96 A: Action,
97 B: PacketBytes,
98{
99 type Header = Header<A>;
100
101 fn header(&self) -> &Self::Header {
102 &self.header
103 }
104
105 fn header_mut(&mut self) -> &mut Self::Header {
106 &mut self.header
107 }
108
109 fn empty() -> Self {
110 Self {
111 header: Self::Header::empty(),
112 bytes: B::new(Self::Header::LEN as usize),
113 }
114 }
115
116 fn from_bytes_and_header(
117 bytes: B,
118 header: Self::Header,
119 ) -> packet::Result<Self> {
120 Ok(Self { header, bytes })
122 }
123
124 fn into_bytes(mut self) -> B {
125 let body_len = self.bytes.body().len();
126 self.header.body_len = body_len as u32;
127 self.header.to_bytes(self.bytes.header_mut());
128 self.bytes
129 }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
133pub struct Header<A> {
134 body_len: u32,
135 flags: Flags,
136 msg_flags: MessageFlags,
137 id: u32,
138 action: MaybeAction<A>,
139}
140
141impl<A> Header<A>
142where
143 A: Action,
144{
145 pub fn empty() -> Self {
146 Self {
147 body_len: 0,
148 flags: Flags::empty(),
149 msg_flags: MessageFlags::new(true),
150 id: 0,
151 action: MaybeAction::None,
152 }
153 }
154
155 pub fn to_bytes(&self, mut bytes: BytesMut) {
156 bytes.write_u32(self.body_len);
157 bytes.write_u8(self.flags.as_u8());
158 bytes.write_u8(self.msg_flags.as_u8());
159 bytes.write_u32(self.id);
160 bytes.write_u16(self.action.as_u16());
161 }
162
163 pub fn set_action(&mut self, action: A) {
164 self.action = MaybeAction::Action(action);
165 }
166}
167
168impl<A> PacketHeader for Header<A>
169where
170 A: Action,
171{
172 const LEN: u32 = 4 + 1 + 1 + 4 + 2;
173
174 fn from_bytes(mut bytes: Bytes) -> packet::Result<Self> {
175 let me = Self {
176 body_len: bytes.read_u32(),
177 flags: Flags::from_u8(bytes.read_u8())?,
178 msg_flags: MessageFlags::from_u8(bytes.read_u8()),
179 id: bytes.read_u32(),
180 action: {
181 let action_num = bytes.read_u16();
182 if action_num == 0 {
183 MaybeAction::None
184 } else if let Some(action) = A::from_u16(action_num) {
185 MaybeAction::Action(action)
191 } else {
192 MaybeAction::Unknown(action_num)
193 }
194 },
195 };
196
197 if let Some(max) = A::max_body_size(&me) {
198 if me.body_len > max {
199 return Err(PacketError::BodyLimitReached(max));
200 }
201 }
202
203 Ok(me)
204 }
205
206 fn body_len(&self) -> u32 {
207 self.body_len
208 }
209
210 fn flags(&self) -> &Flags {
211 &self.flags
212 }
213
214 fn set_flags(&mut self, flags: Flags) {
215 self.flags = flags;
216 }
217
218 fn id(&self) -> u32 {
219 self.id
220 }
221
222 fn set_id(&mut self, id: u32) {
223 self.id = id;
224 }
225}
226
227#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
234struct MessageFlags {
235 inner: u8,
236}
237
238impl MessageFlags {
239 const SUCCESS_BIT: u8 = 0b0000_0001;
240
241 pub fn new(success: bool) -> Self {
242 let mut me = Self { inner: 0 };
243 me.set_success(success);
244
245 me
246 }
247
248 pub fn from_u8(inner: u8) -> Self {
249 Self { inner }
250 }
251
252 pub fn is_success(&self) -> bool {
253 self.inner & Self::SUCCESS_BIT != 0
254 }
255
256 pub fn set_success(&mut self, success: bool) {
257 if success {
258 self.inner |= Self::SUCCESS_BIT;
259 } else {
260 self.inner &= !Self::SUCCESS_BIT;
261 }
262 }
263
264 pub fn as_u8(&self) -> u8 {
265 self.inner
266 }
267}
268
269#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
270enum MaybeAction<A> {
271 Action(A),
272 Unknown(u16),
274 None,
276}
277
278impl<A> MaybeAction<A>
279where
280 A: Action,
281{
282 pub fn as_u16(&self) -> u16 {
283 match self {
284 Self::Action(a) => a.as_u16(),
285 Self::Unknown(u) => *u,
286 Self::None => 0,
287 }
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 use lafere::packet::PlainBytes;
296
297 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
298 enum SomeAction {
299 One,
300 Two,
301 }
302
303 impl Action for SomeAction {
304 fn from_u16(num: u16) -> Option<Self> {
305 match num {
306 1 => Some(Self::One),
307 2 => Some(Self::Two),
308 _ => None,
309 }
310 }
311
312 fn as_u16(&self) -> u16 {
313 match self {
314 Self::One => 1,
315 Self::Two => 2,
316 }
317 }
318 }
319
320 #[test]
321 fn msg_from_to_bytes() {
322 let mut msg = Message::<_, PlainBytes>::new();
323 msg.header_mut().set_action(SomeAction::Two);
324 msg.body_mut().write_u16(u16::MAX);
325
326 let bytes = msg.clone().into_bytes();
327 let header = Header::from_bytes(bytes.header()).unwrap();
328 let n_msg = Message::from_bytes_and_header(bytes, header).unwrap();
329 assert_eq!(msg.action(), n_msg.action());
330 assert_eq!(msg.header().flags(), n_msg.header().flags());
331
332 assert_eq!(n_msg.body().read_u16(), u16::MAX);
333 }
334}