Skip to main content

br_pgsql/
lib.rs

1use crate::config::Config;
2use crate::connect::Connect;
3use crate::pools::Pools;
4use json::JsonValue;
5
6mod config;
7pub mod connect;
8mod error;
9mod format;
10mod packet;
11pub mod pools;
12
13pub use error::PgsqlError;
14
15#[derive(Clone, Debug)]
16pub struct Pgsql {
17    pub config: Config,
18}
19
20impl Pgsql {
21    pub fn new(config: &JsonValue) -> Result<Self, PgsqlError> {
22        Ok(Self {
23            config: Config::new(config),
24        })
25    }
26
27    pub fn connect(&mut self) -> Result<Connect, PgsqlError> {
28        Connect::new(self.config.clone())
29    }
30
31    pub fn pools(&mut self) -> Result<Pools, PgsqlError> {
32        let size = self.config.pool_max.max(1) as usize;
33        Pools::new(self.config.clone(), size)
34    }
35}
36
37#[cfg(test)]
38mod tests {
39    use super::Pgsql;
40    use std::io::{Read, Write};
41    use std::net::TcpListener;
42    use std::thread;
43    use std::time::Duration;
44
45    fn pg_msg(tag: u8, payload: &[u8]) -> Vec<u8> {
46        let mut m = Vec::with_capacity(5 + payload.len());
47        m.push(tag);
48        m.extend(&((payload.len() as u32 + 4).to_be_bytes()));
49        m.extend_from_slice(payload);
50        m
51    }
52
53    fn pg_auth(auth_type: u32, extra: &[u8]) -> Vec<u8> {
54        let mut body = Vec::new();
55        body.extend(&auth_type.to_be_bytes());
56        body.extend_from_slice(extra);
57        pg_msg(b'R', &body)
58    }
59
60    fn post_auth_ok() -> Vec<u8> {
61        let mut v = Vec::new();
62        v.extend(pg_auth(0, &[]));
63        v.extend(pg_msg(b'S', b"server_version\x0015.0\x00"));
64        let mut k = Vec::new();
65        k.extend(&1u32.to_be_bytes());
66        k.extend(&2u32.to_be_bytes());
67        v.extend(pg_msg(b'K', &k));
68        v.extend(pg_msg(b'Z', b"I"));
69        v
70    }
71
72    fn simple_query_response() -> Vec<u8> {
73        let mut r = Vec::new();
74        r.extend(pg_msg(b'1', &[]));
75        r.extend(pg_msg(b'2', &[]));
76        let mut rd = Vec::new();
77        rd.extend(&1u16.to_be_bytes());
78        rd.extend(b"c\x00");
79        rd.extend(&0u32.to_be_bytes());
80        rd.extend(&1u16.to_be_bytes());
81        rd.extend(&23u32.to_be_bytes());
82        rd.extend(&4i16.to_be_bytes());
83        rd.extend(&(-1i32).to_be_bytes());
84        rd.extend(&0u16.to_be_bytes());
85        r.extend(pg_msg(b'T', &rd));
86        let mut dr = Vec::new();
87        dr.extend(&1u16.to_be_bytes());
88        dr.extend(&1u32.to_be_bytes());
89        dr.push(b'1');
90        r.extend(pg_msg(b'D', &dr));
91        r.extend(pg_msg(b'C', b"SELECT 1\x00"));
92        r.extend(pg_msg(b'Z', b"I"));
93        r
94    }
95
96    fn spawn_cleartext_server() -> u16 {
97        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
98        let port = listener.local_addr().unwrap().port();
99        thread::spawn(move || {
100            listener.set_nonblocking(true).unwrap();
101            let deadline = std::time::Instant::now() + Duration::from_secs(5);
102            loop {
103                if std::time::Instant::now() >= deadline {
104                    break;
105                }
106                match listener.accept() {
107                    Ok((s, _)) => {
108                        s.set_nonblocking(false).ok();
109                        thread::spawn(move || {
110                            s.set_read_timeout(Some(Duration::from_secs(5))).ok();
111                            let mut s = s;
112                            let mut buf = [0u8; 4096];
113                            if s.read(&mut buf).unwrap_or(0) == 0 {
114                                return;
115                            }
116                            let _ = s.write_all(&pg_auth(3, &[]));
117                            if s.read(&mut buf).unwrap_or(0) == 0 {
118                                return;
119                            }
120                            let _ = s.write_all(&post_auth_ok());
121                            loop {
122                                match s.read(&mut buf) {
123                                    Ok(0) | Err(_) => break,
124                                    Ok(_) => {
125                                        let _ = s.write_all(&simple_query_response());
126                                    }
127                                }
128                            }
129                        });
130                    }
131                    Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
132                        thread::sleep(Duration::from_millis(5));
133                    }
134                    Err(_) => break,
135                }
136            }
137        });
138        thread::sleep(Duration::from_millis(30));
139        port
140    }
141
142    #[test]
143    fn new_with_empty_config_uses_defaults() {
144        let config = json::object! {};
145
146        let pgsql = Pgsql::new(&config);
147        assert!(pgsql.is_ok());
148
149        let pgsql = pgsql.unwrap();
150        assert!(!pgsql.config.debug);
151        assert_eq!(pgsql.config.hostname, "localhost");
152        assert_eq!(pgsql.config.hostport, 5432);
153        assert_eq!(pgsql.config.username, "postgres");
154        assert_eq!(pgsql.config.userpass, "111111");
155        assert_eq!(pgsql.config.database, "postgres");
156        assert_eq!(pgsql.config.charset, "utf8mb4");
157        assert_eq!(pgsql.config.pool_max, 5);
158    }
159
160    #[test]
161    fn new_with_full_config_uses_provided_values() {
162        let config = json::object! {
163            debug: true,
164            hostname: "db.example.com",
165            hostport: 15432,
166            username: "admin",
167            userpass: "secret",
168            database: "app_db",
169            charset: "utf8",
170            pool_max: 30,
171        };
172
173        let pgsql = Pgsql::new(&config);
174        assert!(pgsql.is_ok());
175
176        let pgsql = pgsql.unwrap();
177        assert!(pgsql.config.debug);
178        assert_eq!(pgsql.config.hostname, "db.example.com");
179        assert_eq!(pgsql.config.hostport, 15432);
180        assert_eq!(pgsql.config.username, "admin");
181        assert_eq!(pgsql.config.userpass, "secret");
182        assert_eq!(pgsql.config.database, "app_db");
183        assert_eq!(pgsql.config.charset, "utf8");
184        assert_eq!(pgsql.config.pool_max, 30);
185    }
186
187    #[test]
188    fn new_with_partial_config_mixes_values_and_defaults() {
189        let config = json::object! {
190            hostname: "127.0.0.1",
191            username: "reader",
192        };
193
194        let pgsql = Pgsql::new(&config);
195        assert!(pgsql.is_ok());
196
197        let pgsql = pgsql.unwrap();
198        assert!(!pgsql.config.debug);
199        assert_eq!(pgsql.config.hostname, "127.0.0.1");
200        assert_eq!(pgsql.config.hostport, 5432);
201        assert_eq!(pgsql.config.username, "reader");
202        assert_eq!(pgsql.config.userpass, "111111");
203        assert_eq!(pgsql.config.database, "postgres");
204        assert_eq!(pgsql.config.charset, "utf8mb4");
205        assert_eq!(pgsql.config.pool_max, 5);
206    }
207
208    #[test]
209    fn pgsql_connect_success() {
210        let port = spawn_cleartext_server();
211        let config = json::object! {
212            hostname: "127.0.0.1",
213            hostport: port as u32,
214            username: "u",
215            userpass: "p",
216            database: "d",
217        };
218        let mut pgsql = Pgsql::new(&config).unwrap();
219        let conn = pgsql.connect();
220        assert!(conn.is_ok());
221    }
222
223    #[test]
224    fn pgsql_connect_failure() {
225        let config = json::object! {
226            hostname: "127.0.0.1",
227            hostport: 1,
228            username: "u",
229            userpass: "p",
230            database: "d",
231        };
232        let mut pgsql = Pgsql::new(&config).unwrap();
233        let conn = pgsql.connect();
234        assert!(conn.is_err());
235    }
236
237    #[test]
238    fn pgsql_pools_success() {
239        crate::pools::DB_POOL
240            .lock()
241            .unwrap_or_else(|e| e.into_inner())
242            .clear();
243        let port = spawn_cleartext_server();
244        let config = json::object! {
245            hostname: "127.0.0.1",
246            hostport: port as u32,
247            username: "u",
248            userpass: "p",
249            database: "d",
250            pool_max: 2,
251        };
252        let mut pgsql = Pgsql::new(&config).unwrap();
253        let pools = pgsql.pools();
254        assert!(pools.is_ok());
255        crate::pools::DB_POOL
256            .lock()
257            .unwrap_or_else(|e| e.into_inner())
258            .clear();
259    }
260
261    #[test]
262    fn pgsql_pools_failure() {
263        crate::pools::DB_POOL
264            .lock()
265            .unwrap_or_else(|e| e.into_inner())
266            .clear();
267        let config = json::object! {
268            hostname: "127.0.0.1",
269            hostport: 1,
270            username: "u",
271            userpass: "p",
272            database: "d",
273            pool_max: 2,
274        };
275        let mut pgsql = Pgsql::new(&config).unwrap();
276        let pools = pgsql.pools();
277        assert!(pools.is_ok());
278        assert_eq!(pools.unwrap().total_connections(), 0);
279        crate::pools::DB_POOL
280            .lock()
281            .unwrap_or_else(|e| e.into_inner())
282            .clear();
283    }
284}