Skip to main content

br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::TcpStream;
6use std::sync::Arc;
7use std::time::Duration;
8
9#[derive(Clone, Debug)]
10pub struct Connect {
11    pub(crate) stream: Arc<TcpStream>,
12    packet: Packet,
13    auth_status: AuthStatus,
14}
15
16impl Connect {
17    pub fn is_valid(&mut self) -> bool {
18        self.query("SELECT 1").is_ok()
19    }
20
21    pub fn _close(&mut self) {
22        let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
23        let _ = self.stream.shutdown(std::net::Shutdown::Both);
24    }
25
26    pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
27        let stream =
28            TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
29
30        stream
31            .set_read_timeout(Some(Duration::from_secs(30)))
32            .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
33        stream
34            .set_write_timeout(Some(Duration::from_secs(30)))
35            .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
36
37        let _ = stream.peer_addr();
38
39        let mut connect = Self {
40            stream: Arc::new(stream),
41            packet: Packet::new(config),
42            auth_status: AuthStatus::None,
43        };
44
45        connect.authenticate()?;
46
47        Ok(connect)
48    }
49
50    fn authenticate(&mut self) -> Result<(), PgsqlError> {
51        self.stream
52            .as_ref()
53            .write_all(&self.packet.pack_first())
54            .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
55
56        let data = self.read()?;
57        self.packet.unpack(data, 0)?;
58
59        if !self.packet.md5_salt.is_empty() {
60            self.md5_auth()?;
61        } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
62            self.cleartext_auth()?;
63        } else {
64            self.scram_auth()?;
65        }
66
67        self.auth_status = AuthStatus::AuthenticationOk;
68        Ok(())
69    }
70
71    fn md5_auth(&mut self) -> Result<(), PgsqlError> {
72        self.stream
73            .as_ref()
74            .write_all(&self.packet.pack_md5_password())
75            .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
76
77        let data = self.read()?;
78        self.packet.unpack(data, 0)?;
79        Ok(())
80    }
81
82    fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
83        self.stream
84            .as_ref()
85            .write_all(&self.packet.pack_cleartext_password())
86            .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
87
88        let data = self.read()?;
89        self.packet.unpack(data, 0)?;
90        Ok(())
91    }
92
93    fn scram_auth(&mut self) -> Result<(), PgsqlError> {
94        self.stream
95            .as_ref()
96            .write_all(&self.packet.pack_auth())
97            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
98
99        let data = self.read()?;
100        self.packet.unpack(data, 0)?;
101
102        self.stream
103            .as_ref()
104            .write_all(&self.packet.pack_auth_verify())
105            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
106
107        let data = self.read()?;
108        self.packet.unpack(data, 0)?;
109        Ok(())
110    }
111
112    fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
113        let mut msg = Vec::new();
114        let mut buf = [0u8; 4096];
115        let mut retry_count = 0;
116        const MAX_RETRIES: u32 = 100;
117
118        loop {
119            match self.stream.as_ref().read(&mut buf) {
120                Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
121                Ok(n) => {
122                    msg.extend_from_slice(&buf[..n]);
123                    retry_count = 0;
124                }
125                Err(ref e)
126                    if e.kind() == std::io::ErrorKind::WouldBlock
127                        || e.kind() == std::io::ErrorKind::TimedOut =>
128                {
129                    retry_count += 1;
130                    if retry_count > MAX_RETRIES {
131                        return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
132                    }
133                    std::thread::sleep(Duration::from_millis(10));
134                    continue;
135                }
136                Err(e) => return Err(PgsqlError::Io(e)),
137            };
138
139            if let AuthStatus::AuthenticationOk = self.auth_status {
140                if msg.ends_with(&[90, 0, 0, 0, 5, 73])
141                    || msg.ends_with(&[90, 0, 0, 0, 5, 84])
142                    || msg.ends_with(&[90, 0, 0, 0, 5, 69])
143                {
144                    break;
145                }
146            } else if msg.len() >= 5 {
147                let len_bytes = &msg[1..=4];
148                if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
149                    if msg.len() >= len as usize {
150                        break;
151                    }
152                }
153            }
154        }
155
156        Ok(msg)
157    }
158
159    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
160        self.stream
161            .as_ref()
162            .write_all(&self.packet.pack_query(sql))
163            .map_err(PgsqlError::Io)?;
164
165        let data = self.read()?;
166
167        self.packet.unpack(data, 0)
168    }
169
170    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
171        self.stream
172            .as_ref()
173            .write_all(&self.packet.pack_execute(sql))
174            .map_err(PgsqlError::Io)?;
175        let data = self.read()?;
176        self.packet.unpack(data, 0)
177    }
178}
179
180impl Drop for Connect {
181    fn drop(&mut self) {
182        let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
183        let _ = self.stream.shutdown(std::net::Shutdown::Both);
184    }
185}