1use std::io::{Read, Write};
23
24use super::compiled::{CompiledQuery, CompiledQueryBuffer, FORMAT_VERSION, MAGIC};
25
26pub const HEADER_SIZE: usize = 64;
28
29#[derive(Debug, Clone)]
31pub enum SerializeError {
32 InvalidMagic([u8; 4]),
34 VersionMismatch { expected: u32, found: u32 },
36 ChecksumMismatch { expected: u32, found: u32 },
38 Io(String),
40 HeaderTooShort,
42 AlignmentError,
44}
45
46impl std::fmt::Display for SerializeError {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 match self {
49 SerializeError::InvalidMagic(m) => {
50 write!(f, "invalid magic: {:?}", m)
51 }
52 SerializeError::VersionMismatch { expected, found } => {
53 write!(
54 f,
55 "version mismatch: expected {}, found {}",
56 expected, found
57 )
58 }
59 SerializeError::ChecksumMismatch { expected, found } => {
60 write!(
61 f,
62 "checksum mismatch: expected {:08x}, found {:08x}",
63 expected, found
64 )
65 }
66 SerializeError::Io(msg) => write!(f, "io error: {}", msg),
67 SerializeError::HeaderTooShort => write!(f, "header too short"),
68 SerializeError::AlignmentError => write!(f, "buffer alignment error"),
69 }
70 }
71}
72
73impl std::error::Error for SerializeError {}
74
75impl From<std::io::Error> for SerializeError {
76 fn from(e: std::io::Error) -> Self {
77 SerializeError::Io(e.to_string())
78 }
79}
80
81pub type SerializeResult<T> = Result<T, SerializeError>;
83
84fn crc32(data: &[u8]) -> u32 {
86 const CRC32_TABLE: [u32; 256] = generate_crc32_table();
88
89 let mut crc: u32 = 0xFFFFFFFF;
90 for &byte in data {
91 let index = ((crc ^ byte as u32) & 0xFF) as usize;
92 crc = CRC32_TABLE[index] ^ (crc >> 8);
93 }
94 !crc
95}
96
97const fn generate_crc32_table() -> [u32; 256] {
98 const POLYNOMIAL: u32 = 0xEDB88320;
99 let mut table = [0u32; 256];
100 let mut i = 0;
101 while i < 256 {
102 let mut crc = i as u32;
103 let mut j = 0;
104 while j < 8 {
105 if crc & 1 != 0 {
106 crc = (crc >> 1) ^ POLYNOMIAL;
107 } else {
108 crc >>= 1;
109 }
110 j += 1;
111 }
112 table[i] = crc;
113 i += 1;
114 }
115 table
116}
117
118#[repr(C)]
124struct Header {
125 magic: [u8; 4],
126 version: u32,
127 checksum: u32,
128 buffer_len: u32,
129 successors_offset: u32,
130 effects_offset: u32,
131 negated_fields_offset: u32,
132 string_refs_offset: u32,
133 string_bytes_offset: u32,
134 type_defs_offset: u32,
135 type_members_offset: u32,
136 entrypoints_offset: u32,
137 trivia_kinds_offset: u32,
138 negated_field_count: u16,
140 string_ref_count: u16,
141 type_def_count: u16,
142 type_member_count: u16,
143 entrypoint_count: u16,
144 trivia_kind_count: u16,
145}
146
147const _: () = assert!(std::mem::size_of::<Header>() == HEADER_SIZE);
148
149impl Header {
150 fn to_bytes(&self) -> [u8; HEADER_SIZE] {
151 let mut bytes = [0u8; HEADER_SIZE];
152 bytes[0..4].copy_from_slice(&self.magic);
153 bytes[4..8].copy_from_slice(&self.version.to_le_bytes());
154 bytes[8..12].copy_from_slice(&self.checksum.to_le_bytes());
155 bytes[12..16].copy_from_slice(&self.buffer_len.to_le_bytes());
156 bytes[16..20].copy_from_slice(&self.successors_offset.to_le_bytes());
157 bytes[20..24].copy_from_slice(&self.effects_offset.to_le_bytes());
158 bytes[24..28].copy_from_slice(&self.negated_fields_offset.to_le_bytes());
159 bytes[28..32].copy_from_slice(&self.string_refs_offset.to_le_bytes());
160 bytes[32..36].copy_from_slice(&self.string_bytes_offset.to_le_bytes());
161 bytes[36..40].copy_from_slice(&self.type_defs_offset.to_le_bytes());
162 bytes[40..44].copy_from_slice(&self.type_members_offset.to_le_bytes());
163 bytes[44..48].copy_from_slice(&self.entrypoints_offset.to_le_bytes());
164 bytes[48..52].copy_from_slice(&self.trivia_kinds_offset.to_le_bytes());
165 bytes[52..54].copy_from_slice(&self.negated_field_count.to_le_bytes());
167 bytes[54..56].copy_from_slice(&self.string_ref_count.to_le_bytes());
168 bytes[56..58].copy_from_slice(&self.type_def_count.to_le_bytes());
169 bytes[58..60].copy_from_slice(&self.type_member_count.to_le_bytes());
170 bytes[60..62].copy_from_slice(&self.entrypoint_count.to_le_bytes());
171 bytes[62..64].copy_from_slice(&self.trivia_kind_count.to_le_bytes());
172 bytes
173 }
174
175 fn from_bytes(bytes: &[u8; HEADER_SIZE]) -> Self {
176 Self {
177 magic: bytes[0..4].try_into().unwrap(),
178 version: u32::from_le_bytes(bytes[4..8].try_into().unwrap()),
179 checksum: u32::from_le_bytes(bytes[8..12].try_into().unwrap()),
180 buffer_len: u32::from_le_bytes(bytes[12..16].try_into().unwrap()),
181 successors_offset: u32::from_le_bytes(bytes[16..20].try_into().unwrap()),
182 effects_offset: u32::from_le_bytes(bytes[20..24].try_into().unwrap()),
183 negated_fields_offset: u32::from_le_bytes(bytes[24..28].try_into().unwrap()),
184 string_refs_offset: u32::from_le_bytes(bytes[28..32].try_into().unwrap()),
185 string_bytes_offset: u32::from_le_bytes(bytes[32..36].try_into().unwrap()),
186 type_defs_offset: u32::from_le_bytes(bytes[36..40].try_into().unwrap()),
187 type_members_offset: u32::from_le_bytes(bytes[40..44].try_into().unwrap()),
188 entrypoints_offset: u32::from_le_bytes(bytes[44..48].try_into().unwrap()),
189 trivia_kinds_offset: u32::from_le_bytes(bytes[48..52].try_into().unwrap()),
190 negated_field_count: u16::from_le_bytes(bytes[52..54].try_into().unwrap()),
191 string_ref_count: u16::from_le_bytes(bytes[54..56].try_into().unwrap()),
192 type_def_count: u16::from_le_bytes(bytes[56..58].try_into().unwrap()),
193 type_member_count: u16::from_le_bytes(bytes[58..60].try_into().unwrap()),
194 entrypoint_count: u16::from_le_bytes(bytes[60..62].try_into().unwrap()),
195 trivia_kind_count: u16::from_le_bytes(bytes[62..64].try_into().unwrap()),
196 }
197 }
198}
199
200pub fn serialize<W: Write>(query: &CompiledQuery, mut writer: W) -> SerializeResult<()> {
202 let offsets = query.offsets();
203 let buffer = query.buffer();
204
205 let mut header = Header {
207 magic: MAGIC,
208 version: FORMAT_VERSION,
209 checksum: 0, buffer_len: buffer.len() as u32,
211 successors_offset: offsets.successors_offset,
212 effects_offset: offsets.effects_offset,
213 negated_fields_offset: offsets.negated_fields_offset,
214 string_refs_offset: offsets.string_refs_offset,
215 string_bytes_offset: offsets.string_bytes_offset,
216 type_defs_offset: offsets.type_defs_offset,
217 type_members_offset: offsets.type_members_offset,
218 entrypoints_offset: offsets.entrypoints_offset,
219 trivia_kinds_offset: offsets.trivia_kinds_offset,
220 negated_field_count: query.negated_fields().len() as u16,
221 string_ref_count: query.string_refs().len() as u16,
222 type_def_count: query.type_defs().len() as u16,
223 type_member_count: query.type_members().len() as u16,
224 entrypoint_count: query.entrypoint_count(),
225 trivia_kind_count: query.trivia_kinds().len() as u16,
226 };
227
228 let header_bytes = header.to_bytes();
230 let mut checksum_data = Vec::with_capacity(52 + buffer.len());
231 checksum_data.extend_from_slice(&header_bytes[12..]);
232 checksum_data.extend_from_slice(buffer.as_slice());
233 header.checksum = crc32(&checksum_data);
234
235 writer.write_all(&header.to_bytes())?;
237 writer.write_all(buffer.as_slice())?;
238
239 Ok(())
240}
241
242pub fn to_bytes(query: &CompiledQuery) -> SerializeResult<Vec<u8>> {
244 let mut bytes = Vec::with_capacity(HEADER_SIZE + query.buffer().len());
245 serialize(query, &mut bytes)?;
246 Ok(bytes)
247}
248
249pub fn deserialize<R: Read>(mut reader: R) -> SerializeResult<CompiledQuery> {
251 let mut header_bytes = [0u8; HEADER_SIZE];
253 reader.read_exact(&mut header_bytes)?;
254
255 let header = Header::from_bytes(&header_bytes);
256
257 if header.magic != MAGIC {
259 return Err(SerializeError::InvalidMagic(header.magic));
260 }
261
262 if header.version != FORMAT_VERSION {
264 return Err(SerializeError::VersionMismatch {
265 expected: FORMAT_VERSION,
266 found: header.version,
267 });
268 }
269
270 let buffer_len = header.buffer_len as usize;
272 let mut buffer = CompiledQueryBuffer::allocate(buffer_len);
273 reader.read_exact(buffer.as_mut_slice())?;
274
275 let mut checksum_data = Vec::with_capacity(52 + buffer_len);
277 checksum_data.extend_from_slice(&header_bytes[12..]);
278 checksum_data.extend_from_slice(buffer.as_slice());
279 let computed_checksum = crc32(&checksum_data);
280
281 if header.checksum != computed_checksum {
282 return Err(SerializeError::ChecksumMismatch {
283 expected: header.checksum,
284 found: computed_checksum,
285 });
286 }
287
288 let transition_count = header.successors_offset / 64;
290 let successor_count = compute_count_from_offsets(
291 header.successors_offset,
292 header.effects_offset,
293 4, );
295 let effect_count = compute_count_from_offsets(
296 header.effects_offset,
297 header.negated_fields_offset,
298 4, );
300
301 let negated_field_count = header.negated_field_count;
303 let string_ref_count = header.string_ref_count;
304 let type_def_count = header.type_def_count;
305 let type_member_count = header.type_member_count;
306 let entrypoint_count = header.entrypoint_count;
307 let trivia_kind_count = header.trivia_kind_count;
308
309 Ok(CompiledQuery::new(
310 buffer,
311 header.successors_offset,
312 header.effects_offset,
313 header.negated_fields_offset,
314 header.string_refs_offset,
315 header.string_bytes_offset,
316 header.type_defs_offset,
317 header.type_members_offset,
318 header.entrypoints_offset,
319 header.trivia_kinds_offset,
320 transition_count,
321 successor_count,
322 effect_count,
323 negated_field_count,
324 string_ref_count,
325 type_def_count,
326 type_member_count,
327 entrypoint_count,
328 trivia_kind_count,
329 ))
330}
331
332pub fn from_bytes(bytes: &[u8]) -> SerializeResult<CompiledQuery> {
334 deserialize(std::io::Cursor::new(bytes))
335}
336
337fn compute_count_from_offsets(start: u32, end: u32, element_size: u32) -> u32 {
338 if end <= start {
339 return 0;
340 }
341 (end - start) / element_size
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn crc32_known_value() {
350 let data = b"123456789";
352 let crc = crc32(data);
353 assert_eq!(crc, 0xCBF43926);
354 }
355
356 #[test]
357 fn header_roundtrip() {
358 let header = Header {
359 magic: MAGIC,
360 version: FORMAT_VERSION,
361 checksum: 0x12345678,
362 buffer_len: 1024,
363 successors_offset: 64,
364 effects_offset: 128,
365 negated_fields_offset: 256,
366 string_refs_offset: 300,
367 string_bytes_offset: 400,
368 type_defs_offset: 500,
369 type_members_offset: 600,
370 entrypoints_offset: 700,
371 trivia_kinds_offset: 800,
372 negated_field_count: 5,
373 string_ref_count: 8,
374 type_def_count: 3,
375 type_member_count: 12,
376 entrypoint_count: 2,
377 trivia_kind_count: 1,
378 };
379
380 let bytes = header.to_bytes();
381 let parsed = Header::from_bytes(&bytes);
382
383 assert_eq!(parsed.magic, header.magic);
384 assert_eq!(parsed.version, header.version);
385 assert_eq!(parsed.checksum, header.checksum);
386 assert_eq!(parsed.buffer_len, header.buffer_len);
387 assert_eq!(parsed.successors_offset, header.successors_offset);
388 assert_eq!(parsed.trivia_kinds_offset, header.trivia_kinds_offset);
389 assert_eq!(parsed.entrypoint_count, header.entrypoint_count);
390 assert_eq!(parsed.type_def_count, header.type_def_count);
391 }
392
393 #[test]
394 fn invalid_magic_rejected() {
395 let mut data = vec![0u8; HEADER_SIZE + 64];
396 data[0..4].copy_from_slice(b"NOTM");
397
398 let result = from_bytes(&data);
399 assert!(matches!(result, Err(SerializeError::InvalidMagic(_))));
400 }
401
402 #[test]
403 fn version_mismatch_rejected() {
404 let mut data = vec![0u8; HEADER_SIZE + 64];
405 data[0..4].copy_from_slice(&MAGIC);
406 data[4..8].copy_from_slice(&999u32.to_le_bytes());
407
408 let result = from_bytes(&data);
409 assert!(matches!(
410 result,
411 Err(SerializeError::VersionMismatch { .. })
412 ));
413 }
414}