1use rand::RngExt;
2
3use crate::LOCALHOST_HOST_ADDR;
4use crate::LOCALHOST_IP;
5use crate::UNSPECIFIED_IP;
6use crate::certificate;
7use crate::definition;
8
9pub const PGDATA: &str = "/var/lib/pg-ephemeral";
10
11#[derive(Debug, thiserror::Error)]
12pub enum Error {
13 #[error("PostgreSQL did not become available within {timeout:?}")]
14 ConnectionTimeout {
15 timeout: std::time::Duration,
16 #[source]
17 source: Option<sqlx::Error>,
18 },
19 #[error("Failed to execute command in container")]
20 ContainerExec(#[from] cmd_proc::CommandError),
21 #[error(transparent)]
22 SeedApply(#[from] crate::definition::SeedApplyError),
23 #[error(transparent)]
24 SeedLoad(#[from] crate::seed::LoadError),
25}
26const ENV_POSTGRES_PASSWORD: cmd_proc::EnvVariableName<'static> =
27 cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_PASSWORD");
28const ENV_POSTGRES_USER: cmd_proc::EnvVariableName<'static> =
29 cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_USER");
30const ENV_PGDATA: cmd_proc::EnvVariableName<'static> =
31 cmd_proc::EnvVariableName::from_static_or_panic("PGDATA");
32const ENV_PG_EPHEMERAL_SSL_DIR: cmd_proc::EnvVariableName<'static> =
33 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SSL_DIR");
34const ENV_PG_EPHEMERAL_CA_CERT_PEM: cmd_proc::EnvVariableName<'static> =
35 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_CA_CERT_PEM");
36const ENV_PG_EPHEMERAL_SERVER_CERT_PEM: cmd_proc::EnvVariableName<'static> =
37 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_CERT_PEM");
38const ENV_PG_EPHEMERAL_SERVER_KEY_PEM: cmd_proc::EnvVariableName<'static> =
39 cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_KEY_PEM");
40
41const SSL_SETUP_SCRIPT: &str = r#"
42printf '%s' "$PG_EPHEMERAL_CA_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/root.crt
43printf '%s' "$PG_EPHEMERAL_SERVER_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.crt
44printf '%s' "$PG_EPHEMERAL_SERVER_KEY_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.key
45chown postgres ${PG_EPHEMERAL_SSL_DIR}/root.crt
46chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.crt
47chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.key
48chmod 600 ${PG_EPHEMERAL_SSL_DIR}/root.crt
49chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.crt
50chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.key
51exec docker-entrypoint.sh "$@"
52"#;
53
54#[derive(Debug)]
60pub struct Definition {
61 pub image: ociman::image::Reference,
62 pub password: pg_client::config::Password,
63 pub user: pg_client::User,
64 pub database: pg_client::Database,
65 pub backend: ociman::Backend,
66 pub cross_container_access: bool,
67 pub application_name: Option<pg_client::config::ApplicationName>,
68 pub ssl_config: Option<definition::SslConfig>,
69 pub wait_available_timeout: std::time::Duration,
70}
71
72#[derive(Debug)]
73pub struct Container {
74 host_port: pg_client::config::Port,
75 pub(crate) client_config: pg_client::Config,
76 container: ociman::Container,
77 backend: ociman::Backend,
78 wait_available_timeout: std::time::Duration,
79}
80
81impl Container {
82 pub(crate) async fn run_definition(definition: &crate::definition::Definition) -> Self {
83 let password = generate_password();
84
85 let ociman_definition = definition
86 .to_ociman_definition()
87 .environment_variable(ENV_POSTGRES_PASSWORD, password.as_ref())
88 .environment_variable(ENV_POSTGRES_USER, definition.superuser.as_ref());
89
90 run_container(
91 ociman_definition,
92 definition.cross_container_access,
93 &definition.ssl_config,
94 &definition.backend,
95 &definition.application_name,
96 &definition.database,
97 &password,
98 &definition.superuser,
99 definition.wait_available_timeout,
100 definition.remove,
101 )
102 .await
103 }
104
105 pub async fn run_container_definition(definition: &Definition) -> Self {
106 let ociman_definition =
107 ociman::Definition::new(definition.backend.clone(), definition.image.clone());
108
109 run_container(
110 ociman_definition,
111 definition.cross_container_access,
112 &definition.ssl_config,
113 &definition.backend,
114 &definition.application_name,
115 &definition.database,
116 &definition.password,
117 &definition.user,
118 definition.wait_available_timeout,
119 true, )
121 .await
122 }
123
124 pub async fn wait_available(&self) -> Result<(), Error> {
125 let config = self.client_config.to_sqlx_connect_options().unwrap();
126
127 let start = std::time::Instant::now();
128 let max_duration = self.wait_available_timeout;
129 let sleep_duration = std::time::Duration::from_millis(100);
130
131 let mut last_error: Option<sqlx::Error> = None;
132
133 while start.elapsed() <= max_duration {
134 log::trace!("connection attempt");
135 match sqlx::ConnectOptions::connect(&config).await {
136 Ok(connection) => {
137 sqlx::Connection::close(connection)
138 .await
139 .expect("connection close failed");
140
141 log::debug!(
142 "pg is available on endpoint: {:#?}",
143 self.client_config.endpoint
144 );
145
146 return Ok(());
147 }
148 Err(error) => {
149 log::trace!("{error:#?}, retry in 100ms");
150 last_error = Some(error);
151 }
152 }
153 tokio::time::sleep(sleep_duration).await;
154 }
155
156 Err(Error::ConnectionTimeout {
157 timeout: max_duration,
158 source: last_error,
159 })
160 }
161
162 pub async fn exec_schema_dump(&self, pg_schema_dump: &pg_client::PgSchemaDump) -> String {
163 let output = self
164 .container
165 .exec("pg_dump")
166 .arguments(pg_schema_dump.arguments())
167 .environment_variables(self.container_client_config().to_pg_env())
168 .build()
169 .stdout_capture()
170 .bytes()
171 .await
172 .unwrap();
173 crate::convert_schema(&output)
174 }
175
176 #[must_use]
177 pub fn client_config(&self) -> &pg_client::Config {
178 &self.client_config
179 }
180
181 pub async fn with_connection<T, F: AsyncFnMut(&mut sqlx::postgres::PgConnection) -> T>(
182 &self,
183 mut action: F,
184 ) -> T {
185 self.client_config
186 .with_sqlx_connection(async |connection| action(connection).await)
187 .await
188 .unwrap()
189 }
190
191 pub async fn apply_sql(&self, sql: &str) -> Result<(), sqlx::Error> {
192 self.with_connection(async |connection| {
193 log::debug!("Executing: {sql}");
194 sqlx::raw_sql(sqlx::AssertSqlSafe(sql))
195 .execute(connection)
196 .await
197 .map(|_| ())
198 })
199 .await
200 }
201
202 pub async fn apply_csv(
203 &self,
204 table: &pg_client::QualifiedTable,
205 content: &str,
206 ) -> Result<(), sqlx::Error> {
207 self.with_connection(async |connection| {
208 let row = sqlx::query(r#"SELECT format('%I.%I', $1, $2) AS table_identifier"#)
209 .bind(table.schema.as_ref())
210 .bind(table.table.as_ref())
211 .fetch_one(&mut *connection)
212 .await?;
213 let table_identifier: String = sqlx::Row::get(&row, "table_identifier");
214
215 let statement = format!("COPY {table_identifier} FROM STDIN WITH (FORMAT csv, HEADER)");
216 log::debug!("Executing: {statement}");
217 let mut copy = connection.copy_in_raw(&statement).await?;
218 copy.send(content.as_bytes()).await?;
219 copy.finish().await?;
220 Ok(())
221 })
222 .await
223 }
224
225 pub(crate) async fn exec_container_script(
226 &self,
227 script: &str,
228 ) -> Result<(), cmd_proc::CommandError> {
229 self.container
230 .exec("sh")
231 .arguments(["-e", "-c", script])
232 .build()
233 .status()
234 .await
235 }
236
237 pub(crate) async fn exec_container_shell(&self) {
238 self.container
239 .exec("sh")
240 .environment_variables(self.container_client_config().to_pg_env())
241 .interactive()
242 .status()
243 .await
244 .unwrap();
245 }
246
247 pub(crate) async fn exec_psql(&self) {
248 self.container
249 .exec("psql")
250 .environment_variables(self.container_client_config().to_pg_env())
251 .interactive()
252 .status()
253 .await
254 .unwrap();
255 }
256
257 fn container_client_config(&self) -> pg_client::Config {
258 let mut config = self.client_config.clone();
259 if let pg_client::config::Endpoint::Network {
260 ref host,
261 ref channel_binding,
262 ref host_addr,
263 ..
264 } = config.endpoint
265 {
266 config.endpoint = pg_client::config::Endpoint::Network {
267 host: host.clone(),
268 channel_binding: *channel_binding,
269 host_addr: host_addr.clone(),
270 port: Some(pg_client::config::Port::new(5432)),
271 };
272 }
273 config
274 }
275
276 pub async fn cross_container_client_config(&self) -> pg_client::Config {
277 let ip_address = self
280 .backend
281 .resolve_container_host()
282 .await
283 .expect("Failed to resolve container host from container");
284
285 let channel_binding = match &self.client_config.endpoint {
286 pg_client::config::Endpoint::Network {
287 channel_binding, ..
288 } => *channel_binding,
289 pg_client::config::Endpoint::SocketPath(_) => None,
290 };
291
292 let endpoint = pg_client::config::Endpoint::Network {
293 host: pg_client::config::Host::IpAddr(ip_address),
294 channel_binding,
295 host_addr: None,
296 port: Some(self.host_port),
297 };
298
299 self.client_config.clone().endpoint(endpoint)
300 }
301
302 #[must_use]
303 pub fn pg_env(&self) -> std::collections::BTreeMap<cmd_proc::EnvVariableName<'static>, String> {
304 self.client_config.to_pg_env()
305 }
306
307 #[must_use]
308 pub fn database_url(&self) -> String {
309 self.client_config.to_url_string()
310 }
311
312 pub async fn stop(&mut self) {
313 self.container.stop().await
314 }
315
316 pub(crate) async fn stop_commit_remove(&mut self, reference: &ociman::Reference) {
319 self.container.stop().await;
320 self.container.commit(reference, false).await.unwrap();
321 self.container.remove().await;
322 }
323
324 async fn wait_for_container_socket(&self) -> Result<(), Error> {
325 let start = std::time::Instant::now();
326 let max_duration = self.wait_available_timeout;
327 let sleep_duration = std::time::Duration::from_millis(100);
328
329 while start.elapsed() <= max_duration {
330 if self
331 .container
332 .exec("pg_isready")
333 .argument("--host")
334 .argument("localhost")
335 .build()
336 .stdout_capture()
337 .bytes()
338 .await
339 .is_ok()
340 {
341 return Ok(());
342 }
343 tokio::time::sleep(sleep_duration).await;
344 }
345
346 Err(Error::ConnectionTimeout {
347 timeout: max_duration,
348 source: None,
349 })
350 }
351
352 pub async fn set_superuser_password(
357 &self,
358 password: &pg_client::config::Password,
359 ) -> Result<(), Error> {
360 self.wait_for_container_socket().await?;
361
362 self.container
363 .exec("psql")
364 .argument("--host")
365 .argument("/var/run/postgresql")
366 .argument("--username")
367 .argument(self.client_config.session.user.as_ref())
368 .argument("--dbname")
369 .argument("postgres")
370 .argument("--variable")
371 .argument(format!(
372 "target_user={}",
373 self.client_config.session.user.as_ref()
374 ))
375 .argument("--variable")
376 .argument(format!("new_password={}", password.as_ref()))
377 .stdin("ALTER USER :target_user WITH PASSWORD :'new_password'")
378 .build()
379 .stdout_capture()
380 .bytes()
381 .await?;
382
383 Ok(())
384 }
385}
386
387fn generate_password() -> pg_client::config::Password {
388 let value: String = rand::rng()
389 .sample_iter(rand::distr::Alphanumeric)
390 .take(32)
391 .map(char::from)
392 .collect();
393
394 <pg_client::config::Password as std::str::FromStr>::from_str(&value).unwrap()
395}
396
397#[allow(clippy::too_many_arguments)]
398async fn run_container(
399 ociman_definition: ociman::Definition,
400 cross_container_access: bool,
401 ssl_config: &Option<definition::SslConfig>,
402 backend: &ociman::Backend,
403 application_name: &Option<pg_client::config::ApplicationName>,
404 database: &pg_client::Database,
405 password: &pg_client::config::Password,
406 user: &pg_client::User,
407 wait_available_timeout: std::time::Duration,
408 remove: bool,
409) -> Container {
410 let backend = backend.clone();
411 let host_ip = if cross_container_access {
412 UNSPECIFIED_IP
413 } else {
414 LOCALHOST_IP
415 };
416
417 let mut ociman_definition = ociman_definition
418 .environment_variable(ENV_PGDATA, "/var/lib/pg-ephemeral")
419 .publish(ociman::Publish::tcp(5432).host_ip(host_ip));
420
421 if remove {
422 ociman_definition = ociman_definition.remove();
423 }
424
425 let ssl_bundle = if let Some(ssl_config) = ssl_config {
426 let hostname = match ssl_config {
427 definition::SslConfig::Generated { hostname } => hostname.as_str(),
428 };
429
430 let bundle = certificate::Bundle::generate(hostname)
431 .expect("Failed to generate SSL certificate bundle");
432
433 let ssl_dir = "/var/lib/postgresql";
434
435 ociman_definition = ociman_definition
436 .entrypoint("sh")
437 .argument("-e")
438 .argument("-c")
439 .argument(SSL_SETUP_SCRIPT)
440 .argument("--")
441 .argument("postgres")
442 .argument("--ssl=on")
443 .argument(format!("--ssl_cert_file={ssl_dir}/server.crt"))
444 .argument(format!("--ssl_key_file={ssl_dir}/server.key"))
445 .argument(format!("--ssl_ca_file={ssl_dir}/root.crt"))
446 .environment_variable(ENV_PG_EPHEMERAL_SSL_DIR, ssl_dir)
447 .environment_variable(ENV_PG_EPHEMERAL_CA_CERT_PEM, &bundle.ca_cert_pem)
448 .environment_variable(ENV_PG_EPHEMERAL_SERVER_CERT_PEM, &bundle.server_cert_pem)
449 .environment_variable(ENV_PG_EPHEMERAL_SERVER_KEY_PEM, &bundle.server_key_pem);
450
451 Some(bundle)
452 } else {
453 None
454 };
455
456 let container = ociman_definition.run_detached().await;
457
458 let port: pg_client::config::Port = container
459 .read_host_tcp_port(5432)
460 .await
461 .expect("port 5432 not published")
462 .into();
463
464 let (host, host_addr, ssl_mode, ssl_root_cert) = if let Some(ssl_config) = ssl_config {
465 let hostname = match ssl_config {
466 definition::SslConfig::Generated { hostname } => hostname.clone(),
467 };
468
469 let timestamp = std::time::SystemTime::now()
470 .duration_since(std::time::UNIX_EPOCH)
471 .unwrap()
472 .as_nanos();
473 let ca_cert_path = std::env::temp_dir().join(format!("pg_ephemeral_ca_{timestamp}.crt"));
474 std::fs::write(&ca_cert_path, &ssl_bundle.as_ref().unwrap().ca_cert_pem)
475 .expect("Failed to write CA certificate to temp file");
476
477 (
478 pg_client::config::Host::HostName(hostname),
479 Some(LOCALHOST_HOST_ADDR),
480 pg_client::config::SslMode::VerifyFull,
481 Some(pg_client::config::SslRootCert::File(ca_cert_path)),
482 )
483 } else {
484 (
485 pg_client::config::Host::IpAddr(LOCALHOST_IP),
486 None,
487 pg_client::config::SslMode::Disable,
488 None,
489 )
490 };
491
492 let client_config = pg_client::Config {
493 endpoint: pg_client::config::Endpoint::Network {
494 host,
495 channel_binding: None,
496 host_addr,
497 port: Some(port),
498 },
499 session: pg_client::config::Session {
500 application_name: application_name.clone(),
501 database: database.clone(),
502 password: Some(password.clone()),
503 user: user.clone(),
504 },
505 ssl_mode,
506 ssl_root_cert,
507 sqlx: Default::default(),
508 };
509
510 Container {
511 host_port: port,
512 container,
513 backend,
514 client_config,
515 wait_available_timeout,
516 }
517}