br_pgsql/
connect.rs

1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use std::io::{Read, Write};
4use std::net::TcpStream;
5use std::sync::Arc;
6use std::time::Duration;
7
8#[derive(Clone, Debug)]
9pub struct Connect {
10    /// 基础配置
11    stream: Arc<TcpStream>,
12    types: String,
13    packet: Packet,
14    /// 认证状态
15    auth_status: AuthStatus,
16}
17
18impl Connect {
19    pub fn new(mut config: Config) -> Result<Connect, String> {
20        let stream = match TcpStream::connect(config.url()) {
21            Ok(stream) => stream,
22            Err(e) => return Err(e.to_string()),
23        };
24        stream.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
25        stream.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
26
27        let mut connect = Self {
28            stream: Arc::new(stream),
29            types: String::new(),
30            packet: Packet::new(config),
31            auth_status: AuthStatus::None,
32        };
33
34        connect.startup_message()?;
35        connect.sasl_initial_response_message()?;
36
37        Ok(connect)
38    }
39
40    fn read(&mut self) -> Result<Vec<u8>, String> {
41        let mut msg = vec![];
42        loop {
43            let mut response = [0u8; 1024];
44            match self.stream.try_clone().unwrap().read(&mut response) {
45                Ok(e) => {
46                    msg.extend(response[..e].to_vec());
47                }
48                Err(e) => return Err(format!("Error reading from stream: {e}")),
49            }
50            if msg.is_empty() {
51                continue;
52            }
53            if let AuthStatus::AuthenticationOk = self.auth_status {
54                if msg.ends_with(&[90, 0, 0, 0, 5, 73]) | msg.ends_with(&[90, 0, 0, 0, 5, 84]) | msg.ends_with(&[90, 0, 0, 0, 5, 69]) {
55                    break;
56                }
57                continue;
58            } else {
59                let t = &msg[1..=4];
60                let len = u32::from_be_bytes(t.try_into().unwrap());
61                if msg.len() < (len as usize) {
62                    continue;
63                }
64                break;
65            }
66        }
67        Ok(msg)
68    }
69    /// Startup Message
70    fn startup_message(&mut self) -> Result<(), String> {
71        self.stream.try_clone().unwrap().write_all(&self.packet.pack_first()).unwrap();
72        let data = self.read()?;
73        self.packet.unpack(data, 0)?;
74        Ok(())
75    }
76    /// `SASLInitialResponse` message
77    fn sasl_initial_response_message(&mut self) -> Result<(), String> {
78        self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth()).unwrap();
79        let data = self.read()?;
80        self.packet.unpack(data, 0)?;
81        self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth_verify()).unwrap();
82        let data = self.read()?;
83        self.packet.unpack(data, 0)?;
84        self.auth_status = AuthStatus::AuthenticationOk;
85        Ok(())
86    }
87    /// 查询
88    pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
89        self.stream.try_clone().unwrap().write_all(&self.packet.pack_query(sql)).unwrap();
90        let data = self.read()?;
91        let mut packet = self.packet.clone();
92
93        std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
94            packet.unpack(data, 0)
95        }).unwrap().join().unwrap()
96    }
97    /// 执行
98    pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
99        self.stream.try_clone().unwrap().write_all(&self.packet.pack_execute(sql)).unwrap();
100        let data = self.read()?;
101        let mut packet = self.packet.clone();
102
103        std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
104            packet.unpack(data, 0)
105        }).unwrap().join().unwrap()
106    }
107}
108
109impl Drop for Connect {
110    fn drop(&mut self) {
111        println!("Dropping connection :{}", self.types);
112    }
113}