br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use std::io::{Read, Write};
4use std::net::TcpStream;
5use std::sync::Arc;
6use std::time::Duration;
7
8#[derive(Clone, Debug)]
9pub struct Connect {
10    /// 基础配置
11    pub config: Config,
12    pub stream: Arc<TcpStream>,
13    pub types: String,
14    pub packet: Packet,
15    /// 认证状态
16    auth_status: AuthStatus,
17}
18
19impl Connect {
20    pub fn new(mut config: Config) -> Result<Connect, String> {
21        let stream = match TcpStream::connect(config.url()) {
22            Ok(stream) => stream,
23            Err(e) => return Err(e.to_string()),
24        };
25        stream.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
26        stream.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
27
28        let mut connect = Self {
29            config: config.clone(),
30            stream: Arc::new(stream),
31            types: "".to_string(),
32            packet: Packet::new(config),
33            auth_status: AuthStatus::None,
34        };
35
36        connect.startup_message()?;
37        connect.sasl_initial_response_message()?;
38
39        //connect.auth=true;
40        Ok(connect)
41    }
42
43    fn read(&mut self) -> Result<Vec<u8>, String> {
44        let mut msg = vec![];
45        loop {
46            let mut response = [0u8; 1024];
47            match self.stream.try_clone().unwrap().read(&mut response) {
48                Ok(e) => {
49                    msg.extend(response[..e].to_vec());
50                }
51                Err(e) => return Err(format!("Error reading from stream: {}", e)),
52            };
53            if msg.is_empty() {
54                continue;
55            }
56
57            match self.auth_status {
58                AuthStatus::AuthenticationOk => {
59                    if msg.ends_with(&[90, 0, 0, 0, 5, 73]) {
60                        break;
61                    }
62                    continue;
63                }
64                _ => {
65                    let t = &msg[1..=4];
66                    let len = u32::from_be_bytes(t.try_into().unwrap());
67                    if msg.len() < (len as usize) {
68                        continue;
69                    }
70                    break;
71                }
72            }
73        }
74        Ok(msg)
75    }
76    /// Startup Message
77    fn startup_message(&mut self) -> Result<(), String> {
78        self.stream.try_clone().unwrap().write_all(&self.packet.pack_first()).unwrap();
79        self.packet.unpack(self.clone().read()?)?;
80        Ok(())
81    }
82    /// SASLInitialResponse message
83    fn sasl_initial_response_message(&mut self) -> Result<(), String> {
84        self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth()).unwrap();
85        self.packet.unpack(self.clone().read()?)?;
86        self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth_verify()).unwrap();
87        self.packet.unpack(self.clone().read()?)?;
88        self.auth_status = AuthStatus::AuthenticationOk;
89        Ok(())
90    }
91    /// 查询
92    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
93        self.stream.try_clone().unwrap().write_all(&self.packet.pack_query(sql)).unwrap();
94        let res = self.packet.unpack(self.clone().read()?)?;
95        Ok(res)
96    }
97    /// 执行
98    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
99        self.stream.try_clone().unwrap().write_all(&self.packet.pack_execute(sql)).unwrap();
100        let res = self.packet.unpack(self.clone().read()?)?;
101        Ok(res)
102    }
103}