1use rand::RngExt;
2
3use crate::LOCALHOST_HOST_ADDR;
4use crate::LOCALHOST_IP;
5use crate::UNSPECIFIED_IP;
6use crate::certificate;
7use crate::definition;
8
9pub const PGDATA: &str = "/var/lib/pg-ephemeral";
10
11#[derive(Debug, thiserror::Error)]
12pub enum Error {
13 #[error("PostgreSQL did not become available within {timeout:?}")]
14 ConnectionTimeout {
15 timeout: std::time::Duration,
16 #[source]
17 source: Option<sqlx::Error>,
18 },
19 #[error("Failed to execute command in container")]
20 ContainerExec(#[from] cmd_proc::CommandError),
21 #[error(transparent)]
22 SeedApply(#[from] crate::definition::SeedApplyError),
23 #[error(transparent)]
24 SeedLoad(#[from] crate::seed::LoadError),
25}
26const ENV_POSTGRES_PASSWORD: cmd_proc::EnvVariableName<'static> =
27 cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_PASSWORD");
28const ENV_POSTGRES_USER: cmd_proc::EnvVariableName<'static> =
29 cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_USER");
30const ENV_PGDATA: cmd_proc::EnvVariableName<'static> =
31 cmd_proc::EnvVariableName::from_static_or_panic("PGDATA");
32const ENV_PG_EPHEMERAL_SSL_DIR: cmd_proc::EnvVariableName<'static> =
33 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SSL_DIR");
34const ENV_PG_EPHEMERAL_CA_CERT_PEM: cmd_proc::EnvVariableName<'static> =
35 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_CA_CERT_PEM");
36const ENV_PG_EPHEMERAL_SERVER_CERT_PEM: cmd_proc::EnvVariableName<'static> =
37 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_CERT_PEM");
38const ENV_PG_EPHEMERAL_SERVER_KEY_PEM: cmd_proc::EnvVariableName<'static> =
39 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_KEY_PEM");
40
41const SSL_SETUP_SCRIPT: &str = r#"
42printf '%s' "$PG_EPHEMERAL_CA_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/root.crt
43printf '%s' "$PG_EPHEMERAL_SERVER_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.crt
44printf '%s' "$PG_EPHEMERAL_SERVER_KEY_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.key
45chown postgres ${PG_EPHEMERAL_SSL_DIR}/root.crt
46chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.crt
47chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.key
48chmod 600 ${PG_EPHEMERAL_SSL_DIR}/root.crt
49chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.crt
50chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.key
51exec docker-entrypoint.sh "$@"
52"#;
53
54#[derive(Debug)]
60pub struct Definition {
61 pub image: ociman::image::Reference,
62 pub password: pg_client::Password,
63 pub user: pg_client::User,
64 pub database: pg_client::Database,
65 pub backend: ociman::Backend,
66 pub cross_container_access: bool,
67 pub application_name: Option<pg_client::ApplicationName>,
68 pub ssl_config: Option<definition::SslConfig>,
69 pub wait_available_timeout: std::time::Duration,
70}
71
72#[derive(Debug)]
73pub struct Container {
74 host_port: pg_client::Port,
75 pub(crate) client_config: pg_client::Config,
76 container: ociman::Container,
77 backend: ociman::Backend,
78 wait_available_timeout: std::time::Duration,
79}
80
81impl Container {
82 pub(crate) async fn run_definition(definition: &crate::definition::Definition) -> Self {
83 let password = generate_password();
84
85 let ociman_definition = definition
86 .to_ociman_definition()
87 .environment_variable(ENV_POSTGRES_PASSWORD, password.as_ref())
88 .environment_variable(ENV_POSTGRES_USER, definition.superuser.as_ref());
89
90 run_container(
91 ociman_definition,
92 definition.cross_container_access,
93 &definition.ssl_config,
94 &definition.backend,
95 &definition.application_name,
96 &definition.database,
97 &password,
98 &definition.superuser,
99 definition.wait_available_timeout,
100 definition.remove,
101 )
102 .await
103 }
104
105 pub async fn run_container_definition(definition: &Definition) -> Self {
106 let ociman_definition =
107 ociman::Definition::new(definition.backend.clone(), definition.image.clone());
108
109 run_container(
110 ociman_definition,
111 definition.cross_container_access,
112 &definition.ssl_config,
113 &definition.backend,
114 &definition.application_name,
115 &definition.database,
116 &definition.password,
117 &definition.user,
118 definition.wait_available_timeout,
119 true, )
121 .await
122 }
123
124 pub async fn wait_available(&self) -> Result<(), Error> {
125 let config = self.client_config.to_sqlx_connect_options().unwrap();
126
127 let start = std::time::Instant::now();
128 let max_duration = self.wait_available_timeout;
129 let sleep_duration = std::time::Duration::from_millis(100);
130
131 let mut last_error: Option<sqlx::Error> = None;
132
133 while start.elapsed() <= max_duration {
134 log::trace!("connection attempt");
135 match sqlx::ConnectOptions::connect(&config).await {
136 Ok(connection) => {
137 sqlx::Connection::close(connection)
138 .await
139 .expect("connection close failed");
140
141 log::debug!(
142 "pg is available on endpoint: {:#?}",
143 self.client_config.endpoint
144 );
145
146 return Ok(());
147 }
148 Err(error) => {
149 log::trace!("{error:#?}, retry in 100ms");
150 last_error = Some(error);
151 }
152 }
153 tokio::time::sleep(sleep_duration).await;
154 }
155
156 Err(Error::ConnectionTimeout {
157 timeout: max_duration,
158 source: last_error,
159 })
160 }
161
162 pub async fn exec_schema_dump(&self, pg_schema_dump: &pg_client::PgSchemaDump) -> String {
163 let output = self
164 .container
165 .exec("pg_dump")
166 .arguments(pg_schema_dump.arguments())
167 .environment_variables(self.container_client_config().to_pg_env())
168 .build()
169 .stdout_capture()
170 .bytes()
171 .await
172 .unwrap();
173 crate::convert_schema(&output)
174 }
175
176 #[must_use]
177 pub fn client_config(&self) -> &pg_client::Config {
178 &self.client_config
179 }
180
181 pub async fn with_connection<T, F: AsyncFnMut(&mut sqlx::postgres::PgConnection) -> T>(
182 &self,
183 mut action: F,
184 ) -> T {
185 self.client_config
186 .with_sqlx_connection(async |connection| action(connection).await)
187 .await
188 .unwrap()
189 }
190
191 pub async fn apply_sql(&self, sql: &str) -> Result<(), sqlx::Error> {
192 self.with_connection(async |connection| {
193 log::debug!("Executing: {sql}");
194 sqlx::raw_sql(sqlx::AssertSqlSafe(sql))
195 .execute(connection)
196 .await
197 .map(|_| ())
198 })
199 .await
200 }
201
202 pub(crate) async fn exec_container_script(
203 &self,
204 script: &str,
205 ) -> Result<(), cmd_proc::CommandError> {
206 self.container
207 .exec("sh")
208 .arguments(["-e", "-c", script])
209 .build()
210 .status()
211 .await
212 }
213
214 pub(crate) async fn exec_container_shell(&self) {
215 self.container
216 .exec("sh")
217 .environment_variables(self.container_client_config().to_pg_env())
218 .interactive()
219 .status()
220 .await
221 .unwrap();
222 }
223
224 pub(crate) async fn exec_psql(&self) {
225 self.container
226 .exec("psql")
227 .environment_variables(self.container_client_config().to_pg_env())
228 .interactive()
229 .status()
230 .await
231 .unwrap();
232 }
233
234 fn container_client_config(&self) -> pg_client::Config {
235 let mut config = self.client_config.clone();
236 if let pg_client::Endpoint::Network {
237 ref host,
238 ref channel_binding,
239 ref host_addr,
240 ..
241 } = config.endpoint
242 {
243 config.endpoint = pg_client::Endpoint::Network {
244 host: host.clone(),
245 channel_binding: *channel_binding,
246 host_addr: host_addr.clone(),
247 port: Some(pg_client::Port::new(5432)),
248 };
249 }
250 config
251 }
252
253 pub async fn cross_container_client_config(&self) -> pg_client::Config {
254 let ip_address = self
257 .backend
258 .resolve_container_host()
259 .await
260 .expect("Failed to resolve container host from container");
261
262 let channel_binding = match &self.client_config.endpoint {
263 pg_client::Endpoint::Network {
264 channel_binding, ..
265 } => *channel_binding,
266 pg_client::Endpoint::SocketPath(_) => None,
267 };
268
269 let endpoint = pg_client::Endpoint::Network {
270 host: pg_client::Host::IpAddr(ip_address),
271 channel_binding,
272 host_addr: None,
273 port: Some(self.host_port),
274 };
275
276 self.client_config.clone().endpoint(endpoint)
277 }
278
279 #[must_use]
280 pub fn pg_env(&self) -> std::collections::BTreeMap<cmd_proc::EnvVariableName<'static>, String> {
281 self.client_config.to_pg_env()
282 }
283
284 #[must_use]
285 pub fn database_url(&self) -> String {
286 self.client_config.to_url_string()
287 }
288
289 pub async fn stop(&mut self) {
290 self.container.stop().await
291 }
292
293 async fn terminate_connections(&self) {
294 let sql =
295 "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid()";
296
297 if let Err(error) = self.apply_sql(sql).await {
298 log::debug!("Failed to terminate connections: {error}");
299 }
300 }
301
302 pub(crate) async fn stop_commit_remove(&mut self, reference: &ociman::Reference) {
305 self.terminate_connections().await;
306 self.container.stop().await;
307 self.container.commit(reference, false).await.unwrap();
308 self.container.remove().await;
309 }
310
311 async fn wait_for_container_socket(&self) -> Result<(), Error> {
312 let start = std::time::Instant::now();
313 let max_duration = self.wait_available_timeout;
314 let sleep_duration = std::time::Duration::from_millis(100);
315
316 while start.elapsed() <= max_duration {
317 if self
318 .container
319 .exec("pg_isready")
320 .argument("--host")
321 .argument("localhost")
322 .build()
323 .stdout_capture()
324 .bytes()
325 .await
326 .is_ok()
327 {
328 return Ok(());
329 }
330 tokio::time::sleep(sleep_duration).await;
331 }
332
333 Err(Error::ConnectionTimeout {
334 timeout: max_duration,
335 source: None,
336 })
337 }
338
339 pub async fn set_superuser_password(
344 &self,
345 password: &pg_client::Password,
346 ) -> Result<(), Error> {
347 self.wait_for_container_socket().await?;
348
349 self.container
350 .exec("psql")
351 .argument("--host")
352 .argument("/var/run/postgresql")
353 .argument("--username")
354 .argument(self.client_config.user.as_ref())
355 .argument("--dbname")
356 .argument("postgres")
357 .argument("--variable")
358 .argument(format!("target_user={}", self.client_config.user.as_ref()))
359 .argument("--variable")
360 .argument(format!("new_password={}", password.as_ref()))
361 .stdin("ALTER USER :target_user WITH PASSWORD :'new_password'")
362 .build()
363 .stdout_capture()
364 .bytes()
365 .await?;
366
367 Ok(())
368 }
369}
370
371fn generate_password() -> pg_client::Password {
372 let value: String = rand::rng()
373 .sample_iter(rand::distr::Alphanumeric)
374 .take(32)
375 .map(char::from)
376 .collect();
377
378 <pg_client::Password as std::str::FromStr>::from_str(&value).unwrap()
379}
380
381#[allow(clippy::too_many_arguments)]
382async fn run_container(
383 ociman_definition: ociman::Definition,
384 cross_container_access: bool,
385 ssl_config: &Option<definition::SslConfig>,
386 backend: &ociman::Backend,
387 application_name: &Option<pg_client::ApplicationName>,
388 database: &pg_client::Database,
389 password: &pg_client::Password,
390 user: &pg_client::User,
391 wait_available_timeout: std::time::Duration,
392 remove: bool,
393) -> Container {
394 let backend = backend.clone();
395 let host_ip = if cross_container_access {
396 UNSPECIFIED_IP
397 } else {
398 LOCALHOST_IP
399 };
400
401 let mut ociman_definition = ociman_definition
402 .environment_variable(ENV_PGDATA, "/var/lib/pg-ephemeral")
403 .publish(ociman::Publish::tcp(5432).host_ip(host_ip));
404
405 if remove {
406 ociman_definition = ociman_definition.remove();
407 }
408
409 let ssl_bundle = if let Some(ssl_config) = ssl_config {
410 let hostname = match ssl_config {
411 definition::SslConfig::Generated { hostname } => hostname.as_str(),
412 };
413
414 let bundle = certificate::Bundle::generate(hostname)
415 .expect("Failed to generate SSL certificate bundle");
416
417 let ssl_dir = "/var/lib/postgresql";
418
419 ociman_definition = ociman_definition
420 .entrypoint("sh")
421 .argument("-e")
422 .argument("-c")
423 .argument(SSL_SETUP_SCRIPT)
424 .argument("--")
425 .argument("postgres")
426 .argument("--ssl=on")
427 .argument(format!("--ssl_cert_file={ssl_dir}/server.crt"))
428 .argument(format!("--ssl_key_file={ssl_dir}/server.key"))
429 .argument(format!("--ssl_ca_file={ssl_dir}/root.crt"))
430 .environment_variable(ENV_PG_EPHEMERAL_SSL_DIR, ssl_dir)
431 .environment_variable(ENV_PG_EPHEMERAL_CA_CERT_PEM, &bundle.ca_cert_pem)
432 .environment_variable(ENV_PG_EPHEMERAL_SERVER_CERT_PEM, &bundle.server_cert_pem)
433 .environment_variable(ENV_PG_EPHEMERAL_SERVER_KEY_PEM, &bundle.server_key_pem);
434
435 Some(bundle)
436 } else {
437 None
438 };
439
440 let container = ociman_definition.run_detached().await;
441
442 let port: pg_client::Port = container
443 .read_host_tcp_port(5432)
444 .await
445 .expect("port 5432 not published")
446 .into();
447
448 let (host, host_addr, ssl_mode, ssl_root_cert) = if let Some(ssl_config) = ssl_config {
449 let hostname = match ssl_config {
450 definition::SslConfig::Generated { hostname } => hostname.clone(),
451 };
452
453 let timestamp = std::time::SystemTime::now()
454 .duration_since(std::time::UNIX_EPOCH)
455 .unwrap()
456 .as_nanos();
457 let ca_cert_path = std::env::temp_dir().join(format!("pg_ephemeral_ca_{timestamp}.crt"));
458 std::fs::write(&ca_cert_path, &ssl_bundle.as_ref().unwrap().ca_cert_pem)
459 .expect("Failed to write CA certificate to temp file");
460
461 (
462 pg_client::Host::HostName(hostname),
463 Some(LOCALHOST_HOST_ADDR),
464 pg_client::SslMode::VerifyFull,
465 Some(pg_client::SslRootCert::File(ca_cert_path)),
466 )
467 } else {
468 (
469 pg_client::Host::IpAddr(LOCALHOST_IP),
470 None,
471 pg_client::SslMode::Disable,
472 None,
473 )
474 };
475
476 let client_config = pg_client::Config {
477 application_name: application_name.clone(),
478 database: database.clone(),
479 endpoint: pg_client::Endpoint::Network {
480 host,
481 channel_binding: None,
482 host_addr,
483 port: Some(port),
484 },
485 password: Some(password.clone()),
486 ssl_mode,
487 ssl_root_cert,
488 user: user.clone(),
489 };
490
491 Container {
492 host_port: port,
493 container,
494 backend,
495 client_config,
496 wait_available_timeout,
497 }
498}