1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use std::io::{Read, Write};
4use std::net::TcpStream;
5use std::sync::Arc;
6use std::time::Duration;
7use log::info;
8
9#[derive(Clone, Debug)]
10pub struct Connect {
11 pub(crate) stream: Arc<TcpStream>,
13 packet: Packet,
14 auth_status: AuthStatus,
16}
17
18impl Connect {
19
20 pub fn is_valid(&mut self) -> bool {
22 self.query("SELECT 1").is_ok()
24 }
25
26 pub fn _close(&mut self) {
28 let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
29 let _ = self.stream.shutdown(std::net::Shutdown::Both);
30 }
31
32 pub fn new(mut config: Config) -> Result<Connect, String> {
33 let stream = match TcpStream::connect(config.url()) {
34 Ok(stream) => stream,
35 Err(e) => return Err(format!("数据库连接失败: {}", e)),
36 };
37 stream.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
38 stream.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
39
40 info!("[NEW] 创建连接: {}", stream.peer_addr().unwrap()); let mut connect = Self {
43 stream: Arc::new(stream),
44 packet: Packet::new(config),
45 auth_status: AuthStatus::None,
46 };
47
48 connect.startup_message()?;
49 connect.sasl_initial_response_message()?;
50
51 Ok(connect)
52 }
53
54 fn read(&mut self) -> Result<Vec<u8>, String> {
55 let mut msg = Vec::new();
56 let mut buf = [0u8; 1024];
57
58 loop {
59 let n = match self.stream.as_ref().read(&mut buf) {
60 Ok(0) => return Err("连接已关闭或服务端断开".into()),
61 Ok(n) => msg.extend_from_slice(&buf[..n]),
62 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock
63 || e.kind() == std::io::ErrorKind::TimedOut => {
64 std::thread::sleep(std::time::Duration::from_millis(5));
66 continue;
67 }
68 Err(e) => return Err(format!("读取失败: {}", e)),
69 };
70
71 if let AuthStatus::AuthenticationOk = self.auth_status {
72 if msg.ends_with(&[90, 0, 0, 0, 5, 73])
73 || msg.ends_with(&[90, 0, 0, 0, 5, 84])
74 || msg.ends_with(&[90, 0, 0, 0, 5, 69])
75 {
76 break;
77 }
78 } else if msg.len() >= 5 {
79 let len_bytes = &msg[1..=4];
81 if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
82 if msg.len() >= len as usize { break; }
83 }
84 }
85 }
86
87 Ok(msg)
88 }
89 fn startup_message(&mut self) -> Result<(), String> {
91 self.stream.try_clone().unwrap().write_all(&self.packet.pack_first()).unwrap();
92 let data = self.read()?;
93 self.packet.unpack(data, 0)?;
94 Ok(())
95 }
96 fn sasl_initial_response_message(&mut self) -> Result<(), String> {
98 self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth()).unwrap();
99 let data = self.read()?;
100 self.packet.unpack(data, 0)?;
101 self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth_verify()).unwrap();
102 let data = self.read()?;
103 self.packet.unpack(data, 0)?;
104 self.auth_status = AuthStatus::AuthenticationOk;
105 Ok(())
106 }
107 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
109 self.stream.as_ref()
115 .write_all(&self.packet.pack_query(sql))
116 .map_err(|e| format!("query error: {}", e))?;
117
118 let data = self.read()
119 .map_err(|e| format!("query read error: {}", e))?;
120
121 self.packet.unpack(data, 0)
122
123 }
127 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
129 self.stream.as_ref()
138 .write_all(&self.packet.pack_execute(sql))
139 .map_err(|e| format!("execute error: {}", e))?;
140 let data = self.read()?;
141 self.packet.unpack(data, 0)
142 }
143}
144impl Drop for Connect {
145
146 fn drop(&mut self) {
147 info!("[DROP] 关闭连接: {}", self.stream.peer_addr().unwrap());
149
150 let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
152 let _ = self.stream.shutdown(std::net::Shutdown::Both);
153 }
154
155 }