1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use std::io::{Read, Write};
4use std::net::TcpStream;
5use std::sync::Arc;
6use std::time::Duration;
7use log::{info, warn};
8
9#[derive(Clone, Debug)]
10pub struct Connect {
11 pub(crate) stream: Arc<TcpStream>,
13 packet: Packet,
14 auth_status: AuthStatus,
16}
17
18impl Connect {
19
20 pub fn is_valid(&mut self) -> bool {
22 self.query("SELECT 1").is_ok()
24 }
25
26 pub fn _close(&mut self) {
28 let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
29 let _ = self.stream.shutdown(std::net::Shutdown::Both);
30 }
31
32 pub fn new(mut config: Config) -> Result<Connect, String> {
33 let stream = match TcpStream::connect(config.url()) {
34 Ok(stream) => stream,
35 Err(e) => return Err(format!("数据库连接失败: {}", e)),
36 };
37 stream.set_read_timeout(Some(Duration::from_secs(5)))
38 .map_err(|e| format!("设置读取超时失败: {}", e))?;
39 stream.set_write_timeout(Some(Duration::from_secs(5)))
40 .map_err(|e| format!("设置写入超时失败: {}", e))?;
41
42 match stream.peer_addr() {
43 Ok(addr) => info!("[NEW] 创建连接: {}", addr),
44 Err(e) => warn!("无法获取对端地址: {}", e),
45 }
46
47 let mut connect = Self {
48 stream: Arc::new(stream),
49 packet: Packet::new(config),
50 auth_status: AuthStatus::None,
51 };
52
53 connect.startup_message()?;
54 connect.sasl_initial_response_message()?;
55
56 Ok(connect)
57 }
58
59 fn read(&mut self) -> Result<Vec<u8>, String> {
60 let mut msg = Vec::new();
61 let mut buf = [0u8; 1024];
62
63 loop {
64 match self.stream.as_ref().read(&mut buf) {
65 Ok(0) => return Err("连接已关闭或服务端断开".into()),
66 Ok(n) => msg.extend_from_slice(&buf[..n]),
67 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock
68 || e.kind() == std::io::ErrorKind::TimedOut => {
69 std::thread::sleep(Duration::from_millis(5));
71 continue;
72 }
73 Err(e) => return Err(format!("读取失败: {}", e)),
74 };
75
76 if let AuthStatus::AuthenticationOk = self.auth_status {
77 if msg.ends_with(&[90, 0, 0, 0, 5, 73])
78 || msg.ends_with(&[90, 0, 0, 0, 5, 84])
79 || msg.ends_with(&[90, 0, 0, 0, 5, 69])
80 {
81 break;
82 }
83 } else if msg.len() >= 5 {
84 let len_bytes = &msg[1..=4];
86 if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
87 if msg.len() >= len as usize { break; }
88 }
89 }
90 }
91
92 Ok(msg)
93 }
94 fn startup_message(&mut self) -> Result<(), String> {
96 self.stream.as_ref().write_all(&self.packet.pack_first())
98 .map_err(|e| format!("发送 startup message 失败: {}", e))?;
99 let data = self.read()?;
100 self.packet.unpack(data, 0)?;
101 Ok(())
102 }
103 fn sasl_initial_response_message(&mut self) -> Result<(), String> {
105 self.stream.as_ref().write_all(&self.packet.pack_auth())
106 .map_err(|e| format!("发送 SASL Initial Response 失败: {}", e))?;
107
108 let data = self.read()?;
109 self.packet.unpack(data, 0)?;
110
111 self.stream.as_ref().write_all(&self.packet.pack_auth_verify())
112 .map_err(|e| format!("发送 SASL Verify 失败: {}", e))?;
113
114 let data = self.read()?;
115 self.packet.unpack(data, 0)?;
116 self.auth_status = AuthStatus::AuthenticationOk;
117 Ok(())
118 }
119 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
121 self.stream.as_ref()
127 .write_all(&self.packet.pack_query(sql))
128 .map_err(|e| format!("query error: {}", e))?;
129
130 let data = self.read()
131 .map_err(|e| format!("query read error: {}", e))?;
132
133 self.packet.unpack(data, 0)
134
135 }
142 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
144 self.stream.as_ref()
153 .write_all(&self.packet.pack_execute(sql))
154 .map_err(|e| format!("execute error: {}", e))?;
155 let data = self.read()?;
156 self.packet.unpack(data, 0)
157 }
158}
159impl Drop for Connect {
160
161 fn drop(&mut self) {
162 match self.stream.peer_addr() {
165 Ok(addr) => info!("[DROP] 关闭连接: {}", addr),
166 Err(e) => warn!("[DROP] 获取 peer_addr 失败: {e}"),
167 }
168
169 if let Err(e) = self.stream.as_ref().write_all(&Packet::pack_terminate()) {
171 warn!("[DROP] 发送 Terminate 包失败: {e}");
172 }
173
174 if let Err(e) = self.stream.shutdown(std::net::Shutdown::Both) {
175 warn!("[DROP] shutdown 失败: {e}");
176 }
177 }
178
179 }