Skip to main content

pg_ephemeral/
container.rs

1use rand::RngExt;
2
3use crate::LOCALHOST_HOST_ADDR;
4use crate::LOCALHOST_IP;
5use crate::UNSPECIFIED_IP;
6use crate::certificate;
7use crate::definition;
8
9pub const PGDATA: &str = "/var/lib/pg-ephemeral";
10
11#[derive(Debug, thiserror::Error)]
12pub enum Error {
13    #[error("PostgreSQL did not become available within {timeout:?}")]
14    ConnectionTimeout {
15        timeout: std::time::Duration,
16        #[source]
17        source: Option<sqlx::Error>,
18    },
19    #[error("Failed to execute command in container")]
20    ContainerExec(#[from] cmd_proc::CommandError),
21    #[error(transparent)]
22    SeedApply(#[from] crate::definition::SeedApplyError),
23    #[error(transparent)]
24    SeedLoad(#[from] crate::seed::LoadError),
25}
26const ENV_POSTGRES_PASSWORD: cmd_proc::EnvVariableName<'static> =
27    cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_PASSWORD");
28const ENV_POSTGRES_USER: cmd_proc::EnvVariableName<'static> =
29    cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_USER");
30const ENV_PGDATA: cmd_proc::EnvVariableName<'static> =
31    cmd_proc::EnvVariableName::from_static_or_panic("PGDATA");
32const ENV_PG_EPHEMERAL_SSL_DIR: cmd_proc::EnvVariableName<'static> =
33    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SSL_DIR");
34const ENV_PG_EPHEMERAL_CA_CERT_PEM: cmd_proc::EnvVariableName<'static> =
35    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_CA_CERT_PEM");
36const ENV_PG_EPHEMERAL_SERVER_CERT_PEM: cmd_proc::EnvVariableName<'static> =
37    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_CERT_PEM");
38const ENV_PG_EPHEMERAL_SERVER_KEY_PEM: cmd_proc::EnvVariableName<'static> =
39    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_KEY_PEM");
40
41const SSL_SETUP_SCRIPT: &str = r#"
42printf '%s' "$PG_EPHEMERAL_CA_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/root.crt
43printf '%s' "$PG_EPHEMERAL_SERVER_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.crt
44printf '%s' "$PG_EPHEMERAL_SERVER_KEY_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.key
45chown postgres ${PG_EPHEMERAL_SSL_DIR}/root.crt
46chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.crt
47chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.key
48chmod 600 ${PG_EPHEMERAL_SSL_DIR}/root.crt
49chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.crt
50chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.key
51exec docker-entrypoint.sh "$@"
52"#;
53
54/// Low-level container definition for running a pre-initialized PostgreSQL image.
55///
56/// All fields are assumed to represent values already stored in the referenced image.
57/// No validation is performed - the caller is responsible for ensuring the credentials
58/// and configuration match what exists in the image.
59#[derive(Debug)]
60pub struct Definition {
61    pub image: ociman::image::Reference,
62    pub password: pg_client::Password,
63    pub user: pg_client::User,
64    pub database: pg_client::Database,
65    pub backend: ociman::Backend,
66    pub cross_container_access: bool,
67    pub application_name: Option<pg_client::ApplicationName>,
68    pub ssl_config: Option<definition::SslConfig>,
69    pub wait_available_timeout: std::time::Duration,
70}
71
72#[derive(Debug)]
73pub struct Container {
74    host_port: pg_client::Port,
75    pub(crate) client_config: pg_client::Config,
76    container: ociman::Container,
77    backend: ociman::Backend,
78    wait_available_timeout: std::time::Duration,
79}
80
81impl Container {
82    pub(crate) async fn run_definition(definition: &crate::definition::Definition) -> Self {
83        let password = generate_password();
84
85        let ociman_definition = definition
86            .to_ociman_definition()
87            .environment_variable(ENV_POSTGRES_PASSWORD, password.as_ref())
88            .environment_variable(ENV_POSTGRES_USER, definition.superuser.as_ref());
89
90        run_container(
91            ociman_definition,
92            definition.cross_container_access,
93            &definition.ssl_config,
94            &definition.backend,
95            &definition.application_name,
96            &definition.database,
97            &password,
98            &definition.superuser,
99            definition.wait_available_timeout,
100            definition.remove,
101        )
102        .await
103    }
104
105    pub async fn run_container_definition(definition: &Definition) -> Self {
106        let ociman_definition =
107            ociman::Definition::new(definition.backend.clone(), definition.image.clone());
108
109        run_container(
110            ociman_definition,
111            definition.cross_container_access,
112            &definition.ssl_config,
113            &definition.backend,
114            &definition.application_name,
115            &definition.database,
116            &definition.password,
117            &definition.user,
118            definition.wait_available_timeout,
119            true, // Always remove containers when using low-level API
120        )
121        .await
122    }
123
124    pub async fn wait_available(&self) -> Result<(), Error> {
125        let config = self.client_config.to_sqlx_connect_options().unwrap();
126
127        let start = std::time::Instant::now();
128        let max_duration = self.wait_available_timeout;
129        let sleep_duration = std::time::Duration::from_millis(100);
130
131        let mut last_error: Option<sqlx::Error> = None;
132
133        while start.elapsed() <= max_duration {
134            log::trace!("connection attempt");
135            match sqlx::ConnectOptions::connect(&config).await {
136                Ok(connection) => {
137                    sqlx::Connection::close(connection)
138                        .await
139                        .expect("connection close failed");
140
141                    log::debug!(
142                        "pg is available on endpoint: {:#?}",
143                        self.client_config.endpoint
144                    );
145
146                    return Ok(());
147                }
148                Err(error) => {
149                    log::trace!("{error:#?}, retry in 100ms");
150                    last_error = Some(error);
151                }
152            }
153            tokio::time::sleep(sleep_duration).await;
154        }
155
156        Err(Error::ConnectionTimeout {
157            timeout: max_duration,
158            source: last_error,
159        })
160    }
161
162    pub async fn exec_schema_dump(&self, pg_schema_dump: &pg_client::PgSchemaDump) -> String {
163        let output = self
164            .container
165            .exec("pg_dump")
166            .arguments(pg_schema_dump.arguments())
167            .environment_variables(self.container_client_config().to_pg_env())
168            .build()
169            .stdout_capture()
170            .bytes()
171            .await
172            .unwrap();
173        crate::convert_schema(&output)
174    }
175
176    #[must_use]
177    pub fn client_config(&self) -> &pg_client::Config {
178        &self.client_config
179    }
180
181    pub async fn with_connection<T, F: AsyncFnMut(&mut sqlx::postgres::PgConnection) -> T>(
182        &self,
183        mut action: F,
184    ) -> T {
185        self.client_config
186            .with_sqlx_connection(async |connection| action(connection).await)
187            .await
188            .unwrap()
189    }
190
191    pub async fn apply_sql(&self, sql: &str) -> Result<(), sqlx::Error> {
192        self.with_connection(async |connection| {
193            log::debug!("Executing: {sql}");
194            sqlx::raw_sql(sqlx::AssertSqlSafe(sql))
195                .execute(connection)
196                .await
197                .map(|_| ())
198        })
199        .await
200    }
201
202    pub(crate) async fn exec_container_script(
203        &self,
204        script: &str,
205    ) -> Result<(), cmd_proc::CommandError> {
206        self.container
207            .exec("sh")
208            .arguments(["-e", "-c", script])
209            .build()
210            .status()
211            .await
212    }
213
214    pub(crate) async fn exec_container_shell(&self) {
215        self.container
216            .exec("sh")
217            .environment_variables(self.container_client_config().to_pg_env())
218            .interactive()
219            .status()
220            .await
221            .unwrap();
222    }
223
224    pub(crate) async fn exec_psql(&self) {
225        self.container
226            .exec("psql")
227            .environment_variables(self.container_client_config().to_pg_env())
228            .interactive()
229            .status()
230            .await
231            .unwrap();
232    }
233
234    fn container_client_config(&self) -> pg_client::Config {
235        let mut config = self.client_config.clone();
236        if let pg_client::Endpoint::Network {
237            ref host,
238            ref channel_binding,
239            ref host_addr,
240            ..
241        } = config.endpoint
242        {
243            config.endpoint = pg_client::Endpoint::Network {
244                host: host.clone(),
245                channel_binding: *channel_binding,
246                host_addr: host_addr.clone(),
247                port: Some(pg_client::Port::new(5432)),
248            };
249        }
250        config
251    }
252
253    pub async fn cross_container_client_config(&self) -> pg_client::Config {
254        // Resolve the container host from inside a container
255        // This DNS name only works from inside containers, not from the host
256        let ip_address = self
257            .backend
258            .resolve_container_host()
259            .await
260            .expect("Failed to resolve container host from container");
261
262        let channel_binding = match &self.client_config.endpoint {
263            pg_client::Endpoint::Network {
264                channel_binding, ..
265            } => *channel_binding,
266            pg_client::Endpoint::SocketPath(_) => None,
267        };
268
269        let endpoint = pg_client::Endpoint::Network {
270            host: pg_client::Host::IpAddr(ip_address),
271            channel_binding,
272            host_addr: None,
273            port: Some(self.host_port),
274        };
275
276        self.client_config.clone().endpoint(endpoint)
277    }
278
279    #[must_use]
280    pub fn pg_env(&self) -> std::collections::BTreeMap<cmd_proc::EnvVariableName<'static>, String> {
281        self.client_config.to_pg_env()
282    }
283
284    #[must_use]
285    pub fn database_url(&self) -> String {
286        self.client_config.to_url_string()
287    }
288
289    pub async fn stop(&mut self) {
290        self.container.stop().await
291    }
292
293    async fn terminate_connections(&self) {
294        let sql =
295            "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid()";
296
297        if let Err(error) = self.apply_sql(sql).await {
298            log::debug!("Failed to terminate connections: {error}");
299        }
300    }
301
302    /// Stop the container (clean PostgreSQL shutdown), commit it to an image,
303    /// and remove the stopped container.
304    pub(crate) async fn stop_commit_remove(&mut self, reference: &ociman::Reference) {
305        self.terminate_connections().await;
306        self.container.stop().await;
307        self.container.commit(reference, false).await.unwrap();
308        self.container.remove().await;
309    }
310
311    async fn wait_for_container_socket(&self) -> Result<(), Error> {
312        let start = std::time::Instant::now();
313        let max_duration = self.wait_available_timeout;
314        let sleep_duration = std::time::Duration::from_millis(100);
315
316        while start.elapsed() <= max_duration {
317            if self
318                .container
319                .exec("pg_isready")
320                .argument("--host")
321                .argument("localhost")
322                .build()
323                .stdout_capture()
324                .bytes()
325                .await
326                .is_ok()
327            {
328                return Ok(());
329            }
330            tokio::time::sleep(sleep_duration).await;
331        }
332
333        Err(Error::ConnectionTimeout {
334            timeout: max_duration,
335            source: None,
336        })
337    }
338
339    /// Set the superuser password using peer authentication via Unix domain socket.
340    ///
341    /// This is useful when resuming from a cached image where the password
342    /// doesn't match the newly generated one.
343    pub async fn set_superuser_password(
344        &self,
345        password: &pg_client::Password,
346    ) -> Result<(), Error> {
347        self.wait_for_container_socket().await?;
348
349        self.container
350            .exec("psql")
351            .argument("--host")
352            .argument("/var/run/postgresql")
353            .argument("--username")
354            .argument(self.client_config.user.as_ref())
355            .argument("--dbname")
356            .argument("postgres")
357            .argument("--variable")
358            .argument(format!("target_user={}", self.client_config.user.as_ref()))
359            .argument("--variable")
360            .argument(format!("new_password={}", password.as_ref()))
361            .stdin("ALTER USER :target_user WITH PASSWORD :'new_password'")
362            .build()
363            .stdout_capture()
364            .bytes()
365            .await?;
366
367        Ok(())
368    }
369}
370
371fn generate_password() -> pg_client::Password {
372    let value: String = rand::rng()
373        .sample_iter(rand::distr::Alphanumeric)
374        .take(32)
375        .map(char::from)
376        .collect();
377
378    <pg_client::Password as std::str::FromStr>::from_str(&value).unwrap()
379}
380
381#[allow(clippy::too_many_arguments)]
382async fn run_container(
383    ociman_definition: ociman::Definition,
384    cross_container_access: bool,
385    ssl_config: &Option<definition::SslConfig>,
386    backend: &ociman::Backend,
387    application_name: &Option<pg_client::ApplicationName>,
388    database: &pg_client::Database,
389    password: &pg_client::Password,
390    user: &pg_client::User,
391    wait_available_timeout: std::time::Duration,
392    remove: bool,
393) -> Container {
394    let backend = backend.clone();
395    let host_ip = if cross_container_access {
396        UNSPECIFIED_IP
397    } else {
398        LOCALHOST_IP
399    };
400
401    let mut ociman_definition = ociman_definition
402        .environment_variable(ENV_PGDATA, "/var/lib/pg-ephemeral")
403        .publish(ociman::Publish::tcp(5432).host_ip(host_ip));
404
405    if remove {
406        ociman_definition = ociman_definition.remove();
407    }
408
409    let ssl_bundle = if let Some(ssl_config) = ssl_config {
410        let hostname = match ssl_config {
411            definition::SslConfig::Generated { hostname } => hostname.as_str(),
412        };
413
414        let bundle = certificate::Bundle::generate(hostname)
415            .expect("Failed to generate SSL certificate bundle");
416
417        let ssl_dir = "/var/lib/postgresql";
418
419        ociman_definition = ociman_definition
420            .entrypoint("sh")
421            .argument("-e")
422            .argument("-c")
423            .argument(SSL_SETUP_SCRIPT)
424            .argument("--")
425            .argument("postgres")
426            .argument("--ssl=on")
427            .argument(format!("--ssl_cert_file={ssl_dir}/server.crt"))
428            .argument(format!("--ssl_key_file={ssl_dir}/server.key"))
429            .argument(format!("--ssl_ca_file={ssl_dir}/root.crt"))
430            .environment_variable(ENV_PG_EPHEMERAL_SSL_DIR, ssl_dir)
431            .environment_variable(ENV_PG_EPHEMERAL_CA_CERT_PEM, &bundle.ca_cert_pem)
432            .environment_variable(ENV_PG_EPHEMERAL_SERVER_CERT_PEM, &bundle.server_cert_pem)
433            .environment_variable(ENV_PG_EPHEMERAL_SERVER_KEY_PEM, &bundle.server_key_pem);
434
435        Some(bundle)
436    } else {
437        None
438    };
439
440    let container = ociman_definition.run_detached().await;
441
442    let port: pg_client::Port = container
443        .read_host_tcp_port(5432)
444        .await
445        .expect("port 5432 not published")
446        .into();
447
448    let (host, host_addr, ssl_mode, ssl_root_cert) = if let Some(ssl_config) = ssl_config {
449        let hostname = match ssl_config {
450            definition::SslConfig::Generated { hostname } => hostname.clone(),
451        };
452
453        let timestamp = std::time::SystemTime::now()
454            .duration_since(std::time::UNIX_EPOCH)
455            .unwrap()
456            .as_nanos();
457        let ca_cert_path = std::env::temp_dir().join(format!("pg_ephemeral_ca_{timestamp}.crt"));
458        std::fs::write(&ca_cert_path, &ssl_bundle.as_ref().unwrap().ca_cert_pem)
459            .expect("Failed to write CA certificate to temp file");
460
461        (
462            pg_client::Host::HostName(hostname),
463            Some(LOCALHOST_HOST_ADDR),
464            pg_client::SslMode::VerifyFull,
465            Some(pg_client::SslRootCert::File(ca_cert_path)),
466        )
467    } else {
468        (
469            pg_client::Host::IpAddr(LOCALHOST_IP),
470            None,
471            pg_client::SslMode::Disable,
472            None,
473        )
474    };
475
476    let client_config = pg_client::Config {
477        application_name: application_name.clone(),
478        database: database.clone(),
479        endpoint: pg_client::Endpoint::Network {
480            host,
481            channel_binding: None,
482            host_addr,
483            port: Some(port),
484        },
485        password: Some(password.clone()),
486        ssl_mode,
487        ssl_root_cert,
488        user: user.clone(),
489    };
490
491    Container {
492        host_port: port,
493        container,
494        backend,
495        client_config,
496        wait_available_timeout,
497    }
498}