1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::TcpStream;
6use std::sync::Arc;
7use std::time::Duration;
8
9#[derive(Clone, Debug)]
10pub struct Connect {
11 pub(crate) stream: Arc<TcpStream>,
12 packet: Packet,
13 auth_status: AuthStatus,
14}
15
16impl Connect {
17 pub fn is_valid(&mut self) -> bool {
18 self.query("SELECT 1").is_ok()
19 }
20
21 pub fn _close(&mut self) {
22 let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
23 let _ = self.stream.shutdown(std::net::Shutdown::Both);
24 }
25
26 pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
27 let stream =
28 TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
29
30 stream
31 .set_read_timeout(Some(Duration::from_secs(30)))
32 .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
33 stream
34 .set_write_timeout(Some(Duration::from_secs(30)))
35 .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
36
37 let _ = stream.peer_addr();
38
39 let mut connect = Self {
40 stream: Arc::new(stream),
41 packet: Packet::new(config),
42 auth_status: AuthStatus::None,
43 };
44
45 connect.authenticate()?;
46
47 Ok(connect)
48 }
49
50 fn authenticate(&mut self) -> Result<(), PgsqlError> {
51 self.stream
52 .as_ref()
53 .write_all(&self.packet.pack_first())
54 .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
55
56 let data = self.read()?;
57 self.packet.unpack(data, 0)?;
58
59 if !self.packet.md5_salt.is_empty() {
60 self.md5_auth()?;
61 } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
62 self.cleartext_auth()?;
63 } else {
64 self.scram_auth()?;
65 }
66
67 self.auth_status = AuthStatus::AuthenticationOk;
68 Ok(())
69 }
70
71 fn md5_auth(&mut self) -> Result<(), PgsqlError> {
72 self.stream
73 .as_ref()
74 .write_all(&self.packet.pack_md5_password())
75 .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
76
77 let data = self.read()?;
78 self.packet.unpack(data, 0)?;
79 Ok(())
80 }
81
82 fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
83 self.stream
84 .as_ref()
85 .write_all(&self.packet.pack_cleartext_password())
86 .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
87
88 let data = self.read()?;
89 self.packet.unpack(data, 0)?;
90 Ok(())
91 }
92
93 fn scram_auth(&mut self) -> Result<(), PgsqlError> {
94 self.stream
95 .as_ref()
96 .write_all(&self.packet.pack_auth())
97 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
98
99 let data = self.read()?;
100 self.packet.unpack(data, 0)?;
101
102 self.stream
103 .as_ref()
104 .write_all(&self.packet.pack_auth_verify())
105 .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
106
107 let data = self.read()?;
108 self.packet.unpack(data, 0)?;
109 Ok(())
110 }
111
112 fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
113 let mut msg = Vec::new();
114 let mut buf = [0u8; 4096];
115 let mut retry_count = 0;
116 const MAX_RETRIES: u32 = 100;
117
118 loop {
119 match self.stream.as_ref().read(&mut buf) {
120 Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
121 Ok(n) => {
122 msg.extend_from_slice(&buf[..n]);
123 retry_count = 0;
124 }
125 Err(ref e)
126 if e.kind() == std::io::ErrorKind::WouldBlock
127 || e.kind() == std::io::ErrorKind::TimedOut =>
128 {
129 retry_count += 1;
130 if retry_count > MAX_RETRIES {
131 return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
132 }
133 std::thread::sleep(Duration::from_millis(10));
134 continue;
135 }
136 Err(e) => return Err(PgsqlError::Io(e)),
137 };
138
139 if let AuthStatus::AuthenticationOk = self.auth_status {
140 if msg.ends_with(&[90, 0, 0, 0, 5, 73])
141 || msg.ends_with(&[90, 0, 0, 0, 5, 84])
142 || msg.ends_with(&[90, 0, 0, 0, 5, 69])
143 {
144 break;
145 }
146 } else if msg.len() >= 5 {
147 let len_bytes = &msg[1..=4];
148 if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
149 if msg.len() >= len as usize {
150 break;
151 }
152 }
153 }
154 }
155
156 Ok(msg)
157 }
158
159 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
160 self.stream
161 .as_ref()
162 .write_all(&self.packet.pack_query(sql))
163 .map_err(PgsqlError::Io)?;
164
165 let data = self.read()?;
166
167 self.packet.unpack(data, 0)
168 }
169
170 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
171 self.stream
172 .as_ref()
173 .write_all(&self.packet.pack_execute(sql))
174 .map_err(PgsqlError::Io)?;
175 let data = self.read()?;
176 self.packet.unpack(data, 0)
177 }
178}
179
180impl Drop for Connect {
181 fn drop(&mut self) {
182 let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
183 let _ = self.stream.shutdown(std::net::Shutdown::Both);
184 }
185}