1#![allow(dead_code, unused_variables)]
12
13use bytes::{Buf, BufMut, Bytes, BytesMut};
14use sha2::{Digest, Sha256};
15use std::collections::HashMap;
16use std::io::ErrorKind;
17use std::sync::{Arc, OnceLock};
18use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
19use tokio::net::TcpStream;
20#[cfg(unix)]
21use tokio::net::UnixStream;
22use tracing::{debug, error, info, warn};
23
24use regex::Regex;
25
26use crate::{EmbeddedDatabase, Tuple, Value};
27
28const PROTOCOL_VERSION: u8 = 10;
33const SERVER_VERSION: &str = "8.0.35-HeliosDB-Nano";
34
35const UTF8MB4_GENERAL_CI: u8 = 45;
37
38#[derive(Debug, Clone, Copy)]
46pub struct CapabilityFlags(u32);
47
48impl CapabilityFlags {
49 pub const CLIENT_LONG_PASSWORD: u32 = 0x0000_0001;
50 pub const CLIENT_FOUND_ROWS: u32 = 0x0000_0002;
51 pub const CLIENT_LONG_FLAG: u32 = 0x0000_0004;
52 pub const CLIENT_CONNECT_WITH_DB: u32 = 0x0000_0008;
53 pub const CLIENT_NO_SCHEMA: u32 = 0x0000_0010;
54 pub const CLIENT_COMPRESS: u32 = 0x0000_0020;
55 pub const CLIENT_ODBC: u32 = 0x0000_0040;
56 pub const CLIENT_LOCAL_FILES: u32 = 0x0000_0080;
57 pub const CLIENT_IGNORE_SPACE: u32 = 0x0000_0100;
58 pub const CLIENT_PROTOCOL_41: u32 = 0x0000_0200;
59 pub const CLIENT_INTERACTIVE: u32 = 0x0000_0400;
60 pub const CLIENT_SSL: u32 = 0x0000_0800;
61 pub const CLIENT_IGNORE_SIGPIPE: u32 = 0x0000_1000;
62 pub const CLIENT_TRANSACTIONS: u32 = 0x0000_2000;
63 pub const CLIENT_RESERVED: u32 = 0x0000_4000;
64 pub const CLIENT_SECURE_CONNECTION: u32 = 0x0000_8000;
65 pub const CLIENT_MULTI_STATEMENTS: u32 = 0x0001_0000;
66 pub const CLIENT_MULTI_RESULTS: u32 = 0x0002_0000;
67 pub const CLIENT_PS_MULTI_RESULTS: u32 = 0x0004_0000;
68 pub const CLIENT_PLUGIN_AUTH: u32 = 0x0008_0000;
69 pub const CLIENT_CONNECT_ATTRS: u32 = 0x0010_0000;
70 pub const CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA: u32 = 0x0020_0000;
71 pub const CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS: u32 = 0x0040_0000;
72 pub const CLIENT_SESSION_TRACK: u32 = 0x0080_0000;
73 pub const CLIENT_DEPRECATE_EOF: u32 = 0x0100_0000;
74
75 pub fn new(flags: u32) -> Self {
76 Self(flags)
77 }
78
79 pub fn has(&self, flag: u32) -> bool {
80 (self.0 & flag) != 0
81 }
82
83 pub fn set(&mut self, flag: u32) {
84 self.0 |= flag;
85 }
86
87 pub fn as_u32(&self) -> u32 {
88 self.0
89 }
90
91 pub fn server_default() -> Self {
93 Self(
94 Self::CLIENT_LONG_PASSWORD
95 | Self::CLIENT_FOUND_ROWS
96 | Self::CLIENT_LONG_FLAG
97 | Self::CLIENT_CONNECT_WITH_DB
98 | Self::CLIENT_NO_SCHEMA
99 | Self::CLIENT_ODBC
100 | Self::CLIENT_LOCAL_FILES
101 | Self::CLIENT_IGNORE_SPACE
102 | Self::CLIENT_PROTOCOL_41
103 | Self::CLIENT_INTERACTIVE
104 | Self::CLIENT_IGNORE_SIGPIPE
105 | Self::CLIENT_TRANSACTIONS
106 | Self::CLIENT_SECURE_CONNECTION
107 | Self::CLIENT_MULTI_STATEMENTS
108 | Self::CLIENT_MULTI_RESULTS
109 | Self::CLIENT_PS_MULTI_RESULTS
110 | Self::CLIENT_PLUGIN_AUTH
111 | Self::CLIENT_CONNECT_ATTRS
112 | Self::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
113 | Self::CLIENT_SESSION_TRACK
114 | Self::CLIENT_DEPRECATE_EOF,
115 )
116 }
117}
118
119#[derive(Debug, Clone, Copy)]
124pub struct StatusFlags(u16);
125
126impl StatusFlags {
127 pub const SERVER_STATUS_IN_TRANS: u16 = 0x0001;
128 pub const SERVER_STATUS_AUTOCOMMIT: u16 = 0x0002;
129 pub const SERVER_MORE_RESULTS_EXISTS: u16 = 0x0008;
130
131 pub fn new(flags: u16) -> Self {
132 Self(flags)
133 }
134
135 pub fn has(&self, flag: u16) -> bool {
136 (self.0 & flag) != 0
137 }
138
139 pub fn set(&mut self, flag: u16) {
140 self.0 |= flag;
141 }
142
143 pub fn clear(&mut self, flag: u16) {
144 self.0 &= !flag;
145 }
146
147 pub fn as_u16(&self) -> u16 {
148 self.0
149 }
150
151 pub fn default_flags() -> Self {
152 Self(Self::SERVER_STATUS_AUTOCOMMIT)
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq)]
161#[repr(u8)]
162pub enum ColumnType {
163 Decimal = 0x00,
164 Tiny = 0x01,
165 Short = 0x02,
166 Long = 0x03,
167 Float = 0x04,
168 Double = 0x05,
169 Null = 0x06,
170 Timestamp = 0x07,
171 LongLong = 0x08,
172 Int24 = 0x09,
173 Date = 0x0a,
174 Time = 0x0b,
175 DateTime = 0x0c,
176 Year = 0x0d,
177 VarChar = 0x0f,
178 Bit = 0x10,
179 Json = 0xf5,
180 NewDecimal = 0xf6,
181 Blob = 0xfc,
182 VarString = 0xfd,
183 String = 0xfe,
184}
185
186impl ColumnType {
187 fn from_value(v: &Value) -> Self {
189 match v {
190 Value::Null => ColumnType::Null,
191 Value::Boolean(_) => ColumnType::Tiny,
192 Value::Int2(_) => ColumnType::Short,
193 Value::Int4(_) => ColumnType::Long,
194 Value::Int8(_) => ColumnType::LongLong,
195 Value::Float4(_) => ColumnType::Float,
196 Value::Float8(_) => ColumnType::Double,
197 Value::Numeric(_) => ColumnType::NewDecimal,
198 Value::String(_) => ColumnType::VarString,
199 Value::Bytes(_) => ColumnType::Blob,
200 Value::Uuid(_) => ColumnType::VarString,
201 Value::Timestamp(_) => ColumnType::Timestamp,
202 Value::Date(_) => ColumnType::Date,
203 Value::Time(_) => ColumnType::Time,
204 Value::Interval(_) => ColumnType::VarString,
205 Value::Json(_) => ColumnType::Json,
206 Value::Array(_) => ColumnType::Json,
207 Value::Vector(_) => ColumnType::Json,
208 Value::DictRef { .. } => ColumnType::LongLong,
209 Value::CasRef { .. } => ColumnType::VarString,
210 Value::ColumnarRef => ColumnType::VarString,
211 }
212 }
213}
214
215#[derive(Debug, Clone, Copy, PartialEq)]
220#[repr(u8)]
221pub enum Command {
222 ComQuit = 0x01,
223 ComInitDb = 0x02,
224 ComQuery = 0x03,
225 ComFieldList = 0x04,
226 ComStatistics = 0x09,
227 ComPing = 0x0e,
228 ComStmtPrepare = 0x16,
229 ComStmtExecute = 0x17,
230 ComStmtClose = 0x19,
231 ComStmtReset = 0x1a,
232 ComSetOption = 0x1b,
233 ComResetConnection = 0x1f,
234}
235
236impl Command {
237 pub fn from_u8(value: u8) -> Option<Self> {
238 match value {
239 0x01 => Some(Self::ComQuit),
240 0x02 => Some(Self::ComInitDb),
241 0x03 => Some(Self::ComQuery),
242 0x04 => Some(Self::ComFieldList),
243 0x09 => Some(Self::ComStatistics),
244 0x0e => Some(Self::ComPing),
245 0x16 => Some(Self::ComStmtPrepare),
246 0x17 => Some(Self::ComStmtExecute),
247 0x19 => Some(Self::ComStmtClose),
248 0x1a => Some(Self::ComStmtReset),
249 0x1b => Some(Self::ComSetOption),
250 0x1f => Some(Self::ComResetConnection),
251 _ => None,
252 }
253 }
254}
255
256#[derive(Debug)]
261pub enum MySqlError {
262 Io(std::io::Error),
263 Protocol(String),
264 ConnectionClosed,
265 Unsupported(u8),
266 StatementNotFound(u32),
267 Db(crate::Error),
268}
269
270impl From<std::io::Error> for MySqlError {
271 fn from(e: std::io::Error) -> Self {
272 Self::Io(e)
273 }
274}
275
276impl From<crate::Error> for MySqlError {
277 fn from(e: crate::Error) -> Self {
278 Self::Db(e)
279 }
280}
281
282impl std::fmt::Display for MySqlError {
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 match self {
285 Self::Io(e) => write!(f, "IO: {}", e),
286 Self::Protocol(msg) => write!(f, "Protocol: {}", msg),
287 Self::ConnectionClosed => write!(f, "Connection closed"),
288 Self::Unsupported(c) => write!(f, "Unsupported command: 0x{:02x}", c),
289 Self::StatementNotFound(id) => write!(f, "Statement {} not found", id),
290 Self::Db(e) => write!(f, "DB: {}", e),
291 }
292 }
293}
294
295pub type Result<T> = std::result::Result<T, MySqlError>;
296
297async fn read_packet<S: AsyncRead + Unpin>(stream: &mut S) -> Result<(u8, Bytes)> {
303 let mut hdr = [0u8; 4];
304 stream.read_exact(&mut hdr).await.map_err(|e| {
305 if e.kind() == ErrorKind::UnexpectedEof {
306 MySqlError::ConnectionClosed
307 } else {
308 MySqlError::Io(e)
309 }
310 })?;
311 let len = u32::from_le_bytes([hdr[0], hdr[1], hdr[2], 0]) as usize;
312 let seq = hdr[3];
313 let mut payload = vec![0u8; len];
314 stream.read_exact(&mut payload).await?;
315 Ok((seq, Bytes::from(payload)))
316}
317
318async fn write_packet<S: AsyncWrite + Unpin>(stream: &mut S, seq: u8, payload: &[u8]) -> Result<()> {
320 let len = payload.len() as u32;
321 let mut buf = BytesMut::with_capacity(4 + payload.len());
322 buf.put_u8((len & 0xFF) as u8);
323 buf.put_u8(((len >> 8) & 0xFF) as u8);
324 buf.put_u8(((len >> 16) & 0xFF) as u8);
325 buf.put_u8(seq);
326 buf.put_slice(payload);
327 stream.write_all(&buf).await?;
328 stream.flush().await?;
329 Ok(())
330}
331
332fn write_lenenc_int(buf: &mut BytesMut, value: u64) {
337 if value < 251 {
338 buf.put_u8(value as u8);
339 } else if value < 65536 {
340 buf.put_u8(0xFC);
341 buf.put_u16_le(value as u16);
342 } else if value < 16_777_216 {
343 buf.put_u8(0xFD);
344 buf.put_u8((value & 0xFF) as u8);
345 buf.put_u8(((value >> 8) & 0xFF) as u8);
346 buf.put_u8(((value >> 16) & 0xFF) as u8);
347 } else {
348 buf.put_u8(0xFE);
349 buf.put_u64_le(value);
350 }
351}
352
353fn write_lenenc_str(buf: &mut BytesMut, s: &str) {
354 write_lenenc_int(buf, s.len() as u64);
355 buf.put_slice(s.as_bytes());
356}
357
358fn read_lenenc_int(buf: &mut Bytes) -> Result<u64> {
359 if buf.is_empty() {
360 return Err(MySqlError::Protocol("empty buffer in lenenc_int".into()));
361 }
362 let first = buf.get_u8();
363 match first {
364 0xFB => Ok(0),
365 0xFC => {
366 if buf.remaining() < 2 {
367 return Err(MySqlError::Protocol("short lenenc_int (2)".into()));
368 }
369 Ok(u64::from(buf.get_u16_le()))
370 }
371 0xFD => {
372 if buf.remaining() < 3 {
373 return Err(MySqlError::Protocol("short lenenc_int (3)".into()));
374 }
375 let b1 = u64::from(buf.get_u8());
376 let b2 = u64::from(buf.get_u8());
377 let b3 = u64::from(buf.get_u8());
378 Ok(b1 | (b2 << 8) | (b3 << 16))
379 }
380 0xFE => {
381 if buf.remaining() < 8 {
382 return Err(MySqlError::Protocol("short lenenc_int (8)".into()));
383 }
384 Ok(buf.get_u64_le())
385 }
386 _ => Ok(u64::from(first)),
387 }
388}
389
390fn read_lenenc_str(buf: &mut Bytes) -> Result<String> {
391 let len = read_lenenc_int(buf)? as usize;
392 if buf.remaining() < len {
393 return Err(MySqlError::Protocol("short lenenc_str".into()));
394 }
395 let bytes = buf.copy_to_bytes(len);
396 String::from_utf8(bytes.to_vec())
397 .map_err(|e| MySqlError::Protocol(format!("invalid utf-8: {}", e)))
398}
399
400fn read_lenenc_bytes(buf: &mut Bytes) -> Result<Vec<u8>> {
401 let len = read_lenenc_int(buf)? as usize;
402 if buf.remaining() < len {
403 return Err(MySqlError::Protocol("short lenenc_bytes".into()));
404 }
405 Ok(buf.copy_to_bytes(len).to_vec())
406}
407
408fn read_null_terminated(buf: &mut Bytes) -> Result<String> {
409 let mut out = Vec::new();
410 loop {
411 if buf.is_empty() {
412 return Err(MySqlError::Protocol("unterminated null string".into()));
413 }
414 let b = buf.get_u8();
415 if b == 0 {
416 break;
417 }
418 out.push(b);
419 }
420 String::from_utf8(out)
421 .map_err(|e| MySqlError::Protocol(format!("invalid utf-8: {}", e)))
422}
423
424fn read_null_terminated_bytes(buf: &mut Bytes) -> Result<Vec<u8>> {
425 let mut out = Vec::new();
426 loop {
427 if buf.is_empty() {
428 return Err(MySqlError::Protocol("unterminated null bytes".into()));
429 }
430 let b = buf.get_u8();
431 if b == 0 {
432 break;
433 }
434 out.push(b);
435 }
436 Ok(out)
437}
438
439#[derive(Debug)]
444struct HandshakeResponse {
445 capability_flags: CapabilityFlags,
446 max_packet_size: u32,
447 character_set: u8,
448 username: String,
449 auth_response: Vec<u8>,
450 database: Option<String>,
451 auth_plugin_name: Option<String>,
452 connect_attrs: HashMap<String, String>,
453}
454
455impl HandshakeResponse {
456 fn decode(mut payload: Bytes, server_caps: &CapabilityFlags) -> Result<Self> {
457 if payload.remaining() < 4 {
458 return Err(MySqlError::Protocol("handshake response too short".into()));
459 }
460 let client_flags = CapabilityFlags::new(payload.get_u32_le());
461 let max_packet_size = payload.get_u32_le();
462 let character_set = payload.get_u8();
463
464 if payload.remaining() < 23 {
466 return Err(MySqlError::Protocol("handshake response too short (reserved)".into()));
467 }
468 payload.advance(23);
469
470 let username = read_null_terminated(&mut payload)?;
471
472 let auth_response =
473 if client_flags.has(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
474 read_lenenc_bytes(&mut payload)?
475 } else if client_flags.has(CapabilityFlags::CLIENT_SECURE_CONNECTION) {
476 let len = payload.get_u8() as usize;
477 if payload.remaining() < len {
478 return Err(MySqlError::Protocol("auth response truncated".into()));
479 }
480 payload.copy_to_bytes(len).to_vec()
481 } else {
482 read_null_terminated_bytes(&mut payload)?
483 };
484
485 let database =
486 if client_flags.has(CapabilityFlags::CLIENT_CONNECT_WITH_DB) && payload.has_remaining() {
487 Some(read_null_terminated(&mut payload)?)
488 } else {
489 None
490 };
491
492 let auth_plugin_name =
493 if client_flags.has(CapabilityFlags::CLIENT_PLUGIN_AUTH) && payload.has_remaining() {
494 Some(read_null_terminated(&mut payload)?)
495 } else {
496 None
497 };
498
499 let mut connect_attrs = HashMap::new();
500 if client_flags.has(CapabilityFlags::CLIENT_CONNECT_ATTRS) && payload.has_remaining() {
501 let attrs_len = read_lenenc_int(&mut payload)? as usize;
502 let mut attrs = payload.copy_to_bytes(attrs_len.min(payload.remaining()));
503 while attrs.has_remaining() {
504 let key = read_lenenc_str(&mut attrs)?;
505 let val = read_lenenc_str(&mut attrs)?;
506 connect_attrs.insert(key, val);
507 }
508 }
509
510 Ok(Self {
511 capability_flags: client_flags,
512 max_packet_size,
513 character_set,
514 username,
515 auth_response,
516 database,
517 auth_plugin_name,
518 connect_attrs,
519 })
520 }
521}
522
523#[derive(Debug, Clone)]
528struct PreparedStatement {
529 id: u32,
530 sql: String,
531 num_params: u16,
532}
533
534fn split_sql_respecting_quotes(sql: &str) -> Vec<String> {
541 let mut statements = Vec::new();
542 let mut current = String::new();
543 let mut in_single_quote = false;
544 let mut chars = sql.chars().peekable();
545
546 while let Some(ch) = chars.next() {
547 match ch {
548 '\'' if !in_single_quote => {
549 in_single_quote = true;
550 current.push(ch);
551 }
552 '\'' if in_single_quote => {
553 current.push(ch);
554 if chars.peek() == Some(&'\'') {
556 current.push(chars.next().unwrap_or('\''));
557 } else {
558 in_single_quote = false;
559 }
560 }
561 '\\' if in_single_quote => {
562 current.push(ch);
564 if let Some(next) = chars.next() {
565 current.push(next);
566 }
567 }
568 ';' if !in_single_quote => {
569 let trimmed = current.trim().to_string();
570 if !trimmed.is_empty() {
571 statements.push(trimmed);
572 }
573 current.clear();
574 }
575 _ => current.push(ch),
576 }
577 }
578
579 let trimmed = current.trim().to_string();
580 if !trimmed.is_empty() {
581 statements.push(trimmed);
582 }
583
584 statements
585}
586
587#[inline]
588fn starts_with_icase(s: &str, prefix: &str) -> bool {
589 s.len() >= prefix.len()
590 && s.as_bytes()
591 .get(..prefix.len())
592 .map_or(false, |b| b.eq_ignore_ascii_case(prefix.as_bytes()))
593}
594
595pub struct MySqlHandler<S: AsyncRead + AsyncWrite + Unpin + Send> {
608 database: Arc<EmbeddedDatabase>,
609 stream: S,
610 seq: u8,
611 connection_id: u32,
612 capabilities: CapabilityFlags,
613 status_flags: StatusFlags,
614 character_set: u8,
615 auth_seed: [u8; 20],
616 auth_plugin: String,
617 username: Option<String>,
618 current_database: Option<String>,
619 in_transaction: bool,
620 prepared_statements: HashMap<u32, PreparedStatement>,
621 next_stmt_id: u32,
622 last_row_count: u64,
623 last_insert_id: u64,
625}
626
627impl<S: AsyncRead + AsyncWrite + Unpin + Send> MySqlHandler<S> {
628 fn new(database: Arc<EmbeddedDatabase>, stream: S, connection_id: u32) -> Self {
633 let mut auth_seed = [0u8; 20];
634 use rand::Rng;
635 rand::thread_rng().fill(&mut auth_seed);
636
637 Self {
638 database,
639 stream,
640 seq: 0,
641 connection_id,
642 capabilities: CapabilityFlags::server_default(),
643 status_flags: StatusFlags::default_flags(),
644 character_set: UTF8MB4_GENERAL_CI,
645 auth_seed,
646 auth_plugin: "mysql_native_password".into(),
647 username: None,
648 current_database: None,
649 in_transaction: false,
650 prepared_statements: HashMap::new(),
651 next_stmt_id: 1,
652 last_row_count: 0,
653 last_insert_id: 0,
654 }
655 }
656
657 fn next_seq(&mut self) -> u8 {
662 let s = self.seq;
663 self.seq = self.seq.wrapping_add(1);
664 s
665 }
666
667 fn reset_seq(&mut self) {
668 self.seq = 0;
669 }
670
671 async fn write_pkt(&mut self, payload: &[u8]) -> Result<()> {
677 let seq = self.next_seq();
678 write_packet(&mut self.stream, seq, payload).await
679 }
680
681 pub async fn handle_connection(
688 database: Arc<EmbeddedDatabase>,
689 stream: S,
690 connection_id: u32,
691 ) -> Result<()> {
692 let mut handler = Self::new(database, stream, connection_id);
693 info!("New MySQL connection: id={}", connection_id);
694
695 handler.send_handshake().await?;
697 let hs = handler.receive_handshake_response().await?;
698
699 handler.authenticate(&hs)?;
701 handler.send_ok(0, 0).await?;
702
703 loop {
705 handler.reset_seq();
706 match handler.receive_command().await {
707 Ok((cmd, payload)) => {
708 if let Err(e) = handler.dispatch_command(cmd, payload).await {
709 match e {
710 MySqlError::ConnectionClosed => {
711 info!("MySQL connection {} closed", connection_id);
712 break;
713 }
714 _ => {
715 error!("Command error: {}", e);
716 let msg = e.to_string();
717 let (code, state) = map_error_code(&msg);
718 let _ = handler
719 .send_error(code, state, &msg)
720 .await;
721 }
722 }
723 }
724 }
725 Err(MySqlError::ConnectionClosed) => {
726 info!("MySQL connection {} disconnected", connection_id);
727 break;
728 }
729 Err(e) => {
730 error!("Receive error: {}", e);
731 break;
732 }
733 }
734 }
735 Ok(())
736 }
737
738 async fn send_handshake(&mut self) -> Result<()> {
743 let mut p = BytesMut::new();
744
745 p.put_u8(PROTOCOL_VERSION);
747
748 p.put_slice(SERVER_VERSION.as_bytes());
750 p.put_u8(0);
751
752 p.put_u32_le(self.connection_id);
754
755 #[allow(clippy::indexing_slicing)]
757 p.put_slice(&self.auth_seed[0..8]);
758
759 p.put_u8(0);
761
762 p.put_u16_le((self.capabilities.as_u32() & 0xFFFF) as u16);
764
765 p.put_u8(self.character_set);
767
768 p.put_u16_le(self.status_flags.as_u16());
770
771 p.put_u16_le(((self.capabilities.as_u32() >> 16) & 0xFFFF) as u16);
773
774 p.put_u8(21);
776
777 p.put_bytes(0, 10);
779
780 #[allow(clippy::indexing_slicing)]
782 p.put_slice(&self.auth_seed[8..20]);
783 p.put_u8(0);
784
785 if self.capabilities.has(CapabilityFlags::CLIENT_PLUGIN_AUTH) {
787 p.put_slice(self.auth_plugin.as_bytes());
788 p.put_u8(0);
789 }
790
791 self.write_pkt(&p).await?;
792 debug!("Sent HandshakeV10");
793 Ok(())
794 }
795
796 async fn receive_handshake_response(&mut self) -> Result<HandshakeResponse> {
797 let (seq, payload) = read_packet(&mut self.stream).await?;
798 self.seq = seq.wrapping_add(1);
799 HandshakeResponse::decode(payload, &self.capabilities)
800 }
801
802 fn authenticate(&mut self, hs: &HandshakeResponse) -> Result<()> {
804 self.username = Some(hs.username.clone());
805 self.current_database = hs.database.clone();
806
807 self.capabilities = CapabilityFlags::new(
809 self.capabilities.as_u32() & hs.capability_flags.as_u32(),
810 );
811
812 let plugin = hs
813 .auth_plugin_name
814 .as_deref()
815 .unwrap_or("mysql_native_password");
816
817 debug!(
818 "Auth user='{}' plugin='{}' db={:?}",
819 hs.username, plugin, hs.database
820 );
821
822 info!("User '{}' authenticated (trust)", hs.username);
825 Ok(())
826 }
827
828 async fn receive_command(&mut self) -> Result<(Command, Bytes)> {
833 let (seq, mut payload) = read_packet(&mut self.stream).await?;
834 self.seq = seq.wrapping_add(1);
835
836 if payload.is_empty() {
837 return Err(MySqlError::Protocol("empty command packet".into()));
838 }
839
840 let cmd_byte = payload.get_u8();
841 let command = Command::from_u8(cmd_byte)
842 .ok_or(MySqlError::Unsupported(cmd_byte))?;
843
844 debug!("Received {:?}", command);
845 Ok((command, payload))
846 }
847
848 async fn dispatch_command(&mut self, cmd: Command, payload: Bytes) -> Result<()> {
849 match cmd {
850 Command::ComQuit => {
851 return Err(MySqlError::ConnectionClosed);
852 }
853 Command::ComPing => {
854 self.send_ok(0, 0).await?;
855 }
856 Command::ComInitDb => {
857 self.handle_init_db(payload).await?;
858 }
859 Command::ComQuery => {
860 self.handle_com_query(payload).await?;
861 }
862 Command::ComStmtPrepare => {
863 self.handle_stmt_prepare(payload).await?;
864 }
865 Command::ComStmtExecute => {
866 self.handle_stmt_execute(payload).await?;
867 }
868 Command::ComStmtClose => {
869 self.handle_stmt_close(payload);
870 }
871 Command::ComStmtReset => {
872 self.send_ok(0, 0).await?;
873 }
874 Command::ComResetConnection => {
875 self.status_flags = StatusFlags::default_flags();
876 self.in_transaction = false;
877 self.send_ok(0, 0).await?;
878 }
879 Command::ComStatistics => {
880 let stats = format!(
883 "Uptime: 0 Threads: 1 Questions: 0 Slow queries: 0 \
884 Opens: 0 Flush tables: 0 Open tables: 0 \
885 Queries per second avg: 0.000"
886 );
887 self.write_pkt(stats.as_bytes()).await?;
888 }
889 _ => {
890 warn!("Unsupported MySQL command: {:?}", cmd);
891 self.send_error(
892 1047,
893 "08S01",
894 &format!("Unsupported command: {:?}", cmd),
895 )
896 .await?;
897 }
898 }
899 Ok(())
900 }
901
902 async fn handle_init_db(&mut self, payload: Bytes) -> Result<()> {
907 let db_name = String::from_utf8_lossy(&payload).to_string();
908 debug!("COM_INIT_DB: {}", db_name);
909 self.current_database = Some(db_name);
910 self.send_ok(0, 0).await
911 }
912
913 async fn handle_com_query(&mut self, payload: Bytes) -> Result<()> {
918 let raw_bytes = payload.as_ref();
920 let trimmed_bytes = if raw_bytes.last() == Some(&0) {
921 &raw_bytes[..raw_bytes.len() - 1]
922 } else {
923 raw_bytes
924 };
925 let raw_sql = String::from_utf8_lossy(trimmed_bytes).to_string();
926
927 let translated = super::translator::translate(&raw_sql);
929 let sql = translated.as_str();
930
931 let is_create = raw_sql.to_ascii_uppercase().contains("CREATE TABLE");
933 if is_create {
934 info!("MySQL DDL in: {}", &raw_sql[..raw_sql.len().min(200)]);
935 info!("MySQL DDL out: {}", &sql[..sql.len().min(200)]);
936 } else if translated != raw_sql {
937 debug!("MySQL→PG: {}", &sql[..sql.len().min(200)]);
938 }
939 let trimmed = sql.trim();
940 if trimmed.is_empty() {
941 return self.send_ok(0, 0).await;
942 }
943
944 if starts_with_icase(trimmed, "SET ") {
946 return self.send_ok(0, 0).await;
947 }
948
949 if starts_with_icase(trimmed, "SHOW ") {
951 return self.handle_show(trimmed).await;
952 }
953
954 if starts_with_icase(trimmed, "DESCRIBE ") || starts_with_icase(trimmed, "DESC ") {
956 return self.handle_describe(trimmed).await;
957 }
958
959 if starts_with_icase(trimmed, "BEGIN")
961 || starts_with_icase(trimmed, "START TRANSACTION")
962 {
963 return self.handle_begin().await;
964 }
965 if trimmed.eq_ignore_ascii_case("COMMIT") {
966 return self.handle_commit().await;
967 }
968 if trimmed.eq_ignore_ascii_case("ROLLBACK") {
969 return self.handle_rollback().await;
970 }
971
972 {
974 let upper = trimmed.to_uppercase();
975 if upper.contains("FOUND_ROWS()") {
976 let cols = vec!["FOUND_ROWS()".to_string()];
977 let rows = vec![Tuple::new(vec![Value::Int8(self.last_row_count as i64)])];
978 return self.send_result_set(&cols, &rows).await;
979 }
980 }
981
982 {
984 let upper = trimmed.to_uppercase();
985 if upper.contains("LAST_INSERT_ID()") {
986 let cols = vec!["LAST_INSERT_ID()".to_string()];
987 let rows = vec![Tuple::new(vec![Value::Int8(self.last_insert_id as i64)])];
988 return self.send_result_set(&cols, &rows).await;
989 }
990 }
991
992 {
994 let upper = trimmed.to_uppercase();
995 if upper.contains("VERSION()") && !upper.contains("@@") {
996 let cols = vec!["VERSION()".to_string()];
997 let rows = vec![Tuple::new(vec![Value::String(SERVER_VERSION.to_string())])];
998 return self.send_result_set(&cols, &rows).await;
999 }
1000 }
1001
1002 if starts_with_icase(trimmed, "SELECT") && trimmed.contains("@@") {
1004 return self.handle_select_variable(trimmed).await;
1005 }
1006
1007 if starts_with_icase(trimmed, "USE ") {
1009 return self.send_ok(0, 0).await;
1010 }
1011
1012 {
1014 let lower = trimmed.to_lowercase();
1015 if lower.contains("information_schema") {
1016 return self.handle_information_schema(trimmed).await;
1017 }
1018 }
1019
1020 let is_select = starts_with_icase(trimmed, "SELECT")
1022 || starts_with_icase(trimmed, "WITH")
1023 || starts_with_icase(trimmed, "VALUES")
1024 || starts_with_icase(trimmed, "TABLE ");
1025
1026 if is_select {
1027 self.execute_query(trimmed).await
1028 } else {
1029 self.execute_dml(trimmed).await
1032 }
1033 }
1034
1035 async fn handle_begin(&mut self) -> Result<()> {
1040 if !self.in_transaction {
1041 self.database.begin()?;
1042 self.in_transaction = true;
1043 self.status_flags.set(StatusFlags::SERVER_STATUS_IN_TRANS);
1044 }
1045 self.send_ok(0, 0).await
1046 }
1047
1048 async fn handle_commit(&mut self) -> Result<()> {
1049 if self.in_transaction {
1050 self.database.commit()?;
1051 self.in_transaction = false;
1052 self.status_flags.clear(StatusFlags::SERVER_STATUS_IN_TRANS);
1053 }
1054 self.send_ok(0, 0).await
1055 }
1056
1057 async fn handle_rollback(&mut self) -> Result<()> {
1058 if self.in_transaction {
1059 self.database.rollback()?;
1060 self.in_transaction = false;
1061 self.status_flags.clear(StatusFlags::SERVER_STATUS_IN_TRANS);
1062 }
1063 self.send_ok(0, 0).await
1064 }
1065
1066 async fn execute_query(&mut self, sql: &str) -> Result<()> {
1071 match self.database.query_with_columns(sql) {
1072 Ok((rows, columns)) => {
1073 self.last_row_count = rows.len() as u64;
1074 self.send_result_set(&columns, &rows).await
1075 }
1076 Err(e) => {
1077 let msg = e.to_string();
1078 let (code, state) = map_error_code(&msg);
1079 self.send_error(code, state, &msg).await
1080 }
1081 }
1082 }
1083
1084 async fn execute_dml(&mut self, sql: &str) -> Result<()> {
1089 let statements = split_sql_respecting_quotes(sql);
1093
1094 let mut total_affected: u64 = 0;
1095 let mut last_insert_id: u64 = 0;
1096
1097 for stmt in &statements {
1098 let is_insert = starts_with_icase(stmt.trim(), "INSERT");
1100 let table_name = if is_insert {
1101 Self::extract_insert_table(stmt)
1102 } else {
1103 None
1104 };
1105
1106 match self.database.execute(stmt) {
1107 Ok(affected) => {
1108 total_affected += affected;
1109 if is_insert && affected > 0 {
1111 if let Some(ref tbl) = table_name {
1112 let id = self.query_last_serial_id(tbl);
1113 if id > 0 {
1114 last_insert_id = id;
1115 }
1116 }
1117 }
1118 }
1119 Err(e) => {
1120 let msg = e.to_string();
1121 let (code, state) = map_error_code(&msg);
1122 return self.send_error(code, state, &msg).await;
1123 }
1124 }
1125 }
1126
1127 if last_insert_id > 0 {
1128 self.last_insert_id = last_insert_id;
1129 }
1130 self.send_ok(total_affected, last_insert_id).await
1131 }
1132
1133 async fn handle_upsert_dml(&mut self, translated_sql: &str, raw_sql: &str) -> Result<()> {
1140 match self.database.execute(translated_sql) {
1142 Ok(affected) => {
1143 let table_name = Self::extract_insert_table(translated_sql);
1144 let insert_id = if affected > 0 {
1145 if let Some(ref tbl) = table_name {
1146 self.query_last_serial_id(tbl)
1147 } else {
1148 0
1149 }
1150 } else {
1151 0
1152 };
1153 if insert_id > 0 {
1154 self.last_insert_id = insert_id;
1155 }
1156 self.send_ok(affected, insert_id).await
1157 }
1158 Err(e) => {
1159 let msg = e.to_string();
1160 let msg_lower = msg.to_lowercase();
1161 if msg_lower.contains("duplicate key")
1163 || msg_lower.contains("unique constraint")
1164 || msg_lower.contains("primary key constraint")
1165 {
1166 if let Some(update_sql) = Self::build_upsert_update(raw_sql) {
1168 let translated_update = super::translator::translate(&update_sql);
1169 match self.database.execute(&translated_update) {
1170 Ok(affected) => self.send_ok(affected, 0).await,
1171 Err(ue) => {
1172 let umsg = ue.to_string();
1173 let (code, state) = map_error_code(&umsg);
1174 self.send_error(code, state, &umsg).await
1175 }
1176 }
1177 } else {
1178 let (code, state) = map_error_code(&msg);
1180 self.send_error(code, state, &msg).await
1181 }
1182 } else {
1183 let (code, state) = map_error_code(&msg);
1184 self.send_error(code, state, &msg).await
1185 }
1186 }
1187 }
1188 }
1189
1190 fn build_upsert_update(raw_sql: &str) -> Option<String> {
1196 let upper = raw_sql.to_uppercase();
1197 let odk_pos = upper.find("ON DUPLICATE KEY UPDATE")?;
1198
1199 let set_part = raw_sql.get(odk_pos + 23..)?.trim();
1201
1202 let table_name = Self::extract_insert_table(raw_sql)?;
1204
1205 let insert_part = &raw_sql[..odk_pos];
1207 let (columns, values) = Self::extract_insert_columns_values(insert_part)?;
1208
1209 let mut col_val_map = std::collections::HashMap::new();
1211 for (i, col) in columns.iter().enumerate() {
1212 if let Some(val) = values.get(i) {
1213 col_val_map.insert(col.to_uppercase(), val.clone());
1214 }
1215 }
1216
1217 let mut set_clauses = Vec::new();
1219 for assignment in set_part.split(',') {
1220 let parts: Vec<&str> = assignment.trim().splitn(2, '=').collect();
1221 if parts.len() != 2 {
1222 continue;
1223 }
1224 let col = parts[0].trim().trim_matches('`');
1225 let expr = parts[1].trim();
1226 let expr_upper = expr.to_uppercase();
1227
1228 if expr_upper.starts_with("VALUES(") || expr_upper.starts_with("VALUES (") {
1230 let inner = expr.trim_end_matches(')');
1231 let inner = inner.find('(').map(|p| &inner[p + 1..])?;
1232 let ref_col = inner.trim().trim_matches('`').to_uppercase();
1233 if let Some(val) = col_val_map.get(&ref_col) {
1234 set_clauses.push(format!("{} = {}", col, val));
1235 }
1236 } else {
1237 set_clauses.push(format!("{} = {}", col, expr));
1238 }
1239 }
1240
1241 if set_clauses.is_empty() {
1242 return None;
1243 }
1244
1245 let where_clause = if let (Some(pk_col), Some(pk_val)) = (columns.first(), values.first()) {
1249 format!("{} = {}", pk_col, pk_val)
1250 } else {
1251 return None;
1252 };
1253
1254 Some(format!(
1255 "UPDATE {} SET {} WHERE {}",
1256 table_name,
1257 set_clauses.join(", "),
1258 where_clause
1259 ))
1260 }
1261
1262 fn extract_insert_columns_values(insert_sql: &str) -> Option<(Vec<String>, Vec<String>)> {
1264 let first_paren = insert_sql.find('(')?;
1266 let first_close = insert_sql.find(')')?;
1267 let col_str = insert_sql.get(first_paren + 1..first_close)?;
1268 let columns: Vec<String> = col_str
1269 .split(',')
1270 .map(|c| c.trim().trim_matches('`').to_string())
1271 .collect();
1272
1273 let upper = insert_sql.to_uppercase();
1275 let values_pos = upper.find("VALUES")?;
1276 let rest = insert_sql.get(values_pos + 6..)?.trim();
1277 let val_open = rest.find('(')?;
1278 let inner = rest.get(val_open + 1..)?;
1280 let close_idx = Self::find_matching_close_paren(inner)?;
1281 let val_str = inner.get(..close_idx)?;
1282
1283 let values = Self::split_sql_values(val_str);
1285
1286 Some((columns, values))
1287 }
1288
1289 fn find_matching_close_paren(s: &str) -> Option<usize> {
1291 let mut depth = 0u32;
1292 let mut in_quote = false;
1293 for (i, ch) in s.char_indices() {
1294 if in_quote {
1295 if ch == '\'' {
1296 in_quote = false;
1297 }
1298 continue;
1299 }
1300 match ch {
1301 '\'' => in_quote = true,
1302 '(' => depth += 1,
1303 ')' => {
1304 if depth == 0 {
1305 return Some(i);
1306 }
1307 depth -= 1;
1308 }
1309 _ => {}
1310 }
1311 }
1312 None
1313 }
1314
1315 fn split_sql_values(s: &str) -> Vec<String> {
1317 let mut result = Vec::new();
1318 let mut current = String::new();
1319 let mut in_quote = false;
1320 let mut depth = 0u32;
1321
1322 for ch in s.chars() {
1323 if in_quote {
1324 current.push(ch);
1325 if ch == '\'' {
1326 in_quote = false;
1327 }
1328 continue;
1329 }
1330 match ch {
1331 '\'' => {
1332 in_quote = true;
1333 current.push(ch);
1334 }
1335 '(' => {
1336 depth += 1;
1337 current.push(ch);
1338 }
1339 ')' => {
1340 depth = depth.saturating_sub(1);
1341 current.push(ch);
1342 }
1343 ',' if depth == 0 => {
1344 result.push(current.trim().to_string());
1345 current.clear();
1346 }
1347 _ => current.push(ch),
1348 }
1349 }
1350 if !current.trim().is_empty() {
1351 result.push(current.trim().to_string());
1352 }
1353 result
1354 }
1355
1356 fn extract_insert_table(sql: &str) -> Option<String> {
1358 static INSERT_TABLE_RE: OnceLock<Regex> = OnceLock::new();
1359 let re = INSERT_TABLE_RE.get_or_init(|| {
1360 Regex::new(r#"(?i)\bINSERT\s+INTO\s+[`"]*(\w+)[`"]*"#)
1361 .unwrap_or_else(|_| Regex::new("^$").expect("static regex"))
1362 });
1363 re.captures(sql).and_then(|c| c.get(1).map(|m| m.as_str().to_string()))
1364 }
1365
1366 fn query_last_serial_id(&self, table_name: &str) -> u64 {
1372 let pk_col = match self.database.storage.catalog().get_table_schema(table_name) {
1374 Ok(schema) => {
1375 schema.columns.iter()
1376 .find(|c| c.primary_key)
1377 .map(|c| c.name.clone())
1378 }
1379 Err(_) => None,
1380 };
1381
1382 let pk_col = match pk_col {
1383 Some(c) => c,
1384 None => return 0,
1385 };
1386
1387 let query = format!("SELECT MAX({}) FROM {}", pk_col, table_name);
1389 match self.database.query_with_columns(&query) {
1390 Ok((rows, _)) => {
1391 let result = rows.first()
1392 .and_then(|r| r.values.first())
1393 .and_then(|v| match v {
1394 Value::Int4(n) => Some(*n as u64),
1395 Value::Int8(n) => Some(*n as u64),
1396 Value::Int2(n) => Some(*n as u64),
1397 _ => None,
1398 })
1399 .unwrap_or(0);
1400 tracing::debug!("query_last_serial_id({}): pk_col={}, result={}", table_name, pk_col, result);
1401 result
1402 }
1403 Err(e) => {
1404 tracing::debug!("query_last_serial_id({}) error: {}", table_name, e);
1405 0
1406 }
1407 }
1408 }
1409
1410 async fn handle_show(&mut self, trimmed: &str) -> Result<()> {
1415 let upper = trimmed.to_uppercase();
1416
1417 if upper.contains("DATABASES") || upper.contains("SCHEMAS") {
1418 return self.show_single_column("Database", &["heliosdb"]).await;
1419 }
1420
1421 if upper.contains("TABLE STATUS") {
1423 return self.handle_show_table_status(trimmed).await;
1424 }
1425
1426 if upper.contains("TABLES") {
1427 let mut tables = self
1429 .database
1430 .storage
1431 .catalog()
1432 .list_tables()
1433 .unwrap_or_default();
1434
1435 if let Some(like_pattern) = extract_like_pattern(trimmed) {
1437 tables.retain(|t| sql_like_match(t, &like_pattern));
1438 }
1439
1440 let refs: Vec<&str> = tables.iter().map(String::as_str).collect();
1441 return self.show_single_column("Tables_in_heliosdb", &refs).await;
1442 }
1443
1444 if upper.contains("INDEX") || upper.contains("INDEXES") || upper.contains("KEYS") {
1445 return self.handle_show_index(trimmed).await;
1446 }
1447
1448 if upper.contains("COLUMNS") || upper.contains("FIELDS") {
1449 return self.handle_show_columns(trimmed).await;
1450 }
1451
1452 if upper.contains("CREATE TABLE") {
1453 return self.handle_show_create_table(trimmed).await;
1454 }
1455
1456 if upper.contains("VARIABLES") || upper.contains("SESSION STATUS")
1457 || upper.contains("GLOBAL STATUS")
1458 {
1459 return self.handle_show_variables(&upper).await;
1460 }
1461
1462 if upper.contains("WARNINGS") {
1463 return self
1465 .show_three_columns("Level", "Code", "Message", &[])
1466 .await;
1467 }
1468
1469 if upper.contains("ERRORS") {
1470 return self
1471 .show_three_columns("Level", "Code", "Message", &[])
1472 .await;
1473 }
1474
1475 if upper.contains("COLLATION") {
1476 return self.handle_show_collation().await;
1477 }
1478
1479 if upper.contains("CHARACTER SET") || upper.contains("CHARSET") {
1480 return self.handle_show_character_set().await;
1481 }
1482
1483 if upper.contains("ENGINES") {
1484 return self.handle_show_engines().await;
1485 }
1486
1487 if upper.contains("PROCESSLIST") {
1488 return self.handle_show_processlist().await;
1489 }
1490
1491 if upper.contains("PLUGINS") {
1492 return self.handle_show_plugins().await;
1493 }
1494
1495 if upper.contains("PRIVILEGES") {
1496 return self.handle_show_privileges().await;
1497 }
1498
1499 if upper.contains("GRANTS") {
1500 let user = self.username.clone().unwrap_or_else(|| "root".to_string());
1502 let line = format!("GRANT ALL PRIVILEGES ON *.* TO '{}'@'%' WITH GRANT OPTION", user);
1503 return self.show_single_column(
1504 &format!("Grants for {}@%", user),
1505 &[&line],
1506 ).await;
1507 }
1508
1509 if upper.contains("MASTER STATUS") || upper.contains("BINARY LOGS")
1510 || upper.contains("REPLICA STATUS") || upper.contains("SLAVE STATUS")
1511 {
1512 return self.send_ok(0, 0).await;
1514 }
1515
1516 self.send_ok(0, 0).await
1518 }
1519
1520 async fn handle_show_engines(&mut self) -> Result<()> {
1522 let cols = vec![
1523 "Engine".to_string(), "Support".to_string(), "Comment".to_string(),
1524 "Transactions".to_string(), "XA".to_string(), "Savepoints".to_string(),
1525 ];
1526 let row = |e: &str, s: &str, c: &str, t: &str, x: &str, sv: &str| {
1527 Tuple::new(vec![
1528 Value::String(e.into()), Value::String(s.into()), Value::String(c.into()),
1529 Value::String(t.into()), Value::String(x.into()), Value::String(sv.into()),
1530 ])
1531 };
1532 let rows = vec![
1533 row("HeliosDB", "DEFAULT", "HeliosDB Nano RocksDB-backed LSM engine",
1534 "YES", "NO", "YES"),
1535 row("InnoDB", "YES", "Alias of HeliosDB (for client compatibility)",
1536 "YES", "NO", "YES"),
1537 row("MEMORY", "YES", "In-memory tables (via CREATE TABLE ... ENGINE=MEMORY)",
1538 "NO", "NO", "NO"),
1539 row("MyISAM", "YES", "Alias of HeliosDB (for client compatibility)",
1540 "NO", "NO", "NO"),
1541 ];
1542 self.send_result_set(&cols, &rows).await
1543 }
1544
1545 async fn handle_show_character_set(&mut self) -> Result<()> {
1547 let cols = vec![
1548 "Charset".to_string(), "Description".to_string(),
1549 "Default collation".to_string(), "Maxlen".to_string(),
1550 ];
1551 let row = |c: &str, d: &str, dc: &str, m: i64| {
1552 Tuple::new(vec![
1553 Value::String(c.into()), Value::String(d.into()),
1554 Value::String(dc.into()), Value::Int8(m),
1555 ])
1556 };
1557 let rows = vec![
1558 row("utf8mb4", "UTF-8 Unicode", "utf8mb4_general_ci", 4),
1559 row("utf8mb3", "UTF-8 Unicode (legacy)", "utf8mb3_general_ci", 3),
1560 row("utf8", "UTF-8 Unicode", "utf8_general_ci", 3),
1561 row("latin1", "cp1252 West European", "latin1_swedish_ci", 1),
1562 row("ascii", "US ASCII", "ascii_general_ci", 1),
1563 row("binary", "Binary pseudo charset", "binary", 1),
1564 ];
1565 self.send_result_set(&cols, &rows).await
1566 }
1567
1568 async fn handle_show_collation(&mut self) -> Result<()> {
1570 let cols = vec![
1571 "Collation".to_string(), "Charset".to_string(), "Id".to_string(),
1572 "Default".to_string(), "Compiled".to_string(), "Sortlen".to_string(),
1573 ];
1574 let row = |coll: &str, cs: &str, id: i64, d: &str| {
1575 Tuple::new(vec![
1576 Value::String(coll.into()), Value::String(cs.into()),
1577 Value::Int8(id), Value::String(d.into()),
1578 Value::String("Yes".into()), Value::Int8(1),
1579 ])
1580 };
1581 let rows = vec![
1582 row("utf8mb4_general_ci", "utf8mb4", 45, "Yes"),
1583 row("utf8mb4_unicode_ci", "utf8mb4", 224, ""),
1584 row("utf8mb4_bin", "utf8mb4", 46, ""),
1585 row("utf8_general_ci", "utf8", 33, "Yes"),
1586 row("latin1_swedish_ci", "latin1", 8, "Yes"),
1587 row("ascii_general_ci", "ascii", 11, "Yes"),
1588 row("binary", "binary", 63, "Yes"),
1589 ];
1590 self.send_result_set(&cols, &rows).await
1591 }
1592
1593 async fn handle_show_processlist(&mut self) -> Result<()> {
1595 let cols = vec![
1596 "Id".to_string(), "User".to_string(), "Host".to_string(),
1597 "db".to_string(), "Command".to_string(), "Time".to_string(),
1598 "State".to_string(), "Info".to_string(),
1599 ];
1600 let user = self.username.clone().unwrap_or_else(|| "root".to_string());
1601 let db = self.current_database.clone().unwrap_or_else(|| "heliosdb".to_string());
1602 let rows = vec![Tuple::new(vec![
1603 Value::Int8(self.connection_id as i64),
1604 Value::String(user),
1605 Value::String("localhost".to_string()),
1606 Value::String(db),
1607 Value::String("Query".to_string()),
1608 Value::Int8(0),
1609 Value::String("executing".to_string()),
1610 Value::String("SHOW PROCESSLIST".to_string()),
1611 ])];
1612 self.send_result_set(&cols, &rows).await
1613 }
1614
1615 async fn handle_show_plugins(&mut self) -> Result<()> {
1617 let cols = vec![
1618 "Name".to_string(), "Status".to_string(), "Type".to_string(),
1619 "Library".to_string(), "License".to_string(),
1620 ];
1621 let rows = vec![Tuple::new(vec![
1622 Value::String("mysql_native_password".into()),
1623 Value::String("ACTIVE".into()),
1624 Value::String("AUTHENTICATION".into()),
1625 Value::Null,
1626 Value::String("Apache-2.0".into()),
1627 ])];
1628 self.send_result_set(&cols, &rows).await
1629 }
1630
1631 async fn handle_show_privileges(&mut self) -> Result<()> {
1633 let cols = vec![
1634 "Privilege".to_string(), "Context".to_string(), "Comment".to_string(),
1635 ];
1636 let row = |p: &str, c: &str, d: &str| {
1637 Tuple::new(vec![
1638 Value::String(p.into()), Value::String(c.into()), Value::String(d.into()),
1639 ])
1640 };
1641 let rows = vec![
1642 row("ALL", "Server Admin", "All privileges (trust auth on local socket)"),
1643 row("SELECT", "Tables", "Read data"),
1644 row("INSERT", "Tables", "Insert rows"),
1645 row("UPDATE", "Tables", "Update rows"),
1646 row("DELETE", "Tables", "Delete rows"),
1647 row("CREATE", "Databases,Tables", "Create schemas and tables"),
1648 row("DROP", "Databases,Tables", "Drop schemas and tables"),
1649 ];
1650 self.send_result_set(&cols, &rows).await
1651 }
1652
1653 async fn handle_show_columns(&mut self, sql: &str) -> Result<()> {
1654 let upper = sql.to_uppercase();
1655 let is_full = upper.contains("FULL");
1656
1657 let table_name = upper
1659 .find("FROM ")
1660 .and_then(|pos| {
1661 let rest = sql.get(pos + 5..)?;
1662 let name = rest.trim().trim_end_matches(';').trim();
1663 let name = name.trim_matches('`').trim_matches('"');
1664 Some(name.to_string())
1665 });
1666
1667 let table_name = match table_name {
1668 Some(t) => t,
1669 None => return self.send_ok(0, 0).await,
1670 };
1671
1672 let schema = match self.database.storage.catalog().get_table_schema(&table_name) {
1674 Ok(s) => s,
1675 Err(_) => {
1676 return self.send_error(1146, "42S02",
1677 &format!("Table '{}' doesn't exist", table_name)).await;
1678 }
1679 };
1680
1681 if is_full {
1682 let cols = vec![
1684 "Field".to_string(), "Type".to_string(), "Collation".to_string(),
1685 "Null".to_string(), "Key".to_string(), "Default".to_string(),
1686 "Extra".to_string(), "Privileges".to_string(), "Comment".to_string(),
1687 ];
1688 let rows: Vec<Tuple> = schema.columns.iter().map(|c| {
1689 let type_str = datatype_to_mysql(&c.data_type);
1690 let null_str = if c.nullable { "YES" } else { "NO" };
1691 let key_str = if c.primary_key { "PRI" } else if c.unique { "UNI" } else { "" };
1692 let default_str = c.default_expr.as_deref().unwrap_or("NULL");
1693 let extra = if c.primary_key && matches!(c.data_type, crate::DataType::Int4 | crate::DataType::Int8) {
1694 "auto_increment"
1695 } else { "" };
1696 Tuple::new(vec![
1697 Value::String(c.name.clone()),
1698 Value::String(type_str),
1699 Value::String("utf8mb4_unicode_ci".to_string()),
1700 Value::String(null_str.to_string()),
1701 Value::String(key_str.to_string()),
1702 Value::String(default_str.to_string()),
1703 Value::String(extra.to_string()),
1704 Value::String("select,insert,update,references".to_string()),
1705 Value::String(String::new()),
1706 ])
1707 }).collect();
1708 self.send_result_set(&cols, &rows).await
1709 } else {
1710 let cols = vec![
1712 "Field".to_string(), "Type".to_string(), "Null".to_string(),
1713 "Key".to_string(), "Default".to_string(), "Extra".to_string(),
1714 ];
1715 let rows: Vec<Tuple> = schema.columns.iter().map(|c| {
1716 let type_str = datatype_to_mysql(&c.data_type);
1717 let null_str = if c.nullable { "YES" } else { "NO" };
1718 let key_str = if c.primary_key { "PRI" } else if c.unique { "UNI" } else { "" };
1719 let default_str = c.default_expr.as_deref().unwrap_or("NULL");
1720 let extra = if c.primary_key && matches!(c.data_type, crate::DataType::Int4 | crate::DataType::Int8) {
1721 "auto_increment"
1722 } else { "" };
1723 Tuple::new(vec![
1724 Value::String(c.name.clone()),
1725 Value::String(type_str),
1726 Value::String(null_str.to_string()),
1727 Value::String(key_str.to_string()),
1728 Value::String(default_str.to_string()),
1729 Value::String(extra.to_string()),
1730 ])
1731 }).collect();
1732 self.send_result_set(&cols, &rows).await
1733 }
1734 }
1735
1736 async fn handle_show_create_table(&mut self, sql: &str) -> Result<()> {
1737 let table_name = sql
1738 .to_uppercase()
1739 .find("TABLE ")
1740 .and_then(|pos| {
1741 let after_kw = sql.get(pos + 6..)?;
1742 let name = after_kw.trim().trim_end_matches(';').trim();
1743 let name = name.trim_matches('`');
1744 Some(name.to_string())
1745 });
1746
1747 let table_name = match table_name {
1748 Some(t) => t,
1749 None => return self.send_ok(0, 0).await,
1750 };
1751
1752 let ddl = self.generate_create_table_ddl(&table_name);
1753 let cols = vec!["Table".to_string(), "Create Table".to_string()];
1754 let row = Tuple::new(vec![
1755 Value::String(table_name),
1756 Value::String(ddl),
1757 ]);
1758 self.send_result_set(&cols, &[row]).await
1759 }
1760
1761 fn generate_create_table_ddl(&self, table_name: &str) -> String {
1763 let schema = match self.database.storage.catalog().get_table_schema(table_name) {
1764 Ok(s) => s,
1765 Err(_) => {
1766 return format!("CREATE TABLE `{}` (\n /* schema not available */\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4", table_name);
1767 }
1768 };
1769
1770 let mut col_defs = Vec::new();
1771 let mut pk_cols = Vec::new();
1772
1773 for col in &schema.columns {
1774 let mysql_type = datatype_to_mysql(&col.data_type);
1775 let nullable = if col.nullable { "" } else { " NOT NULL" };
1776 let default = col.default_expr.as_ref().map_or(String::new(), |d| format!(" DEFAULT {}", d));
1777 col_defs.push(format!(" `{}` {}{}{}", col.name, mysql_type, nullable, default));
1778 if col.primary_key {
1779 pk_cols.push(format!("`{}`", col.name));
1780 }
1781 }
1782
1783 if !pk_cols.is_empty() {
1784 col_defs.push(format!(" PRIMARY KEY ({})", pk_cols.join(",")));
1785 }
1786
1787 for col in &schema.columns {
1789 if col.unique && !col.primary_key {
1790 col_defs.push(format!(" UNIQUE KEY `idx_{}_unique` (`{}`)", col.name, col.name));
1791 }
1792 }
1793
1794 format!(
1795 "CREATE TABLE `{}` (\n{}\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4",
1796 table_name,
1797 col_defs.join(",\n")
1798 )
1799 }
1800
1801 async fn handle_show_variables(&mut self, upper: &str) -> Result<()> {
1802 let vars: Vec<(&str, &str)> = vec![
1804 ("character_set_client", "utf8mb4"),
1805 ("character_set_connection", "utf8mb4"),
1806 ("character_set_results", "utf8mb4"),
1807 ("character_set_server", "utf8mb4"),
1808 ("collation_connection", "utf8mb4_general_ci"),
1809 ("collation_server", "utf8mb4_general_ci"),
1810 ("version", SERVER_VERSION),
1811 ("version_comment", "HeliosDB Nano"),
1812 ("max_allowed_packet", "67108864"),
1813 ("system_time_zone", "UTC"),
1814 ("time_zone", "SYSTEM"),
1815 ("lower_case_table_names", "0"),
1816 ("sql_mode", "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION"),
1817 ("autocommit", "ON"),
1818 ("tx_isolation", "REPEATABLE-READ"),
1819 ("transaction_isolation", "REPEATABLE-READ"),
1820 ];
1821
1822 let filter = if let Some(pos) = upper.find("LIKE ") {
1824 let rest = upper.get(pos + 5..).unwrap_or("").trim();
1825 let pattern = rest.trim_matches('\'').trim_matches('%');
1826 if pattern.is_empty() {
1827 None
1828 } else {
1829 Some(pattern.to_lowercase())
1830 }
1831 } else {
1832 None
1833 };
1834
1835 let cols = vec!["Variable_name".to_string(), "Value".to_string()];
1836 let rows: Vec<Tuple> = vars
1837 .iter()
1838 .filter(|(name, _)| {
1839 if let Some(ref pat) = filter {
1840 name.to_lowercase().contains(pat.as_str())
1841 } else {
1842 true
1843 }
1844 })
1845 .map(|(name, val)| {
1846 Tuple::new(vec![
1847 Value::String((*name).to_string()),
1848 Value::String((*val).to_string()),
1849 ])
1850 })
1851 .collect();
1852
1853 self.send_result_set(&cols, &rows).await
1854 }
1855
1856 async fn handle_show_index(&mut self, sql: &str) -> Result<()> {
1861 let table_name = sql
1862 .to_uppercase()
1863 .find("FROM ")
1864 .and_then(|pos| {
1865 let rest = sql.get(pos + 5..)?;
1866 let name = rest.trim();
1869 let name = name.split_once(|c: char| c.is_whitespace() || c == ';')
1870 .map_or(name, |(first, _)| first);
1871 let name = name.trim_matches('`').trim_matches('"');
1872 if name.is_empty() { return None; }
1873 let name = name.rsplit('.').next().unwrap_or(name);
1875 Some(name.to_string())
1876 });
1877
1878 let table_name = match table_name {
1879 Some(t) => t,
1880 None => return self.send_ok(0, 0).await,
1881 };
1882 tracing::debug!("handle_show_index: resolved table_name = '{}'", table_name);
1883
1884 let cols = vec![
1885 "Table".to_string(),
1886 "Non_unique".to_string(),
1887 "Key_name".to_string(),
1888 "Seq_in_index".to_string(),
1889 "Column_name".to_string(),
1890 "Collation".to_string(),
1891 "Cardinality".to_string(),
1892 "Sub_part".to_string(),
1893 "Packed".to_string(),
1894 "Null".to_string(),
1895 "Index_type".to_string(),
1896 "Comment".to_string(),
1897 "Index_comment".to_string(),
1898 ];
1899
1900 let mut rows: Vec<Tuple> = Vec::new();
1901
1902 if let Ok(schema) = self.database.storage.catalog().get_table_schema(&table_name) {
1904 let mut seq = 1i64;
1905 for col in &schema.columns {
1906 if col.primary_key {
1907 rows.push(Tuple::new(vec![
1908 Value::String(table_name.clone()), Value::String("0".to_string()), Value::String("PRIMARY".to_string()), Value::String(seq.to_string()), Value::String(col.name.clone()), Value::String("A".to_string()), Value::String("0".to_string()), Value::Null, Value::Null, Value::String(String::new()), Value::String("BTREE".to_string()), Value::String(String::new()), Value::String(String::new()), ]));
1922 seq += 1;
1923 }
1924 }
1925
1926 let mut unique_seq = 1i64;
1928 let mut seen_unique_cols: std::collections::HashSet<String> = std::collections::HashSet::new();
1930 for col in &schema.columns {
1931 if col.unique && !col.primary_key {
1932 seen_unique_cols.insert(col.name.to_uppercase());
1933 rows.push(Tuple::new(vec![
1934 Value::String(table_name.clone()),
1935 Value::String("0".to_string()),
1936 Value::String(format!("idx_{}_unique", col.name)),
1937 Value::String(unique_seq.to_string()),
1938 Value::String(col.name.clone()),
1939 Value::String("A".to_string()),
1940 Value::String("0".to_string()),
1941 Value::Null,
1942 Value::Null,
1943 if col.nullable { Value::String("YES".to_string()) } else { Value::String(String::new()) },
1944 Value::String("BTREE".to_string()),
1945 Value::String(String::new()),
1946 Value::String(String::new()),
1947 ]));
1948 unique_seq += 1;
1949 }
1950 }
1951
1952 if let Ok(constraints) = self.database.storage.catalog().load_table_constraints(&table_name) {
1955 for uc in &constraints.unique_constraints {
1956 if uc.is_primary_key {
1957 continue; }
1959 if uc.columns.len() == 1 {
1961 if let Some(first) = uc.columns.first() {
1962 if seen_unique_cols.contains(&first.to_uppercase()) {
1963 continue;
1964 }
1965 }
1966 }
1967 let key_name = uc.name.clone();
1968 for (idx, col_name) in uc.columns.iter().enumerate() {
1969 let nullable = schema.columns.iter()
1970 .find(|c| c.name.eq_ignore_ascii_case(col_name))
1971 .map_or(false, |c| c.nullable);
1972 rows.push(Tuple::new(vec![
1973 Value::String(table_name.clone()),
1974 Value::String("0".to_string()),
1975 Value::String(key_name.clone()),
1976 Value::String((idx as i64 + 1).to_string()),
1977 Value::String(col_name.clone()),
1978 Value::String("A".to_string()),
1979 Value::String("0".to_string()),
1980 Value::Null,
1981 Value::Null,
1982 if nullable { Value::String("YES".to_string()) } else { Value::String(String::new()) },
1983 Value::String("BTREE".to_string()),
1984 Value::String(String::new()),
1985 Value::String(String::new()),
1986 ]));
1987 }
1988 }
1989 }
1990 }
1991
1992 self.send_result_set(&cols, &rows).await
1993 }
1994
1995 async fn handle_select_variable(&mut self, sql: &str) -> Result<()> {
2005 static VAR_RE: OnceLock<Regex> = OnceLock::new();
2008 let re = VAR_RE.get_or_init(|| {
2009 Regex::new(r"@@(?:session\.|global\.)?(\w+)")
2010 .unwrap_or_else(|_| Regex::new("^$").expect("static regex"))
2011 });
2012
2013 let mut col_names: Vec<String> = Vec::new();
2014 let mut values: Vec<Value> = Vec::new();
2015
2016 for cap in re.captures_iter(sql) {
2017 let full_match = cap.get(0).map_or("", |m| m.as_str());
2018 let var_name = cap.get(1).map_or("", |m| m.as_str()).to_lowercase();
2019
2020 let val = match var_name.as_str() {
2021 "version" => Value::String(SERVER_VERSION.to_string()),
2022 "version_comment" => Value::String("HeliosDB Nano".to_string()),
2023 "max_allowed_packet" => Value::Int8(67_108_864),
2024 "character_set_client" | "character_set_connection"
2025 | "character_set_results" | "character_set_server"
2026 | "character_set_database" => Value::String("utf8mb4".to_string()),
2027 "collation_connection" | "collation_server"
2028 | "collation_database" => Value::String("utf8mb4_general_ci".to_string()),
2029 "auto_increment_increment" | "auto_increment_offset" => Value::Int8(1),
2030 "interactive_timeout" | "wait_timeout" => Value::Int8(28800),
2031 "net_write_timeout" | "net_read_timeout" => Value::Int8(30),
2032 "sql_mode" => Value::String(
2033 "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION".to_string()
2034 ),
2035 "time_zone" | "system_time_zone" => Value::String("SYSTEM".to_string()),
2036 "tx_isolation" | "transaction_isolation" => Value::String("REPEATABLE-READ".to_string()),
2037 "autocommit" => Value::Int8(1),
2038 "have_ssl" | "have_openssl" => Value::String("YES".to_string()),
2039 "lower_case_table_names" => Value::Int8(0),
2040 "sql_auto_is_null" => Value::Int8(0),
2041 "last_insert_id" => Value::Int8(self.last_insert_id as i64),
2042 _ => Value::String(String::new()),
2043 };
2044 col_names.push(full_match.to_string());
2045 values.push(val);
2046 }
2047
2048 if col_names.is_empty() {
2049 return self.send_ok(0, 0).await;
2051 }
2052
2053 let row = Tuple::new(values);
2054 self.send_result_set(&col_names, &[row]).await
2055 }
2056
2057 async fn handle_information_schema(&mut self, sql: &str) -> Result<()> {
2068 use crate::protocol::postgres::catalog::PgCatalog;
2069
2070 let catalog = PgCatalog::with_database(Arc::clone(&self.database));
2071 match catalog.handle_query(sql) {
2072 Ok(Some((schema, rows))) => {
2073 let col_names: Vec<String> = schema.columns.iter()
2074 .map(|c| c.name.clone())
2075 .collect();
2076 self.send_result_set(&col_names, &rows).await
2077 }
2078 Ok(None) => {
2079 self.execute_query(sql).await
2081 }
2082 Err(e) => {
2083 debug!("information_schema catalog handler error: {}, falling back to SQL", e);
2085 self.execute_query(sql).await
2086 }
2087 }
2088 }
2089
2090 async fn handle_show_table_status(&mut self, sql: &str) -> Result<()> {
2096 let tables = self.database.storage.catalog().list_tables().unwrap_or_default();
2097
2098 let like_pattern = extract_like_pattern(sql);
2099
2100 let cols = vec![
2101 "Name".to_string(),
2102 "Engine".to_string(),
2103 "Version".to_string(),
2104 "Row_format".to_string(),
2105 "Rows".to_string(),
2106 "Avg_row_length".to_string(),
2107 "Data_length".to_string(),
2108 "Max_data_length".to_string(),
2109 "Index_length".to_string(),
2110 "Data_free".to_string(),
2111 "Auto_increment".to_string(),
2112 "Create_time".to_string(),
2113 "Update_time".to_string(),
2114 "Check_time".to_string(),
2115 "Collation".to_string(),
2116 "Checksum".to_string(),
2117 "Create_options".to_string(),
2118 "Comment".to_string(),
2119 ];
2120
2121 let mut rows: Vec<Tuple> = Vec::new();
2122 for table in &tables {
2123 if let Some(ref pat) = like_pattern {
2125 if !sql_like_match(table, pat) {
2126 continue;
2127 }
2128 }
2129
2130 rows.push(Tuple::new(vec![
2131 Value::String(table.clone()), Value::String("InnoDB".to_string()), Value::String("10".to_string()), Value::String("Dynamic".to_string()), Value::Int8(0), Value::Int8(0), Value::Int8(0), Value::Int8(0), Value::Int8(0), Value::Int8(0), Value::Null, Value::Null, Value::Null, Value::Null, Value::String("utf8mb4_general_ci".to_string()), Value::Null, Value::String(String::new()), Value::String(String::new()), ]));
2150 }
2151
2152 self.send_result_set(&cols, &rows).await
2153 }
2154
2155 async fn handle_describe(&mut self, sql: &str) -> Result<()> {
2157 let table_name = if starts_with_icase(sql, "DESCRIBE ") {
2159 sql.get(9..)
2160 } else {
2161 sql.get(5..)
2163 };
2164
2165 let table_name = match table_name {
2166 Some(rest) => {
2167 let name = rest.trim().trim_end_matches(';').trim().trim_matches('`');
2168 if name.is_empty() {
2169 return self.send_ok(0, 0).await;
2170 }
2171 name.to_string()
2172 }
2173 None => return self.send_ok(0, 0).await,
2174 };
2175
2176 let cols = vec![
2177 "Field".to_string(),
2178 "Type".to_string(),
2179 "Null".to_string(),
2180 "Key".to_string(),
2181 "Default".to_string(),
2182 "Extra".to_string(),
2183 ];
2184
2185 let mut rows: Vec<Tuple> = Vec::new();
2186
2187 if let Ok(schema) = self.database.storage.catalog().get_table_schema(&table_name) {
2188 for col in &schema.columns {
2189 let mysql_type = datatype_to_mysql(&col.data_type);
2190 let null_str = if col.nullable { "YES" } else { "NO" };
2191 let key_str = if col.primary_key {
2192 "PRI"
2193 } else if col.unique {
2194 "UNI"
2195 } else {
2196 ""
2197 };
2198 let default_val = col.default_expr.clone().unwrap_or_default();
2199
2200 rows.push(Tuple::new(vec![
2201 Value::String(col.name.clone()),
2202 Value::String(mysql_type),
2203 Value::String(null_str.to_string()),
2204 Value::String(key_str.to_string()),
2205 if default_val.is_empty() { Value::Null } else { Value::String(default_val) },
2206 Value::String(String::new()),
2207 ]));
2208 }
2209 } else {
2210 let msg = format!("Table '{}' doesn't exist", table_name);
2211 return self.send_error(1146, "42S02", &msg).await;
2212 }
2213
2214 self.send_result_set(&cols, &rows).await
2215 }
2216
2217 async fn show_single_column(&mut self, col_name: &str, values: &[&str]) -> Result<()> {
2219 let cols = vec![col_name.to_string()];
2220 let rows: Vec<Tuple> = values
2221 .iter()
2222 .map(|v| Tuple::new(vec![Value::String((*v).to_string())]))
2223 .collect();
2224 self.send_result_set(&cols, &rows).await
2225 }
2226
2227 async fn show_three_columns(
2229 &mut self,
2230 c1: &str,
2231 c2: &str,
2232 c3: &str,
2233 rows: &[(String, String, String)],
2234 ) -> Result<()> {
2235 let cols = vec![c1.to_string(), c2.to_string(), c3.to_string()];
2236 let tuples: Vec<Tuple> = rows
2237 .iter()
2238 .map(|(a, b, c)| {
2239 Tuple::new(vec![
2240 Value::String(a.clone()),
2241 Value::String(b.clone()),
2242 Value::String(c.clone()),
2243 ])
2244 })
2245 .collect();
2246 self.send_result_set(&cols, &tuples).await
2247 }
2248
2249 async fn handle_stmt_prepare(&mut self, payload: Bytes) -> Result<()> {
2254 let raw_sql = String::from_utf8_lossy(&payload).to_string();
2255 debug!("COM_STMT_PREPARE: {}", raw_sql);
2256 let sql = super::translator::translate(&raw_sql);
2257
2258 let stmt_id = self.next_stmt_id;
2259 self.next_stmt_id += 1;
2260
2261 let num_params = sql.matches('?').count() as u16;
2262
2263 self.prepared_statements.insert(
2264 stmt_id,
2265 PreparedStatement {
2266 id: stmt_id,
2267 sql,
2268 num_params,
2269 },
2270 );
2271
2272 let mut resp = BytesMut::new();
2274 resp.put_u8(0x00); resp.put_u32_le(stmt_id);
2276 resp.put_u16_le(0); resp.put_u16_le(num_params);
2278 resp.put_u8(0x00); resp.put_u16_le(0); self.write_pkt(&resp).await?;
2281
2282 for i in 0..num_params {
2284 self.send_column_def(&format!("?{}", i), ColumnType::VarString)
2285 .await?;
2286 }
2287
2288 if num_params > 0 && !self.capabilities.has(CapabilityFlags::CLIENT_DEPRECATE_EOF) {
2289 self.send_eof().await?;
2290 }
2291
2292 Ok(())
2293 }
2294
2295 async fn handle_stmt_execute(&mut self, mut payload: Bytes) -> Result<()> {
2296 if payload.remaining() < 9 {
2297 return Err(MySqlError::Protocol("COM_STMT_EXECUTE too short".into()));
2298 }
2299
2300 let stmt_id = payload.get_u32_le();
2301 let _flags = payload.get_u8();
2302 let _iteration_count = payload.get_u32_le();
2303
2304 let stmt = self
2305 .prepared_statements
2306 .get(&stmt_id)
2307 .ok_or(MySqlError::StatementNotFound(stmt_id))?
2308 .clone();
2309
2310 debug!("COM_STMT_EXECUTE: id={} sql={}", stmt_id, stmt.sql);
2311
2312 let sql_bytes = Bytes::from(stmt.sql.clone());
2315 self.handle_com_query(sql_bytes).await
2316 }
2317
2318 fn handle_stmt_close(&mut self, mut payload: Bytes) {
2319 if payload.remaining() >= 4 {
2320 let stmt_id = payload.get_u32_le();
2321 self.prepared_statements.remove(&stmt_id);
2322 debug!("COM_STMT_CLOSE: id={}", stmt_id);
2323 }
2324 }
2326
2327 async fn send_result_set(
2333 &mut self,
2334 columns: &[String],
2335 rows: &[Tuple],
2336 ) -> Result<()> {
2337 let ncols = columns.len();
2338
2339 {
2341 let mut buf = BytesMut::new();
2342 write_lenenc_int(&mut buf, ncols as u64);
2343 self.write_pkt(&buf).await?;
2344 }
2345
2346 for (i, col_name) in columns.iter().enumerate() {
2350 let col_type = rows.iter()
2351 .filter_map(|r| r.values.get(i))
2352 .find(|v| !matches!(v, Value::Null))
2353 .map(ColumnType::from_value)
2354 .unwrap_or(ColumnType::VarString);
2355 self.send_column_def(col_name, col_type).await?;
2356 }
2357
2358 if !self.capabilities.has(CapabilityFlags::CLIENT_DEPRECATE_EOF) {
2360 self.send_eof().await?;
2361 }
2362
2363 for row in rows {
2365 self.send_text_result_row(row).await?;
2366 }
2367
2368 if self.capabilities.has(CapabilityFlags::CLIENT_DEPRECATE_EOF) {
2370 self.send_ok(0, 0).await
2371 } else {
2372 self.send_eof().await
2373 }
2374 }
2375
2376 async fn send_column_def(&mut self, name: &str, col_type: ColumnType) -> Result<()> {
2378 let mut p = BytesMut::new();
2379
2380 write_lenenc_str(&mut p, "def"); write_lenenc_str(&mut p, ""); write_lenenc_str(&mut p, ""); write_lenenc_str(&mut p, ""); write_lenenc_str(&mut p, name); write_lenenc_str(&mut p, name); write_lenenc_int(&mut p, 0x0c);
2389 p.put_u16_le(u16::from(UTF8MB4_GENERAL_CI)); p.put_u32_le(255); p.put_u8(col_type as u8); p.put_u16_le(0); p.put_u8(0); p.put_u16_le(0); self.write_pkt(&p).await
2397 }
2398
2399 async fn send_text_result_row(&mut self, row: &Tuple) -> Result<()> {
2401 let mut p = BytesMut::new();
2402 for val in &row.values {
2403 match val {
2404 Value::Null => {
2405 p.put_u8(0xFB); }
2407 _ => {
2408 let s = value_to_mysql_string(val);
2409 write_lenenc_str(&mut p, &s);
2410 }
2411 }
2412 }
2413 self.write_pkt(&p).await
2414 }
2415
2416 async fn send_ok(&mut self, affected_rows: u64, last_insert_id: u64) -> Result<()> {
2421 let mut p = BytesMut::new();
2422 p.put_u8(0x00); write_lenenc_int(&mut p, affected_rows);
2424 write_lenenc_int(&mut p, last_insert_id);
2425
2426 if self.capabilities.has(CapabilityFlags::CLIENT_PROTOCOL_41) {
2427 p.put_u16_le(self.status_flags.as_u16());
2428 p.put_u16_le(0); }
2430
2431 self.write_pkt(&p).await
2432 }
2433
2434 async fn send_error(&mut self, code: u16, state: &str, msg: &str) -> Result<()> {
2435 let mut p = BytesMut::new();
2436 p.put_u8(0xFF); p.put_u16_le(code);
2438
2439 if self.capabilities.has(CapabilityFlags::CLIENT_PROTOCOL_41) {
2440 p.put_u8(b'#');
2441 let state_bytes = state.as_bytes();
2443 #[allow(clippy::indexing_slicing)]
2444 for i in 0..5 {
2445 p.put_u8(if i < state_bytes.len() {
2446 state_bytes[i]
2447 } else {
2448 b' '
2449 });
2450 }
2451 }
2452
2453 p.put_slice(msg.as_bytes());
2454 self.write_pkt(&p).await
2455 }
2456
2457 async fn send_eof(&mut self) -> Result<()> {
2458 let mut p = BytesMut::new();
2459 p.put_u8(0xFE); if self.capabilities.has(CapabilityFlags::CLIENT_PROTOCOL_41) {
2462 p.put_u16_le(0); p.put_u16_le(self.status_flags.as_u16());
2464 }
2465
2466 self.write_pkt(&p).await
2467 }
2468}
2469
2470fn value_to_mysql_string(v: &Value) -> String {
2480 match v {
2481 Value::Null => String::new(), Value::Boolean(b) => if *b { "1" } else { "0" }.to_string(),
2483 Value::Int2(i) => i.to_string(),
2484 Value::Int4(i) => i.to_string(),
2485 Value::Int8(i) => i.to_string(),
2486 Value::Float4(f) => f.to_string(),
2487 Value::Float8(f) => f.to_string(),
2488 Value::Numeric(n) => n.clone(),
2489 Value::String(s) => s.clone(),
2490 Value::Bytes(b) => format!("0x{}", hex::encode(b)),
2491 Value::Uuid(u) => u.to_string(),
2492 Value::Timestamp(ts) => ts.format("%Y-%m-%d %H:%M:%S").to_string(),
2493 Value::Date(d) => d.format("%Y-%m-%d").to_string(),
2494 Value::Time(t) => t.format("%H:%M:%S").to_string(),
2495 Value::Interval(micros) => {
2496 let total_secs = micros / 1_000_000;
2497 let days = total_secs / 86400;
2498 let hours = (total_secs % 86400) / 3600;
2499 let mins = (total_secs % 3600) / 60;
2500 let secs = total_secs % 60;
2501 if days > 0 {
2502 format!("{} days {:02}:{:02}:{:02}", days, hours, mins, secs)
2503 } else {
2504 format!("{:02}:{:02}:{:02}", hours, mins, secs)
2505 }
2506 }
2507 Value::Json(j) => j.clone(),
2508 Value::Array(arr) => {
2509 let inner: Vec<String> = arr.iter().map(value_to_mysql_string).collect();
2510 format!("[{}]", inner.join(","))
2511 }
2512 Value::Vector(vec) => {
2513 let inner: Vec<String> = vec.iter().map(|f| f.to_string()).collect();
2514 format!("[{}]", inner.join(","))
2515 }
2516 Value::DictRef { dict_id } => dict_id.to_string(),
2517 Value::CasRef { hash } => hex::encode(hash),
2518 Value::ColumnarRef => "<columnar>".to_string(),
2519 }
2520}
2521
2522fn map_error_code(err_msg: &str) -> (u16, &'static str) {
2528 let lower = err_msg.to_lowercase();
2529 if lower.contains("duplicate") || lower.contains("unique") || lower.contains("already exists") {
2530 (1062, "23000") } else if lower.contains("does not exist") || lower.contains("not found") || lower.contains("doesn't exist") {
2532 (1146, "42S02") } else if lower.contains("unknown column") || (lower.contains("column") && lower.contains("not found")) {
2534 (1054, "42S22") } else if lower.contains("syntax") || lower.contains("parse") {
2536 (1064, "42000") } else if lower.contains("access denied") {
2538 (1045, "28000") } else if lower.contains("foreign key") || lower.contains("constraint") {
2540 (1452, "23000") } else if lower.contains("null") && lower.contains("not null") {
2542 (1048, "23000") } else {
2544 (1105, "HY000") }
2546}
2547
2548fn sql_like_match(value: &str, pattern: &str) -> bool {
2556 let mut regex_str = String::from("^");
2561 for ch in pattern.chars() {
2562 match ch {
2563 '%' => regex_str.push_str(".*"),
2564 '_' => regex_str.push('.'),
2565 _ => {
2566 let escaped = regex::escape(&ch.to_string());
2568 regex_str.push_str(&escaped);
2569 }
2570 }
2571 }
2572 regex_str.push('$');
2573 Regex::new(®ex_str)
2574 .map(|re| re.is_match(value))
2575 .unwrap_or(false)
2576}
2577
2578fn extract_like_pattern(sql: &str) -> Option<String> {
2580 let upper = sql.to_uppercase();
2581 let pos = upper.find("LIKE ")?;
2582 let rest = sql.get(pos + 5..)?.trim();
2583 if rest.starts_with('\'') {
2585 let end = rest.get(1..)?.find('\'')?;
2586 rest.get(1..end + 1).map(String::from)
2587 } else {
2588 let end = rest.find(|c: char| c.is_whitespace() || c == ';').unwrap_or(rest.len());
2590 rest.get(..end).map(String::from)
2591 }
2592}
2593
2594fn datatype_to_mysql(dt: &crate::DataType) -> String {
2600 match dt {
2601 crate::DataType::Boolean => "tinyint(1)".to_string(),
2602 crate::DataType::Int2 => "smallint".to_string(),
2603 crate::DataType::Int4 => "int".to_string(),
2604 crate::DataType::Int8 => "bigint".to_string(),
2605 crate::DataType::Float4 => "float".to_string(),
2606 crate::DataType::Float8 => "double".to_string(),
2607 crate::DataType::Numeric => "decimal(65,30)".to_string(),
2608 crate::DataType::Varchar(Some(n)) => format!("varchar({})", n),
2609 crate::DataType::Varchar(None) => "varchar(255)".to_string(),
2610 crate::DataType::Text => "longtext".to_string(),
2611 crate::DataType::Char(n) => format!("char({})", n),
2612 crate::DataType::Bytea => "longblob".to_string(),
2613 crate::DataType::Date => "date".to_string(),
2614 crate::DataType::Time => "time".to_string(),
2615 crate::DataType::Timestamp | crate::DataType::Timestamptz => "datetime".to_string(),
2616 crate::DataType::Interval => "varchar(64)".to_string(),
2617 crate::DataType::Uuid => "char(36)".to_string(),
2618 crate::DataType::Json | crate::DataType::Jsonb => "json".to_string(),
2619 crate::DataType::Array(_) => "json".to_string(),
2620 _ => "varchar(255)".to_string(),
2621 }
2622}
2623
2624pub fn compute_caching_sha2_auth(password: &str, nonce: &[u8]) -> Vec<u8> {
2632 let stage1 = Sha256::digest(password.as_bytes());
2633 let stage2 = Sha256::digest(stage1);
2634 let mut h = Sha256::new();
2635 h.update(stage2);
2636 h.update(nonce);
2637 let stage3 = h.finalize();
2638 stage1
2639 .iter()
2640 .zip(stage3.iter())
2641 .map(|(a, b)| a ^ b)
2642 .collect()
2643}
2644
2645pub fn compute_native_auth(password: &str, nonce: &[u8]) -> Vec<u8> {
2652 let stage1 = Sha256::digest(password.as_bytes());
2654 let stage2 = Sha256::digest(stage1);
2655 let mut h = Sha256::new();
2656 h.update(stage2);
2657 h.update(nonce);
2658 let stage3 = h.finalize();
2659 stage1
2660 .iter()
2661 .zip(stage3.iter())
2662 .map(|(a, b)| a ^ b)
2663 .collect()
2664}
2665
2666pub async fn handle_mysql_connection(
2698 database: Arc<EmbeddedDatabase>,
2699 stream: TcpStream,
2700 connection_id: u32,
2701) -> Result<()> {
2702 MySqlHandler::handle_connection(database, stream, connection_id).await
2703}
2704
2705#[cfg(unix)]
2709pub async fn handle_mysql_connection_unix(
2710 database: Arc<EmbeddedDatabase>,
2711 stream: UnixStream,
2712 connection_id: u32,
2713) -> Result<()> {
2714 MySqlHandler::handle_connection(database, stream, connection_id).await
2715}
2716
2717#[cfg(test)]
2722mod tests {
2723 use super::*;
2724
2725 #[test]
2726 fn test_capability_flags_default() {
2727 let caps = CapabilityFlags::server_default();
2728 assert!(caps.has(CapabilityFlags::CLIENT_PROTOCOL_41));
2729 assert!(caps.has(CapabilityFlags::CLIENT_SECURE_CONNECTION));
2730 assert!(!caps.has(CapabilityFlags::CLIENT_SSL));
2731 }
2732
2733 #[test]
2734 fn test_capability_flags_set() {
2735 let mut caps = CapabilityFlags::server_default();
2736 caps.set(CapabilityFlags::CLIENT_SSL);
2737 assert!(caps.has(CapabilityFlags::CLIENT_SSL));
2738 }
2739
2740 #[test]
2741 fn test_lenenc_int_roundtrip_small() {
2742 let mut buf = BytesMut::new();
2743 write_lenenc_int(&mut buf, 42);
2744 let mut read = buf.freeze();
2745 assert_eq!(read_lenenc_int(&mut read).expect("read"), 42);
2746 }
2747
2748 #[test]
2749 fn test_lenenc_int_roundtrip_medium() {
2750 let mut buf = BytesMut::new();
2751 write_lenenc_int(&mut buf, 1000);
2752 let mut read = buf.freeze();
2753 assert_eq!(read_lenenc_int(&mut read).expect("read"), 1000);
2754 }
2755
2756 #[test]
2757 fn test_lenenc_int_roundtrip_large() {
2758 let mut buf = BytesMut::new();
2759 write_lenenc_int(&mut buf, 100_000);
2760 let mut read = buf.freeze();
2761 assert_eq!(read_lenenc_int(&mut read).expect("read"), 100_000);
2762 }
2763
2764 #[test]
2765 fn test_lenenc_int_roundtrip_u64() {
2766 let mut buf = BytesMut::new();
2767 write_lenenc_int(&mut buf, u64::MAX);
2768 let mut read = buf.freeze();
2769 assert_eq!(read_lenenc_int(&mut read).expect("read"), u64::MAX);
2770 }
2771
2772 #[test]
2773 fn test_lenenc_string_roundtrip() {
2774 let mut buf = BytesMut::new();
2775 write_lenenc_str(&mut buf, "hello");
2776 let mut read = buf.freeze();
2777 assert_eq!(read_lenenc_str(&mut read).expect("read"), "hello");
2778 }
2779
2780 #[test]
2781 fn test_value_to_mysql_string() {
2782 assert_eq!(value_to_mysql_string(&Value::Boolean(true)), "1");
2783 assert_eq!(value_to_mysql_string(&Value::Boolean(false)), "0");
2784 assert_eq!(value_to_mysql_string(&Value::Int4(42)), "42");
2785 assert_eq!(
2786 value_to_mysql_string(&Value::String("abc".into())),
2787 "abc"
2788 );
2789 }
2790
2791 #[test]
2792 fn test_status_flags_clear() {
2793 let mut sf = StatusFlags::default_flags();
2794 sf.set(StatusFlags::SERVER_STATUS_IN_TRANS);
2795 assert!(sf.has(StatusFlags::SERVER_STATUS_IN_TRANS));
2796 sf.clear(StatusFlags::SERVER_STATUS_IN_TRANS);
2797 assert!(!sf.has(StatusFlags::SERVER_STATUS_IN_TRANS));
2798 }
2799
2800 #[test]
2801 fn test_command_from_u8() {
2802 assert_eq!(Command::from_u8(0x03), Some(Command::ComQuery));
2803 assert_eq!(Command::from_u8(0x01), Some(Command::ComQuit));
2804 assert_eq!(Command::from_u8(0xFF), None);
2805 }
2806
2807 #[test]
2808 fn test_starts_with_icase() {
2809 assert!(starts_with_icase("SELECT * FROM t", "SELECT"));
2810 assert!(starts_with_icase("select * FROM t", "SELECT"));
2811 assert!(!starts_with_icase("INS", "INSERT"));
2812 }
2813
2814 #[test]
2815 fn test_map_error_code_duplicate() {
2816 let (code, state) = map_error_code("duplicate key value violates unique constraint");
2817 assert_eq!(code, 1062);
2818 assert_eq!(state, "23000");
2819 }
2820
2821 #[test]
2822 fn test_map_error_code_not_found() {
2823 let (code, state) = map_error_code("Table 'users' does not exist");
2824 assert_eq!(code, 1146);
2825 assert_eq!(state, "42S02");
2826 }
2827
2828 #[test]
2829 fn test_map_error_code_bad_field() {
2830 let (code, state) = map_error_code("unknown column 'foo'");
2831 assert_eq!(code, 1054);
2832 assert_eq!(state, "42S22");
2833 }
2834
2835 #[test]
2836 fn test_map_error_code_syntax() {
2837 let (code, state) = map_error_code("syntax error at or near 'WHERE'");
2838 assert_eq!(code, 1064);
2839 assert_eq!(state, "42000");
2840 }
2841
2842 #[test]
2843 fn test_map_error_code_unknown() {
2844 let (code, state) = map_error_code("something went wrong");
2845 assert_eq!(code, 1105);
2846 assert_eq!(state, "HY000");
2847 }
2848
2849 #[test]
2850 fn test_sql_like_match_percent_wildcard() {
2851 assert!(sql_like_match("wp_users", "wp_%"));
2852 assert!(sql_like_match("wp_posts", "wp_%"));
2853 assert!(!sql_like_match("users", "wp_%"));
2854 }
2855
2856 #[test]
2857 fn test_sql_like_match_underscore_wildcard() {
2858 assert!(sql_like_match("ab", "a_"));
2859 assert!(!sql_like_match("abc", "a_"));
2860 }
2861
2862 #[test]
2863 fn test_sql_like_match_exact() {
2864 assert!(sql_like_match("users", "users"));
2865 assert!(!sql_like_match("users", "posts"));
2866 }
2867
2868 #[test]
2869 fn test_sql_like_match_both_wildcards() {
2870 assert!(sql_like_match("wp_options", "%options"));
2871 assert!(sql_like_match("my_options", "%options"));
2872 assert!(!sql_like_match("my_posts", "%options"));
2873 }
2874
2875 #[test]
2876 fn test_extract_like_pattern_quoted() {
2877 let pat = extract_like_pattern("SHOW TABLES LIKE 'wp_%'");
2878 assert_eq!(pat, Some("wp_%".to_string()));
2879 }
2880
2881 #[test]
2882 fn test_extract_like_pattern_none() {
2883 let pat = extract_like_pattern("SHOW TABLES");
2884 assert_eq!(pat, None);
2885 }
2886
2887 #[test]
2888 fn test_extract_like_pattern_unquoted() {
2889 let pat = extract_like_pattern("SHOW TABLES LIKE wp_%");
2890 assert_eq!(pat, Some("wp_%".to_string()));
2891 }
2892
2893 #[test]
2894 fn test_datatype_to_mysql_coverage() {
2895 assert_eq!(datatype_to_mysql(&crate::DataType::Boolean), "tinyint(1)");
2896 assert_eq!(datatype_to_mysql(&crate::DataType::Int4), "int");
2897 assert_eq!(datatype_to_mysql(&crate::DataType::Int8), "bigint");
2898 assert_eq!(datatype_to_mysql(&crate::DataType::Text), "longtext");
2899 assert_eq!(datatype_to_mysql(&crate::DataType::Varchar(Some(100))), "varchar(100)");
2900 assert_eq!(datatype_to_mysql(&crate::DataType::Varchar(None)), "varchar(255)");
2901 assert_eq!(datatype_to_mysql(&crate::DataType::Json), "json");
2902 assert_eq!(datatype_to_mysql(&crate::DataType::Uuid), "char(36)");
2903 assert_eq!(datatype_to_mysql(&crate::DataType::Bytea), "longblob");
2904 assert_eq!(datatype_to_mysql(&crate::DataType::Timestamp), "datetime");
2905 }
2906}