Skip to main content

ocular_protocol/
mysql.rs

1//! MySQL wire protocol parser (client command packets)
2//!
3//! MySQL packet format: [3-byte length][1-byte seq][payload]
4//! Command byte is the first byte of payload.
5
6/// Parsed MySQL packet
7#[derive(Debug, Clone)]
8pub struct MysqlPacket {
9    pub command: MysqlCommand,
10    pub payload: String,
11}
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum MysqlCommand {
15    Query,
16    StmtPrepare,
17    StmtExecute,
18    StmtClose,
19    Ping,
20    Quit,
21    InitDb,
22    FieldList,
23    Other(u8),
24}
25
26/// MySQL response type (simplified)
27#[derive(Debug, Clone)]
28pub enum MysqlResponse {
29    Ok { affected_rows: u64, message: String },
30    Error { code: u16, message: String },
31    ResultSet { columns: Vec<String>, rows: Vec<Vec<String>>, total_rows: usize },
32    Other,
33}
34
35/// Parse a MySQL client command packet. Returns summary string if parseable.
36/// Only returns Some for actual command packets (seq=0, known command byte).
37pub fn parse_mysql_request(buf: &[u8]) -> Option<MysqlPacket> {
38    // Need at least 4-byte header + 1-byte command
39    if buf.len() < 5 {
40        return None;
41    }
42    let payload_len = (buf[0] as usize) | (buf[1] as usize) << 8 | (buf[2] as usize) << 16;
43    let seq = buf[3];
44    // Command packets always start a new sequence (seq=0)
45    if seq != 0 {
46        return None;
47    }
48    if buf.len() < 4 + payload_len || payload_len == 0 {
49        return None;
50    }
51    let cmd_byte = buf[4];
52    let data = &buf[5..4 + payload_len];
53
54    let (command, payload) = match cmd_byte {
55        0x03 => (MysqlCommand::Query, String::from_utf8_lossy(data).replace(|c: char| c.is_control(), "")),
56        0x16 => (MysqlCommand::StmtPrepare, String::from_utf8_lossy(data).replace(|c: char| c.is_control(), "")),
57        0x17 => {
58            let stmt_id = if data.len() >= 4 {
59                u32::from_le_bytes([data[0], data[1], data[2], data[3]])
60            } else { 0 };
61            (MysqlCommand::StmtExecute, format!("stmt_id={}", stmt_id))
62        }
63        0x19 => {
64            let stmt_id = if data.len() >= 4 {
65                u32::from_le_bytes([data[0], data[1], data[2], data[3]])
66            } else { 0 };
67            (MysqlCommand::StmtClose, format!("stmt_id={}", stmt_id))
68        }
69        0x0e => {
70            // COM_PING: payload should be exactly 1 byte (just the command)
71            if payload_len != 1 { return None; }
72            (MysqlCommand::Ping, "PING".to_string())
73        }
74        0x01 => {
75            // COM_QUIT: payload should be exactly 1 byte
76            if payload_len != 1 { return None; }
77            (MysqlCommand::Quit, "QUIT".to_string())
78        }
79        0x02 => (MysqlCommand::InitDb, String::from_utf8_lossy(data).to_string()),
80        // 0x04 COM_FIELD_LIST: auto-completion noise from mysql CLI, skip
81        0x04 => return None,
82        // Unknown command bytes during handshake — skip them
83        _ => return None,
84    };
85
86    Some(MysqlPacket { command, payload })
87}
88
89/// Parse a MySQL server response packet (first packet of response).
90/// Only parses responses to commands (seq >= 1).
91pub fn parse_mysql_response(buf: &[u8]) -> Option<MysqlResponse> {
92    if buf.len() < 5 {
93        return None;
94    }
95    let payload_len = (buf[0] as usize) | (buf[1] as usize) << 8 | (buf[2] as usize) << 16;
96    let seq = buf[3];
97    // Response to a command has seq >= 1 (server increments from client's seq=0)
98    if seq == 0 {
99        return None;
100    }
101    if buf.len() < 4 + payload_len || payload_len == 0 {
102        return None;
103    }
104    let marker = buf[4];
105    match marker {
106        0x00 => {
107            // OK packet
108            let affected = read_lenenc(&buf[5..]).unwrap_or(0);
109            Some(MysqlResponse::Ok {
110                affected_rows: affected,
111                message: format!("OK ({} rows affected)", affected),
112            })
113        }
114        0xff => {
115            // ERR packet
116            let code = if buf.len() >= 7 {
117                u16::from_le_bytes([buf[5], buf[6]])
118            } else { 0 };
119            // Skip sql_state marker (#) and 5-byte state
120            let msg_start = if buf.len() > 13 && buf[7] == b'#' { 13 } else { 7 };
121            let msg = String::from_utf8_lossy(&buf[msg_start..4 + payload_len]).to_string();
122            Some(MysqlResponse::Error { code, message: format!("ERR {} {}", code, msg) })
123        }
124        _ => {
125            // Result set: first byte is column count
126            let col_count = marker as usize;
127            let (columns, rows) = parse_resultset_packets(buf, col_count);
128            let total_rows = rows.len();
129            Some(MysqlResponse::ResultSet { columns, rows, total_rows })
130        }
131    }
132}
133
134impl MysqlPacket {
135    pub fn to_summary(&self) -> String {
136        match self.command {
137            MysqlCommand::Query => {
138                // Show SQL directly, truncated
139                let truncated: String = self.payload.chars().take(120).collect();
140                if truncated.len() < self.payload.len() {
141                    format!("{}...", truncated)
142                } else {
143                    truncated
144                }
145            }
146            _ => {
147                let cmd = match self.command {
148                    MysqlCommand::StmtPrepare => "PREPARE",
149                    MysqlCommand::StmtExecute => "EXECUTE",
150                    MysqlCommand::StmtClose => "STMT_CLOSE",
151                    MysqlCommand::Ping => "PING",
152                    MysqlCommand::Quit => "QUIT",
153                    MysqlCommand::InitDb => "USE",
154                    MysqlCommand::FieldList => "FIELD_LIST",
155                    MysqlCommand::Other(c) => return format!("CMD(0x{:02x})", c),
156                    MysqlCommand::Query => "QUERY",
157                };
158                if self.payload.is_empty() || self.payload == cmd {
159                    cmd.to_string()
160                } else {
161                    format!("{} {}", cmd, self.payload)
162                }
163            }
164        }
165    }
166}
167
168impl MysqlResponse {
169    pub fn to_summary(&self) -> String {
170        match self {
171            MysqlResponse::Ok { message, .. } => message.clone(),
172            MysqlResponse::Error { message, .. } => message.clone(),
173            MysqlResponse::ResultSet { total_rows, columns, .. } => {
174                format!("ResultSet ({} rows, {} cols: {})", total_rows, columns.len(),
175                    columns.iter().take(5).cloned().collect::<Vec<_>>().join(", "))
176            }
177            MysqlResponse::Other => "...".to_string(),
178        }
179    }
180
181    /// Formatted display for detail panel
182    pub fn to_display(&self) -> String {
183        match self {
184            MysqlResponse::Ok { message, .. } => message.clone(),
185            MysqlResponse::Error { message, .. } => message.clone(),
186            MysqlResponse::ResultSet { columns, rows, total_rows } => {
187                let mut out = format!("ResultSet: {} rows\n", total_rows);
188                if !columns.is_empty() {
189                    out.push_str(&format!("Columns: {}\n", columns.join(" | ")));
190                    out.push_str(&"-".repeat(60));
191                    out.push('\n');
192                }
193                for row in rows.iter().take(20) {
194                    out.push_str(&row.join(" | "));
195                    out.push('\n');
196                }
197                if *total_rows > 20 {
198                    out.push_str(&format!("... ({} more rows)\n", total_rows - 20));
199                }
200                out
201            }
202            MysqlResponse::Other => "...".to_string(),
203        }
204    }
205}
206
207/// Read a length-encoded integer (simplified, handles 1-byte case)
208fn read_lenenc(buf: &[u8]) -> Option<u64> {
209    if buf.is_empty() { return None; }
210    match buf[0] {
211        n if n < 0xfb => Some(n as u64),
212        0xfc if buf.len() >= 3 => Some(u16::from_le_bytes([buf[1], buf[2]]) as u64),
213        0xfd if buf.len() >= 4 => Some((buf[1] as u64) | (buf[2] as u64) << 8 | (buf[3] as u64) << 16),
214        0xfe if buf.len() >= 9 => Some(u64::from_le_bytes([buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8]])),
215        _ => Some(0),
216    }
217}
218
219/// Returns (bytes_consumed, value)
220fn read_lenenc_with_size(buf: &[u8]) -> Option<(usize, u64)> {
221    if buf.is_empty() { return None; }
222    match buf[0] {
223        n if n < 0xfb => Some((1, n as u64)),
224        0xfc if buf.len() >= 3 => Some((3, u16::from_le_bytes([buf[1], buf[2]]) as u64)),
225        0xfd if buf.len() >= 4 => Some((4, (buf[1] as u64) | (buf[2] as u64) << 8 | (buf[3] as u64) << 16)),
226        0xfe if buf.len() >= 9 => Some((9, u64::from_le_bytes([buf[1], buf[2], buf[3], buf[4], buf[5], buf[6], buf[7], buf[8]]))),
227        _ => None,
228    }
229}
230
231/// Read a length-encoded string from buffer, returns (bytes_consumed, string)
232fn read_lenenc_str(buf: &[u8]) -> Option<(usize, String)> {
233    if buf.is_empty() { return None; }
234    if buf[0] == 0xfb {
235        return Some((1, "NULL".to_string()));
236    }
237    let (hdr_size, len) = read_lenenc_with_size(buf)?;
238    let len = len as usize;
239    if buf.len() < hdr_size + len { return None; }
240    let s = String::from_utf8_lossy(&buf[hdr_size..hdr_size + len]).to_string();
241    Some((hdr_size + len, s))
242}
243
244/// Skip a MySQL packet at `pos` in buffer, return next position
245fn skip_packet(buf: &[u8], pos: usize) -> Option<usize> {
246    if pos + 4 > buf.len() { return None; }
247    let pkt_len = (buf[pos] as usize) | (buf[pos+1] as usize) << 8 | (buf[pos+2] as usize) << 16;
248    let end = pos + 4 + pkt_len;
249    if end > buf.len() { None } else { Some(end) }
250}
251
252/// Parse a ResultSet from the full TCP buffer.
253/// Extracts column names and row values.
254fn parse_resultset_packets(buf: &[u8], col_count: usize) -> (Vec<String>, Vec<Vec<String>>) {
255    let mut columns = Vec::new();
256    let mut rows = Vec::new();
257
258    // Skip the first packet (column count packet, already parsed)
259    let Some(mut pos) = skip_packet(buf, 0) else { return (columns, rows) };
260
261    // Read column definition packets
262    for _ in 0..col_count {
263        if pos + 4 >= buf.len() { break; }
264        let pkt_len = (buf[pos] as usize) | (buf[pos+1] as usize) << 8 | (buf[pos+2] as usize) << 16;
265        let payload_start = pos + 4;
266        let payload_end = payload_start + pkt_len;
267        if payload_end > buf.len() { break; }
268        let payload = &buf[payload_start..payload_end];
269        // Column def: catalog(lenenc_str), schema, table, org_table, name, ...
270        // We want the 5th lenenc_str (name)
271        let mut p = 0;
272        for i in 0..5 {
273            if let Some((consumed, s)) = read_lenenc_str(&payload[p..]) {
274                if i == 4 { columns.push(s); }
275                p += consumed;
276            } else { break; }
277        }
278        pos = payload_end;
279    }
280
281    // Skip EOF packet (if present, marker 0xfe)
282    if pos + 4 < buf.len() {
283        let pkt_len = (buf[pos] as usize) | (buf[pos+1] as usize) << 8 | (buf[pos+2] as usize) << 16;
284        let marker = if pos + 4 < buf.len() { buf[pos + 4] } else { 0 };
285        if marker == 0xfe && pkt_len < 9 {
286            pos = pos + 4 + pkt_len;
287        }
288    }
289
290    // Read row packets (text protocol: each field is a lenenc_str)
291    let max_rows = 10000; // parse all, truncate at display time
292    loop {
293        if pos + 4 >= buf.len() { break; }
294        let pkt_len = (buf[pos] as usize) | (buf[pos+1] as usize) << 8 | (buf[pos+2] as usize) << 16;
295        let payload_start = pos + 4;
296        let payload_end = payload_start + pkt_len;
297        if payload_end > buf.len() { break; }
298        let marker = buf[payload_start];
299        // EOF or OK packet signals end
300        if (marker == 0xfe && pkt_len < 9) || marker == 0x00 { break; }
301        // ERR packet
302        if marker == 0xff { break; }
303
304        if rows.len() < max_rows {
305            let payload = &buf[payload_start..payload_end];
306            let mut row = Vec::new();
307            let mut p = 0;
308            for _ in 0..col_count {
309                if let Some((consumed, s)) = read_lenenc_str(&payload[p..]) {
310                    row.push(s);
311                    p += consumed;
312                } else { break; }
313            }
314            rows.push(row);
315        }
316        pos = payload_end;
317        if rows.len() >= max_rows { break; }
318    }
319
320    (columns, rows)
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_parse_query() {
329        // COM_QUERY "SELECT 1"
330        let sql = b"SELECT 1";
331        let mut pkt = vec![
332            (sql.len() + 1) as u8, 0, 0, // 3-byte length
333            0,                             // sequence
334            0x03,                          // COM_QUERY
335        ];
336        pkt.extend_from_slice(sql);
337        let result = parse_mysql_request(&pkt).unwrap();
338        assert_eq!(result.command, MysqlCommand::Query);
339        assert_eq!(result.to_summary(), "SELECT 1");
340    }
341
342    #[test]
343    fn test_parse_ok_response() {
344        // OK packet: 0 affected rows
345        let pkt = vec![7, 0, 0, 1, 0x00, 0, 0, 0x02, 0, 0, 0];
346        let resp = parse_mysql_response(&pkt).unwrap();
347        assert!(matches!(resp, MysqlResponse::Ok { .. }));
348    }
349}
350
351/// Check if a MySQL response buffer is complete (ends with OK/EOF/ERR packet).
352pub fn mysql_response_complete(buf: &[u8]) -> bool {
353    if buf.len() < 5 { return false; }
354    let first_marker = buf[4];
355    match first_marker {
356        0x00 | 0xff => return true,
357        _ => {}
358    }
359    let mut pos = 0;
360    let mut last_marker = 0u8;
361    let mut last_pkt_len = 0usize;
362    while pos + 4 <= buf.len() {
363        let pkt_len = (buf[pos] as usize) | (buf[pos+1] as usize) << 8 | (buf[pos+2] as usize) << 16;
364        let end = pos + 4 + pkt_len;
365        if end > buf.len() { break; }
366        if pkt_len > 0 {
367            last_marker = buf[pos + 4];
368            last_pkt_len = pkt_len;
369        }
370        pos = end;
371    }
372    (last_marker == 0xfe && last_pkt_len < 9) || (last_marker == 0x00 && last_pkt_len < 16 && pos == buf.len())
373}