use std::cmp::max;
use std::collections::HashMap;
use std::net::TcpStream;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use json::{JsonValue, object};
use log::info;
use crate::character_set::Charset;
use crate::client::Client;
use crate::{comm};
use crate::comm::{AuthenticationMethod, CapabilityFlags, Command, StatusFlags};
use crate::config::Config;
use crate::packet::Packet;
use crate::response::Response;
use crate::server::Server;
#[derive(Clone, Debug)]
pub struct Connect {
pub server: Server,
pub client: Client,
pub packet: Packet,
pub config: Config,
pub transaction: Vec<String>,
}
impl Connect {
pub fn connect(config: Config) -> Result<Self, String> {
match TcpStream::connect(config.clone().url()) {
Ok(stream) => {
stream.set_read_timeout(Some(Duration::from_secs(10))).unwrap();
stream.set_write_timeout(Some(Duration::from_secs(10))).unwrap();
let mut mysql = Self {
server: Server::default(),
client: Client::default(),
packet: Packet::new(Arc::from(stream.try_clone().unwrap()), config.clone()),
config,
transaction: vec![],
};
if mysql.config.debug {
info!("连接阶段-开始");
info!("接收-初始握手数据包");
}
match mysql.receive_handshake_response() {
Ok(_) => {
if mysql.config.debug {
info!("服务器握手响应解析成功: {:?}",mysql.server.server_version);
info!("服务器握手响应解析成功: {:?}",mysql.server.connection_id);
info!("客户端响应");
}
match mysql.handshake_response41() {
Ok(_) => {
if mysql.config.debug {
info!("连接阶段完成");
}
}
Err(e) => return Err(e),
};
}
Err(e) => return Err(e),
};
if mysql.config.debug {
info!("指挥阶段开始");
}
Ok(mysql)
}
Err(e) => return Err(e.to_string())
}
}
fn receive_handshake_response(&mut self) -> Result<(), String> {
let mut data = self.packet.connection_stage_read()?;
let protocol_version = data.remove(0);
match protocol_version {
9 => self.handshake_v9(data),
10 => self.handshake_v10(data),
_ => {
let status = format!("{:02X}", data.remove(0));
match status.as_str() {
"FF" => {
let error_code = unsafe { String::from_utf8_unchecked(data) };
return Err(format!("请求错误: {}", error_code));
}
"FE" => {}
"00" => {}
_ => {}
}
return Err(format!("版本号错误: {}", protocol_version));
}
}
}
fn handshake_v9(&mut self, _response: Vec<u8>) -> Result<(), String> {
self.server.protocol_version = 9;
self.server.server_version = (0, 0, 0);
self.server.connection_id = 0;
self.server.auth_plugin_data = vec![];
self.server.character_set = Charset::NONE;
self.server.authentication_method = AuthenticationMethod::None;
self.server.status_flags = StatusFlags::None;
self.server.capability_flags = 0;
Ok(())
}
fn handshake_v10(&mut self, mut response: Vec<u8>) -> Result<(), String> {
self.server.protocol_version = 10;
let index = response.iter().position(|&item| item == 0).unwrap_or(0);
let bytes = response.drain(0..index).collect::<Vec<u8>>();
response.remove(0);
let server_version = String::from_utf8_lossy(&*bytes.clone()).to_string();
let server_version = server_version.as_str().split(".").collect::<Vec<&str>>();
let server_version = (server_version[0].parse::<u16>().unwrap(), server_version[1].parse::<u16>().unwrap(), server_version[2].parse::<u16>().unwrap());
let mut connection_id = response.drain(0..4).collect::<Vec<u8>>();
connection_id.reverse();
let connection_id = hex::encode(connection_id.clone());
let connection_id = u32::from_str_radix(&*connection_id, 16).unwrap();
let mut auth_plugin_data = response.drain(0..8).collect::<Vec<u8>>();
response.remove(0);
let mut capability_flags = response.drain(0..2).collect::<Vec<u8>>();
let character_set = response.remove(0);
let character_set = u8::from_str_radix(&format!("{:02X}", character_set), 16).unwrap();
let mut status_flags = response.drain(0..2).collect::<Vec<u8>>();
status_flags.reverse();
let status_flags = hex::encode(status_flags);
let status_flags = u16::from_str_radix(&*status_flags, 16).unwrap();
let status_flags = comm::StatusFlags::from(status_flags);
capability_flags.extend(response.drain(0..2).collect::<Vec<u8>>());
let capability_flags = hex::encode(capability_flags);
let capability_flags = u32::from_str_radix(&*capability_flags, 16).unwrap();
let scramble_len = if (capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 {
let scramble_len = response.remove(0) as usize;
scramble_len
} else {
0
};
let _ = response.drain(0..10).collect::<Vec<u8>>();
let len = max(13, scramble_len - 8);
auth_plugin_data.extend(response.drain(0..len - 1).collect::<Vec<u8>>());
response.remove(0);
let authentication_method = if (capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 {
let index = response.iter().position(|&item| item == 0).unwrap_or(0);
let bytes = response.drain(0..index).collect::<Vec<u8>>();
String::from_utf8_lossy(&bytes).to_string().trim().to_string()
} else {
"".to_string()
};
let authentication_method = AuthenticationMethod::from(&*authentication_method);
self.server.server_version = server_version;
self.server.connection_id = connection_id;
self.server.auth_plugin_data = auth_plugin_data;
self.server.character_set = character_set;
self.server.authentication_method = authentication_method;
self.server.status_flags = status_flags;
self.server.capability_flags = capability_flags;
Ok(())
}
fn handshake_response41(&mut self) -> Result<(), String> {
let mut buf = vec![];
let mut attr: HashMap<&str, &str> = HashMap::new();
attr.insert("carry xd", "rust blur");
self.client.capability_flags = CapabilityFlags::get_capabilities(&*self.config.database.clone(), attr.clone());
buf.extend(self.client.capability_flags.to_le_bytes());
let mut max_packet_size = hex::decode(format!("{:08x}", 16777215)).unwrap();
max_packet_size.reverse();
buf.extend(max_packet_size);
self.client.character_set = Charset::form_u8(&*self.config.charset);
buf.push(self.client.character_set);
let pack_len = [0u8; 23];
buf.extend(pack_len);
buf.extend(self.config.username.as_bytes());
buf.push(0);
if (self.client.capability_flags & CapabilityFlags::ClientProtocol41.info()) > 0
&& (self.client.capability_flags & CapabilityFlags::ClientSecureConnection.info()) > 0
&& (self.client.capability_flags & CapabilityFlags::ClientPluginAuth.info()) == 0
{
self.client.authentication_method = AuthenticationMethod::MysqlNativePassword;
} else {
match self.server.server_version.0 {
8 => {
self.client.authentication_method = AuthenticationMethod::MysqlNativePassword
}
_ => {
self.client.authentication_method = AuthenticationMethod::MysqlNativePassword
}
}
}
if (self.client.capability_flags & CapabilityFlags::ClientPluginAuthLenencClientData.info()) > 0 {
buf.push(0);
} else {
let tt = String::from_utf8_lossy(&*self.server.auth_plugin_data).to_string();
let auth_response = self.authentication(tt.as_str());
buf.push(auth_response.len() as u8);
buf.extend(auth_response);
}
if (self.client.capability_flags & CapabilityFlags::ClientConnectWithDb.info()) > 0 {
buf.extend(self.config.database.as_bytes());
buf.push(0);
}
if (self.client.capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 {
buf.extend(self.client.authentication_method.clone().into().as_bytes());
buf.push(0);
}
if (self.client.capability_flags & CapabilityFlags::ClientConnectAttrs.info()) > 0 {
let mut list = vec![];
for (key, value) in attr.iter() {
list.push(key.len() as u8);
list.extend(key.as_bytes().to_vec());
list.push(value.len() as u8);
list.extend(value.as_bytes().to_vec());
}
buf.push(list.len() as u8);
buf.extend(list);
}
if (self.client.capability_flags & CapabilityFlags::ClientZstdCompressionAlgorithm.info()) > 0 {}
self.packet.pack(buf.clone())?;
let res_data = self.packet.connection_stage_read()?;
if (self.server.capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 && (self.client.capability_flags & CapabilityFlags::ClientPluginAuth.info()) > 0 {
if self.config.debug {
info!("回执 auth_switch_request");
}
Response::new(res_data, self.server.capability_flags)?;
} else {
if self.config.debug {
info!("原生身份验证");
}
let _ = self.authentication_native41(res_data);
}
Ok(())
}
fn auth_switch_request(&mut self, mut data: Vec<u8>) -> Result<(), String> {
let status = format!("{:02X}", data.remove(0));
match status.as_str() {
"FE" => {
let index = data.iter().position(|&item| item == 0).unwrap_or(0);
let bytes = data.drain(0..index).collect::<Vec<u8>>();
data.remove(0);
let server_version = String::from_utf8_lossy(&*bytes.clone()).to_string();
self.server.authentication_method = AuthenticationMethod::from(&*server_version);
let pass_data = data.drain(0..20).collect::<Vec<u8>>();
match self.server.authentication_method {
AuthenticationMethod::MysqlOldPassword => {}
AuthenticationMethod::MysqlNativePassword => {
let tt = String::from_utf8_lossy(&*pass_data).to_string();
let auth_response = AuthenticationMethod::mysql_native_password(format!("{}", tt).as_str().as_ref(), self.config.userpass.clone().as_ref());
println!(">>>>{:?}", auth_response);
}
AuthenticationMethod::MysqlClearPassword => {}
AuthenticationMethod::CachedSha2Password => {}
AuthenticationMethod::None => {}
}
return Ok(());
}
"00" => {
return Ok(());
}
_ => {
return Err(format!("未知错误: {}", status));
}
}
}
fn authentication_native41(&mut self, mut data: Vec<u8>) -> Result<String, String> {
let status = format!("{:02X}", data.remove(0));
match status.as_str() {
"FE" => {
let index = data.iter().position(|&item| item == 0).unwrap_or(0);
let bytes = data.drain(0..index).collect::<Vec<u8>>();
data.remove(0);
let msg = String::from_utf8_lossy(&*bytes).to_string();
self.client.authentication_method = AuthenticationMethod::from(msg.as_str());
let index = data.iter().position(|&item| item == 0).unwrap_or(0);
let bytes = data.drain(0..index).collect::<Vec<u8>>();
data.remove(0);
let pass = String::from_utf8_lossy(&*bytes).to_string();
let pass = self.authentication(&*pass);
let mut ttt = vec![];
ttt.extend(pass);
match self.packet.pack(ttt) {
Ok(_) => {
match self.packet.connection_stage_read() {
Ok(e) => {
info!("验证密钥");
Response::new(e.clone(), self.client.capability_flags)?;
Ok("".to_string())
}
Err(e) => Err(e)
}
}
Err(e) => Err(e)
}
}
_ => {
return Err("".to_string());
}
}
}
fn authentication(&mut self, auth_plugin_data: &str) -> Vec<u8> {
match self.client.authentication_method {
AuthenticationMethod::MysqlOldPassword => vec![],
AuthenticationMethod::MysqlNativePassword => {
let auth_response = AuthenticationMethod::mysql_native_password(format!("{}", auth_plugin_data).as_str().as_ref(), self.config.userpass.clone().as_ref());
return auth_response.unwrap().to_vec();
}
AuthenticationMethod::MysqlClearPassword => vec![],
AuthenticationMethod::CachedSha2Password => {
let auth_response = AuthenticationMethod::cached_sha2_password(format!("{}", auth_plugin_data).as_str().as_ref(), self.config.userpass.clone().as_ref());
return auth_response.unwrap().to_vec();
}
AuthenticationMethod::None => vec![]
}
}
pub fn close(&mut self) -> Result<bool, String> {
let sql = vec![Command::COM_QUIT];
self.packet.sequence_id = 0;
return match self.packet.pack_eof(sql) {
Ok(e) => Ok(e > 0),
Err(e) => Err(format!("失败: {e}"))
};
}
pub fn status(&mut self) -> Result<bool, String> {
let sql = vec![Command::COM_STATISTICS];
self.packet.sequence_id = 0;
return match self.packet.pack_eof(sql) {
Ok(_) => {
let res = self.packet.connection_stage_read()?;
let res = Response::new(res.clone(), self.client.capability_flags)?;
Ok(res.error_code == 0)
}
Err(e) => Err(format!("失败: {e}"))
};
}
pub fn debug(&mut self) -> Result<bool, String> {
let sql = vec![Command::COM_DEBUG];
self.packet.sequence_id = 0;
return match self.packet.pack_eof(sql) {
Ok(_) => {
let res = self.packet.connection_stage_read()?;
let res = Response::new(res.clone(), self.client.capability_flags)?;
Ok(res.error_code == 0)
}
Err(e) => Err(format!("失败: {e}"))
};
}
pub fn db(&mut self, name: &str) -> Result<bool, String> {
let mut sql = vec![Command::COM_INIT_DB];
sql.extend(name.as_bytes().to_vec());
self.packet.sequence_id = 0;
return match self.packet.pack_eof(sql) {
Ok(_) => {
let res = self.packet.connection_stage_read()?;
let res = Response::new(res.clone(), self.client.capability_flags)?;
Ok(res.error_code == 0)
}
Err(e) => Err(format!("失败: {e}"))
};
}
pub fn query(&mut self, sql: &str) -> Result<JsonValue, String> {
if self.config.debug {
info!("query: {}",sql);
}
let mut sql_data = vec![Command::COM_QUERY];
sql_data.extend(sql.as_bytes().to_vec());
self.packet.sequence_id = 0;
return match self.packet.pack_eof(sql_data) {
Ok(_) => {
if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
let res = self.packet.command_stage_read()?;
if self.config.debug {
info!("返回数据: {:?}",res);
}
let res = Response::column_definition41(res.clone(), self.config.clone(), self.client.capability_flags)?;
Ok(res)
} else {
Ok(object! {})
}
}
Err(e) => Err(format!("失败 {e}"))
};
}
pub fn execute(&mut self, sql: &str) -> Result<Response, String> {
if self.config.debug {
info!("execute: {}",sql);
}
let mut sql_data = vec![Command::COM_STMT_PREPARE];
sql_data.extend(sql.as_bytes().to_vec());
self.packet.sequence_id = 0;
let response = match self.packet.pack_eof(sql_data) {
Ok(_) => {
let res = self.packet.connection_stage_read()?;
if self.config.debug {
info!("返回: {:?}",res);
}
let response = Response::execute(res.clone(), self.client.capability_flags)?;
response
}
Err(e) => return Err(format!("失败 {e}"))
};
let mut sql_data = vec![Command::COM_STMT_EXECUTE];
sql_data.extend(response.statement_id.to_le_bytes());
sql_data.push(0);
sql_data.extend(1_i32.to_le_bytes());
self.packet.sequence_id = 0;
return match self.packet.pack_eof(sql_data) {
Ok(_) => {
let res = self.packet.connection_stage_read()?;
if self.config.debug {
info!("返回: {:?}",res);
}
let response = Response::new(res.clone(), self.client.capability_flags)?;
Ok(response)
}
Err(e) => Err(format!("失败 {e}"))
};
}
pub fn transaction(&mut self) -> Result<bool, String> {
if self.transaction.len() > 0 {
return self.savepoint();
}
let mut sql_data = vec![Command::COM_QUERY];
let sql = format!("START TRANSACTION");
if self.config.debug {
info!("transaction: {}",sql);
}
sql_data.extend(sql.as_bytes().to_vec());
self.packet.sequence_id = 0;
return match self.packet.pack_eof(sql_data) {
Ok(_) => {
if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
let res = self.packet.connection_stage_read()?;
if self.config.debug {
info!("回执: {:?}",res);
}
let response = Response::new(res.clone(), self.client.capability_flags)?;
if response.error_code == 0 {
return self.savepoint();
}
Ok(response.error_code == 0)
} else {
Ok(false)
}
}
Err(e) => Err(format!("失败 {e}"))
};
}
pub fn set_transaction_level(&mut self, name: u8) -> Result<bool, String> {
let mut sql_data = vec![Command::COM_QUERY];
let sql = match name {
0 => format!("REPEATABLE READ"),1 => format!("SERIALIZABLE"), 2 => format!("READ COMMITTED"),3 => format!("READ UNCOMMITTED"), _ => format!("REPEATABLE READ") };
let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", sql);
if self.config.debug {
info!("set_transaction_level: {}",sql);
}
sql_data.extend(sql.as_bytes().to_vec());
self.packet.sequence_id = 0;
return match self.packet.pack_eof(sql_data) {
Ok(_) => {
if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
let res = self.packet.connection_stage_read()?;
if self.config.debug {
info!("回执: {:?}",res);
}
let response = Response::new(res.clone(), self.client.capability_flags)?;
Ok(response.error_code == 0)
} else {
Ok(false)
}
}
Err(e) => Err(format!("失败 {e}"))
};
}
fn savepoint(&mut self) -> Result<bool, String> {
let mut sql_data = vec![Command::COM_QUERY];
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_nanos();
self.transaction.push(format!("{:x}", timestamp));
let sql = format!("SAVEPOINT {}", format!("{:x}", timestamp));
if self.config.debug {
info!("savepoint: {}",sql);
}
sql_data.extend(sql.as_bytes().to_vec());
self.packet.sequence_id = 0;
return match self.packet.pack_eof(sql_data) {
Ok(_) => {
if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
let res = self.packet.connection_stage_read()?;
if self.config.debug {
info!("回执: {:?}",res);
}
let response = Response::new(res.clone(), self.client.capability_flags)?;
Ok(response.error_code == 0)
} else {
Ok(false)
}
}
Err(e) => Err(format!("失败 {e}"))
};
}
pub fn commit(&mut self) -> Result<bool, String> {
let mut sql_data = vec![Command::COM_QUERY];
let sql = format!("COMMIT");
if self.config.debug {
info!("COMMIT: {}",sql);
}
sql_data.extend(sql.as_bytes().to_vec());
self.packet.sequence_id = 0;
return match self.packet.pack_eof(sql_data) {
Ok(_) => {
if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
let res = self.packet.connection_stage_read()?;
if self.config.debug {
info!("回执: {:?}",res);
}
let response = Response::new(res.clone(), self.client.capability_flags)?;
Ok(response.error_code == 0)
} else {
Ok(false)
}
}
Err(e) => Err(format!("失败 {e}"))
};
}
pub fn rollback(&mut self) -> Result<bool, String> {
let mut sql_data = vec![Command::COM_QUERY];
let mut sql = format!("ROLLBACK");
if self.transaction.len() > 0 {
sql = format!("{} TO SAVEPOINT {}", sql, self.transaction.last().unwrap());
self.transaction.remove(self.transaction.len() - 1);
}
if self.config.debug {
info!("rollback: {}",sql);
}
sql_data.extend(sql.as_bytes().to_vec());
self.packet.sequence_id = 0;
return match self.packet.pack_eof(sql_data) {
Ok(_) => {
if CapabilityFlags::CLIENT_PROTOCOL_41 & self.client.capability_flags > 0 {
let res = self.packet.connection_stage_read()?;
if self.config.debug {
info!("回执: {:?}",res);
}
let response = Response::new(res.clone(), self.client.capability_flags)?;
Ok(response.error_code == 0)
} else {
Ok(false)
}
}
Err(e) => Err(format!("失败 {e}"))
};
}
pub fn ping(&mut self) -> Result<bool, String> {
let sql = vec![Command::COM_PING];
self.packet.sequence_id = 0;
return match self.packet.pack_eof(sql) {
Ok(_) => {
let res = self.packet.connection_stage_read()?;
Response::new(res.clone(), self.client.capability_flags)?;
Ok(true)
}
Err(e) => Err(format!("ping失败 {e}"))
};
}
}