1#![allow(clippy::cast_possible_truncation)]
7
8use crate::protocol::{EofPacket, ErrPacket, OkPacket, PacketHeader};
9
10#[derive(Debug)]
12pub struct PacketReader<'a> {
13 data: &'a [u8],
14 pos: usize,
15}
16
17impl<'a> PacketReader<'a> {
18 pub fn new(data: &'a [u8]) -> Self {
20 Self { data, pos: 0 }
21 }
22
23 pub fn remaining(&self) -> usize {
25 self.data.len().saturating_sub(self.pos)
26 }
27
28 pub fn is_empty(&self) -> bool {
30 self.pos >= self.data.len()
31 }
32
33 pub fn peek(&self) -> Option<u8> {
35 self.data.get(self.pos).copied()
36 }
37
38 pub fn read_u8(&mut self) -> Option<u8> {
40 let byte = self.data.get(self.pos)?;
41 self.pos += 1;
42 Some(*byte)
43 }
44
45 pub fn read_u16_le(&mut self) -> Option<u16> {
47 if self.remaining() < 2 {
48 return None;
49 }
50 let value = u16::from_le_bytes([self.data[self.pos], self.data[self.pos + 1]]);
51 self.pos += 2;
52 Some(value)
53 }
54
55 pub fn read_u24_le(&mut self) -> Option<u32> {
57 if self.remaining() < 3 {
58 return None;
59 }
60 let value = u32::from(self.data[self.pos])
61 | (u32::from(self.data[self.pos + 1]) << 8)
62 | (u32::from(self.data[self.pos + 2]) << 16);
63 self.pos += 3;
64 Some(value)
65 }
66
67 pub fn read_u32_le(&mut self) -> Option<u32> {
69 if self.remaining() < 4 {
70 return None;
71 }
72 let value = u32::from_le_bytes([
73 self.data[self.pos],
74 self.data[self.pos + 1],
75 self.data[self.pos + 2],
76 self.data[self.pos + 3],
77 ]);
78 self.pos += 4;
79 Some(value)
80 }
81
82 pub fn read_u64_le(&mut self) -> Option<u64> {
84 if self.remaining() < 8 {
85 return None;
86 }
87 let value = u64::from_le_bytes([
88 self.data[self.pos],
89 self.data[self.pos + 1],
90 self.data[self.pos + 2],
91 self.data[self.pos + 3],
92 self.data[self.pos + 4],
93 self.data[self.pos + 5],
94 self.data[self.pos + 6],
95 self.data[self.pos + 7],
96 ]);
97 self.pos += 8;
98 Some(value)
99 }
100
101 pub fn read_lenenc_int(&mut self) -> Option<u64> {
110 let first = self.read_u8()?;
111 match first {
112 0x00..=0xFA => Some(u64::from(first)),
113 0xFC => self.read_u16_le().map(u64::from),
114 0xFD => self.read_u24_le().map(u64::from),
115 0xFE => self.read_u64_le(),
116 0xFB => None, 0xFF => None, }
119 }
120
121 pub fn read_lenenc_string(&mut self) -> Option<String> {
123 let len = self.read_lenenc_int()? as usize;
124 self.read_string(len)
125 }
126
127 pub fn read_lenenc_bytes(&mut self) -> Option<Vec<u8>> {
129 let len = self.read_lenenc_int()? as usize;
130 self.read_bytes(len).map(|b| b.to_vec())
131 }
132
133 pub fn read_null_string(&mut self) -> Option<String> {
135 let start = self.pos;
136 while self.pos < self.data.len() && self.data[self.pos] != 0 {
137 self.pos += 1;
138 }
139 let s = String::from_utf8_lossy(&self.data[start..self.pos]).into_owned();
140 if self.pos < self.data.len() {
142 self.pos += 1;
143 }
144 Some(s)
145 }
146
147 pub fn read_string(&mut self, len: usize) -> Option<String> {
149 let bytes = self.read_bytes(len)?;
150 Some(String::from_utf8_lossy(bytes).into_owned())
151 }
152
153 pub fn read_rest_string(&mut self) -> String {
155 let s = String::from_utf8_lossy(&self.data[self.pos..]).into_owned();
156 self.pos = self.data.len();
157 s
158 }
159
160 pub fn read_bytes(&mut self, len: usize) -> Option<&'a [u8]> {
162 if self.remaining() < len {
163 return None;
164 }
165 let bytes = &self.data[self.pos..self.pos + len];
166 self.pos += len;
167 Some(bytes)
168 }
169
170 pub fn read_rest(&mut self) -> &'a [u8] {
172 let rest = &self.data[self.pos..];
173 self.pos = self.data.len();
174 rest
175 }
176
177 pub fn skip(&mut self, n: usize) -> bool {
179 if self.remaining() >= n {
180 self.pos += n;
181 true
182 } else {
183 false
184 }
185 }
186
187 pub fn read_packet_header(&mut self) -> Option<PacketHeader> {
189 if self.remaining() < 4 {
190 return None;
191 }
192 let mut header_bytes = [0u8; 4];
193 header_bytes.copy_from_slice(&self.data[self.pos..self.pos + 4]);
194 self.pos += 4;
195 Some(PacketHeader::from_bytes(&header_bytes))
196 }
197
198 pub fn parse_ok_packet(&mut self) -> Option<OkPacket> {
208 if self.peek() == Some(0x00) {
210 self.skip(1);
211 }
212
213 let affected_rows = self.read_lenenc_int()?;
214 let last_insert_id = self.read_lenenc_int()?;
215 let status_flags = self.read_u16_le()?;
216 let warnings = self.read_u16_le()?;
217 let info = if self.remaining() > 0 {
218 self.read_rest_string()
219 } else {
220 String::new()
221 };
222
223 Some(OkPacket {
224 affected_rows,
225 last_insert_id,
226 status_flags,
227 warnings,
228 info,
229 })
230 }
231
232 pub fn parse_err_packet(&mut self) -> Option<ErrPacket> {
241 if self.peek() == Some(0xFF) {
243 self.skip(1);
244 }
245
246 let error_code = self.read_u16_le()?;
247
248 let sql_state = if self.peek() == Some(b'#') {
250 self.skip(1);
251 self.read_string(5)?
252 } else {
253 String::new()
254 };
255
256 let error_message = self.read_rest_string();
257
258 Some(ErrPacket {
259 error_code,
260 sql_state,
261 error_message,
262 })
263 }
264
265 pub fn parse_eof_packet(&mut self) -> Option<EofPacket> {
272 if self.peek() == Some(0xFE) {
274 self.skip(1);
275 }
276
277 let warnings = self.read_u16_le()?;
278 let status_flags = self.read_u16_le()?;
279
280 Some(EofPacket {
281 warnings,
282 status_flags,
283 })
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_read_u8() {
293 let mut reader = PacketReader::new(&[0x42, 0x43]);
294 assert_eq!(reader.read_u8(), Some(0x42));
295 assert_eq!(reader.read_u8(), Some(0x43));
296 assert_eq!(reader.read_u8(), None);
297 }
298
299 #[test]
300 fn test_read_u16_le() {
301 let mut reader = PacketReader::new(&[0x34, 0x12]);
302 assert_eq!(reader.read_u16_le(), Some(0x1234));
303 }
304
305 #[test]
306 fn test_read_u24_le() {
307 let mut reader = PacketReader::new(&[0x56, 0x34, 0x12]);
308 assert_eq!(reader.read_u24_le(), Some(0x0012_3456));
309 }
310
311 #[test]
312 fn test_read_u32_le() {
313 let mut reader = PacketReader::new(&[0x78, 0x56, 0x34, 0x12]);
314 assert_eq!(reader.read_u32_le(), Some(0x1234_5678));
315 }
316
317 #[test]
318 fn test_read_u64_le() {
319 let mut reader = PacketReader::new(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
320 assert_eq!(reader.read_u64_le(), Some(0x0807_0605_0403_0201));
321 }
322
323 #[test]
324 fn test_read_lenenc_int() {
325 let mut reader = PacketReader::new(&[0x42]);
327 assert_eq!(reader.read_lenenc_int(), Some(0x42));
328
329 let mut reader = PacketReader::new(&[0xFC, 0x34, 0x12]);
331 assert_eq!(reader.read_lenenc_int(), Some(0x1234));
332
333 let mut reader = PacketReader::new(&[0xFD, 0x56, 0x34, 0x12]);
335 assert_eq!(reader.read_lenenc_int(), Some(0x0012_3456));
336
337 let mut reader = PacketReader::new(&[0xFE, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
339 assert_eq!(reader.read_lenenc_int(), Some(0x0807_0605_0403_0201));
340 }
341
342 #[test]
343 fn test_read_null_string() {
344 let mut reader = PacketReader::new(b"hello\0world\0");
345 assert_eq!(reader.read_null_string(), Some("hello".to_string()));
346 assert_eq!(reader.read_null_string(), Some("world".to_string()));
347 }
348
349 #[test]
350 fn test_read_lenenc_string() {
351 let mut reader = PacketReader::new(&[0x05, b'h', b'e', b'l', b'l', b'o']);
353 assert_eq!(reader.read_lenenc_string(), Some("hello".to_string()));
354 }
355
356 #[test]
357 fn test_parse_ok_packet() {
358 let data = [0x00, 0x01, 0x2A, 0x02, 0x00, 0x00, 0x00];
360 let mut reader = PacketReader::new(&data);
361 let ok = reader.parse_ok_packet().unwrap();
362 assert_eq!(ok.affected_rows, 1);
363 assert_eq!(ok.last_insert_id, 42);
364 assert_eq!(ok.status_flags, 2);
365 assert_eq!(ok.warnings, 0);
366 }
367
368 #[test]
369 fn test_parse_err_packet() {
370 let mut data = vec![0xFF, 0x15, 0x04, b'#'];
372 data.extend_from_slice(b"28000");
373 data.extend_from_slice(b"Access denied");
374 let mut reader = PacketReader::new(&data);
375 let err = reader.parse_err_packet().unwrap();
376 assert_eq!(err.error_code, 1045);
377 assert_eq!(err.sql_state, "28000");
378 assert_eq!(err.error_message, "Access denied");
379 }
380
381 #[test]
382 fn test_parse_eof_packet() {
383 let data = [0xFE, 0x00, 0x00, 0x02, 0x00];
385 let mut reader = PacketReader::new(&data);
386 let eof = reader.parse_eof_packet().unwrap();
387 assert_eq!(eof.warnings, 0);
388 assert_eq!(eof.status_flags, 2);
389 }
390}