1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use std::io::{Read, Write};
4use std::net::TcpStream;
5use std::sync::Arc;
6use std::time::Duration;
7
8#[derive(Clone, Debug)]
9pub struct Connect {
10 stream: Arc<TcpStream>,
12 packet: Packet,
13 auth_status: AuthStatus,
15}
16
17impl Connect {
18 pub fn new(mut config: Config) -> Result<Connect, String> {
19 let stream = match TcpStream::connect(config.url()) {
20 Ok(stream) => stream,
21 Err(e) => return Err(e.to_string()),
22 };
23 stream.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
24 stream.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
25
26 let mut connect = Self {
27 stream: Arc::new(stream),
28 packet: Packet::new(config),
29 auth_status: AuthStatus::None,
30 };
31
32 connect.startup_message()?;
33 connect.sasl_initial_response_message()?;
34
35 Ok(connect)
36 }
37
38 fn read(&mut self) -> Result<Vec<u8>, String> {
39 let mut msg = vec![];
40 loop {
41 let mut response = [0u8; 1024];
42 match self.stream.try_clone().unwrap().read(&mut response) {
43 Ok(e) => {
44 msg.extend(response[..e].to_vec());
45 }
46 Err(e) => return Err(format!("Error reading from stream: {e}")),
47 }
48 if msg.is_empty() {
49 continue;
50 }
51 if let AuthStatus::AuthenticationOk = self.auth_status {
52 if msg.ends_with(&[90, 0, 0, 0, 5, 73]) | msg.ends_with(&[90, 0, 0, 0, 5, 84]) | msg.ends_with(&[90, 0, 0, 0, 5, 69]) {
53 break;
54 }
55 continue;
56 } else {
57 let t = &msg[1..=4];
58 let len = u32::from_be_bytes(t.try_into().unwrap());
59 if msg.len() < (len as usize) {
60 continue;
61 }
62 break;
63 }
64 }
65 Ok(msg)
66 }
67 fn startup_message(&mut self) -> Result<(), String> {
69 self.stream.try_clone().unwrap().write_all(&self.packet.pack_first()).unwrap();
70 let data = self.read()?;
71 self.packet.unpack(data, 0)?;
72 Ok(())
73 }
74 fn sasl_initial_response_message(&mut self) -> Result<(), String> {
76 self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth()).unwrap();
77 let data = self.read()?;
78 self.packet.unpack(data, 0)?;
79 self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth_verify()).unwrap();
80 let data = self.read()?;
81 self.packet.unpack(data, 0)?;
82 self.auth_status = AuthStatus::AuthenticationOk;
83 Ok(())
84 }
85 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
87 self.stream.try_clone().unwrap().write_all(&self.packet.pack_query(sql)).unwrap();
88 let data = self.read()?;
89 let mut packet = self.packet.clone();
90
91 std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
92 packet.unpack(data, 0)
93 }).unwrap().join().unwrap()
94 }
95 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
97 self.stream.try_clone().unwrap().write_all(&self.packet.pack_execute(sql)).unwrap();
98 let data = self.read()?;
99 let mut packet = self.packet.clone();
100
101 std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
102 packet.unpack(data, 0)
103 }).unwrap().join().unwrap()
104 }
105}