1use crate::config::Config;
2use crate::packet::{AuthStatus, Packet, SuccessMessage};
3use std::io::{Read, Write};
4use std::net::TcpStream;
5use std::sync::Arc;
6use std::time::Duration;
7
8#[derive(Clone, Debug)]
9pub struct Connect {
10 pub config: Config,
12 pub stream: Arc<TcpStream>,
13 pub types: String,
14 pub packet: Packet,
15 auth_status: AuthStatus,
17}
18
19impl Connect {
20 pub fn new(mut config: Config) -> Result<Connect, String> {
21 let stream = match TcpStream::connect(config.url()) {
22 Ok(stream) => stream,
23 Err(e) => return Err(e.to_string()),
24 };
25 stream.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
26 stream.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
27
28 let mut connect = Self {
29 config: config.clone(),
30 stream: Arc::new(stream),
31 types: "".to_string(),
32 packet: Packet::new(config),
33 auth_status: AuthStatus::None,
34 };
35
36 connect.startup_message()?;
37 connect.sasl_initial_response_message()?;
38
39 Ok(connect)
41 }
42
43 fn read(&mut self) -> Result<Vec<u8>, String> {
44 let mut msg = vec![];
45 loop {
46 let mut response = [0u8; 1024];
47 match self.stream.try_clone().unwrap().read(&mut response) {
48 Ok(e) => {
49 msg.extend(response[..e].to_vec());
50 }
51 Err(e) => return Err(format!("Error reading from stream: {}", e)),
52 };
53 if msg.is_empty() {
54 continue;
55 }
56 match self.auth_status {
57 AuthStatus::AuthenticationOk => {
58 if msg.ends_with(&[90, 0, 0, 0, 5, 73]) | msg.ends_with(&[90, 0, 0, 0, 5, 84]) | msg.ends_with(&[90, 0, 0, 0, 5, 69]) {
59 break;
60 }
61 continue;
62 }
63 _ => {
64 let t = &msg[1..=4];
65 let len = u32::from_be_bytes(t.try_into().unwrap());
66 if msg.len() < (len as usize) {
67 continue;
68 }
69 break;
70 }
71 }
72 }
73 Ok(msg)
74 }
75 fn startup_message(&mut self) -> Result<(), String> {
77 self.stream.try_clone().unwrap().write_all(&self.packet.pack_first()).unwrap();
78 let data = self.read()?;
79 self.packet.unpack(data, 0)?;
80 Ok(())
81 }
82 fn sasl_initial_response_message(&mut self) -> Result<(), String> {
84 self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth()).unwrap();
85 let data = self.read()?;
86 self.packet.unpack(data, 0)?;
87 self.stream.try_clone().unwrap().write_all(&self.packet.pack_auth_verify()).unwrap();
88 let data = self.read()?;
89 self.packet.unpack(data, 0)?;
90 self.auth_status = AuthStatus::AuthenticationOk;
91 Ok(())
92 }
93 pub fn query(&mut self, sql: &str) -> Result<SuccessMessage, String> {
95 self.stream.try_clone().unwrap().write_all(&self.packet.pack_query(sql)).unwrap();
96 let data = self.read()?;
97 let mut packet = self.packet.clone();
98
99 std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
100 packet.unpack(data, 0)
101 }).unwrap().join().unwrap()
102 }
103 pub fn execute(&mut self, sql: &str) -> Result<SuccessMessage, String> {
105 self.stream.try_clone().unwrap().write_all(&self.packet.pack_execute(sql)).unwrap();
106 let data = self.read()?;
107 let mut packet = self.packet.clone();
108
109 std::thread::Builder::new().stack_size(8 * 1024 * 1024).spawn(move || -> Result<SuccessMessage, String> {
110 packet.unpack(data, 0)
111 }).unwrap().join().unwrap()
112 }
113}