Skip to main content

sqlmodel_mysql/protocol/
reader.rs

1//! MySQL packet reading utilities.
2//!
3//! This module provides utilities for reading MySQL protocol data types
4//! including length-encoded integers and strings.
5
6#![allow(clippy::cast_possible_truncation)]
7
8use crate::protocol::{EofPacket, ErrPacket, OkPacket, PacketHeader};
9
10/// A reader for MySQL protocol data.
11#[derive(Debug)]
12pub struct PacketReader<'a> {
13    data: &'a [u8],
14    pos: usize,
15}
16
17impl<'a> PacketReader<'a> {
18    /// Create a new reader from a byte slice.
19    pub fn new(data: &'a [u8]) -> Self {
20        Self { data, pos: 0 }
21    }
22
23    /// Get remaining bytes in the buffer.
24    pub fn remaining(&self) -> usize {
25        self.data.len().saturating_sub(self.pos)
26    }
27
28    /// Check if we've reached the end of the data.
29    pub fn is_empty(&self) -> bool {
30        self.pos >= self.data.len()
31    }
32
33    /// Peek at the next byte without advancing.
34    pub fn peek(&self) -> Option<u8> {
35        self.data.get(self.pos).copied()
36    }
37
38    /// Read a single byte.
39    pub fn read_u8(&mut self) -> Option<u8> {
40        let byte = self.data.get(self.pos)?;
41        self.pos += 1;
42        Some(*byte)
43    }
44
45    /// Read a u16 (little-endian).
46    pub fn read_u16_le(&mut self) -> Option<u16> {
47        if self.remaining() < 2 {
48            return None;
49        }
50        let value = u16::from_le_bytes([self.data[self.pos], self.data[self.pos + 1]]);
51        self.pos += 2;
52        Some(value)
53    }
54
55    /// Read a u24 (little-endian, 3 bytes).
56    pub fn read_u24_le(&mut self) -> Option<u32> {
57        if self.remaining() < 3 {
58            return None;
59        }
60        let value = u32::from(self.data[self.pos])
61            | (u32::from(self.data[self.pos + 1]) << 8)
62            | (u32::from(self.data[self.pos + 2]) << 16);
63        self.pos += 3;
64        Some(value)
65    }
66
67    /// Read a u32 (little-endian).
68    pub fn read_u32_le(&mut self) -> Option<u32> {
69        if self.remaining() < 4 {
70            return None;
71        }
72        let value = u32::from_le_bytes([
73            self.data[self.pos],
74            self.data[self.pos + 1],
75            self.data[self.pos + 2],
76            self.data[self.pos + 3],
77        ]);
78        self.pos += 4;
79        Some(value)
80    }
81
82    /// Read a u64 (little-endian).
83    pub fn read_u64_le(&mut self) -> Option<u64> {
84        if self.remaining() < 8 {
85            return None;
86        }
87        let value = u64::from_le_bytes([
88            self.data[self.pos],
89            self.data[self.pos + 1],
90            self.data[self.pos + 2],
91            self.data[self.pos + 3],
92            self.data[self.pos + 4],
93            self.data[self.pos + 5],
94            self.data[self.pos + 6],
95            self.data[self.pos + 7],
96        ]);
97        self.pos += 8;
98        Some(value)
99    }
100
101    /// Read a length-encoded integer.
102    ///
103    /// MySQL uses a variable-length integer encoding:
104    /// - 0x00-0xFA: 1-byte value
105    /// - 0xFC: 2-byte value follows
106    /// - 0xFD: 3-byte value follows
107    /// - 0xFE: 8-byte value follows
108    /// - 0xFB: NULL (special case for length-encoded strings)
109    pub fn read_lenenc_int(&mut self) -> Option<u64> {
110        let first = self.read_u8()?;
111        match first {
112            0x00..=0xFA => Some(u64::from(first)),
113            0xFC => self.read_u16_le().map(u64::from),
114            0xFD => self.read_u24_le().map(u64::from),
115            0xFE => self.read_u64_le(),
116            0xFB => None, // NULL marker
117            0xFF => None, // Reserved/error
118        }
119    }
120
121    /// Read a length-encoded string.
122    pub fn read_lenenc_string(&mut self) -> Option<String> {
123        let len = self.read_lenenc_int()? as usize;
124        self.read_string(len)
125    }
126
127    /// Read a length-encoded byte slice.
128    pub fn read_lenenc_bytes(&mut self) -> Option<Vec<u8>> {
129        let len = self.read_lenenc_int()? as usize;
130        self.read_bytes(len).map(|b| b.to_vec())
131    }
132
133    /// Read a null-terminated string.
134    pub fn read_null_string(&mut self) -> Option<String> {
135        let start = self.pos;
136        while self.pos < self.data.len() && self.data[self.pos] != 0 {
137            self.pos += 1;
138        }
139        let s = String::from_utf8_lossy(&self.data[start..self.pos]).into_owned();
140        // Skip the null terminator
141        if self.pos < self.data.len() {
142            self.pos += 1;
143        }
144        Some(s)
145    }
146
147    /// Read a fixed-length string.
148    pub fn read_string(&mut self, len: usize) -> Option<String> {
149        let bytes = self.read_bytes(len)?;
150        Some(String::from_utf8_lossy(bytes).into_owned())
151    }
152
153    /// Read remaining data as a string.
154    pub fn read_rest_string(&mut self) -> String {
155        let s = String::from_utf8_lossy(&self.data[self.pos..]).into_owned();
156        self.pos = self.data.len();
157        s
158    }
159
160    /// Read a fixed number of bytes.
161    pub fn read_bytes(&mut self, len: usize) -> Option<&'a [u8]> {
162        if self.remaining() < len {
163            return None;
164        }
165        let bytes = &self.data[self.pos..self.pos + len];
166        self.pos += len;
167        Some(bytes)
168    }
169
170    /// Read remaining bytes.
171    pub fn read_rest(&mut self) -> &'a [u8] {
172        let rest = &self.data[self.pos..];
173        self.pos = self.data.len();
174        rest
175    }
176
177    /// Skip a number of bytes.
178    pub fn skip(&mut self, n: usize) -> bool {
179        if self.remaining() >= n {
180            self.pos += n;
181            true
182        } else {
183            false
184        }
185    }
186
187    /// Read a packet header from raw bytes.
188    pub fn read_packet_header(&mut self) -> Option<PacketHeader> {
189        if self.remaining() < 4 {
190            return None;
191        }
192        let mut header_bytes = [0u8; 4];
193        header_bytes.copy_from_slice(&self.data[self.pos..self.pos + 4]);
194        self.pos += 4;
195        Some(PacketHeader::from_bytes(&header_bytes))
196    }
197
198    /// Parse an OK packet from the current position.
199    ///
200    /// OK packet format (protocol 4.1+):
201    /// - 0x00 header (already consumed)
202    /// - affected_rows: lenenc int
203    /// - last_insert_id: lenenc int
204    /// - status_flags: 2 bytes
205    /// - warnings: 2 bytes
206    /// - info: rest of packet (optional)
207    pub fn parse_ok_packet(&mut self) -> Option<OkPacket> {
208        // Skip the 0x00 marker if present
209        if self.peek() == Some(0x00) {
210            self.skip(1);
211        }
212
213        let affected_rows = self.read_lenenc_int()?;
214        let last_insert_id = self.read_lenenc_int()?;
215        let status_flags = self.read_u16_le()?;
216        let warnings = self.read_u16_le()?;
217        let info = if self.remaining() > 0 {
218            self.read_rest_string()
219        } else {
220            String::new()
221        };
222
223        Some(OkPacket {
224            affected_rows,
225            last_insert_id,
226            status_flags,
227            warnings,
228            info,
229        })
230    }
231
232    /// Parse an Error packet from the current position.
233    ///
234    /// ERR packet format (protocol 4.1+):
235    /// - 0xFF header (already consumed)
236    /// - error_code: 2 bytes
237    /// - '#' marker
238    /// - sql_state: 5 bytes
239    /// - error_message: rest of packet
240    pub fn parse_err_packet(&mut self) -> Option<ErrPacket> {
241        // Skip the 0xFF marker if present
242        if self.peek() == Some(0xFF) {
243            self.skip(1);
244        }
245
246        let error_code = self.read_u16_le()?;
247
248        // Check for '#' marker (SQL state follows)
249        let sql_state = if self.peek() == Some(b'#') {
250            self.skip(1);
251            self.read_string(5)?
252        } else {
253            String::new()
254        };
255
256        let error_message = self.read_rest_string();
257
258        Some(ErrPacket {
259            error_code,
260            sql_state,
261            error_message,
262        })
263    }
264
265    /// Parse an EOF packet from the current position.
266    ///
267    /// EOF packet format:
268    /// - 0xFE header (already consumed)
269    /// - warnings: 2 bytes
270    /// - status_flags: 2 bytes
271    pub fn parse_eof_packet(&mut self) -> Option<EofPacket> {
272        // Skip the 0xFE marker if present
273        if self.peek() == Some(0xFE) {
274            self.skip(1);
275        }
276
277        let warnings = self.read_u16_le()?;
278        let status_flags = self.read_u16_le()?;
279
280        Some(EofPacket {
281            warnings,
282            status_flags,
283        })
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_read_u8() {
293        let mut reader = PacketReader::new(&[0x42, 0x43]);
294        assert_eq!(reader.read_u8(), Some(0x42));
295        assert_eq!(reader.read_u8(), Some(0x43));
296        assert_eq!(reader.read_u8(), None);
297    }
298
299    #[test]
300    fn test_read_u16_le() {
301        let mut reader = PacketReader::new(&[0x34, 0x12]);
302        assert_eq!(reader.read_u16_le(), Some(0x1234));
303    }
304
305    #[test]
306    fn test_read_u24_le() {
307        let mut reader = PacketReader::new(&[0x56, 0x34, 0x12]);
308        assert_eq!(reader.read_u24_le(), Some(0x0012_3456));
309    }
310
311    #[test]
312    fn test_read_u32_le() {
313        let mut reader = PacketReader::new(&[0x78, 0x56, 0x34, 0x12]);
314        assert_eq!(reader.read_u32_le(), Some(0x1234_5678));
315    }
316
317    #[test]
318    fn test_read_u64_le() {
319        let mut reader = PacketReader::new(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
320        assert_eq!(reader.read_u64_le(), Some(0x0807_0605_0403_0201));
321    }
322
323    #[test]
324    fn test_read_lenenc_int() {
325        // 1-byte value
326        let mut reader = PacketReader::new(&[0x42]);
327        assert_eq!(reader.read_lenenc_int(), Some(0x42));
328
329        // 2-byte value
330        let mut reader = PacketReader::new(&[0xFC, 0x34, 0x12]);
331        assert_eq!(reader.read_lenenc_int(), Some(0x1234));
332
333        // 3-byte value
334        let mut reader = PacketReader::new(&[0xFD, 0x56, 0x34, 0x12]);
335        assert_eq!(reader.read_lenenc_int(), Some(0x0012_3456));
336
337        // 8-byte value
338        let mut reader = PacketReader::new(&[0xFE, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
339        assert_eq!(reader.read_lenenc_int(), Some(0x0807_0605_0403_0201));
340    }
341
342    #[test]
343    fn test_read_null_string() {
344        let mut reader = PacketReader::new(b"hello\0world\0");
345        assert_eq!(reader.read_null_string(), Some("hello".to_string()));
346        assert_eq!(reader.read_null_string(), Some("world".to_string()));
347    }
348
349    #[test]
350    fn test_read_lenenc_string() {
351        // Length-prefixed string
352        let mut reader = PacketReader::new(&[0x05, b'h', b'e', b'l', b'l', b'o']);
353        assert_eq!(reader.read_lenenc_string(), Some("hello".to_string()));
354    }
355
356    #[test]
357    fn test_parse_ok_packet() {
358        // OK packet: affected_rows=1, last_insert_id=42, status=2, warnings=0
359        let data = [0x00, 0x01, 0x2A, 0x02, 0x00, 0x00, 0x00];
360        let mut reader = PacketReader::new(&data);
361        let ok = reader.parse_ok_packet().unwrap();
362        assert_eq!(ok.affected_rows, 1);
363        assert_eq!(ok.last_insert_id, 42);
364        assert_eq!(ok.status_flags, 2);
365        assert_eq!(ok.warnings, 0);
366    }
367
368    #[test]
369    fn test_parse_err_packet() {
370        // ERR packet: error_code=1045, sql_state=28000, message="Access denied"
371        let mut data = vec![0xFF, 0x15, 0x04, b'#'];
372        data.extend_from_slice(b"28000");
373        data.extend_from_slice(b"Access denied");
374        let mut reader = PacketReader::new(&data);
375        let err = reader.parse_err_packet().unwrap();
376        assert_eq!(err.error_code, 1045);
377        assert_eq!(err.sql_state, "28000");
378        assert_eq!(err.error_message, "Access denied");
379    }
380
381    #[test]
382    fn test_parse_eof_packet() {
383        // EOF packet: warnings=0, status=2
384        let data = [0xFE, 0x00, 0x00, 0x02, 0x00];
385        let mut reader = PacketReader::new(&data);
386        let eof = reader.parse_eof_packet().unwrap();
387        assert_eq!(eof.warnings, 0);
388        assert_eq!(eof.status_flags, 2);
389    }
390}