cueball_postgres_connection/
lib.rs1use std::ops::{Deref, DerefMut};
6
7use native_tls::Certificate as NativeCertificate;
8use native_tls::Error as NativeError;
9use native_tls::TlsConnector;
10use postgres;
11use postgres::{Client, NoTls};
12use postgres_native_tls::MakeTlsConnector;
13use serde_derive::Deserialize;
14
15use cueball::backend::Backend;
16use cueball::connection::Connection;
17
18pub struct PostgresConnection {
19 pub connection: Option<Client>,
20 url: String,
21 tls_config: TlsConfig,
22 connected: bool,
23}
24
25impl PostgresConnection {
26 pub fn connection_creator<'a>(
27 mut config: PostgresConnectionConfig,
28 ) -> impl FnMut(&Backend) -> PostgresConnection + 'a {
29 move |b| {
30 config.host = Some(b.address.to_string());
31 config.port = Some(b.port);
32
33 let url = config.to_owned().into();
34
35 PostgresConnection {
36 connection: None,
37 url,
38 tls_config: config.tls_config.clone(),
39 connected: false,
40 }
41 }
42 }
43}
44
45impl Connection for PostgresConnection {
46 type Error = postgres::Error;
47
48 fn connect(&mut self) -> Result<(), Self::Error> {
49 let connection =
50 if let Some(tls_connector) = make_tls_connector(&self.tls_config) {
51 Client::connect(&self.url, tls_connector)?
52 } else {
53 Client::connect(&self.url, NoTls)?
54 };
55 self.connection = Some(connection);
56 self.connected = true;
57 Ok(())
58 }
59
60 fn is_valid(&mut self) -> bool {
61 self.connection
62 .as_mut()
63 .unwrap()
64 .simple_query("")
65 .map(|_| ())
66 .is_ok()
67 }
68
69 fn has_broken(&self) -> bool {
70 match &self.connection {
71 Some(conn) => conn.is_closed(),
72 None => false,
73 }
74 }
75
76 fn close(&mut self) -> Result<(), Self::Error> {
77 self.connection = None;
78 self.connected = false;
79 Ok(())
80 }
81}
82
83impl Deref for PostgresConnection {
84 type Target = Client;
85
86 fn deref(&self) -> &Client {
87 &self.connection.as_ref().unwrap()
88 }
89}
90
91impl DerefMut for PostgresConnection {
92 fn deref_mut(&mut self) -> &mut Client {
93 self.connection.as_mut().unwrap()
94 }
95}
96
97#[derive(Clone)]
98pub struct PostgresConnectionConfig {
99 pub user: Option<String>,
100 pub password: Option<String>,
101 pub host: Option<String>,
102 pub port: Option<u16>,
103 pub database: Option<String>,
104 pub application_name: Option<String>,
105 pub tls_config: TlsConfig,
106}
107
108impl From<PostgresConnectionConfig> for String {
109 fn from(config: PostgresConnectionConfig) -> Self {
110 let scheme = "postgresql://";
111 let user = config.user.unwrap_or_else(|| "".into());
112
113 let at = if user.is_empty() { "" } else { "@" };
114
115 let host = config.host.unwrap_or_else(|| String::from("localhost"));
116 let port = config
117 .port
118 .map(|p| p.to_string())
119 .unwrap_or_else(|| "".to_string());
120
121 let colon = if port.is_empty() { "" } else { ":" };
122
123 let database = config.database.unwrap_or_else(|| "".into());
124
125 let slash = if database.is_empty() { "" } else { "/" };
126
127 let application_name =
128 config.application_name.unwrap_or_else(|| "".into());
129 let question_mark = "?";
130
131 let app_name_param = if application_name.is_empty() {
132 ""
133 } else {
134 "application_name="
135 };
136
137 let ssl_mode = config.tls_config.mode.to_string();
138 let ssl_mode_param = if application_name.is_empty() {
139 "sslmode="
140 } else {
141 "&sslmode="
142 };
143
144 [
145 scheme,
146 user.as_str(),
147 at,
148 host.as_str(),
149 colon,
150 port.as_str(),
151 slash,
152 database.as_str(),
153 question_mark,
154 app_name_param,
155 application_name.as_str(),
156 ssl_mode_param,
157 ssl_mode.as_str(),
158 ]
159 .concat()
160 }
161}
162
163#[derive(Debug, Clone, Deserialize)]
164pub enum TlsConnectMode {
165 #[serde(alias = "disable")]
166 Disable,
167 #[serde(alias = "allow")]
168 Allow,
169 #[serde(alias = "prefer")]
170 Prefer,
171 #[serde(alias = "require")]
172 Require,
173 #[serde(alias = "verify-ca")]
174 VerifyCa,
175 #[serde(alias = "verify-full")]
176 VerifyFull,
177}
178
179impl ToString for TlsConnectMode {
180 fn to_string(&self) -> String {
181 match self {
182 TlsConnectMode::Disable => String::from("disable"),
183 TlsConnectMode::Allow => String::from("allow"),
184 TlsConnectMode::Prefer => String::from("prefer"),
185 TlsConnectMode::Require => String::from("require"),
186 TlsConnectMode::VerifyCa => String::from("verify-ca"),
187 TlsConnectMode::VerifyFull => String::from("verify-full"),
188 }
189 }
190}
191
192pub type Certificate = NativeCertificate;
194
195pub type CertificateError = NativeError;
197
198#[derive(Clone)]
199pub struct TlsConfig {
200 pub(self) mode: TlsConnectMode,
201 pub(self) certificate: Option<Certificate>,
202}
203
204impl TlsConfig {
205 pub fn disable() -> Self {
206 TlsConfig {
207 mode: TlsConnectMode::Disable,
208 certificate: None,
209 }
210 }
211
212 pub fn allow(certificate: Option<Certificate>) -> Self {
213 TlsConfig {
214 mode: TlsConnectMode::Allow,
215 certificate,
216 }
217 }
218
219 pub fn prefer(certificate: Option<Certificate>) -> Self {
220 TlsConfig {
221 mode: TlsConnectMode::Prefer,
222 certificate,
223 }
224 }
225
226 pub fn require(certificate: Option<Certificate>) -> Self {
227 TlsConfig {
228 mode: TlsConnectMode::Require,
229 certificate,
230 }
231 }
232
233 pub fn verify_ca(certificate: Certificate) -> Self {
234 TlsConfig {
235 mode: TlsConnectMode::VerifyCa,
236 certificate: Some(certificate),
237 }
238 }
239
240 pub fn verify_full(certificate: Certificate) -> Self {
241 TlsConfig {
242 mode: TlsConnectMode::VerifyFull,
243 certificate: Some(certificate),
244 }
245 }
246}
247
248fn make_tls_connector(tls_config: &TlsConfig) -> Option<MakeTlsConnector> {
249 let m_cert = tls_config.certificate.clone();
250 match tls_config.mode {
251 TlsConnectMode::Disable => None,
252 TlsConnectMode::Allow
253 | TlsConnectMode::Prefer
254 | TlsConnectMode::Require => {
255 if let Some(cert) = m_cert {
256 let connector = TlsConnector::builder()
258 .add_root_certificate(cert)
259 .build()
260 .unwrap();
261 let connector = MakeTlsConnector::new(connector);
262 Some(connector)
263 } else {
264 let connector = TlsConnector::builder()
267 .danger_accept_invalid_certs(true)
268 .build()
269 .unwrap();
270 let connector = MakeTlsConnector::new(connector);
271 Some(connector)
272 }
273 }
274 TlsConnectMode::VerifyCa | TlsConnectMode::VerifyFull => {
275 let cert = m_cert.expect(
276 "A certificate is required for \
277 verify-ca, and verify-full SSL modes",
278 );
279 let connector = TlsConnector::builder()
280 .add_root_certificate(cert)
281 .build()
282 .unwrap();
283 Some(MakeTlsConnector::new(connector))
284 }
285 }
286}