br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use std::io::{Read, Write};
4use std::net::TcpStream;
5use std::sync::Arc;
6use std::time::Duration;
7
8#[derive(Clone, Debug)]
9pub struct Connect {
10    /// 基础配置
11    pub config: Config,
12    pub stream: Arc<TcpStream>,
13    pub types: String,
14    pub packet: Packet,
15    /// 认证状态
16    auth_status: AuthStatus,
17}
18
19impl Connect {
20    pub fn new(mut config: Config) -> Result<Connect, String> {
21        let stream = match TcpStream::connect(config.url()) {
22            Ok(stream) => stream,
23            Err(e) => return Err(e.to_string()),
24        };
25        stream.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
26        stream.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
27
28        let mut connect = Self {
29            config: config.clone(),
30            stream: Arc::new(stream),
31            types: "".to_string(),
32            packet: Packet::new(config),
33            auth_status: AuthStatus::None,
34        };
35
36        connect.startup_message()?;
37        connect.sasl_initial_response_message()?;
38
39        //connect.auth=true;
40        Ok(connect)
41    }
42
43    fn read(&mut self) -> Result<Vec<u8>, String> {
44        let mut msg = vec![];
45        loop {
46            let mut response = [0u8; 1024];
47            match self.stream.try_clone().unwrap().read(&mut response) {
48                Ok(e) => {
49                    msg.extend(response[..e].to_vec());
50                }
51                Err(e) => return Err(format!("Error reading from stream: {}", e)),
52            };
53            if msg.is_empty() {
54                continue;
55            }
56            match self.auth_status {
57                AuthStatus::AuthenticationOk => {
58                    if msg.ends_with(&[90, 0, 0, 0, 5, 73]) | msg.ends_with(&[90, 0, 0, 0, 5, 84]) | msg.ends_with(&[90, 0, 0, 0, 5, 69]) {
59                        break;
60                    }
61                    continue;
62                }
63                _ => {
64                    let t = &msg[1..=4];
65                    let len = u32::from_be_bytes(t.try_into().unwrap());
66                    if msg.len() < (len as usize) {
67                        continue;
68                    }
69                    break;
70                }
71            }
72        }
73        Ok(msg)
74    }
75    /// Startup Message
76    fn startup_message(&mut self) -> Result<(), String> {
77        self.stream.try_clone().unwrap().write_all(&self.packet.pack_first()).unwrap();
78        let data = self.read()?;
79        self.packet.unpack(data)?;
80        Ok(())
81    }
82    /// SASLInitialResponse message
83    fn sasl_initial_response_message(&mut self) -> Result<(), String> {
84        self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth()).unwrap();
85        let data = self.read()?;
86        self.packet.unpack(data)?;
87        self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth_verify()).unwrap();
88        let data = self.read()?;
89        self.packet.unpack(data)?;
90        self.auth_status = AuthStatus::AuthenticationOk;
91        Ok(())
92    }
93    /// 查询
94    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
95        self.stream.try_clone().unwrap().write_all(&self.packet.pack_query(sql)).unwrap();
96        let data = self.read()?;
97        let res = self.packet.unpack(data)?;
98        Ok(res)
99    }
100    /// 执行
101    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
102        self.stream.try_clone().unwrap().write_all(&self.packet.pack_execute(sql)).unwrap();
103        let data = self.read()?;
104        let res = self.packet.unpack(data)?;
105        Ok(res)
106    }
107}
108
109
110impl Drop for Connect {
111    fn drop(&mut self) {
112        println!("Dropping connection {:?} {}", self.auth_status,self.packet.client_nonce);
113    }
114}