1use std::cmp::max;
2use std::collections::HashMap;
3use std::net::TcpStream;
4use std::sync::Arc;
5use std::thread;
6use std::time::{Duration, SystemTime, UNIX_EPOCH};
7use json::{JsonValue, object};
8use log::info;
9use crate::character_set::Charset;
10use crate::client::Client;
11use crate::{comm};
12use crate::comm::{AuthenticationMethod, CapabilityFlags, Command, StatusFlags};
13use crate::config::Config;
14use crate::packet::Packet;
15use crate::response::Response;
16use crate::server::Server;
17
18#[derive(Clone, Debug)]
19pub struct Connect {
20 pub server: Server,
22 pub client: Client,
24 pub packet: Packet,
26 pub config: Config,
28 pub transaction: Vec<String>,
30}
31
32impl Connect {
33 pub fn connect(config: Config) -> Result<Self, String> {
34 match TcpStream::connect(config.clone().url()) {
35 Ok(stream) => {
36 stream.set_read_timeout(Some(Duration::from_secs(10))).unwrap();
37 stream.set_write_timeout(Some(Duration::from_secs(10))).unwrap();
38
39 let mut mysql = Self {
40 server: Server::default(),
41 client: Client::default(),
42 packet: Packet::new(Arc::from(stream.try_clone().unwrap()), config.clone()),
43 config,
44 transaction: vec![],
45 };
46
47 if mysql.config.debug {
48 info!("连接阶段-开始");
49 info!("接收-初始握手数据包");
50 }
51 match mysql.receive_handshake_response() {
52 Ok(_) => {
53 if mysql.config.debug {
54 info!("服务器握手响应解析成功: {:?}",mysql.server.server_version);
55 info!("服务器握手响应解析成功: {:?}",mysql.server.connection_id);
56 info!("客户端响应");
57 }
58 match mysql.handshake_response41() {
59 Ok(_) => {
60 if mysql.config.debug {
61 info!("连接阶段完成");
62 }
63 }
64 Err(e) => return Err(e),
65 };
66 }
67 Err(e) => return Err(e),
68 };
69 if mysql.config.debug {
70 info!("指挥阶段开始");
71 }
72 Ok(mysql)
73 }
74 Err(e) => return Err(e.to_string())
75 }
76 }
77 fn receive_handshake_response(&mut self) -> Result<(), String> {
79 let mut data = self.packet.connection_stage_read()?;
80 let protocol_version = data.remove(0);
81 match protocol_version {
82 9 => self.handshake_v9(data),
83 10 => self.handshake_v10(data),
84 _ => {
85 let status = format!("{:02X}", data.remove(0));
86 match status.as_str() {
87 "FF" => {
89 let error_code = unsafe { String::from_utf8_unchecked(data) };
90 return Err(format!("请求错误: {}", error_code));
91 }
92 "FE" => {}
93 "00" => {}
94 _ => {}
95 }
96 return Err(format!("版本号错误: {}", protocol_version));
97 }
98 }
99 }
100 fn handshake_v9(&mut self, _response: Vec<u8>) -> Result<(), String> {
102 self.server.protocol_version = 9;
103 self.server.server_version = (0, 0, 0);
104 self.server.connection_id = 0;
105 self.server.auth_plugin_data = vec![];
106 self.server.character_set = Charset::NONE;
107 self.server.authentication_method = AuthenticationMethod::None;
108 self.server.status_flags = StatusFlags::None;
109 self.server.capability_flags = 0;
110 Ok(())
111 }
112 fn handshake_v10(&mut self, mut response: Vec<u8>) -> Result<(), String> {
114 self.server.protocol_version = 10;
115
116 let index = response.iter().position(|&item| item == 0).unwrap_or(0);
118 let bytes = response.drain(0..index).collect::<Vec<u8>>();
119 response.remove(0);
120 let server_version = String::from_utf8_lossy(&*bytes.clone()).to_string();
121 let server_version = server_version.as_str().split(".").collect::<Vec<&str>>();
122 let server_version = (server_version[0].parse::<u16>().unwrap(), server_version[1].parse::<u16>().unwrap(), server_version[2].parse::<u16>().unwrap());
123
124 let mut connection_id = response.drain(0..4).collect::<Vec<u8>>();
126 connection_id.reverse();
127 let connection_id = hex::encode(connection_id.clone());
128 let connection_id = u32::from_str_radix(&*connection_id, 16).unwrap();
129
130 let mut auth_plugin_data = response.drain(0..8).collect::<Vec<u8>>();
132 response.remove(0);
134
135 let mut capability_flags = response.drain(0..2).collect::<Vec<u8>>();
137
138 let character_set = response.remove(0);
140 let character_set = u8::from_str_radix(&format!("{:02X}", character_set), 16).unwrap();
141
142 let mut status_flags = response.drain(0..2).collect::<Vec<u8>>();
144 status_flags.reverse();
145 let status_flags = hex::encode(status_flags);
146 let status_flags = u16::from_str_radix(&*status_flags, 16).unwrap();
147 let status_flags = comm::StatusFlags::from(status_flags);
148
149 capability_flags.extend(response.drain(0..2).collect::<Vec<u8>>());
151 let capability_flags = hex::encode(capability_flags);
152 let capability_flags = u32::from_str_radix(&*capability_flags, 16).unwrap();
153 let scramble_len = if (capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 {
154 let scramble_len = response.remove(0) as usize;
156 scramble_len
157 } else {
158 0
159 };
160 let _ = response.drain(0..10).collect::<Vec<u8>>();
162
163 let len = max(13, scramble_len - 8);
165
166 auth_plugin_data.extend(response.drain(0..len - 1).collect::<Vec<u8>>());
167 response.remove(0);
168
169 let authentication_method = if (capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 {
170 let index = response.iter().position(|&item| item == 0).unwrap_or(0);
171 let bytes = response.drain(0..index).collect::<Vec<u8>>();
172 String::from_utf8_lossy(&bytes).to_string().trim().to_string()
173 } else {
174 "".to_string()
175 };
176 let authentication_method = AuthenticationMethod::from(&*authentication_method);
177 self.server.server_version = server_version;
178 self.server.connection_id = connection_id;
179 self.server.auth_plugin_data = auth_plugin_data;
180 self.server.character_set = character_set;
181 self.server.authentication_method = authentication_method;
182 self.server.status_flags = status_flags;
183 self.server.capability_flags = capability_flags;
184 Ok(())
185 }
186 fn handshake_response41(&mut self) -> Result<(), String> {
188 let mut buf = vec![];
189
190 let mut attr: HashMap<&str, &str> = HashMap::new();
191 attr.insert("carry xd", "rust blur");
192 self.client.capability_flags = CapabilityFlags::get_capabilities(&*self.config.database.clone(), attr.clone());
194 buf.extend(self.client.capability_flags.to_le_bytes());
195
196 let mut max_packet_size = hex::decode(format!("{:08x}", 16777215)).unwrap();
198 max_packet_size.reverse();
199 buf.extend(max_packet_size);
200
201 self.client.character_set = Charset::form_u8(&*self.config.charset);
203 buf.push(self.client.character_set);
204
205 let pack_len = [0u8; 23];
207 buf.extend(pack_len);
208
209 buf.extend(self.config.username.as_bytes());
211 buf.push(0);
212
213 if (self.client.capability_flags & CapabilityFlags::ClientProtocol41.info()) > 0
215 && (self.client.capability_flags & CapabilityFlags::ClientSecureConnection.info()) > 0
216 && (self.client.capability_flags & CapabilityFlags::ClientPluginAuth.info()) == 0
217 {
218 self.client.authentication_method = AuthenticationMethod::MysqlNativePassword;
219 } else {
220 match self.server.server_version.0 {
221 8 => {
222 self.client.authentication_method = AuthenticationMethod::MysqlNativePassword
223 }
224 _ => {
225 self.client.authentication_method = AuthenticationMethod::MysqlNativePassword
226 }
227 }
228 }
229
230 if (self.client.capability_flags & CapabilityFlags::ClientPluginAuthLenencClientData.info()) > 0 {
231 buf.push(0);
232 } else {
233 let tt = String::from_utf8_lossy(&*self.server.auth_plugin_data).to_string();
234 let auth_response = self.authentication(tt.as_str());
235 buf.push(auth_response.len() as u8);
236 buf.extend(auth_response);
237 }
238 if (self.client.capability_flags & CapabilityFlags::ClientConnectWithDb.info()) > 0 {
239 buf.extend(self.config.database.as_bytes());
240 buf.push(0);
241 }
242 if (self.client.capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 {
243 buf.extend(self.client.authentication_method.clone().into().as_bytes());
244 buf.push(0);
245 }
246 if (self.client.capability_flags & CapabilityFlags::ClientConnectAttrs.info()) > 0 {
247 let mut list = vec![];
248 for (key, value) in attr.iter() {
249 list.push(key.len() as u8);
250 list.extend(key.as_bytes().to_vec());
251 list.push(value.len() as u8);
252 list.extend(value.as_bytes().to_vec());
253 }
254 buf.push(list.len() as u8);
255 buf.extend(list);
256 }
257 if (self.client.capability_flags & CapabilityFlags::ClientZstdCompressionAlgorithm.info()) > 0 {}
258
259 self.packet.pack(buf.clone())?;
260
261 let res_data = self.packet.connection_stage_read()?;
262
263 if (self.server.capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 && (self.client.capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 {
264 if self.config.debug {
265 info!("回执 auth_switch_request");
266 }
267 Response::new(res_data, self.server.capability_flags)?;
268 } else {
269 if self.config.debug {
270 info!("原生身份验证");
271 }
272 let _ = self.authentication_native41(res_data);
273 }
274 Ok(())
275 }
276 fn authentication_native41(&mut self, mut data: Vec<u8>) -> Result<String, String> {
278 let status = format!("{:02X}", data.remove(0));
279 match status.as_str() {
280 "FE" => {
281 let index = data.iter().position(|&item| item == 0).unwrap_or(0);
282 let bytes = data.drain(0..index).collect::<Vec<u8>>();
283 data.remove(0);
284 let msg = String::from_utf8_lossy(&*bytes).to_string();
285 self.client.authentication_method = AuthenticationMethod::from(msg.as_str());
286 let index = data.iter().position(|&item| item == 0).unwrap_or(0);
287 let bytes = data.drain(0..index).collect::<Vec<u8>>();
288 data.remove(0);
289 let pass = String::from_utf8_lossy(&*bytes).to_string();
290 let pass = self.authentication(&*pass);
291 let mut ttt = vec![];
292 ttt.extend(pass);
293 match self.packet.pack(ttt) {
294 Ok(_) => {
295 match self.packet.connection_stage_read() {
296 Ok(e) => {
297 info!("验证密钥");
298 Response::new(e.clone(), self.client.capability_flags)?;
299 Ok("".to_string())
300 }
301 Err(e) => Err(e)
302 }
303 }
304 Err(e) => Err(e)
305 }
306 }
307 _ => {
308 return Err("".to_string());
309 }
310 }
311 }
312 fn authentication(&mut self, auth_plugin_data: &str) -> Vec<u8> {
313 match self.client.authentication_method {
314 AuthenticationMethod::MysqlOldPassword => vec![],
315 AuthenticationMethod::MysqlNativePassword => {
316 let auth_response = AuthenticationMethod::mysql_native_password(format!("{}", auth_plugin_data).as_str().as_ref(), self.config.userpass.clone().as_ref());
317 return auth_response.unwrap().to_vec();
318 }
319 AuthenticationMethod::MysqlClearPassword => vec![],
320 AuthenticationMethod::CachedSha2Password => {
321 let auth_response = AuthenticationMethod::cached_sha2_password(format!("{}", auth_plugin_data).as_str().as_ref(), self.config.userpass.clone().as_ref());
322 return auth_response.unwrap().to_vec();
323 }
324 AuthenticationMethod::None => vec![]
325 }
326 }
327 pub fn close(&mut self) -> Result<bool, String> {
329 let sql = vec![Command::COM_QUIT];
330 self.packet.sequence_id = 0;
331 return match self.packet.pack_eof(sql) {
332 Ok(e) => Ok(e > 0),
333 Err(e) => Err(format!("失败: {e}"))
334 };
335 }
336 pub fn status(&mut self) -> Result<bool, String> {
338 let sql = vec![Command::COM_STATISTICS];
339 self.packet.sequence_id = 0;
340 return match self.packet.pack_eof(sql) {
341 Ok(_) => {
342 let res = self.packet.connection_stage_read()?;
343 let res = Response::new(res.clone(), self.client.capability_flags)?;
344 Ok(res.error_code == 0)
345 }
346 Err(e) => Err(format!("失败: {e}"))
347 };
348 }
349 pub fn debug(&mut self) -> Result<bool, String> {
351 let sql = vec![Command::COM_DEBUG];
352 self.packet.sequence_id = 0;
353 return match self.packet.pack_eof(sql) {
354 Ok(_) => {
355 let res = self.packet.connection_stage_read()?;
356 let res = Response::new(res.clone(), self.client.capability_flags)?;
357 Ok(res.error_code == 0)
358 }
359 Err(e) => Err(format!("失败: {e}"))
360 };
361 }
362 pub fn db(&mut self, name: &str) -> Result<bool, String> {
364 let mut sql = vec![Command::COM_INIT_DB];
365 sql.extend(name.as_bytes().to_vec());
366 self.packet.sequence_id = 0;
367 return match self.packet.pack_eof(sql) {
368 Ok(_) => {
369 let res = self.packet.connection_stage_read()?;
370 let res = Response::new(res.clone(), self.client.capability_flags)?;
371 Ok(res.error_code == 0)
372 }
373 Err(e) => Err(format!("失败: {e}"))
374 };
375 }
376 pub fn query(&mut self, sql: &str) -> Result<JsonValue, String> {
378 if self.config.debug {
379 let thread_id = thread::current().id();
380 info!("query: {:?} {}",thread_id,sql);
381 }
382 let mut sql_data = vec![Command::COM_QUERY];
383 sql_data.extend(sql.as_bytes().to_vec());
384 self.packet.sequence_id = 0;
385 return match self.packet.pack_eof(sql_data) {
386 Ok(_) => {
387 if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
388 let res = self.packet.com_query_read()?;
389 Ok(res.row.into())
390 } else {
391 Ok(object! {})
392 }
393 }
394 Err(e) => Err(format!("失败 {e}"))
395 };
396 }
397 pub fn execute(&mut self, sql: &str) -> Result<Response, String> {
398 if self.config.debug {
399 let thread_id = thread::current().id();
400 info!("execute: {:?} {}",thread_id,sql);
401 }
402 let mut sql_data = vec![Command::COM_STMT_PREPARE];
403 sql_data.extend(sql.as_bytes().to_vec());
404 self.packet.sequence_id = 0;
405 let response = match self.packet.pack_eof(sql_data) {
406 Ok(_) => {
407 let res = self.packet.connection_stage_read()?;
408 let response = Response::execute(res.clone(), self.client.capability_flags)?;
409 response
410 }
411 Err(e) => return Err(format!("失败 {e}"))
412 };
413 let mut sql_data = vec![Command::COM_STMT_EXECUTE];
414 sql_data.extend(response.statement_id.to_le_bytes());
415 sql_data.push(0);
416 sql_data.extend(1_i32.to_le_bytes());
417 self.packet.sequence_id = 0;
418 return match self.packet.pack_eof(sql_data) {
419 Ok(_) => {
420 let res = self.packet.connection_stage_read()?;
421 let response = Response::new(res.clone(), self.client.capability_flags)?;
422 Ok(response)
423 }
424 Err(e) => Err(format!("失败 {e}"))
425 };
426 }
427
428 pub fn transaction(&mut self) -> Result<bool, String> {
430 if self.transaction.len() > 0 {
431 return self.savepoint();
432 }
433 let mut sql_data = vec![Command::COM_QUERY];
434 let sql = format!("START TRANSACTION");
435 if self.config.debug {
436 let thread_id = thread::current().id();
437 info!("transaction: {:?} {}",thread_id,sql);
438 }
439 sql_data.extend(sql.as_bytes().to_vec());
440 self.packet.sequence_id = 0;
441 return match self.packet.pack_eof(sql_data) {
442 Ok(_) => {
443 if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
444 let res = self.packet.connection_stage_read()?;
445 if self.config.debug {
446 info!("回执: {:?}",res);
447 }
448 let response = Response::new(res.clone(), self.client.capability_flags)?;
449 if response.error_code == 0 {
450 return self.savepoint();
451 }
452 Ok(response.error_code == 0)
453 } else {
454 Ok(false)
455 }
456 }
457 Err(e) => Err(format!("失败 {e}"))
458 };
459 }
460 pub fn set_transaction_level(&mut self, name: u8) -> Result<bool, String> {
462 let mut sql_data = vec![Command::COM_QUERY];
463 let sql = match name {
464 0 => format!("REPEATABLE READ"),1 => format!("SERIALIZABLE"), 2 => format!("READ COMMITTED"),3 => format!("READ UNCOMMITTED"), _ => format!("REPEATABLE READ") };
470 let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", sql);
471 if self.config.debug {
472 let thread_id = thread::current().id();
473 info!("set_transaction_level: {:?} {}",thread_id,sql);
474 }
475 sql_data.extend(sql.as_bytes().to_vec());
476 self.packet.sequence_id = 0;
477 return match self.packet.pack_eof(sql_data) {
478 Ok(_) => {
479 if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
480 let res = self.packet.connection_stage_read()?;
481 if self.config.debug {
482 info!("回执: {:?}",res);
483 }
484 let response = Response::new(res.clone(), self.client.capability_flags)?;
485 Ok(response.error_code == 0)
486 } else {
487 Ok(false)
488 }
489 }
490 Err(e) => Err(format!("失败 {e}"))
491 };
492 }
493 fn savepoint(&mut self) -> Result<bool, String> {
495 let mut sql_data = vec![Command::COM_QUERY];
496
497 let timestamp = SystemTime::now()
499 .duration_since(UNIX_EPOCH)
500 .expect("Time went backwards")
501 .as_nanos();
502 self.transaction.push(format!("{:x}", timestamp));
503 let sql = format!("SAVEPOINT {}", format!("{:x}", timestamp));
504 if self.config.debug {
505 let thread_id = thread::current().id();
506 info!("savepoint: {:?} {}",thread_id,sql);
507 }
508 sql_data.extend(sql.as_bytes().to_vec());
509 self.packet.sequence_id = 0;
510 return match self.packet.pack_eof(sql_data) {
511 Ok(_) => {
512 if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
513 let res = self.packet.connection_stage_read()?;
514 let response = Response::new(res.clone(), self.client.capability_flags)?;
515 Ok(response.error_code == 0)
516 } else {
517 Ok(false)
518 }
519 }
520 Err(e) => Err(format!("失败 {e}"))
521 };
522 }
523 pub fn commit(&mut self) -> Result<bool, String> {
525 let mut sql_data = vec![Command::COM_QUERY];
526 let sql = format!("COMMIT");
527 if self.config.debug {
528 let thread_id = thread::current().id();
529 info!("COMMIT: {:?} {}",thread_id,sql);
530 }
531 sql_data.extend(sql.as_bytes().to_vec());
532 self.packet.sequence_id = 0;
533 return match self.packet.pack_eof(sql_data) {
534 Ok(_) => {
535 if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
536 let res = self.packet.connection_stage_read()?;
537 let response = Response::new(res.clone(), self.client.capability_flags)?;
538 Ok(response.error_code == 0)
539 } else {
540 Ok(false)
541 }
542 }
543 Err(e) => Err(format!("失败 {e}"))
544 };
545 }
546 pub fn rollback(&mut self) -> Result<bool, String> {
548 let mut sql_data = vec![Command::COM_QUERY];
549 let mut sql = format!("ROLLBACK");
550 if self.transaction.len() > 0 {
551 sql = format!("{} TO SAVEPOINT {}", sql, self.transaction.last().unwrap());
552 self.transaction.remove(self.transaction.len() - 1);
553 }
554 if self.config.debug {
555 let thread_id = thread::current().id();
556 info!("rollback: {:?} {}",thread_id,sql);
557 }
558 sql_data.extend(sql.as_bytes().to_vec());
559 self.packet.sequence_id = 0;
560 return match self.packet.pack_eof(sql_data) {
561 Ok(_) => {
562 if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
563 let res = self.packet.connection_stage_read()?;
564 let response = Response::new(res.clone(), self.client.capability_flags)?;
565 Ok(response.error_code == 0)
566 } else {
567 Ok(false)
568 }
569 }
570 Err(e) => Err(format!("失败 {e}"))
571 };
572 }
573 pub fn ping(&mut self) -> Result<bool, String> {
575 let sql = vec![Command::COM_PING];
576 self.packet.sequence_id = 0;
577 return match self.packet.pack_eof(sql) {
578 Ok(_) => {
579 let res = self.packet.connection_stage_read()?;
580 Response::new(res.clone(), self.client.capability_flags)?;
581 Ok(true)
582 }
583 Err(e) => Err(format!("ping失败 {e}"))
584 };
585 }
586}