br_mysql/
connect.rs

1use std::cmp::max;
2use std::collections::HashMap;
3use std::net::TcpStream;
4use std::sync::Arc;
5use std::thread;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7use json::{JsonValue, object};
8use log::info;
9use crate::character_set::Charset;
10use crate::client::Client;
11use crate::{comm};
12use crate::comm::{AuthenticationMethod, CapabilityFlags, Command, StatusFlags};
13use crate::config::Config;
14use crate::packet::Packet;
15use crate::response::Response;
16use crate::server::Server;
17
18#[derive(Clone, Debug)]
19pub struct Connect {
20    /// 服务端
21    pub server: Server,
22    /// 客户端
23    pub client: Client,
24    /// 响应包
25    pub packet: Packet,
26    /// 基础配置
27    pub config: Config,
28    /// 事务队列
29    pub transaction: Vec<String>,
30}
31
32impl Connect {
33    pub fn connect(config: Config) -> Result<Self, String> {
34        match TcpStream::connect(config.clone().url()) {
35            Ok(stream) => {
36                stream.set_read_timeout(Some(Duration::from_secs(10))).unwrap();
37                stream.set_write_timeout(Some(Duration::from_secs(10))).unwrap();
38
39                let mut mysql = Self {
40                    server: Server::default(),
41                    client: Client::default(),
42                    packet: Packet::new(Arc::from(stream.try_clone().unwrap()), config.clone()),
43                    config,
44                    transaction: vec![],
45                };
46
47                if mysql.config.debug {
48                    info!("连接阶段-开始");
49                    info!("接收-初始握手数据包");
50                }
51                match mysql.receive_handshake_response() {
52                    Ok(_) => {
53                        if mysql.config.debug {
54                            info!("服务器握手响应解析成功: {:?}",mysql.server.server_version);
55                            info!("服务器握手响应解析成功: {:?}",mysql.server.connection_id);
56                            info!("客户端响应");
57                        }
58                        match mysql.handshake_response41() {
59                            Ok(_) => {
60                                if mysql.config.debug {
61                                    info!("连接阶段完成");
62                                }
63                            }
64                            Err(e) => return Err(e),
65                        };
66                    }
67                    Err(e) => return Err(e),
68                };
69                if mysql.config.debug {
70                    info!("指挥阶段开始");
71                }
72                Ok(mysql)
73            }
74            Err(e) => return Err(e.to_string())
75        }
76    }
77    /// 握手响应
78    fn receive_handshake_response(&mut self) -> Result<(), String> {
79        let mut data = self.packet.connection_stage_read()?;
80        let protocol_version = data.remove(0);
81        match protocol_version {
82            9 => self.handshake_v9(data),
83            10 => self.handshake_v10(data),
84            _ => {
85                let status = format!("{:02X}", data.remove(0));
86                match status.as_str() {
87                    // 错误包
88                    "FF" => {
89                        let error_code = unsafe { String::from_utf8_unchecked(data) };
90                        return Err(format!("请求错误: {}", error_code));
91                    }
92                    "FE" => {}
93                    "00" => {}
94                    _ => {}
95                }
96                return Err(format!("版本号错误: {}", protocol_version));
97            }
98        }
99    }
100    /// 握手协议 v9
101    fn handshake_v9(&mut self, _response: Vec<u8>) -> Result<(), String> {
102        self.server.protocol_version = 9;
103        self.server.server_version = (0, 0, 0);
104        self.server.connection_id = 0;
105        self.server.auth_plugin_data = vec![];
106        self.server.character_set = Charset::NONE;
107        self.server.authentication_method = AuthenticationMethod::None;
108        self.server.status_flags = StatusFlags::None;
109        self.server.capability_flags = 0;
110        Ok(())
111    }
112    /// 握手协议 v10
113    fn handshake_v10(&mut self, mut response: Vec<u8>) -> Result<(), String> {
114        self.server.protocol_version = 10;
115
116        // 服务器版本
117        let index = response.iter().position(|&item| item == 0).unwrap_or(0);
118        let bytes = response.drain(0..index).collect::<Vec<u8>>();
119        response.remove(0);
120        let server_version = String::from_utf8_lossy(&*bytes.clone()).to_string();
121        let server_version = server_version.as_str().split(".").collect::<Vec<&str>>();
122        let server_version = (server_version[0].parse::<u16>().unwrap(), server_version[1].parse::<u16>().unwrap(), server_version[2].parse::<u16>().unwrap());
123
124        // 线程ID
125        let mut connection_id = response.drain(0..4).collect::<Vec<u8>>();
126        connection_id.reverse();
127        let connection_id = hex::encode(connection_id.clone());
128        let connection_id = u32::from_str_radix(&*connection_id, 16).unwrap();
129
130        // 密码加密部分1
131        let mut auth_plugin_data = response.drain(0..8).collect::<Vec<u8>>();
132        // 填充
133        response.remove(0);
134
135        // 能力 标签 Capability Flags Lower
136        let mut capability_flags = response.drain(0..2).collect::<Vec<u8>>();
137
138        // 数据库的编码
139        let character_set = response.remove(0);
140        let character_set = u8::from_str_radix(&format!("{:02X}", character_set), 16).unwrap();
141
142        // 状态标签
143        let mut status_flags = response.drain(0..2).collect::<Vec<u8>>();
144        status_flags.reverse();
145        let status_flags = hex::encode(status_flags);
146        let status_flags = u16::from_str_radix(&*status_flags, 16).unwrap();
147        let status_flags = comm::StatusFlags::from(status_flags);
148
149        // 能力 标签 Capability Flags Upper
150        capability_flags.extend(response.drain(0..2).collect::<Vec<u8>>());
151        let capability_flags = hex::encode(capability_flags);
152        let capability_flags = u32::from_str_radix(&*capability_flags, 16).unwrap();
153        let scramble_len = if (capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 {
154            // 挑战数据长度
155            let scramble_len = response.remove(0) as usize;
156            scramble_len
157        } else {
158            0
159        };
160        // 保留
161        let _ = response.drain(0..10).collect::<Vec<u8>>();
162
163        // 密码加密部分2
164        let len = max(13, scramble_len - 8);
165
166        auth_plugin_data.extend(response.drain(0..len - 1).collect::<Vec<u8>>());
167        response.remove(0);
168
169        let authentication_method = if (capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 {
170            let index = response.iter().position(|&item| item == 0).unwrap_or(0);
171            let bytes = response.drain(0..index).collect::<Vec<u8>>();
172            String::from_utf8_lossy(&bytes).to_string().trim().to_string()
173        } else {
174            "".to_string()
175        };
176        let authentication_method = AuthenticationMethod::from(&*authentication_method);
177        self.server.server_version = server_version;
178        self.server.connection_id = connection_id;
179        self.server.auth_plugin_data = auth_plugin_data;
180        self.server.character_set = character_set;
181        self.server.authentication_method = authentication_method;
182        self.server.status_flags = status_flags;
183        self.server.capability_flags = capability_flags;
184        Ok(())
185    }
186    /// 握手响应客户端响应
187    fn handshake_response41(&mut self) -> Result<(), String> {
188        let mut buf = vec![];
189
190        let mut attr: HashMap<&str, &str> = HashMap::new();
191        attr.insert("carry xd", "rust blur");
192        // 客户端_flag 4
193        self.client.capability_flags = CapabilityFlags::get_capabilities(&*self.config.database.clone(), attr.clone());
194        buf.extend(self.client.capability_flags.to_le_bytes());
195
196        // 最大数据包大小 4
197        let mut max_packet_size = hex::decode(format!("{:08x}", 16777215)).unwrap();
198        max_packet_size.reverse();
199        buf.extend(max_packet_size);
200
201        // 客户端字符集 1
202        self.client.character_set = Charset::form_u8(&*self.config.charset);
203        buf.push(self.client.character_set);
204
205        // 填充
206        let pack_len = [0u8; 23];
207        buf.extend(pack_len);
208
209        // 用户名 5
210        buf.extend(self.config.username.as_bytes());
211        buf.push(0);
212
213        // 认证模式设置
214        if (self.client.capability_flags & CapabilityFlags::ClientProtocol41.info()) > 0
215            && (self.client.capability_flags & CapabilityFlags::ClientSecureConnection.info()) > 0
216            && (self.client.capability_flags & CapabilityFlags::ClientPluginAuth.info()) == 0
217        {
218            self.client.authentication_method = AuthenticationMethod::MysqlNativePassword;
219        } else {
220            match self.server.server_version.0 {
221                8 => {
222                    self.client.authentication_method = AuthenticationMethod::MysqlNativePassword
223                }
224                _ => {
225                    self.client.authentication_method = AuthenticationMethod::MysqlNativePassword
226                }
227            }
228        }
229
230        if (self.client.capability_flags & CapabilityFlags::ClientPluginAuthLenencClientData.info()) > 0 {
231            buf.push(0);
232        } else {
233            let tt = String::from_utf8_lossy(&*self.server.auth_plugin_data).to_string();
234            let auth_response = self.authentication(tt.as_str());
235            buf.push(auth_response.len() as u8);
236            buf.extend(auth_response);
237        }
238        if (self.client.capability_flags & CapabilityFlags::ClientConnectWithDb.info()) > 0 {
239            buf.extend(self.config.database.as_bytes());
240            buf.push(0);
241        }
242        if (self.client.capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 {
243            buf.extend(self.client.authentication_method.clone().into().as_bytes());
244            buf.push(0);
245        }
246        if (self.client.capability_flags & CapabilityFlags::ClientConnectAttrs.info()) > 0 {
247            let mut list = vec![];
248            for (key, value) in attr.iter() {
249                list.push(key.len() as u8);
250                list.extend(key.as_bytes().to_vec());
251                list.push(value.len() as u8);
252                list.extend(value.as_bytes().to_vec());
253            }
254            buf.push(list.len() as u8);
255            buf.extend(list);
256        }
257        if (self.client.capability_flags & CapabilityFlags::ClientZstdCompressionAlgorithm.info()) > 0 {}
258
259        self.packet.pack(buf.clone())?;
260
261        let res_data = self.packet.connection_stage_read()?;
262
263        if (self.server.capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 && (self.client.capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 {
264            if self.config.debug {
265                info!("回执 auth_switch_request");
266            }
267            Response::new(res_data, self.server.capability_flags)?;
268        } else {
269            if self.config.debug {
270                info!("原生身份验证");
271            }
272            let _ = self.authentication_native41(res_data);
273        }
274        Ok(())
275    }
276    /// 原生身份验证
277    fn authentication_native41(&mut self, mut data: Vec<u8>) -> Result<String, String> {
278        let status = format!("{:02X}", data.remove(0));
279        match status.as_str() {
280            "FE" => {
281                let index = data.iter().position(|&item| item == 0).unwrap_or(0);
282                let bytes = data.drain(0..index).collect::<Vec<u8>>();
283                data.remove(0);
284                let msg = String::from_utf8_lossy(&*bytes).to_string();
285                self.client.authentication_method = AuthenticationMethod::from(msg.as_str());
286                let index = data.iter().position(|&item| item == 0).unwrap_or(0);
287                let bytes = data.drain(0..index).collect::<Vec<u8>>();
288                data.remove(0);
289                let pass = String::from_utf8_lossy(&*bytes).to_string();
290                let pass = self.authentication(&*pass);
291                let mut ttt = vec![];
292                ttt.extend(pass);
293                match self.packet.pack(ttt) {
294                    Ok(_) => {
295                        match self.packet.connection_stage_read() {
296                            Ok(e) => {
297                                info!("验证密钥");
298                                Response::new(e.clone(), self.client.capability_flags)?;
299                                Ok("".to_string())
300                            }
301                            Err(e) => Err(e)
302                        }
303                    }
304                    Err(e) => Err(e)
305                }
306            }
307            _ => {
308                return Err("".to_string());
309            }
310        }
311    }
312    fn authentication(&mut self, auth_plugin_data: &str) -> Vec<u8> {
313        match self.client.authentication_method {
314            AuthenticationMethod::MysqlOldPassword => vec![],
315            AuthenticationMethod::MysqlNativePassword => {
316                let auth_response = AuthenticationMethod::mysql_native_password(format!("{}", auth_plugin_data).as_str().as_ref(), self.config.userpass.clone().as_ref());
317                return auth_response.unwrap().to_vec();
318            }
319            AuthenticationMethod::MysqlClearPassword => vec![],
320            AuthenticationMethod::CachedSha2Password => {
321                let auth_response = AuthenticationMethod::cached_sha2_password(format!("{}", auth_plugin_data).as_str().as_ref(), self.config.userpass.clone().as_ref());
322                return auth_response.unwrap().to_vec();
323            }
324            AuthenticationMethod::None => vec![]
325        }
326    }
327    /// 关闭连接
328    pub fn close(&mut self) -> Result<bool, String> {
329        let sql = vec![Command::COM_QUIT];
330        self.packet.sequence_id = 0;
331        return match self.packet.pack_eof(sql) {
332            Ok(e) => Ok(e > 0),
333            Err(e) => Err(format!("失败: {e}"))
334        };
335    }
336    /// 获取一些内部状态变量
337    pub fn status(&mut self) -> Result<bool, String> {
338        let sql = vec![Command::COM_STATISTICS];
339        self.packet.sequence_id = 0;
340        return match self.packet.pack_eof(sql) {
341            Ok(_) => {
342                let res = self.packet.connection_stage_read()?;
343                let res = Response::new(res.clone(), self.client.capability_flags)?;
344                Ok(res.error_code == 0)
345            }
346            Err(e) => Err(format!("失败: {e}"))
347        };
348    }
349    /// 将调试信息转储到服务器的stdout
350    pub fn debug(&mut self) -> Result<bool, String> {
351        let sql = vec![Command::COM_DEBUG];
352        self.packet.sequence_id = 0;
353        return match self.packet.pack_eof(sql) {
354            Ok(_) => {
355                let res = self.packet.connection_stage_read()?;
356                let res = Response::new(res.clone(), self.client.capability_flags)?;
357                Ok(res.error_code == 0)
358            }
359            Err(e) => Err(format!("失败: {e}"))
360        };
361    }
362    /// 更换连接
363    pub fn db(&mut self, name: &str) -> Result<bool, String> {
364        let mut sql = vec![Command::COM_INIT_DB];
365        sql.extend(name.as_bytes().to_vec());
366        self.packet.sequence_id = 0;
367        return match self.packet.pack_eof(sql) {
368            Ok(_) => {
369                let res = self.packet.connection_stage_read()?;
370                let res = Response::new(res.clone(), self.client.capability_flags)?;
371                Ok(res.error_code == 0)
372            }
373            Err(e) => Err(format!("失败: {e}"))
374        };
375    }
376    /// 查询
377    pub fn query(&mut self, sql: &str) -> Result<JsonValue, String> {
378        if self.config.debug {
379            let thread_id = thread::current().id();
380            info!("query: {:?} {}",thread_id,sql);
381        }
382        let mut sql_data = vec![Command::COM_QUERY];
383        sql_data.extend(sql.as_bytes().to_vec());
384        self.packet.sequence_id = 0;
385        return match self.packet.pack_eof(sql_data) {
386            Ok(_) => {
387                if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
388                    let res = self.packet.com_query_read()?;
389                    Ok(res.row.into())
390                } else {
391                    Ok(object! {})
392                }
393            }
394            Err(e) => Err(format!("失败 {e}"))
395        };
396    }
397    pub fn execute(&mut self, sql: &str) -> Result<Response, String> {
398        if self.config.debug {
399            let thread_id = thread::current().id();
400            info!("execute: {:?} {}",thread_id,sql);
401        }
402        let mut sql_data = vec![Command::COM_STMT_PREPARE];
403        sql_data.extend(sql.as_bytes().to_vec());
404        self.packet.sequence_id = 0;
405        let response = match self.packet.pack_eof(sql_data) {
406            Ok(_) => {
407                let res = self.packet.connection_stage_read()?;
408                let response = Response::execute(res.clone(), self.client.capability_flags)?;
409                response
410            }
411            Err(e) => return Err(format!("失败 {e}"))
412        };
413        let mut sql_data = vec![Command::COM_STMT_EXECUTE];
414        sql_data.extend(response.statement_id.to_le_bytes());
415        sql_data.push(0);
416        sql_data.extend(1_i32.to_le_bytes());
417        self.packet.sequence_id = 0;
418        return match self.packet.pack_eof(sql_data) {
419            Ok(_) => {
420                let res = self.packet.connection_stage_read()?;
421                let response = Response::new(res.clone(), self.client.capability_flags)?;
422                Ok(response)
423            }
424            Err(e) => Err(format!("失败 {e}"))
425        };
426    }
427
428    /// 开始事务
429    pub fn transaction(&mut self) -> Result<bool, String> {
430        if self.transaction.len() > 0 {
431            return self.savepoint();
432        }
433        let mut sql_data = vec![Command::COM_QUERY];
434        let sql = format!("START TRANSACTION");
435        if self.config.debug {
436            let thread_id = thread::current().id();
437            info!("transaction: {:?} {}",thread_id,sql);
438        }
439        sql_data.extend(sql.as_bytes().to_vec());
440        self.packet.sequence_id = 0;
441        return match self.packet.pack_eof(sql_data) {
442            Ok(_) => {
443                if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
444                    let res = self.packet.connection_stage_read()?;
445                    if self.config.debug {
446                        info!("回执: {:?}",res);
447                    }
448                    let response = Response::new(res.clone(), self.client.capability_flags)?;
449                    if response.error_code == 0 {
450                        return self.savepoint();
451                    }
452                    Ok(response.error_code == 0)
453                } else {
454                    Ok(false)
455                }
456            }
457            Err(e) => Err(format!("失败 {e}"))
458        };
459    }
460    /// 设置事务等级
461    pub fn set_transaction_level(&mut self, name: u8) -> Result<bool, String> {
462        let mut sql_data = vec![Command::COM_QUERY];
463        let sql = match name {
464            0 => format!("REPEATABLE READ"),// 可重复读
465            1 => format!("SERIALIZABLE"), // 可串行化
466            2 => format!("READ COMMITTED"),// 读已提交
467            3 => format!("READ UNCOMMITTED"), // 读未提交
468            _ => format!("REPEATABLE READ") // 可重复读
469        };
470        let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", sql);
471        if self.config.debug {
472            let thread_id = thread::current().id();
473            info!("set_transaction_level: {:?} {}",thread_id,sql);
474        }
475        sql_data.extend(sql.as_bytes().to_vec());
476        self.packet.sequence_id = 0;
477        return match self.packet.pack_eof(sql_data) {
478            Ok(_) => {
479                if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
480                    let res = self.packet.connection_stage_read()?;
481                    if self.config.debug {
482                        info!("回执: {:?}",res);
483                    }
484                    let response = Response::new(res.clone(), self.client.capability_flags)?;
485                    Ok(response.error_code == 0)
486                } else {
487                    Ok(false)
488                }
489            }
490            Err(e) => Err(format!("失败 {e}"))
491        };
492    }
493    /// 事务保存点
494    fn savepoint(&mut self) -> Result<bool, String> {
495        let mut sql_data = vec![Command::COM_QUERY];
496
497        // 获取当前时间戳
498        let timestamp = SystemTime::now()
499            .duration_since(UNIX_EPOCH)
500            .expect("Time went backwards")
501            .as_nanos();
502        self.transaction.push(format!("{:x}", timestamp));
503        let sql = format!("SAVEPOINT {}", format!("{:x}", timestamp));
504        if self.config.debug {
505            let thread_id = thread::current().id();
506            info!("savepoint: {:?} {}",thread_id,sql);
507        }
508        sql_data.extend(sql.as_bytes().to_vec());
509        self.packet.sequence_id = 0;
510        return match self.packet.pack_eof(sql_data) {
511            Ok(_) => {
512                if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
513                    let res = self.packet.connection_stage_read()?;
514                    let response = Response::new(res.clone(), self.client.capability_flags)?;
515                    Ok(response.error_code == 0)
516                } else {
517                    Ok(false)
518                }
519            }
520            Err(e) => Err(format!("失败 {e}"))
521        };
522    }
523    /// 事务提交
524    pub fn commit(&mut self) -> Result<bool, String> {
525        let mut sql_data = vec![Command::COM_QUERY];
526        let sql = format!("COMMIT");
527        if self.config.debug {
528            let thread_id = thread::current().id();
529            info!("COMMIT: {:?} {}",thread_id,sql);
530        }
531        sql_data.extend(sql.as_bytes().to_vec());
532        self.packet.sequence_id = 0;
533        return match self.packet.pack_eof(sql_data) {
534            Ok(_) => {
535                if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
536                    let res = self.packet.connection_stage_read()?;
537                    let response = Response::new(res.clone(), self.client.capability_flags)?;
538                    Ok(response.error_code == 0)
539                } else {
540                    Ok(false)
541                }
542            }
543            Err(e) => Err(format!("失败 {e}"))
544        };
545    }
546    /// 回滚事务
547    pub fn rollback(&mut self) -> Result<bool, String> {
548        let mut sql_data = vec![Command::COM_QUERY];
549        let mut sql = format!("ROLLBACK");
550        if self.transaction.len() > 0 {
551            sql = format!("{} TO SAVEPOINT {}", sql, self.transaction.last().unwrap());
552            self.transaction.remove(self.transaction.len() - 1);
553        }
554        if self.config.debug {
555            let thread_id = thread::current().id();
556            info!("rollback: {:?} {}",thread_id,sql);
557        }
558        sql_data.extend(sql.as_bytes().to_vec());
559        self.packet.sequence_id = 0;
560        return match self.packet.pack_eof(sql_data) {
561            Ok(_) => {
562                if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
563                    let res = self.packet.connection_stage_read()?;
564                    let response = Response::new(res.clone(), self.client.capability_flags)?;
565                    Ok(response.error_code == 0)
566                } else {
567                    Ok(false)
568                }
569            }
570            Err(e) => Err(format!("失败 {e}"))
571        };
572    }
573    /// ping
574    pub fn ping(&mut self) -> Result<bool, String> {
575        let sql = vec![Command::COM_PING];
576        self.packet.sequence_id = 0;
577        return match self.packet.pack_eof(sql) {
578            Ok(_) => {
579                let res = self.packet.connection_stage_read()?;
580                Response::new(res.clone(), self.client.capability_flags)?;
581                Ok(true)
582            }
583            Err(e) => Err(format!("ping失败 {e}"))
584        };
585    }
586}