1#![doc = include_str!("../README.md")]
2#![cfg(not(target_arch = "wasm32"))]
3
4mod sql;
5mod types;
6
7use std::collections::HashMap;
8use std::str;
9
10use anyhow::{Context as _, Result, anyhow};
11use deadpool_postgres::{Pool, PoolConfig, Runtime};
12use omnia::Backend;
13use rustls::crypto::ring;
14use rustls::{ClientConfig, RootCertStore};
15use tokio_postgres::config::{Host, SslMode};
16use tokio_postgres_rustls::MakeRustlsConnect;
17use tracing::instrument;
18use webpki_roots::TLS_SERVER_ROOTS;
19
20#[derive(Clone, Debug)]
22pub struct Client(HashMap<String, Pool>);
23
24impl Backend for Client {
26 type ConnectOptions = ConnectOptions;
27
28 #[instrument]
30 async fn connect_with(options: Self::ConnectOptions) -> Result<Self> {
31 let mut pools = HashMap::new();
32 let runtime = Some(Runtime::Tokio1);
33 let mut tls_factory: Option<MakeRustlsConnect> = None; for entry in std::iter::once(&options.default_pool).chain(&options.additional_pools) {
36 let pool_config = deadpool_postgres::Config::try_from(entry)?;
37
38 let pool = if pool_config.ssl_mode.is_none() {
39 pool_config
41 .create_pool(runtime, tokio_postgres::NoTls)
42 .context(format!("failed to create postgres pool: '{}'", entry.name))?
43 } else {
44 let factory = if let Some(f) = &tls_factory {
46 f.clone()
47 } else {
48 ring::default_provider()
49 .install_default()
50 .map_err(|_e| anyhow!("Failed to install rustls crypto provider"))?;
51
52 let mut cert_store = RootCertStore::empty();
53 cert_store.extend(TLS_SERVER_ROOTS.iter().cloned());
54
55 let client_config = ClientConfig::builder()
56 .with_root_certificates(cert_store)
57 .with_no_client_auth();
58
59 let factory = MakeRustlsConnect::new(client_config);
60 tls_factory = Some(factory.clone());
61
62 factory
63 };
64
65 pool_config
66 .create_pool(runtime, factory) .context(format!("failed to create postgres pool: '{}'", entry.name))?
68 };
69
70 let cnn = pool.get().await;
72 if cnn.is_err() {
73 return Err(anyhow!("failed to get connection from pool: {:?}", cnn.err()));
74 }
75
76 tracing::info!(
77 "connected to Postgres database {:?}, with pool name '{}', tls '{}'",
78 pool_config.dbname.unwrap_or_default(),
79 entry.name,
80 pool_config.ssl_mode.is_none()
81 );
82 pools.insert(entry.name.clone(), pool);
83 }
84
85 Ok(Self(pools))
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct PoolEntry {
92 pub name: String,
94 pub uri: String,
96 pub pool_size: usize,
98}
99
100#[allow(missing_docs)]
101mod config {
102 use fromenv::FromEnv;
103
104 use super::PoolEntry;
105
106 #[derive(Debug, Clone, FromEnv)]
108 pub struct ConnectOptions {
109 pub default_pool: PoolEntry,
111 pub additional_pools: Vec<PoolEntry>,
113 }
114}
115pub use config::ConnectOptions;
116
117impl omnia::FromEnv for ConnectOptions {
118 fn from_env() -> Result<Self> {
119 let default_uri = std::env::var("POSTGRES_URL").context("POSTGRES_URL must be set");
121 let default_size =
122 std::env::var("POSTGRES_POOL_SIZE").unwrap_or_default().parse().unwrap_or(10);
123
124 let default = PoolEntry {
125 name: "default".to_ascii_uppercase(),
126 uri: default_uri?,
127 pool_size: default_size,
128 };
129
130 let extras = std::env::var("POSTGRES_POOLS")
132 .unwrap_or_default()
133 .split(',')
134 .map(str::trim)
135 .filter(|name| !name.is_empty())
136 .map(|name| -> anyhow::Result<PoolEntry> {
137 let name = name.to_ascii_uppercase();
138 let uri_key = format!("POSTGRES_URL__{name}");
139 let size_key = format!("POSTGRES_POOL_SIZE__{name}");
140
141 let uri = std::env::var(&uri_key)
142 .with_context(|| format!("missing {uri_key} for pool {name}"))?;
143 let pool_size = std::env::var(&size_key)
144 .ok()
145 .and_then(|v| v.parse().ok())
146 .unwrap_or(default.pool_size);
147
148 Ok(PoolEntry { name, uri, pool_size })
149 })
150 .collect::<Result<Vec<_>, _>>()?;
151
152 Ok(Self {
153 default_pool: default,
154 additional_pools: extras,
155 })
156 }
158}
159
160impl TryFrom<&PoolEntry> for deadpool_postgres::Config {
161 type Error = anyhow::Error;
162
163 fn try_from(options: &PoolEntry) -> Result<Self> {
164 let tokio: tokio_postgres::Config = options.uri.parse().context("parsing Postgres URI")?;
166 let host = tokio
167 .get_hosts()
168 .first()
169 .map(|h| match h {
170 Host::Tcp(name) => name.to_owned(),
171 Host::Unix(path) => path.to_string_lossy().to_string(),
172 })
173 .unwrap_or_default();
174 let port = tokio.get_ports().first().copied().ok_or_else(|| anyhow!("Port is missing"))?;
175 let username = tokio.get_user().ok_or_else(|| anyhow!("Username is missing"))?;
176 let password = tokio.get_password().ok_or_else(|| anyhow!("Password is missing"))?;
177 let database = tokio.get_dbname().ok_or_else(|| anyhow!("Database is missing"))?;
178 let password = str::from_utf8(password).context("Password contains invalid UTF-8")?;
179 let cli_options = tokio.get_options().unwrap_or_default();
180
181 let mut deadpool = Self::new();
183 deadpool.host = Some(host);
184 deadpool.dbname = Some(database.to_string());
185 deadpool.port = Some(port);
186 deadpool.user = Some(username.to_string());
187 deadpool.password = Some(password.to_owned());
188 deadpool.pool = Some(PoolConfig {
189 max_size: options.pool_size,
190 ..PoolConfig::default()
191 });
192 deadpool.ssl_mode = match tokio.get_ssl_mode() {
193 SslMode::Require => Some(deadpool_postgres::SslMode::Require),
194 SslMode::Prefer => Some(deadpool_postgres::SslMode::Prefer),
195 _ => None,
196 };
197 deadpool.options = Some(cli_options.to_string());
198
199 Ok(deadpool)
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn try_from_pool_entry_valid_uri() {
209 let entry = PoolEntry {
210 name: "test".to_string(),
211 uri: "postgresql://user:pass@localhost:5432/mydb".to_string(),
212 pool_size: 10,
213 };
214
215 let config = deadpool_postgres::Config::try_from(&entry).unwrap();
216
217 assert_eq!(config.host, Some("localhost".to_string()));
218 assert_eq!(config.port, Some(5432));
219 assert_eq!(config.user, Some("user".to_string()));
220 assert_eq!(config.password, Some("pass".to_string()));
221 assert_eq!(config.dbname, Some("mydb".to_string()));
222 assert_eq!(config.pool.unwrap().max_size, 10);
223 }
224
225 #[test]
226 fn try_from_pool_entry_missing_password() {
227 let entry = PoolEntry {
228 name: "test".to_string(),
229 uri: "postgresql://user@localhost/mydb".to_string(),
230 pool_size: 10,
231 };
232
233 let result = deadpool_postgres::Config::try_from(&entry);
234 assert!(result.is_err());
235 assert!(result.unwrap_err().to_string().contains("Password is missing"));
236 }
237
238 #[test]
239 fn try_from_pool_entry_invalid_uri() {
240 let entry = PoolEntry {
241 name: "test".to_string(),
242 uri: "not-a-valid-uri".to_string(),
243 pool_size: 10,
244 };
245
246 let result = deadpool_postgres::Config::try_from(&entry);
247 result.unwrap_err();
248 }
249}