1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use log::warn;
4use std::io::{Read, Write};
5use std::net::TcpStream;
6use std::sync::Arc;
7use std::time::Duration;
8
9#[derive(Clone, Debug)]
10pub struct Connect {
11 pub(crate) stream: Arc<TcpStream>,
13 packet: Packet,
14 auth_status: AuthStatus,
16}
17
18impl Connect {
19 pub fn is_valid(&mut self) -> bool {
21 if self.stream.peer_addr().is_err() {
23 return false;
24 }
25 self.query("SELECT 1").is_ok()
27 }
28
29 pub fn is_quick_valid(&self) -> bool {
32 self.stream.peer_addr().is_ok()
34 }
35
36 pub fn _close(&mut self) {
38 let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
39 let _ = self.stream.shutdown(std::net::Shutdown::Both);
40 }
41
42 pub fn new(config: Config) -> Result<Connect, String> {
43 let stream = match TcpStream::connect(config.url()) {
44 Ok(stream) => stream,
45 Err(e) => return Err(format!("数据库连接失败: {}", e)),
46 };
47 stream
48 .set_read_timeout(Some(Duration::from_secs(5)))
49 .map_err(|e| format!("设置读取超时失败: {}", e))?;
50 stream
51 .set_write_timeout(Some(Duration::from_secs(5)))
52 .map_err(|e| format!("设置写入超时失败: {}", e))?;
53
54 match stream.peer_addr() {
55 Ok(_) => {}
56 Err(e) => warn!("无法获取对端地址: {}", e),
57 }
58
59 let mut connect = Self {
60 stream: Arc::new(stream),
61 packet: Packet::new(config),
62 auth_status: AuthStatus::None,
63 };
64
65 connect.startup_message()?;
66 connect.sasl_initial_response_message()?;
67
68 Ok(connect)
69 }
70
71 fn read(&mut self) -> Result<Vec<u8>, String> {
72 const BUFFER_SIZE: usize = 4096; const SLEEP_DURATION: Duration = Duration::from_millis(5);
74
75 let mut msg = Vec::new();
76 let mut buf = vec![0u8; BUFFER_SIZE];
77
78 loop {
79 match self.stream.as_ref().read(&mut buf) {
80 Ok(0) => return Err("连接已关闭或服务端断开".into()),
81 Ok(n) => {
82 msg.extend_from_slice(&buf[..n]);
83 if n < BUFFER_SIZE {
85 if msg.len() >= 5 {
87 let len_bytes = &msg[1..=4];
88 if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
89 if msg.len() >= len as usize {
90 break;
91 }
92 }
93 }
94 }
95 }
96 Err(ref e)
97 if e.kind() == std::io::ErrorKind::WouldBlock
98 || e.kind() == std::io::ErrorKind::TimedOut =>
99 {
100 std::thread::sleep(SLEEP_DURATION);
102 continue;
103 }
104 Err(e) => return Err(format!("读取失败: {}", e)),
105 };
106
107 if let AuthStatus::AuthenticationOk = self.auth_status {
108 if msg.ends_with(&[90, 0, 0, 0, 5, 73])
109 || msg.ends_with(&[90, 0, 0, 0, 5, 84])
110 || msg.ends_with(&[90, 0, 0, 0, 5, 69])
111 {
112 break;
113 }
114 } else if msg.len() >= 5 {
115 let len_bytes = &msg[1..=4];
117 if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
118 if msg.len() >= len as usize {
119 break;
120 }
121 }
122 }
123 }
124
125 Ok(msg)
126 }
127 fn startup_message(&mut self) -> Result<(), String> {
129 self.stream
131 .as_ref()
132 .write_all(&self.packet.pack_first())
133 .map_err(|e| format!("发送 startup message 失败: {}", e))?;
134 let data = self.read()?;
135 self.packet.unpack(data, 0)?;
136 Ok(())
137 }
138 fn sasl_initial_response_message(&mut self) -> Result<(), String> {
140 self.stream
141 .as_ref()
142 .write_all(&self.packet.pack_auth())
143 .map_err(|e| format!("发送 SASL Initial Response 失败: {}", e))?;
144
145 let data = self.read()?;
146 self.packet.unpack(data, 0)?;
147
148 self.stream
149 .as_ref()
150 .write_all(&self.packet.pack_auth_verify())
151 .map_err(|e| format!("发送 SASL Verify 失败: {}", e))?;
152
153 let data = self.read()?;
154 self.packet.unpack(data, 0)?;
155 self.auth_status = AuthStatus::AuthenticationOk;
156 Ok(())
157 }
158 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
160 self.stream
166 .as_ref()
167 .write_all(&self.packet.pack_query(sql))
168 .map_err(|e| format!("query error: {}", e))?;
169
170 let data = self
171 .read()
172 .map_err(|e| format!("query read error: {}", e))?;
173
174 self.packet.unpack(data, 0)
175
176 }
183 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
185 self.stream
194 .as_ref()
195 .write_all(&self.packet.pack_execute(sql))
196 .map_err(|e| format!("execute error: {}", e))?;
197 let data = self.read()?;
198 self.packet.unpack(data, 0)
199 }
200}
201impl Drop for Connect {
202 fn drop(&mut self) {
203 match self.stream.peer_addr() {
206 Ok(_addr) => {
207 },
209 Err(e) => warn!("[DROP] 获取 peer_addr 失败: {e}"),
210 }
211
212 if let Err(e) = self.stream.as_ref().write_all(&Packet::pack_terminate()) {
214 warn!("[DROP] 发送 Terminate 包失败: {e}");
215 }
216
217 if let Err(e) = self.stream.shutdown(std::net::Shutdown::Both) {
218 warn!("[DROP] shutdown 失败: {e}");
219 }
220 }
221
222 }