1use rand::RngExt;
2
3use crate::LOCALHOST_HOST_ADDR;
4use crate::LOCALHOST_IP;
5use crate::UNSPECIFIED_IP;
6use crate::certificate;
7use crate::definition;
8
9pub const PGDATA: &str = "/var/lib/pg-ephemeral";
10
11#[derive(Debug, thiserror::Error)]
12pub enum Error {
13 #[error("PostgreSQL did not become available within {timeout:?}")]
14 ConnectionTimeout {
15 timeout: std::time::Duration,
16 #[source]
17 source: Option<sqlx::Error>,
18 },
19 #[error("Failed to execute command in container")]
20 ContainerExec(#[from] cmd_proc::CommandError),
21 #[error(transparent)]
22 SeedApply(#[from] crate::definition::SeedApplyError),
23 #[error(transparent)]
24 SeedLoad(#[from] crate::seed::LoadError),
25 #[error("Failed to terminate backend connections")]
26 TerminateConnections(#[source] sqlx::Error),
27 #[error("Failed to checkpoint")]
28 Checkpoint(#[source] sqlx::Error),
29}
30const ENV_POSTGRES_PASSWORD: cmd_proc::EnvVariableName<'static> =
31 cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_PASSWORD");
32const ENV_POSTGRES_USER: cmd_proc::EnvVariableName<'static> =
33 cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_USER");
34const ENV_PGDATA: cmd_proc::EnvVariableName<'static> =
35 cmd_proc::EnvVariableName::from_static_or_panic("PGDATA");
36const ENV_PG_EPHEMERAL_SSL_DIR: cmd_proc::EnvVariableName<'static> =
37 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SSL_DIR");
38const ENV_PG_EPHEMERAL_CA_CERT_PEM: cmd_proc::EnvVariableName<'static> =
39 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_CA_CERT_PEM");
40const ENV_PG_EPHEMERAL_SERVER_CERT_PEM: cmd_proc::EnvVariableName<'static> =
41 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_CERT_PEM");
42const ENV_PG_EPHEMERAL_SERVER_KEY_PEM: cmd_proc::EnvVariableName<'static> =
43 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_KEY_PEM");
44
45const SSL_SETUP_SCRIPT: &str = r#"
46printf '%s' "$PG_EPHEMERAL_CA_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/root.crt
47printf '%s' "$PG_EPHEMERAL_SERVER_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.crt
48printf '%s' "$PG_EPHEMERAL_SERVER_KEY_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.key
49chown postgres ${PG_EPHEMERAL_SSL_DIR}/root.crt
50chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.crt
51chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.key
52chmod 600 ${PG_EPHEMERAL_SSL_DIR}/root.crt
53chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.crt
54chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.key
55exec docker-entrypoint.sh "$@"
56"#;
57
58#[derive(Debug)]
64pub struct Definition {
65 pub image: ociman::image::Reference,
66 pub password: pg_client::config::Password,
67 pub user: pg_client::User,
68 pub database: pg_client::Database,
69 pub backend: ociman::Backend,
70 pub cross_container_access: bool,
71 pub application_name: Option<pg_client::config::ApplicationName>,
72 pub ssl_config: Option<definition::SslConfig>,
73 pub wait_available_timeout: std::time::Duration,
74}
75
76#[derive(Debug)]
77pub struct Container {
78 host_port: pg_client::config::Port,
79 pub(crate) client_config: pg_client::Config,
80 container: ociman::Container,
81 backend: ociman::Backend,
82 wait_available_timeout: std::time::Duration,
83}
84
85impl Container {
86 pub(crate) async fn run_definition(definition: &crate::definition::Definition) -> Self {
87 let password = generate_password();
88
89 let ociman_definition = definition
90 .to_ociman_definition()
91 .environment_variable(ENV_POSTGRES_PASSWORD, password.as_ref())
92 .environment_variable(ENV_POSTGRES_USER, definition.superuser.as_ref());
93
94 run_container(
95 ociman_definition,
96 definition.cross_container_access,
97 &definition.ssl_config,
98 &definition.backend,
99 &definition.application_name,
100 &definition.database,
101 &password,
102 &definition.superuser,
103 definition.wait_available_timeout,
104 definition.remove,
105 )
106 .await
107 }
108
109 pub async fn run_container_definition(definition: &Definition) -> Self {
110 let ociman_definition =
111 ociman::Definition::new(definition.backend.clone(), definition.image.clone());
112
113 run_container(
114 ociman_definition,
115 definition.cross_container_access,
116 &definition.ssl_config,
117 &definition.backend,
118 &definition.application_name,
119 &definition.database,
120 &definition.password,
121 &definition.user,
122 definition.wait_available_timeout,
123 true, )
125 .await
126 }
127
128 pub async fn wait_available(&self) -> Result<(), Error> {
129 let config = self.client_config.to_sqlx_connect_options().unwrap();
130
131 let start = std::time::Instant::now();
132 let max_duration = self.wait_available_timeout;
133 let sleep_duration = std::time::Duration::from_millis(100);
134
135 let mut last_error: Option<sqlx::Error> = None;
136
137 while start.elapsed() <= max_duration {
138 log::trace!("connection attempt");
139 match sqlx::ConnectOptions::connect(&config).await {
140 Ok(connection) => {
141 sqlx::Connection::close(connection)
142 .await
143 .expect("connection close failed");
144
145 log::debug!(
146 "pg is available on endpoint: {:#?}",
147 self.client_config.endpoint
148 );
149
150 return Ok(());
151 }
152 Err(error) => {
153 log::trace!("{error:#?}, retry in 100ms");
154 last_error = Some(error);
155 }
156 }
157 tokio::time::sleep(sleep_duration).await;
158 }
159
160 Err(Error::ConnectionTimeout {
161 timeout: max_duration,
162 source: last_error,
163 })
164 }
165
166 pub async fn exec_schema_dump(&self, pg_schema_dump: &pg_client::PgSchemaDump) -> String {
167 let output = self
168 .container
169 .exec("pg_dump")
170 .arguments(pg_schema_dump.arguments())
171 .environment_variables(self.container_client_config().to_pg_env())
172 .build()
173 .stdout_capture()
174 .bytes()
175 .await
176 .unwrap();
177 crate::convert_schema(&output)
178 }
179
180 #[must_use]
181 pub fn client_config(&self) -> &pg_client::Config {
182 &self.client_config
183 }
184
185 pub async fn with_connection<T, F: AsyncFnMut(&mut sqlx::postgres::PgConnection) -> T>(
186 &self,
187 mut action: F,
188 ) -> T {
189 self.client_config
190 .with_sqlx_connection(async |connection| action(connection).await)
191 .await
192 .unwrap()
193 }
194
195 pub async fn apply_sql(&self, sql: &str) -> Result<(), sqlx::Error> {
196 self.with_connection(async |connection| {
197 log::debug!("Executing: {sql}");
198 sqlx::raw_sql(sqlx::AssertSqlSafe(sql))
199 .execute(connection)
200 .await
201 .map(|_| ())
202 })
203 .await
204 }
205
206 pub(crate) async fn exec_container_script(
207 &self,
208 script: &str,
209 ) -> Result<(), cmd_proc::CommandError> {
210 self.container
211 .exec("sh")
212 .arguments(["-e", "-c", script])
213 .build()
214 .status()
215 .await
216 }
217
218 pub(crate) async fn exec_container_shell(&self) {
219 self.container
220 .exec("sh")
221 .environment_variables(self.container_client_config().to_pg_env())
222 .interactive()
223 .status()
224 .await
225 .unwrap();
226 }
227
228 pub(crate) async fn exec_psql(&self) {
229 self.container
230 .exec("psql")
231 .environment_variables(self.container_client_config().to_pg_env())
232 .interactive()
233 .status()
234 .await
235 .unwrap();
236 }
237
238 fn container_client_config(&self) -> pg_client::Config {
239 let mut config = self.client_config.clone();
240 if let pg_client::config::Endpoint::Network {
241 ref host,
242 ref channel_binding,
243 ref host_addr,
244 ..
245 } = config.endpoint
246 {
247 config.endpoint = pg_client::config::Endpoint::Network {
248 host: host.clone(),
249 channel_binding: *channel_binding,
250 host_addr: host_addr.clone(),
251 port: Some(pg_client::config::Port::new(5432)),
252 };
253 }
254 config
255 }
256
257 pub async fn cross_container_client_config(&self) -> pg_client::Config {
258 let ip_address = self
261 .backend
262 .resolve_container_host()
263 .await
264 .expect("Failed to resolve container host from container");
265
266 let channel_binding = match &self.client_config.endpoint {
267 pg_client::config::Endpoint::Network {
268 channel_binding, ..
269 } => *channel_binding,
270 pg_client::config::Endpoint::SocketPath(_) => None,
271 };
272
273 let endpoint = pg_client::config::Endpoint::Network {
274 host: pg_client::config::Host::IpAddr(ip_address),
275 channel_binding,
276 host_addr: None,
277 port: Some(self.host_port),
278 };
279
280 self.client_config.clone().endpoint(endpoint)
281 }
282
283 #[must_use]
284 pub fn pg_env(&self) -> std::collections::BTreeMap<cmd_proc::EnvVariableName<'static>, String> {
285 self.client_config.to_pg_env()
286 }
287
288 #[must_use]
289 pub fn database_url(&self) -> String {
290 self.client_config.to_url_string()
291 }
292
293 pub async fn stop(&mut self) {
294 self.container.stop().await
295 }
296
297 async fn terminate_connections(&self) -> Result<(), Error> {
298 self.apply_sql(
299 "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid()",
300 )
301 .await
302 .map_err(Error::TerminateConnections)
303 }
304
305 async fn checkpoint(&self) -> Result<(), Error> {
306 self.apply_sql("CHECKPOINT")
307 .await
308 .map_err(Error::Checkpoint)
309 }
310
311 pub(crate) async fn stop_commit_remove(
314 &mut self,
315 reference: &ociman::Reference,
316 ) -> Result<(), Error> {
317 self.terminate_connections().await?;
318 self.checkpoint().await?;
319 self.container.stop().await;
320 self.container.commit(reference, false).await.unwrap();
321 self.container.remove().await;
322 Ok(())
323 }
324
325 async fn wait_for_container_socket(&self) -> Result<(), Error> {
326 let start = std::time::Instant::now();
327 let max_duration = self.wait_available_timeout;
328 let sleep_duration = std::time::Duration::from_millis(100);
329
330 while start.elapsed() <= max_duration {
331 if self
332 .container
333 .exec("pg_isready")
334 .argument("--host")
335 .argument("localhost")
336 .build()
337 .stdout_capture()
338 .bytes()
339 .await
340 .is_ok()
341 {
342 return Ok(());
343 }
344 tokio::time::sleep(sleep_duration).await;
345 }
346
347 Err(Error::ConnectionTimeout {
348 timeout: max_duration,
349 source: None,
350 })
351 }
352
353 pub async fn set_superuser_password(
358 &self,
359 password: &pg_client::config::Password,
360 ) -> Result<(), Error> {
361 self.wait_for_container_socket().await?;
362
363 self.container
364 .exec("psql")
365 .argument("--host")
366 .argument("/var/run/postgresql")
367 .argument("--username")
368 .argument(self.client_config.session.user.as_ref())
369 .argument("--dbname")
370 .argument("postgres")
371 .argument("--variable")
372 .argument(format!(
373 "target_user={}",
374 self.client_config.session.user.as_ref()
375 ))
376 .argument("--variable")
377 .argument(format!("new_password={}", password.as_ref()))
378 .stdin("ALTER USER :target_user WITH PASSWORD :'new_password'")
379 .build()
380 .stdout_capture()
381 .bytes()
382 .await?;
383
384 Ok(())
385 }
386}
387
388fn generate_password() -> pg_client::config::Password {
389 let value: String = rand::rng()
390 .sample_iter(rand::distr::Alphanumeric)
391 .take(32)
392 .map(char::from)
393 .collect();
394
395 <pg_client::config::Password as std::str::FromStr>::from_str(&value).unwrap()
396}
397
398#[allow(clippy::too_many_arguments)]
399async fn run_container(
400 ociman_definition: ociman::Definition,
401 cross_container_access: bool,
402 ssl_config: &Option<definition::SslConfig>,
403 backend: &ociman::Backend,
404 application_name: &Option<pg_client::config::ApplicationName>,
405 database: &pg_client::Database,
406 password: &pg_client::config::Password,
407 user: &pg_client::User,
408 wait_available_timeout: std::time::Duration,
409 remove: bool,
410) -> Container {
411 let backend = backend.clone();
412 let host_ip = if cross_container_access {
413 UNSPECIFIED_IP
414 } else {
415 LOCALHOST_IP
416 };
417
418 let mut ociman_definition = ociman_definition
419 .environment_variable(ENV_PGDATA, "/var/lib/pg-ephemeral")
420 .publish(ociman::Publish::tcp(5432).host_ip(host_ip));
421
422 if remove {
423 ociman_definition = ociman_definition.remove();
424 }
425
426 let ssl_bundle = if let Some(ssl_config) = ssl_config {
427 let hostname = match ssl_config {
428 definition::SslConfig::Generated { hostname } => hostname.as_str(),
429 };
430
431 let bundle = certificate::Bundle::generate(hostname)
432 .expect("Failed to generate SSL certificate bundle");
433
434 let ssl_dir = "/var/lib/postgresql";
435
436 ociman_definition = ociman_definition
437 .entrypoint("sh")
438 .argument("-e")
439 .argument("-c")
440 .argument(SSL_SETUP_SCRIPT)
441 .argument("--")
442 .argument("postgres")
443 .argument("--ssl=on")
444 .argument(format!("--ssl_cert_file={ssl_dir}/server.crt"))
445 .argument(format!("--ssl_key_file={ssl_dir}/server.key"))
446 .argument(format!("--ssl_ca_file={ssl_dir}/root.crt"))
447 .environment_variable(ENV_PG_EPHEMERAL_SSL_DIR, ssl_dir)
448 .environment_variable(ENV_PG_EPHEMERAL_CA_CERT_PEM, &bundle.ca_cert_pem)
449 .environment_variable(ENV_PG_EPHEMERAL_SERVER_CERT_PEM, &bundle.server_cert_pem)
450 .environment_variable(ENV_PG_EPHEMERAL_SERVER_KEY_PEM, &bundle.server_key_pem);
451
452 Some(bundle)
453 } else {
454 None
455 };
456
457 let container = ociman_definition.run_detached().await;
458
459 let port: pg_client::config::Port = container
460 .read_host_tcp_port(5432)
461 .await
462 .expect("port 5432 not published")
463 .into();
464
465 let (host, host_addr, ssl_mode, ssl_root_cert) = if let Some(ssl_config) = ssl_config {
466 let hostname = match ssl_config {
467 definition::SslConfig::Generated { hostname } => hostname.clone(),
468 };
469
470 let timestamp = std::time::SystemTime::now()
471 .duration_since(std::time::UNIX_EPOCH)
472 .unwrap()
473 .as_nanos();
474 let ca_cert_path = std::env::temp_dir().join(format!("pg_ephemeral_ca_{timestamp}.crt"));
475 std::fs::write(&ca_cert_path, &ssl_bundle.as_ref().unwrap().ca_cert_pem)
476 .expect("Failed to write CA certificate to temp file");
477
478 (
479 pg_client::config::Host::HostName(hostname),
480 Some(LOCALHOST_HOST_ADDR),
481 pg_client::config::SslMode::VerifyFull,
482 Some(pg_client::config::SslRootCert::File(ca_cert_path)),
483 )
484 } else {
485 (
486 pg_client::config::Host::IpAddr(LOCALHOST_IP),
487 None,
488 pg_client::config::SslMode::Disable,
489 None,
490 )
491 };
492
493 let client_config = pg_client::Config {
494 endpoint: pg_client::config::Endpoint::Network {
495 host,
496 channel_binding: None,
497 host_addr,
498 port: Some(port),
499 },
500 session: pg_client::config::Session {
501 application_name: application_name.clone(),
502 database: database.clone(),
503 password: Some(password.clone()),
504 user: user.clone(),
505 },
506 ssl_mode,
507 ssl_root_cert,
508 sqlx: Default::default(),
509 };
510
511 Container {
512 host_port: port,
513 container,
514 backend,
515 client_config,
516 wait_available_timeout,
517 }
518}