br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use std::io::{Read, Write};
4use std::net::TcpStream;
5use std::sync::Arc;
6use std::time::Duration;
7use log::info;
8
9#[derive(Clone, Debug)]
10pub struct Connect {
11    /// 基础配置
12    pub(crate) stream: Arc<TcpStream>,
13    packet: Packet,
14    /// 认证状态
15    auth_status: AuthStatus,
16}
17
18impl Connect {
19    
20    // 判断连接是否正常
21    pub fn is_valid(&mut self) -> bool {
22        // 发送一个简单心跳
23        self.query("SELECT 1").is_ok()
24    }
25    
26    // 关闭程序和pgsql的连接
27    pub fn _close(&mut self) {
28        let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
29        let _ = self.stream.shutdown(std::net::Shutdown::Both);
30    }
31
32    pub fn new(mut config: Config) -> Result<Connect, String> {
33        let stream = match TcpStream::connect(config.url()) {
34            Ok(stream) => stream,
35            Err(e) => return Err(format!("数据库连接失败: {}", e)),
36        };
37        stream.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
38        stream.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
39
40        info!("[NEW] 创建连接: {}", stream.peer_addr().unwrap()); // 调试日志
41        
42        let mut connect = Self {
43            stream: Arc::new(stream),
44            packet: Packet::new(config),
45            auth_status: AuthStatus::None,
46        };
47
48        connect.startup_message()?;
49        connect.sasl_initial_response_message()?;
50
51        Ok(connect)
52    }
53
54    fn read(&mut self) -> Result<Vec<u8>, String> {
55        let mut msg = Vec::new();
56        let mut buf = [0u8; 1024];
57
58        loop {
59            let n = match self.stream.as_ref().read(&mut buf) {
60                Ok(0) => return Err("连接已关闭或服务端断开".into()),
61                Ok(n) => msg.extend_from_slice(&buf[..n]),
62                Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock
63                    || e.kind() == std::io::ErrorKind::TimedOut => {
64                    // 避免空循环占用 CPU
65                    std::thread::sleep(std::time::Duration::from_millis(5));
66                    continue;
67                }
68                Err(e) => return Err(format!("读取失败: {}", e)),
69            };
70
71            if let AuthStatus::AuthenticationOk = self.auth_status {
72                if msg.ends_with(&[90, 0, 0, 0, 5, 73])
73                    || msg.ends_with(&[90, 0, 0, 0, 5, 84])
74                    || msg.ends_with(&[90, 0, 0, 0, 5, 69])
75                {
76                    break;
77                }
78            } else if msg.len() >= 5 {
79                // 取长度字节
80                let len_bytes = &msg[1..=4];
81                if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
82                    if msg.len() >= len as usize { break; }
83                }
84            }
85        }
86
87        Ok(msg)
88    }
89    /// Startup Message
90    fn startup_message(&mut self) -> Result<(), String> {
91        self.stream.try_clone().unwrap().write_all(&self.packet.pack_first()).unwrap();
92        let data = self.read()?;
93        self.packet.unpack(data, 0)?;
94        Ok(())
95    }
96    /// `SASLInitialResponse` message
97    fn sasl_initial_response_message(&mut self) -> Result<(), String> {
98        self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth()).unwrap();
99        let data = self.read()?;
100        self.packet.unpack(data, 0)?;
101        self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth_verify()).unwrap();
102        let data = self.read()?;
103        self.packet.unpack(data, 0)?;
104        self.auth_status = AuthStatus::AuthenticationOk;
105        Ok(())
106    }
107    /// 查询
108    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
109        // self.stream.try_clone().unwrap().write_all(&self.packet.pack_query(sql)).unwrap();
110        // let data = self.read()?;
111        // let mut packet = self.packet.clone();
112
113        // 直接使用原始stream,避免try_clone
114        self.stream.as_ref()
115            .write_all(&self.packet.pack_query(sql))
116            .map_err(|e| format!("query error: {}", e))?;
117
118        let data = self.read()
119            .map_err(|e| format!("query read error: {}", e))?;
120
121        self.packet.unpack(data, 0)
122
123        // std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
124        //     packet.unpack(data, 0)
125        // }).unwrap().join().unwrap()
126    }
127    /// 执行
128    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
129        // self.stream.try_clone().unwrap().write_all(&self.packet.pack_execute(sql)).unwrap();
130        // let data = self.read()?;
131        // let mut packet = self.packet.clone();
132        // 
133        // std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
134        //     packet.unpack(data, 0)
135        // }).unwrap().join().unwrap()
136
137        self.stream.as_ref()
138            .write_all(&self.packet.pack_execute(sql))
139            .map_err(|e| format!("execute error: {}", e))?;
140        let data = self.read()?;
141        self.packet.unpack(data, 0)
142    }
143}
144impl Drop for Connect {
145
146    fn drop(&mut self) {
147        // 只做资源清理,不处理连接池归还
148        info!("[DROP] 关闭连接: {}", self.stream.peer_addr().unwrap());
149
150        // 发送 Terminate 包通知 PostgreSQL 服务端
151        let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
152        let _ = self.stream.shutdown(std::net::Shutdown::Both);
153    }
154    
155    // fn drop(&mut self) {
156    //     println!("[DROP] 丢弃连接: {}", self.id); // 调试日志
157    // 
158    //     let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
159    //     let _ = self.stream.shutdown(std::net::Shutdown::Both);
160    //     
161    //     // // 显式发送 PostgreSQL 协议 Terminate 包,告诉服务端“我要断开了”
162    //     // if let Err(e) = self.stream.as_ref().write_all(&Packet::pack_terminate()) {
163    //     //     println!("[DROP] 发送 Terminate 出错: {}", e);
164    //     // }
165    //     // 
166    //     // if self.is_valid() {
167    //     //     match crate::pools::DB_POOL.lock() {
168    //     //         Ok(mut pool) => {
169    //     //             if pool.len() <= 20 {
170    //     //                 pool.push(self.clone());
171    //     //                 println!("连接成功归还");
172    //     //             } else {
173    //     //                 println!("连接池已满,丢弃连接");
174    //     //             }
175    //     //         }
176    //     //         Err(e) => println!("获取连接池锁失败: {}", e),
177    //     //     }
178    //     // } else {
179    //     //     println!("丢弃无效连接");
180    //     // }
181    // }
182}