br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use std::io::{Read, Write};
4use std::net::TcpStream;
5use std::sync::Arc;
6use std::time::Duration;
7use log::{info, warn};
8
9#[derive(Clone, Debug)]
10pub struct Connect {
11    /// 基础配置
12    pub(crate) stream: Arc<TcpStream>,
13    packet: Packet,
14    /// 认证状态
15    auth_status: AuthStatus,
16}
17
18impl Connect {
19    
20    // 判断连接是否正常
21    pub fn is_valid(&mut self) -> bool {
22        // 发送一个简单心跳
23        self.query("SELECT 1").is_ok()
24    }
25    
26    // 关闭程序和pgsql的连接
27    pub fn _close(&mut self) {
28        let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
29        let _ = self.stream.shutdown(std::net::Shutdown::Both);
30    }
31
32    pub fn new(mut config: Config) -> Result<Connect, String> {
33        let stream = match TcpStream::connect(config.url()) {
34            Ok(stream) => stream,
35            Err(e) => return Err(format!("数据库连接失败: {}", e)),
36        };
37        stream.set_read_timeout(Some(Duration::from_secs(5)))
38            .map_err(|e| format!("设置读取超时失败: {}", e))?;
39        stream.set_write_timeout(Some(Duration::from_secs(5)))
40            .map_err(|e| format!("设置写入超时失败: {}", e))?;
41
42        match stream.peer_addr() {
43            Ok(addr) => info!("[NEW] 创建连接: {}", addr),
44            Err(e) => warn!("无法获取对端地址: {}", e),
45        }
46        
47        let mut connect = Self {
48            stream: Arc::new(stream),
49            packet: Packet::new(config),
50            auth_status: AuthStatus::None,
51        };
52
53        connect.startup_message()?;
54        connect.sasl_initial_response_message()?;
55
56        Ok(connect)
57    }
58
59    fn read(&mut self) -> Result<Vec<u8>, String> {
60        let mut msg = Vec::new();
61        let mut buf = [0u8; 1024];
62
63        loop {
64            match self.stream.as_ref().read(&mut buf) {
65                Ok(0) => return Err("连接已关闭或服务端断开".into()),
66                Ok(n) => msg.extend_from_slice(&buf[..n]),
67                Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock
68                    || e.kind() == std::io::ErrorKind::TimedOut => {
69                    // 避免空循环占用 CPU
70                    std::thread::sleep(Duration::from_millis(5));
71                    continue;
72                }
73                Err(e) => return Err(format!("读取失败: {}", e)),
74            };
75
76            if let AuthStatus::AuthenticationOk = self.auth_status {
77                if msg.ends_with(&[90, 0, 0, 0, 5, 73])
78                    || msg.ends_with(&[90, 0, 0, 0, 5, 84])
79                    || msg.ends_with(&[90, 0, 0, 0, 5, 69])
80                {
81                    break;
82                }
83            } else if msg.len() >= 5 {
84                // 取长度字节
85                let len_bytes = &msg[1..=4];
86                if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
87                    if msg.len() >= len as usize { break; }
88                }
89            }
90        }
91
92        Ok(msg)
93    }
94    /// Startup Message
95    fn startup_message(&mut self) -> Result<(), String> {
96        // try_clone -> as_ref Arc<TcpStream> 是安全的
97        self.stream.as_ref().write_all(&self.packet.pack_first())
98            .map_err(|e| format!("发送 startup message 失败: {}", e))?;
99        let data = self.read()?;
100        self.packet.unpack(data, 0)?;
101        Ok(())
102    }
103    /// `SASLInitialResponse` message
104    fn sasl_initial_response_message(&mut self) -> Result<(), String> {
105        self.stream.as_ref().write_all(&self.packet.pack_auth())
106            .map_err(|e| format!("发送 SASL Initial Response 失败: {}", e))?;
107        
108        let data = self.read()?;
109        self.packet.unpack(data, 0)?;
110        
111        self.stream.as_ref().write_all(&self.packet.pack_auth_verify())
112            .map_err(|e| format!("发送 SASL Verify 失败: {}", e))?;
113        
114        let data = self.read()?;
115        self.packet.unpack(data, 0)?;
116        self.auth_status = AuthStatus::AuthenticationOk;
117        Ok(())
118    }
119    /// 查询
120    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
121        // self.stream.try_clone().unwrap().write_all(&self.packet.pack_query(sql)).unwrap();
122        // let data = self.read()?;
123        // let mut packet = self.packet.clone();
124
125        // 直接使用原始stream,避免try_clone
126        self.stream.as_ref()
127            .write_all(&self.packet.pack_query(sql))
128            .map_err(|e| format!("query error: {}", e))?;
129
130        let data = self.read()
131            .map_err(|e| format!("query read error: {}", e))?;
132
133        self.packet.unpack(data, 0)
134
135        // 唯一的区别是 packet.unpack 运行在一个新栈上(你显式设置了 8 MB stack size)。可能是为了防止深递归或大数据解析时栈溢出。
136        // 如果你的 packet.unpack 是字节流解析(状态机 + 循环),那 2 MB 绰绰有余,不用担心。
137        // 如果它内部藏了大数组或递归,那才需要考虑开大栈。
138        // std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
139        //     packet.unpack(data, 0)
140        // }).unwrap().join().unwrap()
141    }
142    /// 执行
143    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
144        // self.stream.try_clone().unwrap().write_all(&self.packet.pack_execute(sql)).unwrap();
145        // let data = self.read()?;
146        // let mut packet = self.packet.clone();
147        // 
148        // std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
149        //     packet.unpack(data, 0)
150        // }).unwrap().join().unwrap()
151
152        self.stream.as_ref()
153            .write_all(&self.packet.pack_execute(sql))
154            .map_err(|e| format!("execute error: {}", e))?;
155        let data = self.read()?;
156        self.packet.unpack(data, 0)
157    }
158}
159impl Drop for Connect {
160
161    fn drop(&mut self) {
162        // 如果链接关闭的情况下,unwrap会失败,所以此处用match处理
163        // 只做资源清理,不处理连接池归还
164        match self.stream.peer_addr() {
165            Ok(addr) => info!("[DROP] 关闭连接: {}", addr),
166            Err(e) => warn!("[DROP] 获取 peer_addr 失败: {e}"),
167        }
168
169        // 尝试发送 Terminate 包,不强制 unwrap
170        if let Err(e) = self.stream.as_ref().write_all(&Packet::pack_terminate()) {
171            warn!("[DROP] 发送 Terminate 包失败: {e}");
172        }
173
174        if let Err(e) = self.stream.shutdown(std::net::Shutdown::Both) {
175            warn!("[DROP] shutdown 失败: {e}");
176        }
177    }
178    
179    // fn drop(&mut self) {
180    //     println!("[DROP] 丢弃连接: {}", self.id); // 调试日志
181    // 
182    //     let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
183    //     let _ = self.stream.shutdown(std::net::Shutdown::Both);
184    //     
185    //     // // 显式发送 PostgreSQL 协议 Terminate 包,告诉服务端“我要断开了”
186    //     // if let Err(e) = self.stream.as_ref().write_all(&Packet::pack_terminate()) {
187    //     //     println!("[DROP] 发送 Terminate 出错: {}", e);
188    //     // }
189    //     // 
190    //     // if self.is_valid() {
191    //     //     match crate::pools::DB_POOL.lock() {
192    //     //         Ok(mut pool) => {
193    //     //             if pool.len() <= 20 {
194    //     //                 pool.push(self.clone());
195    //     //                 println!("连接成功归还");
196    //     //             } else {
197    //     //                 println!("连接池已满,丢弃连接");
198    //     //             }
199    //     //         }
200    //     //         Err(e) => println!("获取连接池锁失败: {}", e),
201    //     //     }
202    //     // } else {
203    //     //     println!("丢弃无效连接");
204    //     // }
205    // }
206}