nirv_engine/protocol/
mysql_protocol.rs

1use async_trait::async_trait;
2use tokio::io::{AsyncReadExt, AsyncWriteExt};
3use tokio::net::TcpStream;
4
5use crate::protocol::{ProtocolAdapter, ProtocolType, Connection, Credentials, ProtocolQuery, ProtocolResponse};
6use crate::utils::{NirvResult, ProtocolError, QueryResult, ColumnMetadata, Row, Value, DataType};
7
8/// MySQL protocol version
9const MYSQL_PROTOCOL_VERSION: u8 = 10;
10
11/// MySQL server capabilities flags
12const CLIENT_LONG_PASSWORD: u32 = 0x00000001;
13const CLIENT_FOUND_ROWS: u32 = 0x00000002;
14const CLIENT_LONG_FLAG: u32 = 0x00000004;
15const CLIENT_CONNECT_WITH_DB: u32 = 0x00000008;
16const CLIENT_NO_SCHEMA: u32 = 0x00000010;
17const CLIENT_COMPRESS: u32 = 0x00000020;
18const CLIENT_ODBC: u32 = 0x00000040;
19const CLIENT_LOCAL_FILES: u32 = 0x00000080;
20const CLIENT_IGNORE_SPACE: u32 = 0x00000100;
21const CLIENT_PROTOCOL_41: u32 = 0x00000200;
22const CLIENT_INTERACTIVE: u32 = 0x00000400;
23const CLIENT_SSL: u32 = 0x00000800;
24const CLIENT_IGNORE_SIGPIPE: u32 = 0x00001000;
25const CLIENT_TRANSACTIONS: u32 = 0x00002000;
26const CLIENT_RESERVED: u32 = 0x00004000;
27const CLIENT_SECURE_CONNECTION: u32 = 0x00008000;
28const CLIENT_MULTI_STATEMENTS: u32 = 0x00010000;
29const CLIENT_MULTI_RESULTS: u32 = 0x00020000;
30
31/// MySQL command types
32#[derive(Debug, Clone, PartialEq)]
33pub enum MySQLCommand {
34    Sleep = 0x00,
35    Quit = 0x01,
36    InitDB = 0x02,
37    Query = 0x03,
38    FieldList = 0x04,
39    CreateDB = 0x05,
40    DropDB = 0x06,
41    Refresh = 0x07,
42    Shutdown = 0x08,
43    Statistics = 0x09,
44    ProcessInfo = 0x0a,
45    Connect = 0x0b,
46    ProcessKill = 0x0c,
47    Debug = 0x0d,
48    Ping = 0x0e,
49    Time = 0x0f,
50    DelayedInsert = 0x10,
51    ChangeUser = 0x11,
52    BinlogDump = 0x12,
53    TableDump = 0x13,
54    ConnectOut = 0x14,
55    RegisterSlave = 0x15,
56    StmtPrepare = 0x16,
57    StmtExecute = 0x17,
58    StmtSendLongData = 0x18,
59    StmtClose = 0x19,
60    StmtReset = 0x1a,
61    SetOption = 0x1b,
62    StmtFetch = 0x1c,
63}
64
65/// MySQL field types
66#[derive(Debug, Clone, PartialEq)]
67pub enum MySQLFieldType {
68    Decimal = 0x00,
69    Tiny = 0x01,
70    Short = 0x02,
71    Long = 0x03,
72    Float = 0x04,
73    Double = 0x05,
74    Null = 0x06,
75    Timestamp = 0x07,
76    LongLong = 0x08,
77    Int24 = 0x09,
78    Date = 0x0a,
79    Time = 0x0b,
80    DateTime = 0x0c,
81    Year = 0x0d,
82    NewDate = 0x0e,
83    VarChar = 0x0f,
84    Bit = 0x10,
85    NewDecimal = 0xf6,
86    Enum = 0xf7,
87    Set = 0xf8,
88    TinyBlob = 0xf9,
89    MediumBlob = 0xfa,
90    LongBlob = 0xfb,
91    Blob = 0xfc,
92    VarString = 0xfd,
93    String = 0xfe,
94    Geometry = 0xff,
95}
96
97/// MySQL protocol adapter implementation
98#[derive(Debug)]
99pub struct MySQLProtocolAdapter {
100    server_version: String,
101    connection_id: u32,
102    capabilities: u32,
103}
104
105impl MySQLProtocolAdapter {
106    /// Create a new MySQL protocol adapter
107    pub fn new() -> Self {
108        Self {
109            server_version: "8.0.0-NIRV".to_string(),
110            connection_id: 1,
111            capabilities: CLIENT_LONG_PASSWORD
112                | CLIENT_FOUND_ROWS
113                | CLIENT_LONG_FLAG
114                | CLIENT_CONNECT_WITH_DB
115                | CLIENT_NO_SCHEMA
116                | CLIENT_PROTOCOL_41
117                | CLIENT_TRANSACTIONS
118                | CLIENT_SECURE_CONNECTION
119                | CLIENT_MULTI_STATEMENTS
120                | CLIENT_MULTI_RESULTS,
121        }
122    }
123    
124    /// Create initial handshake packet
125    fn create_handshake_packet(&self) -> Vec<u8> {
126        let mut packet = Vec::new();
127        
128        // Protocol version
129        packet.push(MYSQL_PROTOCOL_VERSION);
130        
131        // Server version (null-terminated)
132        packet.extend_from_slice(self.server_version.as_bytes());
133        packet.push(0);
134        
135        // Connection ID (4 bytes, little-endian)
136        packet.extend_from_slice(&self.connection_id.to_le_bytes());
137        
138        // Auth plugin data part 1 (8 bytes)
139        packet.extend_from_slice(b"12345678");
140        
141        // Filler (1 byte)
142        packet.push(0);
143        
144        // Capability flags lower 2 bytes
145        packet.extend_from_slice(&(self.capabilities as u16).to_le_bytes());
146        
147        // Character set (1 byte) - UTF-8
148        packet.push(0x21);
149        
150        // Status flags (2 bytes)
151        packet.extend_from_slice(&0u16.to_le_bytes());
152        
153        // Capability flags upper 2 bytes
154        packet.extend_from_slice(&((self.capabilities >> 16) as u16).to_le_bytes());
155        
156        // Auth plugin data length (1 byte)
157        packet.push(21);
158        
159        // Reserved (10 bytes)
160        packet.extend_from_slice(&[0; 10]);
161        
162        // Auth plugin data part 2 (12 bytes + null terminator)
163        packet.extend_from_slice(b"123456789012");
164        packet.push(0);
165        
166        // Auth plugin name (null-terminated)
167        packet.extend_from_slice(b"mysql_native_password");
168        packet.push(0);
169        
170        self.wrap_packet(&packet, 0)
171    }
172    
173    /// Wrap data in MySQL packet format
174    fn wrap_packet(&self, data: &[u8], sequence_id: u8) -> Vec<u8> {
175        let mut packet = Vec::new();
176        
177        // Packet length (3 bytes, little-endian)
178        let length = data.len() as u32;
179        packet.push((length & 0xff) as u8);
180        packet.push(((length >> 8) & 0xff) as u8);
181        packet.push(((length >> 16) & 0xff) as u8);
182        
183        // Sequence ID (1 byte)
184        packet.push(sequence_id);
185        
186        // Packet data
187        packet.extend_from_slice(data);
188        
189        packet
190    }
191    
192    /// Parse handshake response from client
193    fn parse_handshake_response(&self, data: &[u8]) -> NirvResult<(String, String, String)> {
194        if data.len() < 32 {
195            return Err(ProtocolError::InvalidMessageFormat("Handshake response too short".to_string()).into());
196        }
197        
198        let mut pos = 4; // Skip packet header
199        
200        // Client capabilities (4 bytes)
201        let _client_capabilities = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
202        pos += 4;
203        
204        // Max packet size (4 bytes)
205        let _max_packet_size = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
206        pos += 4;
207        
208        // Character set (1 byte)
209        let _charset = data[pos];
210        pos += 1;
211        
212        // Reserved (23 bytes)
213        pos += 23;
214        
215        // Username (null-terminated)
216        let username_start = pos;
217        while pos < data.len() && data[pos] != 0 {
218            pos += 1;
219        }
220        let username = String::from_utf8_lossy(&data[username_start..pos]).to_string();
221        pos += 1; // Skip null terminator
222        
223        // Password length (1 byte)
224        if pos >= data.len() {
225            return Err(ProtocolError::InvalidMessageFormat("Missing password length".to_string()).into());
226        }
227        let password_len = data[pos] as usize;
228        pos += 1;
229        
230        // Password (password_len bytes)
231        let password = if password_len > 0 {
232            if pos + password_len > data.len() {
233                return Err(ProtocolError::InvalidMessageFormat("Password data truncated".to_string()).into());
234            }
235            String::from_utf8_lossy(&data[pos..pos + password_len]).to_string()
236        } else {
237            String::new()
238        };
239        pos += password_len;
240        
241        // Database (null-terminated, optional)
242        let database = if pos < data.len() {
243            let db_start = pos;
244            while pos < data.len() && data[pos] != 0 {
245                pos += 1;
246            }
247            String::from_utf8_lossy(&data[db_start..pos]).to_string()
248        } else {
249            String::new()
250        };
251        
252        Ok((username, password, database))
253    }
254    
255    /// Create OK packet
256    fn create_ok_packet(&self, affected_rows: u64, last_insert_id: u64) -> Vec<u8> {
257        let mut packet = Vec::new();
258        
259        // OK packet header
260        packet.push(0x00);
261        
262        // Affected rows (length-encoded integer)
263        self.write_length_encoded_integer(&mut packet, affected_rows);
264        
265        // Last insert ID (length-encoded integer)
266        self.write_length_encoded_integer(&mut packet, last_insert_id);
267        
268        // Status flags (2 bytes)
269        packet.extend_from_slice(&0u16.to_le_bytes());
270        
271        // Warnings (2 bytes)
272        packet.extend_from_slice(&0u16.to_le_bytes());
273        
274        self.wrap_packet(&packet, 2)
275    }
276    
277    /// Create error packet
278    fn create_error_packet(&self, error_code: u16, message: &str) -> Vec<u8> {
279        let mut packet = Vec::new();
280        
281        // Error packet header
282        packet.push(0xff);
283        
284        // Error code (2 bytes, little-endian)
285        packet.extend_from_slice(&error_code.to_le_bytes());
286        
287        // SQL state marker
288        packet.push(b'#');
289        
290        // SQL state (5 bytes)
291        packet.extend_from_slice(b"HY000");
292        
293        // Error message
294        packet.extend_from_slice(message.as_bytes());
295        
296        self.wrap_packet(&packet, 1)
297    }
298    
299    /// Create result set header
300    fn create_result_set_header(&self, column_count: usize) -> Vec<u8> {
301        let mut packet = Vec::new();
302        
303        // Column count (length-encoded integer)
304        self.write_length_encoded_integer(&mut packet, column_count as u64);
305        
306        self.wrap_packet(&packet, 1)
307    }
308    
309    /// Create column definition packet
310    fn create_column_definition(&self, column: &ColumnMetadata, sequence_id: u8) -> Vec<u8> {
311        let mut packet = Vec::new();
312        
313        // Catalog (length-encoded string)
314        self.write_length_encoded_string(&mut packet, "def");
315        
316        // Schema (length-encoded string)
317        self.write_length_encoded_string(&mut packet, "");
318        
319        // Table (length-encoded string)
320        self.write_length_encoded_string(&mut packet, "");
321        
322        // Original table (length-encoded string)
323        self.write_length_encoded_string(&mut packet, "");
324        
325        // Name (length-encoded string)
326        self.write_length_encoded_string(&mut packet, &column.name);
327        
328        // Original name (length-encoded string)
329        self.write_length_encoded_string(&mut packet, &column.name);
330        
331        // Length of fixed-length fields (1 byte)
332        packet.push(0x0c);
333        
334        // Character set (2 bytes)
335        packet.extend_from_slice(&0x21u16.to_le_bytes()); // UTF-8
336        
337        // Column length (4 bytes)
338        packet.extend_from_slice(&0u32.to_le_bytes());
339        
340        // Column type
341        let field_type = self.nirv_type_to_mysql_type(&column.data_type);
342        packet.push(field_type as u8);
343        
344        // Flags (2 bytes)
345        let flags: u16 = if column.nullable { 0 } else { 1 }; // NOT_NULL flag
346        packet.extend_from_slice(&flags.to_le_bytes());
347        
348        // Decimals (1 byte)
349        packet.push(0);
350        
351        // Reserved (2 bytes)
352        packet.extend_from_slice(&0u16.to_le_bytes());
353        
354        self.wrap_packet(&packet, sequence_id)
355    }
356    
357    /// Create EOF packet
358    fn create_eof_packet(&self, sequence_id: u8) -> Vec<u8> {
359        let mut packet = Vec::new();
360        
361        // EOF packet header
362        packet.push(0xfe);
363        
364        // Warnings (2 bytes)
365        packet.extend_from_slice(&0u16.to_le_bytes());
366        
367        // Status flags (2 bytes)
368        packet.extend_from_slice(&0u16.to_le_bytes());
369        
370        self.wrap_packet(&packet, sequence_id)
371    }
372    
373    /// Create row data packet
374    fn create_row_packet(&self, row: &Row, sequence_id: u8) -> Vec<u8> {
375        let mut packet = Vec::new();
376        
377        for value in &row.values {
378            match value {
379                Value::Null => {
380                    packet.push(0xfb); // NULL value
381                }
382                _ => {
383                    let value_str = self.value_to_string(value);
384                    self.write_length_encoded_string(&mut packet, &value_str);
385                }
386            }
387        }
388        
389        self.wrap_packet(&packet, sequence_id)
390    }
391    
392    /// Write length-encoded integer
393    fn write_length_encoded_integer(&self, buffer: &mut Vec<u8>, value: u64) {
394        if value < 251 {
395            buffer.push(value as u8);
396        } else if value < 65536 {
397            buffer.push(0xfc);
398            buffer.extend_from_slice(&(value as u16).to_le_bytes());
399        } else if value < 16777216 {
400            buffer.push(0xfd);
401            buffer.push((value & 0xff) as u8);
402            buffer.push(((value >> 8) & 0xff) as u8);
403            buffer.push(((value >> 16) & 0xff) as u8);
404        } else {
405            buffer.push(0xfe);
406            buffer.extend_from_slice(&value.to_le_bytes());
407        }
408    }
409    
410    /// Write length-encoded string
411    fn write_length_encoded_string(&self, buffer: &mut Vec<u8>, value: &str) {
412        let bytes = value.as_bytes();
413        self.write_length_encoded_integer(buffer, bytes.len() as u64);
414        buffer.extend_from_slice(bytes);
415    }
416    
417    /// Convert NIRV data type to MySQL field type
418    fn nirv_type_to_mysql_type(&self, data_type: &DataType) -> MySQLFieldType {
419        match data_type {
420            DataType::Text => MySQLFieldType::VarString,
421            DataType::Integer => MySQLFieldType::LongLong,
422            DataType::Float => MySQLFieldType::Double,
423            DataType::Boolean => MySQLFieldType::Tiny,
424            DataType::Date => MySQLFieldType::Date,
425            DataType::DateTime => MySQLFieldType::DateTime,
426            DataType::Json => MySQLFieldType::VarString,
427            DataType::Binary => MySQLFieldType::Blob,
428        }
429    }
430    
431    /// Convert NIRV Value to MySQL string representation
432    fn value_to_string(&self, value: &Value) -> String {
433        match value {
434            Value::Text(s) => s.clone(),
435            Value::Integer(i) => i.to_string(),
436            Value::Float(f) => f.to_string(),
437            Value::Boolean(b) => if *b { "1".to_string() } else { "0".to_string() },
438            Value::Date(d) => d.clone(),
439            Value::DateTime(dt) => dt.clone(),
440            Value::Json(j) => j.clone(),
441            Value::Binary(b) => {
442                // Simple hex encoding
443                let mut hex_string = String::with_capacity(b.len() * 2);
444                for byte in b {
445                    hex_string.push_str(&format!("{:02x}", byte));
446                }
447                hex_string
448            },
449            Value::Null => String::new(), // Should not be called for NULL values
450        }
451    }
452    
453    /// Parse MySQL command from packet
454    fn parse_command(&self, data: &[u8]) -> NirvResult<(MySQLCommand, Vec<u8>)> {
455        if data.len() < 5 {
456            return Err(ProtocolError::InvalidMessageFormat("Command packet too short".to_string()).into());
457        }
458        
459        // Skip packet header (4 bytes)
460        let command_byte = data[4];
461        let command_data = &data[5..];
462        
463        let command = match command_byte {
464            0x01 => MySQLCommand::Quit,
465            0x02 => MySQLCommand::InitDB,
466            0x03 => MySQLCommand::Query,
467            0x0e => MySQLCommand::Ping,
468            _ => return Err(ProtocolError::UnsupportedFeature(format!("Command {} not supported", command_byte)).into()),
469        };
470        
471        Ok((command, command_data.to_vec()))
472    }
473}
474
475impl Default for MySQLProtocolAdapter {
476    fn default() -> Self {
477        Self::new()
478    }
479}
480
481#[async_trait]
482impl ProtocolAdapter for MySQLProtocolAdapter {
483    async fn accept_connection(&self, stream: TcpStream) -> NirvResult<Connection> {
484        let mut connection = Connection::new(stream, ProtocolType::MySQL);
485        
486        // Send initial handshake packet
487        let handshake = self.create_handshake_packet();
488        connection.stream.write_all(&handshake).await
489            .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send handshake: {}", e)))?;
490        
491        Ok(connection)
492    }
493    
494    async fn authenticate(&self, conn: &mut Connection, credentials: Credentials) -> NirvResult<()> {
495        // Read handshake response
496        let mut buffer = vec![0u8; 8192];
497        let bytes_read = conn.stream.read(&mut buffer).await
498            .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to read handshake response: {}", e)))?;
499        
500        if bytes_read < 32 {
501            return Err(ProtocolError::InvalidMessageFormat("Handshake response too short".to_string()).into());
502        }
503        
504        // Parse handshake response
505        let (username, _password, database) = self.parse_handshake_response(&buffer[..bytes_read])?;
506        
507        // Validate credentials
508        if username != credentials.username {
509            let error_packet = self.create_error_packet(1045, "Access denied for user");
510            conn.stream.write_all(&error_packet).await
511                .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send error: {}", e)))?;
512            return Err(ProtocolError::AuthenticationFailed("Username mismatch".to_string()).into());
513        }
514        
515        if !database.is_empty() && database != credentials.database {
516            let error_packet = self.create_error_packet(1049, "Unknown database");
517            conn.stream.write_all(&error_packet).await
518                .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send error: {}", e)))?;
519            return Err(ProtocolError::AuthenticationFailed("Database mismatch".to_string()).into());
520        }
521        
522        // Send OK packet
523        let ok_packet = self.create_ok_packet(0, 0);
524        conn.stream.write_all(&ok_packet).await
525            .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send OK packet: {}", e)))?;
526        
527        // Update connection state
528        conn.authenticated = true;
529        conn.database = if database.is_empty() { credentials.database } else { database };
530        conn.parameters.insert("user".to_string(), username);
531        
532        Ok(())
533    }
534    
535    async fn handle_query(&self, _conn: &Connection, _query: ProtocolQuery) -> NirvResult<ProtocolResponse> {
536        // For now, create a mock response
537        // In the full implementation, this would parse the query and execute it
538        let columns = vec![
539            ColumnMetadata {
540                name: "id".to_string(),
541                data_type: DataType::Integer,
542                nullable: false,
543            },
544            ColumnMetadata {
545                name: "name".to_string(),
546                data_type: DataType::Text,
547                nullable: true,
548            },
549        ];
550        
551        let rows = vec![
552            Row::new(vec![Value::Integer(1), Value::Text("Test User".to_string())]),
553            Row::new(vec![Value::Integer(2), Value::Text("Another User".to_string())]),
554        ];
555        
556        let result = QueryResult {
557            columns,
558            rows,
559            affected_rows: Some(2),
560            execution_time: std::time::Duration::from_millis(10),
561        };
562        
563        Ok(ProtocolResponse::new(result, ProtocolType::MySQL))
564    }
565    
566    fn get_protocol_type(&self) -> ProtocolType {
567        ProtocolType::MySQL
568    }
569    
570    async fn parse_message(&self, _conn: &Connection, data: &[u8]) -> NirvResult<ProtocolQuery> {
571        let (command, command_data) = self.parse_command(data)?;
572        
573        match command {
574            MySQLCommand::Query => {
575                let query_string = String::from_utf8_lossy(&command_data).to_string();
576                Ok(ProtocolQuery::new(query_string, ProtocolType::MySQL))
577            }
578            MySQLCommand::Quit => {
579                Ok(ProtocolQuery::new("QUIT".to_string(), ProtocolType::MySQL))
580            }
581            MySQLCommand::Ping => {
582                Ok(ProtocolQuery::new("PING".to_string(), ProtocolType::MySQL))
583            }
584            MySQLCommand::InitDB => {
585                let db_name = String::from_utf8_lossy(&command_data).to_string();
586                Ok(ProtocolQuery::new(format!("USE {}", db_name), ProtocolType::MySQL))
587            }
588            _ => {
589                Err(ProtocolError::UnsupportedFeature(format!("Command {:?} not supported", command)).into())
590            }
591        }
592    }
593    
594    async fn format_response(&self, _conn: &Connection, result: QueryResult) -> NirvResult<Vec<u8>> {
595        let mut response = Vec::new();
596        
597        if result.columns.is_empty() {
598            // OK packet for non-SELECT queries
599            let ok_packet = self.create_ok_packet(result.affected_rows.unwrap_or(0), 0);
600            response.extend_from_slice(&ok_packet);
601        } else {
602            // Result set for SELECT queries
603            
604            // Result set header
605            let header = self.create_result_set_header(result.columns.len());
606            response.extend_from_slice(&header);
607            
608            // Column definitions
609            for (i, column) in result.columns.iter().enumerate() {
610                let col_def = self.create_column_definition(column, (i + 2) as u8);
611                response.extend_from_slice(&col_def);
612            }
613            
614            // EOF packet after column definitions
615            let eof1 = self.create_eof_packet((result.columns.len() + 2) as u8);
616            response.extend_from_slice(&eof1);
617            
618            // Row data
619            for (i, row) in result.rows.iter().enumerate() {
620                let row_packet = self.create_row_packet(row, (result.columns.len() + 3 + i) as u8);
621                response.extend_from_slice(&row_packet);
622            }
623            
624            // EOF packet after rows
625            let eof2 = self.create_eof_packet((result.columns.len() + 3 + result.rows.len()) as u8);
626            response.extend_from_slice(&eof2);
627        }
628        
629        Ok(response)
630    }
631    
632    async fn terminate_connection(&self, conn: &mut Connection) -> NirvResult<()> {
633        conn.stream.shutdown().await
634            .map_err(|_e| ProtocolError::ConnectionClosed)?;
635        Ok(())
636    }
637}