Skip to main content

omnia_postgres/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg(not(target_arch = "wasm32"))]
3
4mod sql;
5mod types;
6
7use std::collections::HashMap;
8use std::str;
9
10use anyhow::{Context as _, Result, anyhow};
11use deadpool_postgres::{Pool, PoolConfig, Runtime};
12use omnia::Backend;
13use rustls::crypto::ring;
14use rustls::{ClientConfig, RootCertStore};
15use tokio_postgres::config::{Host, SslMode};
16use tokio_postgres_rustls::MakeRustlsConnect;
17use tracing::instrument;
18use webpki_roots::TLS_SERVER_ROOTS;
19
20/// Postgres client
21#[derive(Clone, Debug)]
22pub struct Client(HashMap<String, Pool>);
23
24/// Postgres resource builder
25impl Backend for Client {
26    type ConnectOptions = ConnectOptions;
27
28    /// Connect to `PostgreSQL` with provided options and return a connection pool
29    #[instrument]
30    async fn connect_with(options: Self::ConnectOptions) -> Result<Self> {
31        let mut pools = HashMap::new();
32        let runtime = Some(Runtime::Tokio1);
33        let mut tls_factory: Option<MakeRustlsConnect> = None; // factory is cheaper to clone
34
35        for entry in std::iter::once(&options.default_pool).chain(&options.additional_pools) {
36            let pool_config = deadpool_postgres::Config::try_from(entry)?;
37
38            let pool = if pool_config.ssl_mode.is_none() {
39                // Non-TLS mode
40                pool_config
41                    .create_pool(runtime, tokio_postgres::NoTls)
42                    .context(format!("failed to create postgres pool: '{}'", entry.name))?
43            } else {
44                // TLS mode
45                let factory = if let Some(f) = &tls_factory {
46                    f.clone()
47                } else {
48                    ring::default_provider()
49                        .install_default()
50                        .map_err(|_e| anyhow!("Failed to install rustls crypto provider"))?;
51
52                    let mut cert_store = RootCertStore::empty();
53                    cert_store.extend(TLS_SERVER_ROOTS.iter().cloned());
54
55                    let client_config = ClientConfig::builder()
56                        .with_root_certificates(cert_store)
57                        .with_no_client_auth();
58
59                    let factory = MakeRustlsConnect::new(client_config);
60                    tls_factory = Some(factory.clone());
61
62                    factory
63                };
64
65                pool_config
66                    .create_pool(runtime, factory) // unwrap is safe here
67                    .context(format!("failed to create postgres pool: '{}'", entry.name))?
68            };
69
70            // Check pool is usable
71            let cnn = pool.get().await;
72            if cnn.is_err() {
73                return Err(anyhow!("failed to get connection from pool: {:?}", cnn.err()));
74            }
75
76            tracing::info!(
77                "connected to Postgres database {:?}, with pool name '{}', tls '{}'",
78                pool_config.dbname.unwrap_or_default(),
79                entry.name,
80                pool_config.ssl_mode.is_none()
81            );
82            pools.insert(entry.name.clone(), pool);
83        }
84
85        Ok(Self(pools))
86    }
87}
88
89/// A named connection pool entry.
90#[derive(Debug, Clone)]
91pub struct PoolEntry {
92    /// Pool name (e.g. `"EVENTSTORE"`). Used as lookup key and env var suffix.
93    pub name: String,
94    /// `PostgreSQL` connection URI.
95    pub uri: String,
96    /// Maximum number of connections in the pool.
97    pub pool_size: usize,
98}
99
100#[allow(missing_docs)]
101mod config {
102    use fromenv::FromEnv;
103
104    use super::PoolEntry;
105
106    /// Connection options for the `PostgreSQL` backend.
107    #[derive(Debug, Clone, FromEnv)]
108    pub struct ConnectOptions {
109        /// Default connection pool (from `POSTGRES_URL`).
110        pub default_pool: PoolEntry,
111        /// Additional named pools discovered from `POSTGRES_POOLS`.
112        pub additional_pools: Vec<PoolEntry>,
113    }
114}
115pub use config::ConnectOptions;
116
117impl omnia::FromEnv for ConnectOptions {
118    fn from_env() -> Result<Self> {
119        // default pool (required)
120        let default_uri = std::env::var("POSTGRES_URL").context("POSTGRES_URL must be set");
121        let default_size =
122            std::env::var("POSTGRES_POOL_SIZE").unwrap_or_default().parse().unwrap_or(10);
123
124        let default = PoolEntry {
125            name: "default".to_ascii_uppercase(),
126            uri: default_uri?,
127            pool_size: default_size,
128        };
129
130        // optional extra pools: POSTGRES_POOLS=eventstore
131        let extras = std::env::var("POSTGRES_POOLS")
132            .unwrap_or_default()
133            .split(',')
134            .map(str::trim)
135            .filter(|name| !name.is_empty())
136            .map(|name| -> anyhow::Result<PoolEntry> {
137                let name = name.to_ascii_uppercase();
138                let uri_key = format!("POSTGRES_URL__{name}");
139                let size_key = format!("POSTGRES_POOL_SIZE__{name}");
140
141                let uri = std::env::var(&uri_key)
142                    .with_context(|| format!("missing {uri_key} for pool {name}"))?;
143                let pool_size = std::env::var(&size_key)
144                    .ok()
145                    .and_then(|v| v.parse().ok())
146                    .unwrap_or(default.pool_size);
147
148                Ok(PoolEntry { name, uri, pool_size })
149            })
150            .collect::<Result<Vec<_>, _>>()?;
151
152        Ok(Self {
153            default_pool: default,
154            additional_pools: extras,
155        })
156        // Self::from_env().finalize().context("issue loading connection options")
157    }
158}
159
160impl TryFrom<&PoolEntry> for deadpool_postgres::Config {
161    type Error = anyhow::Error;
162
163    fn try_from(options: &PoolEntry) -> Result<Self> {
164        // parse postgres uri
165        let tokio: tokio_postgres::Config = options.uri.parse().context("parsing Postgres URI")?;
166        let host = tokio
167            .get_hosts()
168            .first()
169            .map(|h| match h {
170                Host::Tcp(name) => name.to_owned(),
171                Host::Unix(path) => path.to_string_lossy().to_string(),
172            })
173            .unwrap_or_default();
174        let port = tokio.get_ports().first().copied().ok_or_else(|| anyhow!("Port is missing"))?;
175        let username = tokio.get_user().ok_or_else(|| anyhow!("Username is missing"))?;
176        let password = tokio.get_password().ok_or_else(|| anyhow!("Password is missing"))?;
177        let database = tokio.get_dbname().ok_or_else(|| anyhow!("Database is missing"))?;
178        let password = str::from_utf8(password).context("Password contains invalid UTF-8")?;
179        let cli_options = tokio.get_options().unwrap_or_default();
180
181        // convert tokio_postgres::Config to deadpool_postgres::Config
182        let mut deadpool = Self::new();
183        deadpool.host = Some(host);
184        deadpool.dbname = Some(database.to_string());
185        deadpool.port = Some(port);
186        deadpool.user = Some(username.to_string());
187        deadpool.password = Some(password.to_owned());
188        deadpool.pool = Some(PoolConfig {
189            max_size: options.pool_size,
190            ..PoolConfig::default()
191        });
192        deadpool.ssl_mode = match tokio.get_ssl_mode() {
193            SslMode::Require => Some(deadpool_postgres::SslMode::Require),
194            SslMode::Prefer => Some(deadpool_postgres::SslMode::Prefer),
195            _ => None,
196        };
197        deadpool.options = Some(cli_options.to_string());
198
199        Ok(deadpool)
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn try_from_pool_entry_valid_uri() {
209        let entry = PoolEntry {
210            name: "test".to_string(),
211            uri: "postgresql://user:pass@localhost:5432/mydb".to_string(),
212            pool_size: 10,
213        };
214
215        let config = deadpool_postgres::Config::try_from(&entry).unwrap();
216
217        assert_eq!(config.host, Some("localhost".to_string()));
218        assert_eq!(config.port, Some(5432));
219        assert_eq!(config.user, Some("user".to_string()));
220        assert_eq!(config.password, Some("pass".to_string()));
221        assert_eq!(config.dbname, Some("mydb".to_string()));
222        assert_eq!(config.pool.unwrap().max_size, 10);
223    }
224
225    #[test]
226    fn try_from_pool_entry_missing_password() {
227        let entry = PoolEntry {
228            name: "test".to_string(),
229            uri: "postgresql://user@localhost/mydb".to_string(),
230            pool_size: 10,
231        };
232
233        let result = deadpool_postgres::Config::try_from(&entry);
234        assert!(result.is_err());
235        assert!(result.unwrap_err().to_string().contains("Password is missing"));
236    }
237
238    #[test]
239    fn try_from_pool_entry_invalid_uri() {
240        let entry = PoolEntry {
241            name: "test".to_string(),
242            uri: "not-a-valid-uri".to_string(),
243            pool_size: 10,
244        };
245
246        let result = deadpool_postgres::Config::try_from(&entry);
247        result.unwrap_err();
248    }
249}