Skip to main content

br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::error::PgsqlError;
3use crate::packet::{AuthStatus, Packet, SuccessMessage};
4use std::io::{Read, Write};
5use std::net::TcpStream;
6use std::sync::Arc;
7use std::time::Duration;
8
9#[derive(Clone, Debug)]
10pub struct Connect {
11    pub(crate) stream: Arc<TcpStream>,
12    packet: Packet,
13    auth_status: AuthStatus,
14}
15
16impl Connect {
17    pub fn is_valid(&mut self) -> bool {
18        self.query("SELECT 1").is_ok()
19    }
20
21    pub fn _close(&mut self) {
22        let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
23        let _ = self.stream.shutdown(std::net::Shutdown::Both);
24    }
25
26    pub fn new(mut config: Config) -> Result<Connect, PgsqlError> {
27        let stream =
28            TcpStream::connect(config.url()).map_err(|e| PgsqlError::Connection(e.to_string()))?;
29
30        stream
31            .set_read_timeout(Some(Duration::from_secs(30)))
32            .map_err(|e| PgsqlError::Connection(format!("设置读取超时失败: {}", e)))?;
33        stream
34            .set_write_timeout(Some(Duration::from_secs(30)))
35            .map_err(|e| PgsqlError::Connection(format!("设置写入超时失败: {}", e)))?;
36
37        let _ = stream.peer_addr();
38
39        let mut connect = Self {
40            stream: Arc::new(stream),
41            packet: Packet::new(config),
42            auth_status: AuthStatus::None,
43        };
44
45        connect.authenticate()?;
46
47        Ok(connect)
48    }
49
50    fn authenticate(&mut self) -> Result<(), PgsqlError> {
51        self.stream
52            .as_ref()
53            .write_all(&self.packet.pack_first())
54            .map_err(|e| PgsqlError::Auth(format!("发送 startup message 失败: {}", e)))?;
55
56        let data = self.read()?;
57        self.packet.unpack(data, 0)?;
58
59        if !self.packet.md5_salt.is_empty() {
60            self.md5_auth()?;
61        } else if self.packet.auth_mechanism.is_empty() && self.packet.md5_salt.is_empty() {
62            self.cleartext_auth()?;
63        } else {
64            self.scram_auth()?;
65        }
66
67        self.auth_status = AuthStatus::AuthenticationOk;
68        Ok(())
69    }
70
71    fn md5_auth(&mut self) -> Result<(), PgsqlError> {
72        self.stream
73            .as_ref()
74            .write_all(&self.packet.pack_md5_password())
75            .map_err(|e| PgsqlError::Auth(format!("发送 MD5 密码失败: {}", e)))?;
76
77        let data = self.read()?;
78        self.packet.unpack(data, 0)?;
79        Ok(())
80    }
81
82    fn cleartext_auth(&mut self) -> Result<(), PgsqlError> {
83        self.stream
84            .as_ref()
85            .write_all(&self.packet.pack_cleartext_password())
86            .map_err(|e| PgsqlError::Auth(format!("发送明文密码失败: {}", e)))?;
87
88        let data = self.read()?;
89        self.packet.unpack(data, 0)?;
90        Ok(())
91    }
92
93    fn scram_auth(&mut self) -> Result<(), PgsqlError> {
94        self.stream
95            .as_ref()
96            .write_all(&self.packet.pack_auth())
97            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Initial Response 失败: {}", e)))?;
98
99        let data = self.read()?;
100        self.packet.unpack(data, 0)?;
101
102        self.stream
103            .as_ref()
104            .write_all(&self.packet.pack_auth_verify())
105            .map_err(|e| PgsqlError::Auth(format!("发送 SASL Verify 失败: {}", e)))?;
106
107        let data = self.read()?;
108        self.packet.unpack(data, 0)?;
109        Ok(())
110    }
111
112    fn read(&mut self) -> Result<Vec<u8>, PgsqlError> {
113        let mut msg = Vec::new();
114        let mut buf = [0u8; 4096];
115        let mut retry_count = 0;
116
117        #[cfg(not(test))]
118        const MAX_RETRIES: u32 = 100;
119        #[cfg(test)]
120        const MAX_RETRIES: u32 = 3;
121
122        #[cfg(not(test))]
123        const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
124        #[cfg(test)]
125        const MAX_MESSAGE_SIZE: usize = 128;
126
127        #[cfg(not(test))]
128        let deadline = std::time::Instant::now() + Duration::from_secs(300);
129        #[cfg(test)]
130        let deadline = std::time::Instant::now() + Duration::from_millis(200);
131
132        loop {
133            if std::time::Instant::now() >= deadline {
134                return Err(PgsqlError::Timeout("读取总超时".into()));
135            }
136
137            match self.stream.as_ref().read(&mut buf) {
138                Ok(0) => return Err(PgsqlError::Connection("连接已关闭或服务端断开".into())),
139                Ok(n) => {
140                    if msg.len() + n > MAX_MESSAGE_SIZE {
141                        return Err(PgsqlError::Protocol("消息超过最大允许大小".into()));
142                    }
143                    msg.extend_from_slice(&buf[..n]);
144                    retry_count = 0;
145                }
146                Err(ref e)
147                    if e.kind() == std::io::ErrorKind::WouldBlock
148                        || e.kind() == std::io::ErrorKind::TimedOut =>
149                {
150                    retry_count += 1;
151                    if retry_count > MAX_RETRIES {
152                        return Err(PgsqlError::Timeout("读取超时,已达最大重试次数".into()));
153                    }
154                    std::thread::sleep(Duration::from_millis(10));
155                    continue;
156                }
157                Err(e) => return Err(PgsqlError::Io(e)),
158            };
159
160            if let AuthStatus::AuthenticationOk = self.auth_status {
161                if msg.ends_with(&[90, 0, 0, 0, 5, 73])
162                    || msg.ends_with(&[90, 0, 0, 0, 5, 84])
163                    || msg.ends_with(&[90, 0, 0, 0, 5, 69])
164                {
165                    break;
166                }
167            } else if msg.len() >= 5 {
168                let len_bytes = &msg[1..=4];
169                if let Ok(len) = len_bytes.try_into().map(u32::from_be_bytes) {
170                    if msg.len() > len as usize {
171                        break;
172                    }
173                }
174            }
175        }
176
177        Ok(msg)
178    }
179
180    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
181        self.stream
182            .as_ref()
183            .write_all(&self.packet.pack_query(sql))
184            .map_err(PgsqlError::Io)?;
185
186        let data = self.read()?;
187
188        self.packet.unpack(data, 0)
189    }
190
191    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, PgsqlError> {
192        self.stream
193            .as_ref()
194            .write_all(&self.packet.pack_execute(sql))
195            .map_err(PgsqlError::Io)?;
196        let data = self.read()?;
197        self.packet.unpack(data, 0)
198    }
199
200    /// 参数化查询
201    pub fn query_params(
202        &mut self,
203        sql: &str,
204        params: &[Option<&str>],
205    ) -> Result<SuccessMessage, PgsqlError> {
206        self.stream
207            .as_ref()
208            .write_all(&self.packet.pack_query_params(sql, params))
209            .map_err(PgsqlError::Io)?;
210
211        let data = self.read()?;
212        self.packet.unpack(data, 0)
213    }
214
215    /// 参数化执行
216    pub fn execute_params(
217        &mut self,
218        sql: &str,
219        params: &[Option<&str>],
220    ) -> Result<SuccessMessage, PgsqlError> {
221        self.stream
222            .as_ref()
223            .write_all(&self.packet.pack_execute_params(sql, params))
224            .map_err(PgsqlError::Io)?;
225        let data = self.read()?;
226        self.packet.unpack(data, 0)
227    }
228
229    /// 参数化查询(便捷版,所有参数非 NULL)
230    pub fn query_str(&mut self, sql: &str, params: &[&str]) -> Result<SuccessMessage, PgsqlError> {
231        let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
232        self.query_params(sql, &opts)
233    }
234
235    /// 参数化执行(便捷版,所有参数非 NULL)
236    pub fn execute_str(
237        &mut self,
238        sql: &str,
239        params: &[&str],
240    ) -> Result<SuccessMessage, PgsqlError> {
241        let opts: Vec<Option<&str>> = params.iter().map(|s| Some(*s)).collect();
242        self.execute_params(sql, &opts)
243    }
244}
245
246impl Drop for Connect {
247    fn drop(&mut self) {
248        let _ = self.stream.as_ref().write_all(&Packet::pack_terminate());
249        let _ = self.stream.shutdown(std::net::Shutdown::Both);
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use std::net::TcpListener;
257    use std::thread;
258
259    // ── wire-protocol helpers ──────────────────────────────────────────
260
261    /// Build a single PG backend message: type_byte | len(4) | payload
262    fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
263        let mut m = Vec::with_capacity(5 + payload.len());
264        m.push(tag);
265        m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
266        m.extend_from_slice(payload);
267        m
268    }
269
270    /// Build an Authentication message (tag 'R') with given auth_type + extra payload
271    fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
272        let mut body = Vec::new();
273        body.extend(&auth_type.to_be_bytes());
274        body.extend_from_slice(extra);
275        pg_msg(b'R', &body)
276    }
277
278    /// AuthenticationOk (R, type=0)
279    fn auth_ok() -> Vec<u8> {
280        pg_auth(0, &[])
281    }
282
283    /// ParameterStatus for server_version=15.0
284    fn param_status() -> Vec<u8> {
285        pg_msg(b'S', b"server_version\x0015.0\x00")
286    }
287
288    /// BackendKeyData (process_id=1, secret_key=2)
289    fn backend_key() -> Vec<u8> {
290        let mut p = Vec::new();
291        p.extend(&1u32.to_be_bytes());
292        p.extend(&2u32.to_be_bytes());
293        pg_msg(b'K', &p)
294    }
295
296    /// ReadyForQuery (status = Idle)
297    fn ready_for_query() -> Vec<u8> {
298        pg_msg(b'Z', b"I")
299    }
300
301    /// The standard tail sent after auth succeeds: AuthOk + param + key + ready
302    fn post_auth_ok() -> Vec<u8> {
303        let mut v = Vec::new();
304        v.extend(auth_ok());
305        v.extend(param_status());
306        v.extend(backend_key());
307        v.extend(ready_for_query());
308        v
309    }
310
311    /// Build a simple query response: ParseComplete + BindComplete +
312    /// RowDescription(1 int4 col "c") + DataRow("1") + CommandComplete("SELECT 1") + ReadyForQuery
313    fn simple_query_response() -> Vec<u8> {
314        let mut r = Vec::new();
315        // ParseComplete
316        r.extend(pg_msg(b'1', &[]));
317        // BindComplete
318        r.extend(pg_msg(b'2', &[]));
319        // RowDescription – 1 field "c", type_oid=23 (int4)
320        let mut rd = Vec::new();
321        rd.extend(&1u16.to_be_bytes()); // field count
322        rd.extend(b"c\x00"); // name
323        rd.extend(&0u32.to_be_bytes()); // table oid
324        rd.extend(&1u16.to_be_bytes()); // column index
325        rd.extend(&23u32.to_be_bytes()); // type oid (int4)
326        rd.extend(&4i16.to_be_bytes()); // column length
327        rd.extend(&(-1i32).to_be_bytes()); // type modifier
328        rd.extend(&0u16.to_be_bytes()); // format (text)
329        r.extend(pg_msg(b'T', &rd));
330        // DataRow – 1 field, value "1"
331        let mut dr = Vec::new();
332        dr.extend(&1u16.to_be_bytes());
333        dr.extend(&1u32.to_be_bytes()); // length of value
334        dr.push(b'1');
335        r.extend(pg_msg(b'D', &dr));
336        // CommandComplete
337        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
338        // ReadyForQuery
339        r.extend(ready_for_query());
340        r
341    }
342
343    /// Build an execute response (no rows): ParseComplete + BindComplete +
344    /// NoData + CommandComplete("UPDATE 3") + ReadyForQuery
345    fn execute_response() -> Vec<u8> {
346        let mut r = Vec::new();
347        r.extend(pg_msg(b'1', &[]));
348        r.extend(pg_msg(b'2', &[]));
349        r.extend(pg_msg(b'n', &[])); // NoData
350        r.extend(pg_msg(b'C', b"UPDATE 3\x00"));
351        r.extend(ready_for_query());
352        r
353    }
354
355    /// Build a parameterized query response: ParseComplete + ParameterDescription + BindComplete +
356    /// RowDescription(1 int4 col "p") + DataRow("42") + CommandComplete("SELECT 1") + ReadyForQuery
357    fn query_params_response() -> Vec<u8> {
358        let mut r = Vec::new();
359        r.extend(pg_msg(b'1', &[]));
360
361        let mut pd = Vec::new();
362        pd.extend(&1u16.to_be_bytes());
363        pd.extend(&23u32.to_be_bytes());
364        r.extend(pg_msg(b't', &pd));
365
366        r.extend(pg_msg(b'2', &[]));
367
368        let mut rd = Vec::new();
369        rd.extend(&1u16.to_be_bytes());
370        rd.extend(b"p\x00");
371        rd.extend(&0u32.to_be_bytes());
372        rd.extend(&1u16.to_be_bytes());
373        rd.extend(&23u32.to_be_bytes());
374        rd.extend(&4i16.to_be_bytes());
375        rd.extend(&(-1i32).to_be_bytes());
376        rd.extend(&0u16.to_be_bytes());
377        r.extend(pg_msg(b'T', &rd));
378
379        let mut dr = Vec::new();
380        dr.extend(&1u16.to_be_bytes());
381        dr.extend(&2u32.to_be_bytes());
382        dr.extend(b"42");
383        r.extend(pg_msg(b'D', &dr));
384
385        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
386        r.extend(ready_for_query());
387        r
388    }
389
390    /// Build a parameterized execute response: ParseComplete + ParameterDescription + BindComplete +
391    /// NoData + CommandComplete("UPDATE 1") + ReadyForQuery
392    fn execute_params_response() -> Vec<u8> {
393        let mut r = Vec::new();
394        r.extend(pg_msg(b'1', &[]));
395
396        let mut pd = Vec::new();
397        pd.extend(&1u16.to_be_bytes());
398        pd.extend(&23u32.to_be_bytes());
399        r.extend(pg_msg(b't', &pd));
400
401        r.extend(pg_msg(b'2', &[]));
402        r.extend(pg_msg(b'n', &[]));
403        r.extend(pg_msg(b'C', b"UPDATE 1\x00"));
404        r.extend(ready_for_query());
405        r
406    }
407
408    /// Build a parameterized query response with NULL row value.
409    fn query_params_null_response() -> Vec<u8> {
410        let mut r = Vec::new();
411        r.extend(pg_msg(b'1', &[]));
412
413        let mut pd = Vec::new();
414        pd.extend(&1u16.to_be_bytes());
415        pd.extend(&25u32.to_be_bytes());
416        r.extend(pg_msg(b't', &pd));
417
418        r.extend(pg_msg(b'2', &[]));
419
420        let mut rd = Vec::new();
421        rd.extend(&1u16.to_be_bytes());
422        rd.extend(b"n\x00");
423        rd.extend(&0u32.to_be_bytes());
424        rd.extend(&1u16.to_be_bytes());
425        rd.extend(&25u32.to_be_bytes());
426        rd.extend(&(-1i16).to_be_bytes());
427        rd.extend(&(-1i32).to_be_bytes());
428        rd.extend(&0u16.to_be_bytes());
429        r.extend(pg_msg(b'T', &rd));
430
431        let mut dr = Vec::new();
432        dr.extend(&1u16.to_be_bytes());
433        dr.extend(&(-1i32).to_be_bytes());
434        r.extend(pg_msg(b'D', &dr));
435
436        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
437        r.extend(ready_for_query());
438        r
439    }
440
441    /// Build an ErrorResponse for query phase
442    fn error_response() -> Vec<u8> {
443        let mut payload = Vec::new();
444        payload.push(b'C');
445        payload.extend(b"42601\x00");
446        payload.push(b'M');
447        payload.extend(b"syntax error\x00");
448        payload.push(0);
449        let mut r = Vec::new();
450        r.extend(pg_msg(b'1', &[]));
451        r.extend(pg_msg(b'2', &[]));
452        r.extend(pg_msg(b'E', &payload));
453        r.extend(ready_for_query());
454        r
455    }
456
457    // ── mock server spawners ───────────────────────────────────────────
458
459    /// Config pointing at 127.0.0.1:<port>
460    fn mock_config(port: u16) -> Config {
461        Config {
462            debug: false,
463            hostname: "127.0.0.1".into(),
464            hostport: port as i32,
465            username: "u".into(),
466            userpass: "p".into(),
467            database: "d".into(),
468            charset: "utf8".into(),
469            pool_max: 5,
470        }
471    }
472
473    /// Spawn a mock PG server that does **cleartext** auth.
474    /// Returns the port.  The server handles one connection.
475    fn spawn_cleartext_server() -> u16 {
476        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
477        let port = listener.local_addr().unwrap().port();
478        thread::spawn(move || {
479            let (mut s, _) = listener.accept().unwrap();
480            let mut buf = [0u8; 4096];
481            // 1. read startup
482            let _ = s.read(&mut buf).unwrap();
483            // 2. send CleartextPassword request (auth_type=3)
484            let _ = s.write_all(&pg_auth(3, &[]));
485            // 3. read password message
486            let _ = s.read(&mut buf).unwrap();
487            // 4. send AuthOk + params + ready
488            let _ = s.write_all(&post_auth_ok());
489            // keep connection alive for queries
490            loop {
491                match s.read(&mut buf) {
492                    Ok(0) | Err(_) => break,
493                    Ok(_) => {
494                        let _ = s.write_all(&simple_query_response());
495                    }
496                }
497            }
498        });
499        thread::sleep(Duration::from_millis(30));
500        port
501    }
502
503    /// Spawn a mock PG server that does **MD5** auth.
504    fn spawn_md5_server() -> u16 {
505        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
506        let port = listener.local_addr().unwrap().port();
507        thread::spawn(move || {
508            let (mut s, _) = listener.accept().unwrap();
509            let mut buf = [0u8; 4096];
510            // 1. read startup
511            let _ = s.read(&mut buf).unwrap();
512            // 2. send MD5Password request (auth_type=5) + 4-byte salt
513            let _ = s.write_all(&pg_auth(5, &[0xAA, 0xBB, 0xCC, 0xDD]));
514            // 3. read md5 password
515            let _ = s.read(&mut buf).unwrap();
516            // 4. send AuthOk + params + ready
517            let _ = s.write_all(&post_auth_ok());
518            loop {
519                match s.read(&mut buf) {
520                    Ok(0) | Err(_) => break,
521                    Ok(_) => {
522                        let _ = s.write_all(&simple_query_response());
523                    }
524                }
525            }
526        });
527        thread::sleep(Duration::from_millis(30));
528        port
529    }
530
531    /// Spawn a mock PG server that does **SCRAM-SHA-256** auth.
532    fn spawn_scram_server() -> u16 {
533        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
534        let port = listener.local_addr().unwrap().port();
535        thread::spawn(move || {
536            let (mut s, _) = listener.accept().unwrap();
537            let mut buf = [0u8; 4096];
538            // 1. read startup
539            let _ = s.read(&mut buf).unwrap();
540            // 2. send SASL auth (auth_type=10, mechanism)
541            let _ = s.write_all(&pg_auth(10, b"SCRAM-SHA-256\x00\x00"));
542            // 3. read SASLInitialResponse – extract client nonce
543            let n = s.read(&mut buf).unwrap();
544            let payload = &buf[..n];
545            // find "n=,r=" in the payload to extract client nonce
546            let text = String::from_utf8_lossy(payload);
547            let client_nonce = text.split("r=").nth(1).unwrap_or("clientnonce").to_string();
548            // 4. send SCRAM challenge (auth_type=11)
549            let challenge = format!("r={client_nonce}SERVERNONCE,s=c2FsdA==,i=4096");
550            let _ = s.write_all(&pg_auth(11, challenge.as_bytes()));
551            // 5. read SCRAM client final
552            let _ = s.read(&mut buf).unwrap();
553            // 6. send SCRAM complete (auth_type=12) + AuthOk + params + ready
554            let mut resp = Vec::new();
555            resp.extend(pg_auth(12, b"v=dummyproof"));
556            resp.extend(auth_ok());
557            resp.extend(param_status());
558            resp.extend(backend_key());
559            resp.extend(ready_for_query());
560            let _ = s.write_all(&resp);
561            loop {
562                match s.read(&mut buf) {
563                    Ok(0) | Err(_) => break,
564                    Ok(_) => {
565                        let _ = s.write_all(&simple_query_response());
566                    }
567                }
568            }
569        });
570        thread::sleep(Duration::from_millis(30));
571        port
572    }
573
574    /// Spawn a server that accepts connection then immediately closes (EOF).
575    fn spawn_eof_server() -> u16 {
576        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
577        let port = listener.local_addr().unwrap().port();
578        thread::spawn(move || {
579            let (s, _) = listener.accept().unwrap();
580            drop(s); // close immediately
581        });
582        thread::sleep(Duration::from_millis(30));
583        port
584    }
585
586    /// Spawn a server that sends an ErrorResponse after startup.
587    fn spawn_auth_error_server() -> u16 {
588        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
589        let port = listener.local_addr().unwrap().port();
590        thread::spawn(move || {
591            let (mut s, _) = listener.accept().unwrap();
592            let mut buf = [0u8; 4096];
593            let _ = s.read(&mut buf).unwrap();
594            // Send ErrorResponse
595            let mut payload = Vec::new();
596            payload.push(b'C');
597            payload.extend(b"28P01\x00");
598            payload.push(b'M');
599            payload.extend(b"password authentication failed\x00");
600            payload.push(0);
601            let _ = s.write_all(&pg_msg(b'E', &payload));
602        });
603        thread::sleep(Duration::from_millis(30));
604        port
605    }
606
607    /// Spawn a cleartext server that responds with error on query.
608    fn spawn_query_error_server() -> u16 {
609        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
610        let port = listener.local_addr().unwrap().port();
611        thread::spawn(move || {
612            let (mut s, _) = listener.accept().unwrap();
613            let mut buf = [0u8; 4096];
614            // auth
615            let _ = s.read(&mut buf).unwrap();
616            let _ = s.write_all(&pg_auth(3, &[]));
617            let _ = s.read(&mut buf).unwrap();
618            let _ = s.write_all(&post_auth_ok());
619            // query → error
620            loop {
621                match s.read(&mut buf) {
622                    Ok(0) | Err(_) => break,
623                    Ok(_) => {
624                        let _ = s.write_all(&error_response());
625                    }
626                }
627            }
628        });
629        thread::sleep(Duration::from_millis(30));
630        port
631    }
632
633    /// Spawn a cleartext server that responds with execute response.
634    fn spawn_execute_server() -> u16 {
635        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
636        let port = listener.local_addr().unwrap().port();
637        thread::spawn(move || {
638            let (mut s, _) = listener.accept().unwrap();
639            let mut buf = [0u8; 4096];
640            // auth
641            let _ = s.read(&mut buf).unwrap();
642            let _ = s.write_all(&pg_auth(3, &[]));
643            let _ = s.read(&mut buf).unwrap();
644            let _ = s.write_all(&post_auth_ok());
645            // queries
646            loop {
647                match s.read(&mut buf) {
648                    Ok(0) | Err(_) => break,
649                    Ok(_) => {
650                        let _ = s.write_all(&execute_response());
651                    }
652                }
653            }
654        });
655        thread::sleep(Duration::from_millis(30));
656        port
657    }
658
659    /// Spawn a cleartext server that responds with parameterized query response.
660    fn spawn_query_params_server() -> u16 {
661        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
662        let port = listener.local_addr().unwrap().port();
663        thread::spawn(move || {
664            let (mut s, _) = listener.accept().unwrap();
665            let mut buf = [0u8; 4096];
666            let _ = s.read(&mut buf).unwrap();
667            let _ = s.write_all(&pg_auth(3, &[]));
668            let _ = s.read(&mut buf).unwrap();
669            let _ = s.write_all(&post_auth_ok());
670            loop {
671                match s.read(&mut buf) {
672                    Ok(0) | Err(_) => break,
673                    Ok(_) => {
674                        let _ = s.write_all(&query_params_response());
675                    }
676                }
677            }
678        });
679        thread::sleep(Duration::from_millis(30));
680        port
681    }
682
683    fn spawn_params_server() -> u16 {
684        spawn_query_params_server()
685    }
686
687    /// Spawn a cleartext server that responds with parameterized execute response.
688    fn spawn_execute_params_server() -> u16 {
689        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
690        let port = listener.local_addr().unwrap().port();
691        thread::spawn(move || {
692            let (mut s, _) = listener.accept().unwrap();
693            let mut buf = [0u8; 4096];
694            let _ = s.read(&mut buf).unwrap();
695            let _ = s.write_all(&pg_auth(3, &[]));
696            let _ = s.read(&mut buf).unwrap();
697            let _ = s.write_all(&post_auth_ok());
698            loop {
699                match s.read(&mut buf) {
700                    Ok(0) | Err(_) => break,
701                    Ok(_) => {
702                        let _ = s.write_all(&execute_params_response());
703                    }
704                }
705            }
706        });
707        thread::sleep(Duration::from_millis(30));
708        port
709    }
710
711    /// Spawn a cleartext server that returns NULL in parameterized query result.
712    fn spawn_query_params_null_server() -> u16 {
713        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
714        let port = listener.local_addr().unwrap().port();
715        thread::spawn(move || {
716            let (mut s, _) = listener.accept().unwrap();
717            let mut buf = [0u8; 4096];
718            let _ = s.read(&mut buf).unwrap();
719            let _ = s.write_all(&pg_auth(3, &[]));
720            let _ = s.read(&mut buf).unwrap();
721            let _ = s.write_all(&post_auth_ok());
722            loop {
723                match s.read(&mut buf) {
724                    Ok(0) | Err(_) => break,
725                    Ok(_) => {
726                        let _ = s.write_all(&query_params_null_response());
727                    }
728                }
729            }
730        });
731        thread::sleep(Duration::from_millis(30));
732        port
733    }
734
735    // ── tests ──────────────────────────────────────────────────────────
736
737    #[test]
738    fn connect_cleartext_auth_success() {
739        let port = spawn_cleartext_server();
740        let conn = Connect::new(mock_config(port));
741        assert!(conn.is_ok());
742    }
743
744    #[test]
745    fn connect_md5_auth_success() {
746        let port = spawn_md5_server();
747        let conn = Connect::new(mock_config(port));
748        assert!(conn.is_ok());
749    }
750
751    #[test]
752    fn connect_scram_auth_success() {
753        let port = spawn_scram_server();
754        let conn = Connect::new(mock_config(port));
755        assert!(conn.is_ok());
756    }
757
758    #[test]
759    fn connect_connection_refused() {
760        // port 1 is almost certainly not listening
761        let cfg = mock_config(1);
762        let result = Connect::new(cfg);
763        assert!(result.is_err());
764        match result.unwrap_err() {
765            PgsqlError::Connection(_) => {}
766            other => panic!("expected Connection error, got {other:?}"),
767        }
768    }
769
770    #[test]
771    fn connect_server_closes_immediately() {
772        let port = spawn_eof_server();
773        let result = Connect::new(mock_config(port));
774        assert!(result.is_err());
775    }
776
777    #[test]
778    fn connect_auth_error_from_server() {
779        let port = spawn_auth_error_server();
780        let result = Connect::new(mock_config(port));
781        assert!(result.is_err());
782    }
783
784    #[test]
785    fn connect_query_success() {
786        let port = spawn_cleartext_server();
787        let mut conn = Connect::new(mock_config(port)).unwrap();
788        let result = conn.query("SELECT 1");
789        assert!(result.is_ok());
790        let msg = result.unwrap();
791        assert_eq!(msg.rows.len(), 1);
792        assert_eq!(msg.rows[0]["c"].as_i32(), Some(1));
793    }
794
795    #[test]
796    fn connect_execute_success() {
797        let port = spawn_execute_server();
798        let mut conn = Connect::new(mock_config(port)).unwrap();
799        let result = conn.execute("UPDATE t SET x=1");
800        assert!(result.is_ok());
801        let msg = result.unwrap();
802        assert_eq!(msg.affect_count, 3);
803        assert_eq!(msg.tag, "UPDATE 3");
804    }
805
806    #[test]
807    fn connect_query_params_success() {
808        let port = spawn_query_params_server();
809        let mut conn = Connect::new(mock_config(port)).unwrap();
810        let result = conn.query_params("SELECT $1::int", &[Some("42")]);
811        assert!(result.is_ok());
812        let msg = result.unwrap();
813        assert!(!msg.param_oids.is_empty());
814        assert_eq!(msg.rows.len(), 1);
815        assert_eq!(msg.rows[0]["p"].as_i32(), Some(42));
816    }
817
818    #[test]
819    fn connect_execute_params_success() {
820        let port = spawn_execute_params_server();
821        let mut conn = Connect::new(mock_config(port)).unwrap();
822        let result = conn.execute_params("UPDATE t SET x=$1", &[Some("42")]);
823        assert!(result.is_ok());
824        let msg = result.unwrap();
825        assert!(!msg.param_oids.is_empty());
826        assert_eq!(msg.affect_count, 1);
827        assert_eq!(msg.tag, "UPDATE 1");
828    }
829
830    #[test]
831    fn connect_query_str_success() {
832        let port = spawn_params_server();
833        let mut conn = Connect::new(mock_config(port)).unwrap();
834        let result = conn.query_str("SELECT $1::int", &["42"]);
835        assert!(result.is_ok());
836        let msg = result.unwrap();
837        assert!(!msg.param_oids.is_empty());
838        assert_eq!(msg.rows.len(), 1);
839    }
840
841    #[test]
842    fn connect_execute_str_success() {
843        let port = spawn_execute_params_server();
844        let mut conn = Connect::new(mock_config(port)).unwrap();
845        let result = conn.execute_str("UPDATE t SET x=$1", &["1"]);
846        assert!(result.is_ok());
847        let msg = result.unwrap();
848        assert!(!msg.param_oids.is_empty());
849        assert_eq!(msg.affect_count, 1);
850    }
851
852    #[test]
853    fn connect_query_params_with_null() {
854        let port = spawn_query_params_null_server();
855        let mut conn = Connect::new(mock_config(port)).unwrap();
856        let result = conn.query_params("SELECT $1::text", &[None]);
857        assert!(result.is_ok());
858        let msg = result.unwrap();
859        assert!(!msg.param_oids.is_empty());
860        assert_eq!(msg.rows.len(), 1);
861        assert!(msg.rows[0]["n"].is_null());
862    }
863
864    #[test]
865    fn connect_query_params_empty_string_vs_null() {
866        let port = spawn_params_server();
867        let mut conn = Connect::new(mock_config(port)).unwrap();
868
869        // 空字符串参数
870        let r1 = conn.query_params("SELECT $1::text", &[Some("")]);
871        assert!(r1.is_ok());
872
873        // NULL 参数
874        let r2 = conn.query_params("SELECT $1::text", &[None]);
875        assert!(r2.is_ok());
876    }
877
878    #[test]
879    fn connect_query_returns_error() {
880        let port = spawn_query_error_server();
881        let mut conn = Connect::new(mock_config(port)).unwrap();
882        let result = conn.query("BAD SQL");
883        assert!(result.is_err());
884    }
885
886    #[test]
887    fn connect_is_valid_true() {
888        let port = spawn_cleartext_server();
889        let mut conn = Connect::new(mock_config(port)).unwrap();
890        assert!(conn.is_valid());
891    }
892
893    #[test]
894    fn connect_is_valid_false_after_close() {
895        let port = spawn_cleartext_server();
896        let mut conn = Connect::new(mock_config(port)).unwrap();
897        conn._close();
898        // After closing, is_valid should return false
899        assert!(!conn.is_valid());
900    }
901
902    #[test]
903    fn connect_close_does_not_panic() {
904        let port = spawn_cleartext_server();
905        let mut conn = Connect::new(mock_config(port)).unwrap();
906        conn._close();
907        // calling close again should not panic
908        conn._close();
909    }
910
911    #[test]
912    fn connect_drop_does_not_panic() {
913        let port = spawn_cleartext_server();
914        let conn = Connect::new(mock_config(port)).unwrap();
915        drop(conn);
916    }
917
918    fn spawn_transaction_status_server() -> u16 {
919        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
920        let port = listener.local_addr().unwrap().port();
921        thread::spawn(move || {
922            let (mut s, _) = listener.accept().unwrap();
923            let mut buf = [0u8; 4096];
924            let _ = s.read(&mut buf).unwrap();
925            let _ = s.write_all(&pg_auth(3, &[]));
926            let _ = s.read(&mut buf).unwrap();
927            let _ = s.write_all(&post_auth_ok());
928            loop {
929                match s.read(&mut buf) {
930                    Ok(0) | Err(_) => break,
931                    Ok(_) => {
932                        let mut r = Vec::new();
933                        r.extend(pg_msg(b'1', &[]));
934                        r.extend(pg_msg(b'2', &[]));
935                        let mut rd = Vec::new();
936                        rd.extend(&1u16.to_be_bytes());
937                        rd.extend(b"c\x00");
938                        rd.extend(&0u32.to_be_bytes());
939                        rd.extend(&1u16.to_be_bytes());
940                        rd.extend(&23u32.to_be_bytes());
941                        rd.extend(&4i16.to_be_bytes());
942                        rd.extend(&(-1i32).to_be_bytes());
943                        rd.extend(&0u16.to_be_bytes());
944                        r.extend(pg_msg(b'T', &rd));
945                        let mut dr = Vec::new();
946                        dr.extend(&1u16.to_be_bytes());
947                        dr.extend(&1u32.to_be_bytes());
948                        dr.push(b'1');
949                        r.extend(pg_msg(b'D', &dr));
950                        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
951                        r.extend(pg_msg(b'Z', b"T"));
952                        let _ = s.write_all(&r);
953                    }
954                }
955            }
956        });
957        thread::sleep(Duration::from_millis(30));
958        port
959    }
960
961    fn spawn_error_status_server() -> u16 {
962        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
963        let port = listener.local_addr().unwrap().port();
964        thread::spawn(move || {
965            let (mut s, _) = listener.accept().unwrap();
966            let mut buf = [0u8; 4096];
967            let _ = s.read(&mut buf).unwrap();
968            let _ = s.write_all(&pg_auth(3, &[]));
969            let _ = s.read(&mut buf).unwrap();
970            let _ = s.write_all(&post_auth_ok());
971            loop {
972                match s.read(&mut buf) {
973                    Ok(0) | Err(_) => break,
974                    Ok(_) => {
975                        let mut r = Vec::new();
976                        r.extend(pg_msg(b'1', &[]));
977                        r.extend(pg_msg(b'2', &[]));
978                        let mut rd = Vec::new();
979                        rd.extend(&1u16.to_be_bytes());
980                        rd.extend(b"c\x00");
981                        rd.extend(&0u32.to_be_bytes());
982                        rd.extend(&1u16.to_be_bytes());
983                        rd.extend(&23u32.to_be_bytes());
984                        rd.extend(&4i16.to_be_bytes());
985                        rd.extend(&(-1i32).to_be_bytes());
986                        rd.extend(&0u16.to_be_bytes());
987                        r.extend(pg_msg(b'T', &rd));
988                        let mut dr = Vec::new();
989                        dr.extend(&1u16.to_be_bytes());
990                        dr.extend(&1u32.to_be_bytes());
991                        dr.push(b'1');
992                        r.extend(pg_msg(b'D', &dr));
993                        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
994                        r.extend(pg_msg(b'Z', b"E"));
995                        let _ = s.write_all(&r);
996                    }
997                }
998            }
999        });
1000        thread::sleep(Duration::from_millis(30));
1001        port
1002    }
1003
1004    fn spawn_slow_partial_server() -> u16 {
1005        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1006        let port = listener.local_addr().unwrap().port();
1007        thread::spawn(move || {
1008            let (mut s, _) = listener.accept().unwrap();
1009            let mut buf = [0u8; 4096];
1010            let _ = s.read(&mut buf).unwrap();
1011            let _ = s.write_all(&pg_auth(3, &[]));
1012            let _ = s.read(&mut buf).unwrap();
1013            let _ = s.write_all(&post_auth_ok());
1014            match s.read(&mut buf) {
1015                Ok(0) | Err(_) => {}
1016                Ok(_) => {
1017                    let _ = s.write_all(&simple_query_response());
1018                }
1019            }
1020        });
1021        thread::sleep(Duration::from_millis(30));
1022        port
1023    }
1024
1025    fn spawn_rst_on_query_server() -> u16 {
1026        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1027        let port = listener.local_addr().unwrap().port();
1028        thread::spawn(move || {
1029            let (mut s, _) = listener.accept().unwrap();
1030            let mut buf = [0u8; 4096];
1031            let _ = s.read(&mut buf).unwrap();
1032            let _ = s.write_all(&pg_auth(3, &[]));
1033            let _ = s.read(&mut buf).unwrap();
1034            let _ = s.write_all(&post_auth_ok());
1035            match s.read(&mut buf) {
1036                Ok(0) | Err(_) => {}
1037                Ok(_) => {
1038                    drop(s);
1039                }
1040            }
1041        });
1042        thread::sleep(Duration::from_millis(30));
1043        port
1044    }
1045
1046    #[test]
1047    fn connect_query_ready_for_query_transaction_status() {
1048        let port = spawn_transaction_status_server();
1049        let mut conn = Connect::new(mock_config(port)).unwrap();
1050        let result = conn.query("SELECT 1");
1051        assert!(result.is_ok());
1052    }
1053
1054    #[test]
1055    fn connect_query_ready_for_query_error_status() {
1056        let port = spawn_error_status_server();
1057        let mut conn = Connect::new(mock_config(port)).unwrap();
1058        let result = conn.query("SELECT 1");
1059        assert!(result.is_ok());
1060    }
1061
1062    #[test]
1063    fn connect_query_server_closes_after_partial() {
1064        let port = spawn_slow_partial_server();
1065        let mut conn = Connect::new(mock_config(port)).unwrap();
1066        let r1 = conn.query("SELECT 1");
1067        assert!(r1.is_ok());
1068        let r2 = conn.query("SELECT 1");
1069        assert!(r2.is_err());
1070    }
1071
1072    #[test]
1073    fn connect_query_server_rst_returns_io_or_connection_error() {
1074        let port = spawn_rst_on_query_server();
1075        let mut conn = Connect::new(mock_config(port)).unwrap();
1076        let result = conn.query("SELECT 1");
1077        assert!(result.is_err());
1078    }
1079
1080    #[test]
1081    fn connect_read_would_block_max_retries() {
1082        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1083        let port = listener.local_addr().unwrap().port();
1084        thread::spawn(move || {
1085            let (mut s, _) = listener.accept().unwrap();
1086            let mut buf = [0u8; 4096];
1087            let _ = s.read(&mut buf);
1088            let _ = s.write_all(&pg_auth(3, &[]));
1089            let _ = s.read(&mut buf);
1090            let _ = s.write_all(&post_auth_ok());
1091            let _ = s.read(&mut buf);
1092            thread::sleep(Duration::from_secs(5));
1093        });
1094        thread::sleep(Duration::from_millis(30));
1095
1096        let mut conn = Connect::new(mock_config(port)).unwrap();
1097        conn.stream
1098            .set_read_timeout(Some(Duration::from_millis(1)))
1099            .ok();
1100        let result = conn.query("SELECT 1");
1101        assert!(result.is_err());
1102        let err_str = result.unwrap_err().to_string();
1103        assert!(
1104            err_str.contains("超时") || err_str.contains("Timeout") || err_str.contains("重试"),
1105            "expected timeout error, got: {err_str}"
1106        );
1107    }
1108
1109    #[test]
1110    fn connect_read_exceeds_max_message_size() {
1111        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1112        let port = listener.local_addr().unwrap().port();
1113        thread::spawn(move || {
1114            let (mut s, _) = listener.accept().unwrap();
1115            let mut buf = [0u8; 4096];
1116            let _ = s.read(&mut buf);
1117            let _ = s.write_all(&pg_auth(3, &[]));
1118            let _ = s.read(&mut buf);
1119            let _ = s.write_all(&post_auth_ok());
1120            let _ = s.read(&mut buf);
1121            let big = vec![b'X'; 256];
1122            let _ = s.write_all(&big);
1123            thread::sleep(Duration::from_secs(2));
1124        });
1125        thread::sleep(Duration::from_millis(30));
1126
1127        let mut conn = Connect::new(mock_config(port)).unwrap();
1128        let result = conn.query("SELECT 1");
1129        assert!(result.is_err());
1130        let err_str = result.unwrap_err().to_string();
1131        assert!(
1132            err_str.contains("最大") || err_str.contains("大小") || err_str.contains("size"),
1133            "expected max message size error, got: {err_str}"
1134        );
1135    }
1136
1137    #[test]
1138    fn connect_read_deadline_timeout() {
1139        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1140        let port = listener.local_addr().unwrap().port();
1141        thread::spawn(move || {
1142            let (mut s, _) = listener.accept().unwrap();
1143            let mut buf = [0u8; 4096];
1144            let _ = s.read(&mut buf);
1145            let _ = s.write_all(&pg_auth(3, &[]));
1146            let _ = s.read(&mut buf);
1147            let _ = s.write_all(&post_auth_ok());
1148            let _ = s.read(&mut buf);
1149            for _ in 0..200 {
1150                let _ = s.write_all(b"X");
1151                thread::sleep(Duration::from_millis(5));
1152            }
1153        });
1154        thread::sleep(Duration::from_millis(30));
1155
1156        let mut conn = Connect::new(mock_config(port)).unwrap();
1157        let result = conn.query("SELECT 1");
1158        assert!(result.is_err());
1159    }
1160
1161    #[test]
1162    fn connect_read_partial_auth_frame() {
1163        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1164        let port = listener.local_addr().unwrap().port();
1165        thread::spawn(move || {
1166            let (mut s, _) = listener.accept().unwrap();
1167            let mut buf = [0u8; 4096];
1168            let _ = s.read(&mut buf);
1169            let auth = pg_auth(3, &[]);
1170            let _ = s.write_all(&auth[..5]);
1171            thread::sleep(Duration::from_millis(50));
1172            let _ = s.write_all(&auth[5..]);
1173            let _ = s.read(&mut buf);
1174            let _ = s.write_all(&post_auth_ok());
1175            loop {
1176                match s.read(&mut buf) {
1177                    Ok(0) | Err(_) => break,
1178                    Ok(_) => {
1179                        let _ = s.write_all(&simple_query_response());
1180                    }
1181                }
1182            }
1183        });
1184        thread::sleep(Duration::from_millis(30));
1185
1186        let mut conn = Connect::new(mock_config(port)).unwrap();
1187        let result = conn.query("SELECT 1");
1188        assert!(result.is_ok());
1189    }
1190}