br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use log::warn;
4use std::io::{Read, Write};
5use std::net::TcpStream;
6use std::sync::Arc;
7use std::time::Duration;
8
9#[derive(Clone, Debug)]
10pub struct Connect {
11    /// 基础配置
12    pub(crate) stream: Arc<TcpStream>,
13    packet: Packet,
14    /// 认证状态
15    auth_status: AuthStatus,
16}
17
18impl Connect {
19    // 判断连接是否正常(完整检查,包括查询)
20    pub fn is_valid(&mut self) -> bool {
21        // 先检查 TCP 连接状态
22        if self.stream.peer_addr().is_err() {
23            return false;
24        }
25        // 发送一个简单心跳,使用更轻量的查询
26        self.query("SELECT 1").is_ok()
27    }
28
29    /// 快速检查连接是否可用(只检查 TCP 状态,不执行查询)
30    /// 用于连接池,避免执行查询的开销
31    pub fn is_quick_valid(&self) -> bool {
32        // 只检查 TCP 连接状态,避免执行查询的开销
33        self.stream.peer_addr().is_ok()
34    }
35
36    // 关闭程序和pgsql的连接
37    pub fn _close(&mut self) {
38        let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
39        let _ = self.stream.shutdown(std::net::Shutdown::Both);
40    }
41
42    pub fn new(config: Config) -> Result<Connect, String> {
43        let stream = match TcpStream::connect(config.url()) {
44            Ok(stream) => stream,
45            Err(e) => return Err(format!("数据库连接失败: {}", e)),
46        };
47        stream
48            .set_read_timeout(Some(Duration::from_secs(5)))
49            .map_err(|e| format!("设置读取超时失败: {}", e))?;
50        stream
51            .set_write_timeout(Some(Duration::from_secs(5)))
52            .map_err(|e| format!("设置写入超时失败: {}", e))?;
53
54        match stream.peer_addr() {
55            Ok(_) => {}
56            Err(e) => warn!("无法获取对端地址: {}", e),
57        }
58
59        let mut connect = Self {
60            stream: Arc::new(stream),
61            packet: Packet::new(config),
62            auth_status: AuthStatus::None,
63        };
64
65        connect.startup_message()?;
66        connect.sasl_initial_response_message()?;
67
68        Ok(connect)
69    }
70
71    fn read(&mut self) -> Result<Vec<u8>, String> {
72        const BUFFER_SIZE: usize = 4096; // 增大缓冲区以提高性能
73        const SLEEP_DURATION: Duration = Duration::from_millis(5);
74        
75        let mut msg = Vec::new();
76        let mut buf = vec![0u8; BUFFER_SIZE];
77
78        loop {
79            match self.stream.as_ref().read(&mut buf) {
80                Ok(0) => return Err("连接已关闭或服务端断开".into()),
81                Ok(n) => {
82                    msg.extend_from_slice(&buf[..n]);
83                    // 如果读取的数据少于缓冲区大小,可能已经读取完毕
84                    if n < BUFFER_SIZE {
85                        // 检查是否已经读取完整消息
86                        if msg.len() >= 5 {
87                            let len_bytes = &msg[1..=4];
88                            if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
89                                if msg.len() >= len as usize {
90                                    break;
91                                }
92                            }
93                        }
94                    }
95                }
96                Err(ref e)
97                    if e.kind() == std::io::ErrorKind::WouldBlock
98                        || e.kind() == std::io::ErrorKind::TimedOut =>
99                {
100                    // 避免空循环占用 CPU
101                    std::thread::sleep(SLEEP_DURATION);
102                    continue;
103                }
104                Err(e) => return Err(format!("读取失败: {}", e)),
105            };
106
107            if let AuthStatus::AuthenticationOk = self.auth_status {
108                if msg.ends_with(&[90, 0, 0, 0, 5, 73])
109                    || msg.ends_with(&[90, 0, 0, 0, 5, 84])
110                    || msg.ends_with(&[90, 0, 0, 0, 5, 69])
111                {
112                    break;
113                }
114            } else if msg.len() >= 5 {
115                // 取长度字节
116                let len_bytes = &msg[1..=4];
117                if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
118                    if msg.len() >= len as usize {
119                        break;
120                    }
121                }
122            }
123        }
124
125        Ok(msg)
126    }
127    /// Startup Message
128    fn startup_message(&mut self) -> Result<(), String> {
129        // try_clone -> as_ref Arc<TcpStream> 是安全的
130        self.stream
131            .as_ref()
132            .write_all(&self.packet.pack_first())
133            .map_err(|e| format!("发送 startup message 失败: {}", e))?;
134        let data = self.read()?;
135        self.packet.unpack(data, 0)?;
136        Ok(())
137    }
138    /// `SASLInitialResponse` message
139    fn sasl_initial_response_message(&mut self) -> Result<(), String> {
140        self.stream
141            .as_ref()
142            .write_all(&self.packet.pack_auth())
143            .map_err(|e| format!("发送 SASL Initial Response 失败: {}", e))?;
144
145        let data = self.read()?;
146        self.packet.unpack(data, 0)?;
147
148        self.stream
149            .as_ref()
150            .write_all(&self.packet.pack_auth_verify())
151            .map_err(|e| format!("发送 SASL Verify 失败: {}", e))?;
152
153        let data = self.read()?;
154        self.packet.unpack(data, 0)?;
155        self.auth_status = AuthStatus::AuthenticationOk;
156        Ok(())
157    }
158    /// 查询
159    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
160        // self.stream.try_clone().unwrap().write_all(&self.packet.pack_query(sql)).unwrap();
161        // let data = self.read()?;
162        // let mut packet = self.packet.clone();
163
164        // 直接使用原始stream,避免try_clone
165        self.stream
166            .as_ref()
167            .write_all(&self.packet.pack_query(sql))
168            .map_err(|e| format!("query error: {}", e))?;
169
170        let data = self
171            .read()
172            .map_err(|e| format!("query read error: {}", e))?;
173
174        self.packet.unpack(data, 0)
175
176        // 唯一的区别是 packet.unpack 运行在一个新栈上(你显式设置了 8 MB stack size)。可能是为了防止深递归或大数据解析时栈溢出。
177        // 如果你的 packet.unpack 是字节流解析(状态机 + 循环),那 2 MB 绰绰有余,不用担心。
178        // 如果它内部藏了大数组或递归,那才需要考虑开大栈。
179        // std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
180        //     packet.unpack(data, 0)
181        // }).unwrap().join().unwrap()
182    }
183    /// 执行
184    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
185        // self.stream.try_clone().unwrap().write_all(&self.packet.pack_execute(sql)).unwrap();
186        // let data = self.read()?;
187        // let mut packet = self.packet.clone();
188        //
189        // std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
190        //     packet.unpack(data, 0)
191        // }).unwrap().join().unwrap()
192
193        self.stream
194            .as_ref()
195            .write_all(&self.packet.pack_execute(sql))
196            .map_err(|e| format!("execute error: {}", e))?;
197        let data = self.read()?;
198        self.packet.unpack(data, 0)
199    }
200}
201impl Drop for Connect {
202    fn drop(&mut self) {
203        // 如果链接关闭的情况下,unwrap会失败,所以此处用match处理
204        // 只做资源清理,不处理连接池归还
205        match self.stream.peer_addr() {
206            Ok(_addr) => {
207                // 连接关闭(已移除正常日志)
208            },
209            Err(e) => warn!("[DROP] 获取 peer_addr 失败: {e}"),
210        }
211
212        // 尝试发送 Terminate 包,不强制 unwrap
213        if let Err(e) = self.stream.as_ref().write_all(&Packet::pack_terminate()) {
214            warn!("[DROP] 发送 Terminate 包失败: {e}");
215        }
216
217        if let Err(e) = self.stream.shutdown(std::net::Shutdown::Both) {
218            warn!("[DROP] shutdown 失败: {e}");
219        }
220    }
221
222    // fn drop(&mut self) {
223    //     println!("[DROP] 丢弃连接: {}", self.id); // 调试日志
224    //
225    //     let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
226    //     let _ = self.stream.shutdown(std::net::Shutdown::Both);
227    //
228    //     // // 显式发送 PostgreSQL 协议 Terminate 包,告诉服务端“我要断开了”
229    //     // if let Err(e) = self.stream.as_ref().write_all(&Packet::pack_terminate()) {
230    //     //     println!("[DROP] 发送 Terminate 出错: {}", e);
231    //     // }
232    //     //
233    //     // if self.is_valid() {
234    //     //     match crate::pools::DB_POOL.lock() {
235    //     //         Ok(mut pool) => {
236    //     //             if pool.len() <= 20 {
237    //     //                 pool.push(self.clone());
238    //     //                 println!("连接成功归还");
239    //     //             } else {
240    //     //                 println!("连接池已满,丢弃连接");
241    //     //             }
242    //     //         }
243    //     //         Err(e) => println!("获取连接池锁失败: {}", e),
244    //     //     }
245    //     // } else {
246    //     //     println!("丢弃无效连接");
247    //     // }
248    // }
249}