br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use std::io::{Read, Write};
4use std::net::TcpStream;
5use std::sync::Arc;
6use std::time::Duration;
7
8#[derive(Clone, Debug)]
9pub struct Connect {
10    /// 基础配置
11    stream: Arc<TcpStream>,
12    packet: Packet,
13    /// 认证状态
14    auth_status: AuthStatus,
15}
16
17impl Connect {
18    pub fn new(mut config: Config) -> Result<Connect, String> {
19        let stream = match TcpStream::connect(config.url()) {
20            Ok(stream) => stream,
21            Err(e) => return Err(e.to_string()),
22        };
23        stream.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
24        stream.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
25
26        let mut connect = Self {
27            stream: Arc::new(stream),
28            packet: Packet::new(config),
29            auth_status: AuthStatus::None,
30        };
31
32        connect.startup_message()?;
33        connect.sasl_initial_response_message()?;
34
35        Ok(connect)
36    }
37
38    fn read(&mut self) -> Result<Vec<u8>, String> {
39        let mut msg = vec![];
40        loop {
41            let mut response = [0u8; 1024];
42            match self.stream.try_clone().unwrap().read(&mut response) {
43                Ok(e) => {
44                    msg.extend(response[..e].to_vec());
45                }
46                Err(e) => return Err(format!("Error reading from stream: {e}")),
47            }
48            if msg.is_empty() {
49                continue;
50            }
51            if let AuthStatus::AuthenticationOk = self.auth_status {
52                if msg.ends_with(&[90, 0, 0, 0, 5, 73]) | msg.ends_with(&[90, 0, 0, 0, 5, 84]) | msg.ends_with(&[90, 0, 0, 0, 5, 69]) {
53                    break;
54                }
55                continue;
56            } else {
57                let t = &msg[1..=4];
58                let len = u32::from_be_bytes(t.try_into().unwrap());
59                if msg.len() < (len as usize) {
60                    continue;
61                }
62                break;
63            }
64        }
65        Ok(msg)
66    }
67    /// Startup Message
68    fn startup_message(&mut self) -> Result<(), String> {
69        self.stream.try_clone().unwrap().write_all(&self.packet.pack_first()).unwrap();
70        let data = self.read()?;
71        self.packet.unpack(data, 0)?;
72        Ok(())
73    }
74    /// `SASLInitialResponse` message
75    fn sasl_initial_response_message(&mut self) -> Result<(), String> {
76        self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth()).unwrap();
77        let data = self.read()?;
78        self.packet.unpack(data, 0)?;
79        self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth_verify()).unwrap();
80        let data = self.read()?;
81        self.packet.unpack(data, 0)?;
82        self.auth_status = AuthStatus::AuthenticationOk;
83        Ok(())
84    }
85    /// 查询
86    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
87        self.stream.try_clone().unwrap().write_all(&self.packet.pack_query(sql)).unwrap();
88        let data = self.read()?;
89        let mut packet = self.packet.clone();
90
91        std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
92            packet.unpack(data, 0)
93        }).unwrap().join().unwrap()
94    }
95    /// 执行
96    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
97        self.stream.try_clone().unwrap().write_all(&self.packet.pack_execute(sql)).unwrap();
98        let data = self.read()?;
99        let mut packet = self.packet.clone();
100
101        std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
102            packet.unpack(data, 0)
103        }).unwrap().join().unwrap()
104    }
105}