1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use std::io::{Read, Write};
4use std::net::TcpStream;
5use std::sync::Arc;
6use std::time::Duration;
7
8#[derive(Clone, Debug)]
9pub struct Connect {
10 stream: Arc<TcpStream>,
12 types: String,
13 packet: Packet,
14 auth_status: AuthStatus,
16}
17
18impl Connect {
19 pub fn new(mut config: Config) -> Result<Connect, String> {
20 let stream = match TcpStream::connect(config.url()) {
21 Ok(stream) => stream,
22 Err(e) => return Err(e.to_string()),
23 };
24 stream.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
25 stream.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
26
27 let mut connect = Self {
28 stream: Arc::new(stream),
29 types: String::new(),
30 packet: Packet::new(config),
31 auth_status: AuthStatus::None,
32 };
33
34 connect.startup_message()?;
35 connect.sasl_initial_response_message()?;
36
37 Ok(connect)
38 }
39
40 fn read(&mut self) -> Result<Vec<u8>, String> {
41 let mut msg = vec![];
42 loop {
43 let mut response = [0u8; 1024];
44 match self.stream.try_clone().unwrap().read(&mut response) {
45 Ok(e) => {
46 msg.extend(response[..e].to_vec());
47 }
48 Err(e) => return Err(format!("Error reading from stream: {e}")),
49 }
50 if msg.is_empty() {
51 continue;
52 }
53 if let AuthStatus::AuthenticationOk = self.auth_status {
54 if msg.ends_with(&[90, 0, 0, 0, 5, 73]) | msg.ends_with(&[90, 0, 0, 0, 5, 84]) | msg.ends_with(&[90, 0, 0, 0, 5, 69]) {
55 break;
56 }
57 continue;
58 } else {
59 let t = &msg[1..=4];
60 let len = u32::from_be_bytes(t.try_into().unwrap());
61 if msg.len() < (len as usize) {
62 continue;
63 }
64 break;
65 }
66 }
67 Ok(msg)
68 }
69 fn startup_message(&mut self) -> Result<(), String> {
71 self.stream.try_clone().unwrap().write_all(&self.packet.pack_first()).unwrap();
72 let data = self.read()?;
73 self.packet.unpack(data, 0)?;
74 Ok(())
75 }
76 fn sasl_initial_response_message(&mut self) -> Result<(), String> {
78 self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth()).unwrap();
79 let data = self.read()?;
80 self.packet.unpack(data, 0)?;
81 self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth_verify()).unwrap();
82 let data = self.read()?;
83 self.packet.unpack(data, 0)?;
84 self.auth_status = AuthStatus::AuthenticationOk;
85 Ok(())
86 }
87 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
89 self.stream.try_clone().unwrap().write_all(&self.packet.pack_query(sql)).unwrap();
90 let data = self.read()?;
91 let mut packet = self.packet.clone();
92
93 std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
94 packet.unpack(data, 0)
95 }).unwrap().join().unwrap()
96 }
97 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
99 self.stream.try_clone().unwrap().write_all(&self.packet.pack_execute(sql)).unwrap();
100 let data = self.read()?;
101 let mut packet = self.packet.clone();
102
103 std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
104 packet.unpack(data, 0)
105 }).unwrap().join().unwrap()
106 }
107}
108
109impl Drop for Connect {
110 fn drop(&mut self) {
111 println!("Dropping connection :{}", self.types);
112 }
113}