cueball_postgres_connection/
lib.rs

1/*
2 * Copyright 2020 Joyent, Inc.
3 */
4
5use std::ops::{Deref, DerefMut};
6
7use native_tls::Certificate as NativeCertificate;
8use native_tls::Error as NativeError;
9use native_tls::TlsConnector;
10use postgres;
11use postgres::{Client, NoTls};
12use postgres_native_tls::MakeTlsConnector;
13use serde_derive::Deserialize;
14
15use cueball::backend::Backend;
16use cueball::connection::Connection;
17
18pub struct PostgresConnection {
19    pub connection: Option<Client>,
20    url: String,
21    tls_config: TlsConfig,
22    connected: bool,
23}
24
25impl PostgresConnection {
26    pub fn connection_creator<'a>(
27        mut config: PostgresConnectionConfig,
28    ) -> impl FnMut(&Backend) -> PostgresConnection + 'a {
29        move |b| {
30            config.host = Some(b.address.to_string());
31            config.port = Some(b.port);
32
33            let url = config.to_owned().into();
34
35            PostgresConnection {
36                connection: None,
37                url,
38                tls_config: config.tls_config.clone(),
39                connected: false,
40            }
41        }
42    }
43}
44
45impl Connection for PostgresConnection {
46    type Error = postgres::Error;
47
48    fn connect(&mut self) -> Result<(), Self::Error> {
49        let connection =
50            if let Some(tls_connector) = make_tls_connector(&self.tls_config) {
51                Client::connect(&self.url, tls_connector)?
52            } else {
53                Client::connect(&self.url, NoTls)?
54            };
55        self.connection = Some(connection);
56        self.connected = true;
57        Ok(())
58    }
59
60    fn is_valid(&mut self) -> bool {
61        self.connection
62            .as_mut()
63            .unwrap()
64            .simple_query("")
65            .map(|_| ())
66            .is_ok()
67    }
68
69    fn has_broken(&self) -> bool {
70        match &self.connection {
71            Some(conn) => conn.is_closed(),
72            None => false,
73        }
74    }
75
76    fn close(&mut self) -> Result<(), Self::Error> {
77        self.connection = None;
78        self.connected = false;
79        Ok(())
80    }
81}
82
83impl Deref for PostgresConnection {
84    type Target = Client;
85
86    fn deref(&self) -> &Client {
87        &self.connection.as_ref().unwrap()
88    }
89}
90
91impl DerefMut for PostgresConnection {
92    fn deref_mut(&mut self) -> &mut Client {
93        self.connection.as_mut().unwrap()
94    }
95}
96
97#[derive(Clone)]
98pub struct PostgresConnectionConfig {
99    pub user: Option<String>,
100    pub password: Option<String>,
101    pub host: Option<String>,
102    pub port: Option<u16>,
103    pub database: Option<String>,
104    pub application_name: Option<String>,
105    pub tls_config: TlsConfig,
106}
107
108impl From<PostgresConnectionConfig> for String {
109    fn from(config: PostgresConnectionConfig) -> Self {
110        let scheme = "postgresql://";
111        let user = config.user.unwrap_or_else(|| "".into());
112
113        let at = if user.is_empty() { "" } else { "@" };
114
115        let host = config.host.unwrap_or_else(|| String::from("localhost"));
116        let port = config
117            .port
118            .map(|p| p.to_string())
119            .unwrap_or_else(|| "".to_string());
120
121        let colon = if port.is_empty() { "" } else { ":" };
122
123        let database = config.database.unwrap_or_else(|| "".into());
124
125        let slash = if database.is_empty() { "" } else { "/" };
126
127        let application_name =
128            config.application_name.unwrap_or_else(|| "".into());
129        let question_mark = "?";
130
131        let app_name_param = if application_name.is_empty() {
132            ""
133        } else {
134            "application_name="
135        };
136
137        let ssl_mode = config.tls_config.mode.to_string();
138        let ssl_mode_param = if application_name.is_empty() {
139            "sslmode="
140        } else {
141            "&sslmode="
142        };
143
144        [
145            scheme,
146            user.as_str(),
147            at,
148            host.as_str(),
149            colon,
150            port.as_str(),
151            slash,
152            database.as_str(),
153            question_mark,
154            app_name_param,
155            application_name.as_str(),
156            ssl_mode_param,
157            ssl_mode.as_str(),
158        ]
159        .concat()
160    }
161}
162
163#[derive(Debug, Clone, Deserialize)]
164pub enum TlsConnectMode {
165    #[serde(alias = "disable")]
166    Disable,
167    #[serde(alias = "allow")]
168    Allow,
169    #[serde(alias = "prefer")]
170    Prefer,
171    #[serde(alias = "require")]
172    Require,
173    #[serde(alias = "verify-ca")]
174    VerifyCa,
175    #[serde(alias = "verify-full")]
176    VerifyFull,
177}
178
179impl ToString for TlsConnectMode {
180    fn to_string(&self) -> String {
181        match self {
182            TlsConnectMode::Disable => String::from("disable"),
183            TlsConnectMode::Allow => String::from("allow"),
184            TlsConnectMode::Prefer => String::from("prefer"),
185            TlsConnectMode::Require => String::from("require"),
186            TlsConnectMode::VerifyCa => String::from("verify-ca"),
187            TlsConnectMode::VerifyFull => String::from("verify-full"),
188        }
189    }
190}
191
192/// An X509 certificate.
193pub type Certificate = NativeCertificate;
194
195/// An error returned from the TLS implementation.
196pub type CertificateError = NativeError;
197
198#[derive(Clone)]
199pub struct TlsConfig {
200    pub(self) mode: TlsConnectMode,
201    pub(self) certificate: Option<Certificate>,
202}
203
204impl TlsConfig {
205    pub fn disable() -> Self {
206        TlsConfig {
207            mode: TlsConnectMode::Disable,
208            certificate: None,
209        }
210    }
211
212    pub fn allow(certificate: Option<Certificate>) -> Self {
213        TlsConfig {
214            mode: TlsConnectMode::Allow,
215            certificate,
216        }
217    }
218
219    pub fn prefer(certificate: Option<Certificate>) -> Self {
220        TlsConfig {
221            mode: TlsConnectMode::Prefer,
222            certificate,
223        }
224    }
225
226    pub fn require(certificate: Option<Certificate>) -> Self {
227        TlsConfig {
228            mode: TlsConnectMode::Require,
229            certificate,
230        }
231    }
232
233    pub fn verify_ca(certificate: Certificate) -> Self {
234        TlsConfig {
235            mode: TlsConnectMode::VerifyCa,
236            certificate: Some(certificate),
237        }
238    }
239
240    pub fn verify_full(certificate: Certificate) -> Self {
241        TlsConfig {
242            mode: TlsConnectMode::VerifyFull,
243            certificate: Some(certificate),
244        }
245    }
246}
247
248fn make_tls_connector(tls_config: &TlsConfig) -> Option<MakeTlsConnector> {
249    let m_cert = tls_config.certificate.clone();
250    match tls_config.mode {
251        TlsConnectMode::Disable => None,
252        TlsConnectMode::Allow
253        | TlsConnectMode::Prefer
254        | TlsConnectMode::Require => {
255            if let Some(cert) = m_cert {
256                // root cert supplied, use it to verify server certs
257                let connector = TlsConnector::builder()
258                    .add_root_certificate(cert)
259                    .build()
260                    .unwrap();
261                let connector = MakeTlsConnector::new(connector);
262                Some(connector)
263            } else {
264                // no cert is given, disable certificate verification
265                // should we emit a warning to stderr since the function has "danger" in it?
266                let connector = TlsConnector::builder()
267                    .danger_accept_invalid_certs(true)
268                    .build()
269                    .unwrap();
270                let connector = MakeTlsConnector::new(connector);
271                Some(connector)
272            }
273        }
274        TlsConnectMode::VerifyCa | TlsConnectMode::VerifyFull => {
275            let cert = m_cert.expect(
276                "A certificate is required for \
277                 verify-ca, and verify-full SSL modes",
278            );
279            let connector = TlsConnector::builder()
280                .add_root_certificate(cert)
281                .build()
282                .unwrap();
283            Some(MakeTlsConnector::new(connector))
284        }
285    }
286}