1use core::mem::size_of;
7
8use crate::error::{ProtocolError, Result};
9use crate::magic::{HEADER_SIZE, MAX_PAYLOAD_SIZE, MEMLINK_MAGIC, PROTOCOL_VERSION};
10use crate::types::{MethodHash, ModuleId, MessageType, RequestId};
11
12#[repr(C, packed)]
13#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
14pub struct MessageHeader {
15 magic: u32,
16 version: u8,
17 msg_type: u8,
18 features: u16,
19 request_id: u64,
20 module_id: u64,
21 method_hash: u32,
22 payload_len: u32,
23}
24
25const _: () = {
26 assert!(
27 size_of::<MessageHeader>() == HEADER_SIZE,
28 "MessageHeader must be exactly 32 bytes"
29 );
30};
31
32impl MessageHeader {
33 pub fn new(
34 msg_type: MessageType,
35 request_id: RequestId,
36 module_id: ModuleId,
37 method_hash: MethodHash,
38 payload_len: u32,
39 ) -> Self {
40 Self {
41 magic: MEMLINK_MAGIC,
42 version: PROTOCOL_VERSION,
43 msg_type: msg_type.as_u8(),
44 features: 0,
45 request_id,
46 module_id,
47 method_hash,
48 payload_len,
49 }
50 }
51
52 pub fn with_features(
53 msg_type: MessageType,
54 features: u16,
55 request_id: RequestId,
56 module_id: ModuleId,
57 method_hash: MethodHash,
58 payload_len: u32,
59 ) -> Self {
60 Self {
61 magic: MEMLINK_MAGIC,
62 version: PROTOCOL_VERSION,
63 msg_type: msg_type.as_u8(),
64 features,
65 request_id,
66 module_id,
67 method_hash,
68 payload_len,
69 }
70 }
71
72 pub fn magic(&self) -> u32 {
73 self.magic
74 }
75
76 pub fn version(&self) -> u8 {
77 self.version
78 }
79
80 pub fn msg_type(&self) -> u8 {
81 self.msg_type
82 }
83
84 pub fn message_type(&self) -> Option<MessageType> {
85 MessageType::from_u8(self.msg_type)
86 }
87
88 pub fn features(&self) -> u16 {
89 self.features
90 }
91
92 pub fn has_feature(&self, feature: u16) -> bool {
93 self.features & feature != 0
94 }
95
96 pub fn request_id(&self) -> RequestId {
97 self.request_id
98 }
99
100 pub fn module_id(&self) -> ModuleId {
101 self.module_id
102 }
103
104 pub fn method_hash(&self) -> MethodHash {
105 self.method_hash
106 }
107
108 pub fn payload_len(&self) -> u32 {
109 self.payload_len
110 }
111
112 pub fn validate(&self) -> Result<()> {
113 if self.magic != MEMLINK_MAGIC {
114 return Err(ProtocolError::InvalidMagic(self.magic));
115 }
116
117 if self.version != PROTOCOL_VERSION {
118 return Err(ProtocolError::UnsupportedVersion(self.version));
119 }
120
121 if MessageType::from_u8(self.msg_type).is_none() {
122 return Err(ProtocolError::InvalidMessageType(self.msg_type));
123 }
124
125 if self.payload_len as usize > MAX_PAYLOAD_SIZE {
126 return Err(ProtocolError::PayloadTooLarge(
127 self.payload_len as usize,
128 MAX_PAYLOAD_SIZE,
129 ));
130 }
131
132 Ok(())
133 }
134
135 pub fn as_bytes(&self) -> [u8; HEADER_SIZE] {
136 let mut bytes = [0u8; HEADER_SIZE];
137
138 bytes[0..4].copy_from_slice(&self.magic.to_be_bytes());
139 bytes[4] = self.version;
140 bytes[5] = self.msg_type;
141 bytes[6..8].copy_from_slice(&self.features.to_be_bytes());
142 bytes[8..16].copy_from_slice(&self.request_id.to_be_bytes());
143 bytes[16..24].copy_from_slice(&self.module_id.to_be_bytes());
144 bytes[24..28].copy_from_slice(&self.method_hash.to_be_bytes());
145 bytes[28..32].copy_from_slice(&self.payload_len.to_be_bytes());
146
147 bytes
148 }
149
150 pub fn from_bytes(bytes: &[u8; HEADER_SIZE]) -> Result<Self> {
151 let magic = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
152 let version = bytes[4];
153 let msg_type = bytes[5];
154 let features = u16::from_be_bytes([bytes[6], bytes[7]]);
155 let request_id =
156 u64::from_be_bytes([bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15]]);
157 let module_id =
158 u64::from_be_bytes([bytes[16], bytes[17], bytes[18], bytes[19], bytes[20], bytes[21], bytes[22], bytes[23]]);
159 let method_hash = u32::from_be_bytes([bytes[24], bytes[25], bytes[26], bytes[27]]);
160 let payload_len = u32::from_be_bytes([bytes[28], bytes[29], bytes[30], bytes[31]]);
161
162 let header = Self {
163 magic,
164 version,
165 msg_type,
166 features,
167 request_id,
168 module_id,
169 method_hash,
170 payload_len,
171 };
172
173 header.validate()?;
174
175 Ok(header)
176 }
177
178 pub fn request(
179 request_id: RequestId,
180 module_id: ModuleId,
181 method_hash: MethodHash,
182 payload_len: u32,
183 ) -> Self {
184 Self::new(
185 MessageType::Request,
186 request_id,
187 module_id,
188 method_hash,
189 payload_len,
190 )
191 }
192
193 pub fn response(
194 request_id: RequestId,
195 module_id: ModuleId,
196 method_hash: MethodHash,
197 payload_len: u32,
198 ) -> Self {
199 Self::new(
200 MessageType::Response,
201 request_id,
202 module_id,
203 method_hash,
204 payload_len,
205 )
206 }
207
208 pub fn notify(
209 module_id: ModuleId,
210 method_hash: MethodHash,
211 payload_len: u32,
212 features: u16,
213 ) -> Self {
214 Self::with_features(
215 MessageType::Event,
216 features,
217 0,
218 module_id,
219 method_hash,
220 payload_len,
221 )
222 }
223
224 pub fn heartbeat() -> Self {
225 Self::new(
226 MessageType::HealthCheck,
227 0,
228 0,
229 0,
230 0,
231 )
232 }
233}