1use rand::RngExt;
2
3use crate::LOCALHOST_HOST_ADDR;
4use crate::LOCALHOST_IP;
5use crate::UNSPECIFIED_IP;
6use crate::certificate;
7use crate::definition;
8
9pub const PGDATA: &str = "/var/lib/pg-ephemeral";
10
11#[derive(Debug, thiserror::Error)]
12pub enum Error {
13 #[error("PostgreSQL did not become available within {timeout:?}")]
14 ConnectionTimeout {
15 timeout: std::time::Duration,
16 #[source]
17 source: Option<sqlx::Error>,
18 },
19 #[error("Failed to execute command in container")]
20 ContainerExec(#[from] cmd_proc::CommandError),
21 #[error(transparent)]
22 SeedApply(#[from] crate::definition::SeedApplyError),
23 #[error(transparent)]
24 SeedLoad(#[from] crate::seed::LoadError),
25 #[error("Failed to terminate backend connections")]
26 TerminateConnections(#[source] sqlx::Error),
27 #[error("Failed to checkpoint")]
28 Checkpoint(#[source] sqlx::Error),
29}
30const ENV_POSTGRES_PASSWORD: cmd_proc::EnvVariableName<'static> =
31 cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_PASSWORD");
32const ENV_POSTGRES_USER: cmd_proc::EnvVariableName<'static> =
33 cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_USER");
34const ENV_PGDATA: cmd_proc::EnvVariableName<'static> =
35 cmd_proc::EnvVariableName::from_static_or_panic("PGDATA");
36const ENV_PG_EPHEMERAL_SSL_DIR: cmd_proc::EnvVariableName<'static> =
37 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SSL_DIR");
38const ENV_PG_EPHEMERAL_CA_CERT_PEM: cmd_proc::EnvVariableName<'static> =
39 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_CA_CERT_PEM");
40const ENV_PG_EPHEMERAL_SERVER_CERT_PEM: cmd_proc::EnvVariableName<'static> =
41 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_CERT_PEM");
42const ENV_PG_EPHEMERAL_SERVER_KEY_PEM: cmd_proc::EnvVariableName<'static> =
43 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_KEY_PEM");
44
45const SSL_SETUP_SCRIPT: &str = r#"
46printf '%s' "$PG_EPHEMERAL_CA_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/root.crt
47printf '%s' "$PG_EPHEMERAL_SERVER_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.crt
48printf '%s' "$PG_EPHEMERAL_SERVER_KEY_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.key
49chown postgres ${PG_EPHEMERAL_SSL_DIR}/root.crt
50chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.crt
51chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.key
52chmod 600 ${PG_EPHEMERAL_SSL_DIR}/root.crt
53chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.crt
54chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.key
55exec docker-entrypoint.sh "$@"
56"#;
57
58#[derive(Debug)]
64pub struct Definition {
65 pub image: ociman::image::Reference,
66 pub password: pg_client::config::Password,
67 pub user: pg_client::User,
68 pub database: pg_client::Database,
69 pub backend: ociman::Backend,
70 pub cross_container_access: bool,
71 pub application_name: Option<pg_client::config::ApplicationName>,
72 pub ssl_config: Option<definition::SslConfig>,
73 pub wait_available_timeout: std::time::Duration,
74}
75
76#[derive(Debug)]
77pub struct Container {
78 host_port: pg_client::config::Port,
79 pub(crate) client_config: pg_client::Config,
80 container: ociman::Container,
81 backend: ociman::Backend,
82 wait_available_timeout: std::time::Duration,
83}
84
85impl Container {
86 pub(crate) async fn run_definition(definition: &crate::definition::Definition) -> Self {
87 let password = generate_password();
88
89 let ociman_definition = definition
90 .to_ociman_definition()
91 .environment_variable(ENV_POSTGRES_PASSWORD, password.as_ref())
92 .environment_variable(ENV_POSTGRES_USER, definition.superuser.as_ref());
93
94 run_container(
95 ociman_definition,
96 definition.cross_container_access,
97 &definition.ssl_config,
98 &definition.backend,
99 &definition.application_name,
100 &definition.database,
101 &password,
102 &definition.superuser,
103 definition.wait_available_timeout,
104 definition.remove,
105 )
106 .await
107 }
108
109 pub async fn run_container_definition(definition: &Definition) -> Self {
110 let ociman_definition =
111 ociman::Definition::new(definition.backend.clone(), definition.image.clone());
112
113 run_container(
114 ociman_definition,
115 definition.cross_container_access,
116 &definition.ssl_config,
117 &definition.backend,
118 &definition.application_name,
119 &definition.database,
120 &definition.password,
121 &definition.user,
122 definition.wait_available_timeout,
123 true, )
125 .await
126 }
127
128 pub async fn wait_available(&self) -> Result<(), Error> {
129 let config = self.client_config.to_sqlx_connect_options().unwrap();
130
131 let start = std::time::Instant::now();
132 let max_duration = self.wait_available_timeout;
133 let sleep_duration = std::time::Duration::from_millis(100);
134
135 let mut last_error: Option<sqlx::Error> = None;
136
137 while start.elapsed() <= max_duration {
138 log::trace!("connection attempt");
139 match sqlx::ConnectOptions::connect(&config).await {
140 Ok(connection) => {
141 sqlx::Connection::close(connection)
142 .await
143 .expect("connection close failed");
144
145 log::debug!(
146 "pg is available on endpoint: {:#?}",
147 self.client_config.endpoint
148 );
149
150 return Ok(());
151 }
152 Err(error) => {
153 log::trace!("{error:#?}, retry in 100ms");
154 last_error = Some(error);
155 }
156 }
157 tokio::time::sleep(sleep_duration).await;
158 }
159
160 Err(Error::ConnectionTimeout {
161 timeout: max_duration,
162 source: last_error,
163 })
164 }
165
166 pub async fn exec_schema_dump(&self, pg_schema_dump: &pg_client::PgSchemaDump) -> String {
167 let output = self
168 .container
169 .exec("pg_dump")
170 .arguments(pg_schema_dump.arguments())
171 .environment_variables(self.container_client_config().to_pg_env())
172 .build()
173 .stdout_capture()
174 .bytes()
175 .await
176 .unwrap();
177 crate::convert_schema(&output)
178 }
179
180 #[must_use]
181 pub fn client_config(&self) -> &pg_client::Config {
182 &self.client_config
183 }
184
185 pub async fn with_connection<T, F: AsyncFnMut(&mut sqlx::postgres::PgConnection) -> T>(
186 &self,
187 mut action: F,
188 ) -> T {
189 self.client_config
190 .with_sqlx_connection(async |connection| action(connection).await)
191 .await
192 .unwrap()
193 }
194
195 pub async fn apply_sql(&self, sql: &str) -> Result<(), sqlx::Error> {
196 self.with_connection(async |connection| {
197 log::debug!("Executing: {sql}");
198 sqlx::raw_sql(sqlx::AssertSqlSafe(sql))
199 .execute(connection)
200 .await
201 .map(|_| ())
202 })
203 .await
204 }
205
206 pub async fn apply_csv(
210 &self,
211 table: &pg_client::QualifiedTable,
212 delimiter: char,
213 content: &str,
214 ) -> Result<(), sqlx::Error> {
215 self.with_connection(async |connection| {
216 let header_line = content.lines().next().unwrap_or_default();
217
218 let columns: Vec<&str> = header_line.split(delimiter).map(str::trim).collect();
219
220 let row = sqlx::query(
221 r#"SELECT 'COPY ' || format('%I.%I', $1, $2)
222 || '(' || (SELECT string_agg(format('%I', "column"), ', ') FROM unnest($3::text[]) AS "column") || ')'
223 || ' FROM STDIN WITH (FORMAT csv, HEADER MATCH, DELIMITER ' || quote_literal($4) || ')'
224 AS statement"#,
225 )
226 .bind(table.schema.as_ref())
227 .bind(table.table.as_ref())
228 .bind(&columns)
229 .bind(delimiter.to_string())
230 .fetch_one(&mut *connection)
231 .await?;
232 let statement: String = sqlx::Row::get(&row, "statement");
233
234 log::debug!("Executing: {statement}");
235 let mut copy = connection.copy_in_raw(&statement).await?;
236 copy.send(content.as_bytes()).await?;
237 copy.finish().await?;
238 Ok(())
239 })
240 .await
241 }
242
243 pub(crate) async fn exec_container_script(
244 &self,
245 script: &str,
246 ) -> Result<(), cmd_proc::CommandError> {
247 self.container
248 .exec("sh")
249 .arguments(["-e", "-c", script])
250 .build()
251 .status()
252 .await
253 }
254
255 pub(crate) async fn exec_container_shell(&self) {
256 self.container
257 .exec("sh")
258 .environment_variables(self.container_client_config().to_pg_env())
259 .interactive()
260 .status()
261 .await
262 .unwrap();
263 }
264
265 pub(crate) async fn exec_psql(&self) {
266 self.container
267 .exec("psql")
268 .environment_variables(self.container_client_config().to_pg_env())
269 .interactive()
270 .status()
271 .await
272 .unwrap();
273 }
274
275 fn container_client_config(&self) -> pg_client::Config {
276 let mut config = self.client_config.clone();
277 if let pg_client::config::Endpoint::Network {
278 ref host,
279 ref channel_binding,
280 ref host_addr,
281 ..
282 } = config.endpoint
283 {
284 config.endpoint = pg_client::config::Endpoint::Network {
285 host: host.clone(),
286 channel_binding: *channel_binding,
287 host_addr: host_addr.clone(),
288 port: Some(pg_client::config::Port::new(5432)),
289 };
290 }
291 config
292 }
293
294 pub async fn cross_container_client_config(&self) -> pg_client::Config {
295 let ip_address = self
298 .backend
299 .resolve_container_host()
300 .await
301 .expect("Failed to resolve container host from container");
302
303 let channel_binding = match &self.client_config.endpoint {
304 pg_client::config::Endpoint::Network {
305 channel_binding, ..
306 } => *channel_binding,
307 pg_client::config::Endpoint::SocketPath(_) => None,
308 };
309
310 let endpoint = pg_client::config::Endpoint::Network {
311 host: pg_client::config::Host::IpAddr(ip_address),
312 channel_binding,
313 host_addr: None,
314 port: Some(self.host_port),
315 };
316
317 self.client_config.clone().endpoint(endpoint)
318 }
319
320 #[must_use]
321 pub fn pg_env(&self) -> std::collections::BTreeMap<cmd_proc::EnvVariableName<'static>, String> {
322 self.client_config.to_pg_env()
323 }
324
325 #[must_use]
326 pub fn database_url(&self) -> String {
327 self.client_config.to_url_string()
328 }
329
330 pub async fn stop(&mut self) {
331 self.container.stop().await
332 }
333
334 async fn terminate_connections(&self) -> Result<(), Error> {
335 self.apply_sql(
336 "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid()",
337 )
338 .await
339 .map_err(Error::TerminateConnections)
340 }
341
342 async fn checkpoint(&self) -> Result<(), Error> {
343 self.apply_sql("CHECKPOINT")
344 .await
345 .map_err(Error::Checkpoint)
346 }
347
348 pub(crate) async fn stop_commit_remove(
351 &mut self,
352 reference: &ociman::Reference,
353 ) -> Result<(), Error> {
354 self.terminate_connections().await?;
355 self.checkpoint().await?;
356 self.container.stop().await;
357 self.container.commit(reference, false).await.unwrap();
358 self.container.remove().await;
359 Ok(())
360 }
361
362 async fn wait_for_container_socket(&self) -> Result<(), Error> {
363 let start = std::time::Instant::now();
364 let max_duration = self.wait_available_timeout;
365 let sleep_duration = std::time::Duration::from_millis(100);
366
367 while start.elapsed() <= max_duration {
368 if self
369 .container
370 .exec("pg_isready")
371 .argument("--host")
372 .argument("localhost")
373 .build()
374 .stdout_capture()
375 .bytes()
376 .await
377 .is_ok()
378 {
379 return Ok(());
380 }
381 tokio::time::sleep(sleep_duration).await;
382 }
383
384 Err(Error::ConnectionTimeout {
385 timeout: max_duration,
386 source: None,
387 })
388 }
389
390 pub async fn set_superuser_password(
395 &self,
396 password: &pg_client::config::Password,
397 ) -> Result<(), Error> {
398 self.wait_for_container_socket().await?;
399
400 self.container
401 .exec("psql")
402 .argument("--host")
403 .argument("/var/run/postgresql")
404 .argument("--username")
405 .argument(self.client_config.session.user.as_ref())
406 .argument("--dbname")
407 .argument("postgres")
408 .argument("--variable")
409 .argument(format!(
410 "target_user={}",
411 self.client_config.session.user.as_ref()
412 ))
413 .argument("--variable")
414 .argument(format!("new_password={}", password.as_ref()))
415 .stdin("ALTER USER :target_user WITH PASSWORD :'new_password'")
416 .build()
417 .stdout_capture()
418 .bytes()
419 .await?;
420
421 Ok(())
422 }
423}
424
425fn generate_password() -> pg_client::config::Password {
426 let value: String = rand::rng()
427 .sample_iter(rand::distr::Alphanumeric)
428 .take(32)
429 .map(char::from)
430 .collect();
431
432 <pg_client::config::Password as std::str::FromStr>::from_str(&value).unwrap()
433}
434
435#[allow(clippy::too_many_arguments)]
436async fn run_container(
437 ociman_definition: ociman::Definition,
438 cross_container_access: bool,
439 ssl_config: &Option<definition::SslConfig>,
440 backend: &ociman::Backend,
441 application_name: &Option<pg_client::config::ApplicationName>,
442 database: &pg_client::Database,
443 password: &pg_client::config::Password,
444 user: &pg_client::User,
445 wait_available_timeout: std::time::Duration,
446 remove: bool,
447) -> Container {
448 let backend = backend.clone();
449 let host_ip = if cross_container_access {
450 UNSPECIFIED_IP
451 } else {
452 LOCALHOST_IP
453 };
454
455 let mut ociman_definition = ociman_definition
456 .environment_variable(ENV_PGDATA, "/var/lib/pg-ephemeral")
457 .publish(ociman::Publish::tcp(5432).host_ip(host_ip));
458
459 if remove {
460 ociman_definition = ociman_definition.remove();
461 }
462
463 let ssl_bundle = if let Some(ssl_config) = ssl_config {
464 let hostname = match ssl_config {
465 definition::SslConfig::Generated { hostname } => hostname.as_str(),
466 };
467
468 let bundle = certificate::Bundle::generate(hostname)
469 .expect("Failed to generate SSL certificate bundle");
470
471 let ssl_dir = "/var/lib/postgresql";
472
473 ociman_definition = ociman_definition
474 .entrypoint("sh")
475 .argument("-e")
476 .argument("-c")
477 .argument(SSL_SETUP_SCRIPT)
478 .argument("--")
479 .argument("postgres")
480 .argument("--ssl=on")
481 .argument(format!("--ssl_cert_file={ssl_dir}/server.crt"))
482 .argument(format!("--ssl_key_file={ssl_dir}/server.key"))
483 .argument(format!("--ssl_ca_file={ssl_dir}/root.crt"))
484 .environment_variable(ENV_PG_EPHEMERAL_SSL_DIR, ssl_dir)
485 .environment_variable(ENV_PG_EPHEMERAL_CA_CERT_PEM, &bundle.ca_cert_pem)
486 .environment_variable(ENV_PG_EPHEMERAL_SERVER_CERT_PEM, &bundle.server_cert_pem)
487 .environment_variable(ENV_PG_EPHEMERAL_SERVER_KEY_PEM, &bundle.server_key_pem);
488
489 Some(bundle)
490 } else {
491 None
492 };
493
494 let container = ociman_definition.run_detached().await;
495
496 let port: pg_client::config::Port = container
497 .read_host_tcp_port(5432)
498 .await
499 .expect("port 5432 not published")
500 .into();
501
502 let (host, host_addr, ssl_mode, ssl_root_cert) = if let Some(ssl_config) = ssl_config {
503 let hostname = match ssl_config {
504 definition::SslConfig::Generated { hostname } => hostname.clone(),
505 };
506
507 let timestamp = std::time::SystemTime::now()
508 .duration_since(std::time::UNIX_EPOCH)
509 .unwrap()
510 .as_nanos();
511 let ca_cert_path = std::env::temp_dir().join(format!("pg_ephemeral_ca_{timestamp}.crt"));
512 std::fs::write(&ca_cert_path, &ssl_bundle.as_ref().unwrap().ca_cert_pem)
513 .expect("Failed to write CA certificate to temp file");
514
515 (
516 pg_client::config::Host::HostName(hostname),
517 Some(LOCALHOST_HOST_ADDR),
518 pg_client::config::SslMode::VerifyFull,
519 Some(pg_client::config::SslRootCert::File(ca_cert_path)),
520 )
521 } else {
522 (
523 pg_client::config::Host::IpAddr(LOCALHOST_IP),
524 None,
525 pg_client::config::SslMode::Disable,
526 None,
527 )
528 };
529
530 let client_config = pg_client::Config {
531 endpoint: pg_client::config::Endpoint::Network {
532 host,
533 channel_binding: None,
534 host_addr,
535 port: Some(port),
536 },
537 session: pg_client::config::Session {
538 application_name: application_name.clone(),
539 database: database.clone(),
540 password: Some(password.clone()),
541 user: user.clone(),
542 },
543 ssl_mode,
544 ssl_root_cert,
545 sqlx: Default::default(),
546 };
547
548 Container {
549 host_port: port,
550 container,
551 backend,
552 client_config,
553 wait_available_timeout,
554 }
555}