1use async_trait::async_trait;
2use tokio::io::{AsyncReadExt, AsyncWriteExt};
3use tokio::net::TcpStream;
4
5use crate::protocol::{ProtocolAdapter, ProtocolType, Connection, Credentials, ProtocolQuery, ProtocolResponse};
6use crate::utils::{NirvResult, ProtocolError, QueryResult, ColumnMetadata, Row, Value, DataType};
7
8const MYSQL_PROTOCOL_VERSION: u8 = 10;
10
11const CLIENT_LONG_PASSWORD: u32 = 0x00000001;
13const CLIENT_FOUND_ROWS: u32 = 0x00000002;
14const CLIENT_LONG_FLAG: u32 = 0x00000004;
15const CLIENT_CONNECT_WITH_DB: u32 = 0x00000008;
16const CLIENT_NO_SCHEMA: u32 = 0x00000010;
17const CLIENT_COMPRESS: u32 = 0x00000020;
18const CLIENT_ODBC: u32 = 0x00000040;
19const CLIENT_LOCAL_FILES: u32 = 0x00000080;
20const CLIENT_IGNORE_SPACE: u32 = 0x00000100;
21const CLIENT_PROTOCOL_41: u32 = 0x00000200;
22const CLIENT_INTERACTIVE: u32 = 0x00000400;
23const CLIENT_SSL: u32 = 0x00000800;
24const CLIENT_IGNORE_SIGPIPE: u32 = 0x00001000;
25const CLIENT_TRANSACTIONS: u32 = 0x00002000;
26const CLIENT_RESERVED: u32 = 0x00004000;
27const CLIENT_SECURE_CONNECTION: u32 = 0x00008000;
28const CLIENT_MULTI_STATEMENTS: u32 = 0x00010000;
29const CLIENT_MULTI_RESULTS: u32 = 0x00020000;
30
31#[derive(Debug, Clone, PartialEq)]
33pub enum MySQLCommand {
34 Sleep = 0x00,
35 Quit = 0x01,
36 InitDB = 0x02,
37 Query = 0x03,
38 FieldList = 0x04,
39 CreateDB = 0x05,
40 DropDB = 0x06,
41 Refresh = 0x07,
42 Shutdown = 0x08,
43 Statistics = 0x09,
44 ProcessInfo = 0x0a,
45 Connect = 0x0b,
46 ProcessKill = 0x0c,
47 Debug = 0x0d,
48 Ping = 0x0e,
49 Time = 0x0f,
50 DelayedInsert = 0x10,
51 ChangeUser = 0x11,
52 BinlogDump = 0x12,
53 TableDump = 0x13,
54 ConnectOut = 0x14,
55 RegisterSlave = 0x15,
56 StmtPrepare = 0x16,
57 StmtExecute = 0x17,
58 StmtSendLongData = 0x18,
59 StmtClose = 0x19,
60 StmtReset = 0x1a,
61 SetOption = 0x1b,
62 StmtFetch = 0x1c,
63}
64
65#[derive(Debug, Clone, PartialEq)]
67pub enum MySQLFieldType {
68 Decimal = 0x00,
69 Tiny = 0x01,
70 Short = 0x02,
71 Long = 0x03,
72 Float = 0x04,
73 Double = 0x05,
74 Null = 0x06,
75 Timestamp = 0x07,
76 LongLong = 0x08,
77 Int24 = 0x09,
78 Date = 0x0a,
79 Time = 0x0b,
80 DateTime = 0x0c,
81 Year = 0x0d,
82 NewDate = 0x0e,
83 VarChar = 0x0f,
84 Bit = 0x10,
85 NewDecimal = 0xf6,
86 Enum = 0xf7,
87 Set = 0xf8,
88 TinyBlob = 0xf9,
89 MediumBlob = 0xfa,
90 LongBlob = 0xfb,
91 Blob = 0xfc,
92 VarString = 0xfd,
93 String = 0xfe,
94 Geometry = 0xff,
95}
96
97#[derive(Debug)]
99pub struct MySQLProtocolAdapter {
100 server_version: String,
101 connection_id: u32,
102 capabilities: u32,
103}
104
105impl MySQLProtocolAdapter {
106 pub fn new() -> Self {
108 Self {
109 server_version: "8.0.0-NIRV".to_string(),
110 connection_id: 1,
111 capabilities: CLIENT_LONG_PASSWORD
112 | CLIENT_FOUND_ROWS
113 | CLIENT_LONG_FLAG
114 | CLIENT_CONNECT_WITH_DB
115 | CLIENT_NO_SCHEMA
116 | CLIENT_PROTOCOL_41
117 | CLIENT_TRANSACTIONS
118 | CLIENT_SECURE_CONNECTION
119 | CLIENT_MULTI_STATEMENTS
120 | CLIENT_MULTI_RESULTS,
121 }
122 }
123
124 fn create_handshake_packet(&self) -> Vec<u8> {
126 let mut packet = Vec::new();
127
128 packet.push(MYSQL_PROTOCOL_VERSION);
130
131 packet.extend_from_slice(self.server_version.as_bytes());
133 packet.push(0);
134
135 packet.extend_from_slice(&self.connection_id.to_le_bytes());
137
138 packet.extend_from_slice(b"12345678");
140
141 packet.push(0);
143
144 packet.extend_from_slice(&(self.capabilities as u16).to_le_bytes());
146
147 packet.push(0x21);
149
150 packet.extend_from_slice(&0u16.to_le_bytes());
152
153 packet.extend_from_slice(&((self.capabilities >> 16) as u16).to_le_bytes());
155
156 packet.push(21);
158
159 packet.extend_from_slice(&[0; 10]);
161
162 packet.extend_from_slice(b"123456789012");
164 packet.push(0);
165
166 packet.extend_from_slice(b"mysql_native_password");
168 packet.push(0);
169
170 self.wrap_packet(&packet, 0)
171 }
172
173 fn wrap_packet(&self, data: &[u8], sequence_id: u8) -> Vec<u8> {
175 let mut packet = Vec::new();
176
177 let length = data.len() as u32;
179 packet.push((length & 0xff) as u8);
180 packet.push(((length >> 8) & 0xff) as u8);
181 packet.push(((length >> 16) & 0xff) as u8);
182
183 packet.push(sequence_id);
185
186 packet.extend_from_slice(data);
188
189 packet
190 }
191
192 fn parse_handshake_response(&self, data: &[u8]) -> NirvResult<(String, String, String)> {
194 if data.len() < 32 {
195 return Err(ProtocolError::InvalidMessageFormat("Handshake response too short".to_string()).into());
196 }
197
198 let mut pos = 4; let _client_capabilities = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
202 pos += 4;
203
204 let _max_packet_size = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
206 pos += 4;
207
208 let _charset = data[pos];
210 pos += 1;
211
212 pos += 23;
214
215 let username_start = pos;
217 while pos < data.len() && data[pos] != 0 {
218 pos += 1;
219 }
220 let username = String::from_utf8_lossy(&data[username_start..pos]).to_string();
221 pos += 1; if pos >= data.len() {
225 return Err(ProtocolError::InvalidMessageFormat("Missing password length".to_string()).into());
226 }
227 let password_len = data[pos] as usize;
228 pos += 1;
229
230 let password = if password_len > 0 {
232 if pos + password_len > data.len() {
233 return Err(ProtocolError::InvalidMessageFormat("Password data truncated".to_string()).into());
234 }
235 String::from_utf8_lossy(&data[pos..pos + password_len]).to_string()
236 } else {
237 String::new()
238 };
239 pos += password_len;
240
241 let database = if pos < data.len() {
243 let db_start = pos;
244 while pos < data.len() && data[pos] != 0 {
245 pos += 1;
246 }
247 String::from_utf8_lossy(&data[db_start..pos]).to_string()
248 } else {
249 String::new()
250 };
251
252 Ok((username, password, database))
253 }
254
255 fn create_ok_packet(&self, affected_rows: u64, last_insert_id: u64) -> Vec<u8> {
257 let mut packet = Vec::new();
258
259 packet.push(0x00);
261
262 self.write_length_encoded_integer(&mut packet, affected_rows);
264
265 self.write_length_encoded_integer(&mut packet, last_insert_id);
267
268 packet.extend_from_slice(&0u16.to_le_bytes());
270
271 packet.extend_from_slice(&0u16.to_le_bytes());
273
274 self.wrap_packet(&packet, 2)
275 }
276
277 fn create_error_packet(&self, error_code: u16, message: &str) -> Vec<u8> {
279 let mut packet = Vec::new();
280
281 packet.push(0xff);
283
284 packet.extend_from_slice(&error_code.to_le_bytes());
286
287 packet.push(b'#');
289
290 packet.extend_from_slice(b"HY000");
292
293 packet.extend_from_slice(message.as_bytes());
295
296 self.wrap_packet(&packet, 1)
297 }
298
299 fn create_result_set_header(&self, column_count: usize) -> Vec<u8> {
301 let mut packet = Vec::new();
302
303 self.write_length_encoded_integer(&mut packet, column_count as u64);
305
306 self.wrap_packet(&packet, 1)
307 }
308
309 fn create_column_definition(&self, column: &ColumnMetadata, sequence_id: u8) -> Vec<u8> {
311 let mut packet = Vec::new();
312
313 self.write_length_encoded_string(&mut packet, "def");
315
316 self.write_length_encoded_string(&mut packet, "");
318
319 self.write_length_encoded_string(&mut packet, "");
321
322 self.write_length_encoded_string(&mut packet, "");
324
325 self.write_length_encoded_string(&mut packet, &column.name);
327
328 self.write_length_encoded_string(&mut packet, &column.name);
330
331 packet.push(0x0c);
333
334 packet.extend_from_slice(&0x21u16.to_le_bytes()); packet.extend_from_slice(&0u32.to_le_bytes());
339
340 let field_type = self.nirv_type_to_mysql_type(&column.data_type);
342 packet.push(field_type as u8);
343
344 let flags: u16 = if column.nullable { 0 } else { 1 }; packet.extend_from_slice(&flags.to_le_bytes());
347
348 packet.push(0);
350
351 packet.extend_from_slice(&0u16.to_le_bytes());
353
354 self.wrap_packet(&packet, sequence_id)
355 }
356
357 fn create_eof_packet(&self, sequence_id: u8) -> Vec<u8> {
359 let mut packet = Vec::new();
360
361 packet.push(0xfe);
363
364 packet.extend_from_slice(&0u16.to_le_bytes());
366
367 packet.extend_from_slice(&0u16.to_le_bytes());
369
370 self.wrap_packet(&packet, sequence_id)
371 }
372
373 fn create_row_packet(&self, row: &Row, sequence_id: u8) -> Vec<u8> {
375 let mut packet = Vec::new();
376
377 for value in &row.values {
378 match value {
379 Value::Null => {
380 packet.push(0xfb); }
382 _ => {
383 let value_str = self.value_to_string(value);
384 self.write_length_encoded_string(&mut packet, &value_str);
385 }
386 }
387 }
388
389 self.wrap_packet(&packet, sequence_id)
390 }
391
392 fn write_length_encoded_integer(&self, buffer: &mut Vec<u8>, value: u64) {
394 if value < 251 {
395 buffer.push(value as u8);
396 } else if value < 65536 {
397 buffer.push(0xfc);
398 buffer.extend_from_slice(&(value as u16).to_le_bytes());
399 } else if value < 16777216 {
400 buffer.push(0xfd);
401 buffer.push((value & 0xff) as u8);
402 buffer.push(((value >> 8) & 0xff) as u8);
403 buffer.push(((value >> 16) & 0xff) as u8);
404 } else {
405 buffer.push(0xfe);
406 buffer.extend_from_slice(&value.to_le_bytes());
407 }
408 }
409
410 fn write_length_encoded_string(&self, buffer: &mut Vec<u8>, value: &str) {
412 let bytes = value.as_bytes();
413 self.write_length_encoded_integer(buffer, bytes.len() as u64);
414 buffer.extend_from_slice(bytes);
415 }
416
417 fn nirv_type_to_mysql_type(&self, data_type: &DataType) -> MySQLFieldType {
419 match data_type {
420 DataType::Text => MySQLFieldType::VarString,
421 DataType::Integer => MySQLFieldType::LongLong,
422 DataType::Float => MySQLFieldType::Double,
423 DataType::Boolean => MySQLFieldType::Tiny,
424 DataType::Date => MySQLFieldType::Date,
425 DataType::DateTime => MySQLFieldType::DateTime,
426 DataType::Json => MySQLFieldType::VarString,
427 DataType::Binary => MySQLFieldType::Blob,
428 }
429 }
430
431 fn value_to_string(&self, value: &Value) -> String {
433 match value {
434 Value::Text(s) => s.clone(),
435 Value::Integer(i) => i.to_string(),
436 Value::Float(f) => f.to_string(),
437 Value::Boolean(b) => if *b { "1".to_string() } else { "0".to_string() },
438 Value::Date(d) => d.clone(),
439 Value::DateTime(dt) => dt.clone(),
440 Value::Json(j) => j.clone(),
441 Value::Binary(b) => {
442 let mut hex_string = String::with_capacity(b.len() * 2);
444 for byte in b {
445 hex_string.push_str(&format!("{:02x}", byte));
446 }
447 hex_string
448 },
449 Value::Null => String::new(), }
451 }
452
453 fn parse_command(&self, data: &[u8]) -> NirvResult<(MySQLCommand, Vec<u8>)> {
455 if data.len() < 5 {
456 return Err(ProtocolError::InvalidMessageFormat("Command packet too short".to_string()).into());
457 }
458
459 let command_byte = data[4];
461 let command_data = &data[5..];
462
463 let command = match command_byte {
464 0x01 => MySQLCommand::Quit,
465 0x02 => MySQLCommand::InitDB,
466 0x03 => MySQLCommand::Query,
467 0x0e => MySQLCommand::Ping,
468 _ => return Err(ProtocolError::UnsupportedFeature(format!("Command {} not supported", command_byte)).into()),
469 };
470
471 Ok((command, command_data.to_vec()))
472 }
473}
474
475impl Default for MySQLProtocolAdapter {
476 fn default() -> Self {
477 Self::new()
478 }
479}
480
481#[async_trait]
482impl ProtocolAdapter for MySQLProtocolAdapter {
483 async fn accept_connection(&self, stream: TcpStream) -> NirvResult<Connection> {
484 let mut connection = Connection::new(stream, ProtocolType::MySQL);
485
486 let handshake = self.create_handshake_packet();
488 connection.stream.write_all(&handshake).await
489 .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send handshake: {}", e)))?;
490
491 Ok(connection)
492 }
493
494 async fn authenticate(&self, conn: &mut Connection, credentials: Credentials) -> NirvResult<()> {
495 let mut buffer = vec![0u8; 8192];
497 let bytes_read = conn.stream.read(&mut buffer).await
498 .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to read handshake response: {}", e)))?;
499
500 if bytes_read < 32 {
501 return Err(ProtocolError::InvalidMessageFormat("Handshake response too short".to_string()).into());
502 }
503
504 let (username, _password, database) = self.parse_handshake_response(&buffer[..bytes_read])?;
506
507 if username != credentials.username {
509 let error_packet = self.create_error_packet(1045, "Access denied for user");
510 conn.stream.write_all(&error_packet).await
511 .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send error: {}", e)))?;
512 return Err(ProtocolError::AuthenticationFailed("Username mismatch".to_string()).into());
513 }
514
515 if !database.is_empty() && database != credentials.database {
516 let error_packet = self.create_error_packet(1049, "Unknown database");
517 conn.stream.write_all(&error_packet).await
518 .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send error: {}", e)))?;
519 return Err(ProtocolError::AuthenticationFailed("Database mismatch".to_string()).into());
520 }
521
522 let ok_packet = self.create_ok_packet(0, 0);
524 conn.stream.write_all(&ok_packet).await
525 .map_err(|e| ProtocolError::ConnectionFailed(format!("Failed to send OK packet: {}", e)))?;
526
527 conn.authenticated = true;
529 conn.database = if database.is_empty() { credentials.database } else { database };
530 conn.parameters.insert("user".to_string(), username);
531
532 Ok(())
533 }
534
535 async fn handle_query(&self, _conn: &Connection, _query: ProtocolQuery) -> NirvResult<ProtocolResponse> {
536 let columns = vec![
539 ColumnMetadata {
540 name: "id".to_string(),
541 data_type: DataType::Integer,
542 nullable: false,
543 },
544 ColumnMetadata {
545 name: "name".to_string(),
546 data_type: DataType::Text,
547 nullable: true,
548 },
549 ];
550
551 let rows = vec![
552 Row::new(vec![Value::Integer(1), Value::Text("Test User".to_string())]),
553 Row::new(vec![Value::Integer(2), Value::Text("Another User".to_string())]),
554 ];
555
556 let result = QueryResult {
557 columns,
558 rows,
559 affected_rows: Some(2),
560 execution_time: std::time::Duration::from_millis(10),
561 };
562
563 Ok(ProtocolResponse::new(result, ProtocolType::MySQL))
564 }
565
566 fn get_protocol_type(&self) -> ProtocolType {
567 ProtocolType::MySQL
568 }
569
570 async fn parse_message(&self, _conn: &Connection, data: &[u8]) -> NirvResult<ProtocolQuery> {
571 let (command, command_data) = self.parse_command(data)?;
572
573 match command {
574 MySQLCommand::Query => {
575 let query_string = String::from_utf8_lossy(&command_data).to_string();
576 Ok(ProtocolQuery::new(query_string, ProtocolType::MySQL))
577 }
578 MySQLCommand::Quit => {
579 Ok(ProtocolQuery::new("QUIT".to_string(), ProtocolType::MySQL))
580 }
581 MySQLCommand::Ping => {
582 Ok(ProtocolQuery::new("PING".to_string(), ProtocolType::MySQL))
583 }
584 MySQLCommand::InitDB => {
585 let db_name = String::from_utf8_lossy(&command_data).to_string();
586 Ok(ProtocolQuery::new(format!("USE {}", db_name), ProtocolType::MySQL))
587 }
588 _ => {
589 Err(ProtocolError::UnsupportedFeature(format!("Command {:?} not supported", command)).into())
590 }
591 }
592 }
593
594 async fn format_response(&self, _conn: &Connection, result: QueryResult) -> NirvResult<Vec<u8>> {
595 let mut response = Vec::new();
596
597 if result.columns.is_empty() {
598 let ok_packet = self.create_ok_packet(result.affected_rows.unwrap_or(0), 0);
600 response.extend_from_slice(&ok_packet);
601 } else {
602 let header = self.create_result_set_header(result.columns.len());
606 response.extend_from_slice(&header);
607
608 for (i, column) in result.columns.iter().enumerate() {
610 let col_def = self.create_column_definition(column, (i + 2) as u8);
611 response.extend_from_slice(&col_def);
612 }
613
614 let eof1 = self.create_eof_packet((result.columns.len() + 2) as u8);
616 response.extend_from_slice(&eof1);
617
618 for (i, row) in result.rows.iter().enumerate() {
620 let row_packet = self.create_row_packet(row, (result.columns.len() + 3 + i) as u8);
621 response.extend_from_slice(&row_packet);
622 }
623
624 let eof2 = self.create_eof_packet((result.columns.len() + 3 + result.rows.len()) as u8);
626 response.extend_from_slice(&eof2);
627 }
628
629 Ok(response)
630 }
631
632 async fn terminate_connection(&self, conn: &mut Connection) -> NirvResult<()> {
633 conn.stream.shutdown().await
634 .map_err(|_e| ProtocolError::ConnectionClosed)?;
635 Ok(())
636 }
637}