br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use std::io::{Read, Write};
4use std::net::TcpStream;
5use std::sync::Arc;
6use std::time::Duration;
7
8#[derive(Clone, Debug)]
9pub struct Connect {
10    /// 基础配置
11    pub(crate) stream: Arc<TcpStream>,
12    packet: Packet,
13    /// 认证状态
14    auth_status: AuthStatus,
15}
16
17impl Connect {
18    
19    // 判断连接是否正常
20    pub fn is_valid(&self) -> bool {
21        true
22    }
23    
24    // 关闭程序和pgsql的连接
25    pub fn _close(&mut self) {
26        let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
27        let _ = self.stream.shutdown(std::net::Shutdown::Both);
28    }
29
30    pub fn new(mut config: Config) -> Result<Connect, String> {
31        let stream = match TcpStream::connect(config.url()) {
32            Ok(stream) => stream,
33            Err(e) => return Err(e.to_string()),
34        };
35        stream.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
36        stream.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
37
38        // println!("[NEW] 创建连接: {}", stream.peer_addr().unwrap()); // 调试日志
39        let mut connect = Self {
40            stream: Arc::new(stream),
41            packet: Packet::new(config),
42            auth_status: AuthStatus::None,
43        };
44
45        connect.startup_message()?;
46        connect.sasl_initial_response_message()?;
47
48        Ok(connect)
49    }
50
51    fn read(&mut self) -> Result<Vec<u8>, String> {
52        let mut msg = vec![];
53        loop {
54            let mut response = [0u8; 1024];
55            match self.stream.try_clone().unwrap().read(&mut response) {
56                Ok(e) => {
57                    msg.extend(response[..e].to_vec());
58                }
59                Err(e) => return Err(format!("Error reading from stream: {e}")),
60            }
61            if msg.is_empty() {
62                continue;
63            }
64            if let AuthStatus::AuthenticationOk = self.auth_status {
65                if msg.ends_with(&[90, 0, 0, 0, 5, 73]) | msg.ends_with(&[90, 0, 0, 0, 5, 84]) | msg.ends_with(&[90, 0, 0, 0, 5, 69]) {
66                    break;
67                }
68                continue;
69            } else {
70                let t = &msg[1..=4];
71                let len = u32::from_be_bytes(t.try_into().unwrap());
72                if msg.len() < (len as usize) {
73                    continue;
74                }
75                break;
76            }
77        }
78        Ok(msg)
79    }
80    /// Startup Message
81    fn startup_message(&mut self) -> Result<(), String> {
82        self.stream.try_clone().unwrap().write_all(&self.packet.pack_first()).unwrap();
83        let data = self.read()?;
84        self.packet.unpack(data, 0)?;
85        Ok(())
86    }
87    /// `SASLInitialResponse` message
88    fn sasl_initial_response_message(&mut self) -> Result<(), String> {
89        self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth()).unwrap();
90        let data = self.read()?;
91        self.packet.unpack(data, 0)?;
92        self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth_verify()).unwrap();
93        let data = self.read()?;
94        self.packet.unpack(data, 0)?;
95        self.auth_status = AuthStatus::AuthenticationOk;
96        Ok(())
97    }
98    /// 查询
99    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
100        // self.stream.try_clone().unwrap().write_all(&self.packet.pack_query(sql)).unwrap();
101        // let data = self.read()?;
102        // let mut packet = self.packet.clone();
103
104        // 直接使用原始stream,避免try_clone
105        self.stream.as_ref()
106            .write_all(&self.packet.pack_query(sql))
107            .map_err(|e| format!("query error: {}", e))?;
108
109        let data = self.read()
110            .map_err(|e| format!("query read error: {}", e))?;
111
112        self.packet.unpack(data, 0)
113
114        // std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
115        //     packet.unpack(data, 0)
116        // }).unwrap().join().unwrap()
117    }
118    /// 执行
119    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
120        // self.stream.try_clone().unwrap().write_all(&self.packet.pack_execute(sql)).unwrap();
121        // let data = self.read()?;
122        // let mut packet = self.packet.clone();
123        // 
124        // std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
125        //     packet.unpack(data, 0)
126        // }).unwrap().join().unwrap()
127
128        self.stream.as_ref()
129            .write_all(&self.packet.pack_execute(sql))
130            .map_err(|e| format!("execute error: {}", e))?;
131        let data = self.read()?;
132        self.packet.unpack(data, 0)
133    }
134}
135impl Drop for Connect {
136
137    fn drop(&mut self) {
138        crate::pools::DB_POOL.lock().unwrap().push_back(self.clone());
139    }
140    
141    // fn drop(&mut self) {
142    //     println!("[DROP] 丢弃连接: {}", self.id); // 调试日志
143    // 
144    //     let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
145    //     let _ = self.stream.shutdown(std::net::Shutdown::Both);
146    //     
147    //     // // 显式发送 PostgreSQL 协议 Terminate 包,告诉服务端“我要断开了”
148    //     // if let Err(e) = self.stream.as_ref().write_all(&Packet::pack_terminate()) {
149    //     //     println!("[DROP] 发送 Terminate 出错: {}", e);
150    //     // }
151    //     // 
152    //     // if self.is_valid() {
153    //     //     match crate::pools::DB_POOL.lock() {
154    //     //         Ok(mut pool) => {
155    //     //             if pool.len() <= 20 {
156    //     //                 pool.push(self.clone());
157    //     //                 println!("连接成功归还");
158    //     //             } else {
159    //     //                 println!("连接池已满,丢弃连接");
160    //     //             }
161    //     //         }
162    //     //         Err(e) => println!("获取连接池锁失败: {}", e),
163    //     //     }
164    //     // } else {
165    //     //     println!("丢弃无效连接");
166    //     // }
167    // }
168}