Skip to main content

pg_ephemeral/
container.rs

1use rand::RngExt;
2
3use crate::LOCALHOST_HOST_ADDR;
4use crate::LOCALHOST_IP;
5use crate::UNSPECIFIED_IP;
6use crate::certificate;
7use crate::definition;
8
9pub const PGDATA: &str = "/var/lib/pg-ephemeral";
10
11#[derive(Debug, thiserror::Error)]
12pub enum Error {
13    #[error("PostgreSQL did not become available within {timeout:?}")]
14    ConnectionTimeout {
15        timeout: std::time::Duration,
16        #[source]
17        source: Option<sqlx::Error>,
18    },
19    #[error("Failed to execute command in container")]
20    ContainerExec(#[from] cmd_proc::CommandError),
21    #[error(transparent)]
22    SeedApply(#[from] crate::definition::SeedApplyError),
23    #[error(transparent)]
24    SeedLoad(#[from] crate::seed::LoadError),
25    #[error("Failed to terminate backend connections")]
26    TerminateConnections(#[source] sqlx::Error),
27    #[error("Failed to checkpoint")]
28    Checkpoint(#[source] sqlx::Error),
29}
30const ENV_POSTGRES_PASSWORD: cmd_proc::EnvVariableName<'static> =
31    cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_PASSWORD");
32const ENV_POSTGRES_USER: cmd_proc::EnvVariableName<'static> =
33    cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_USER");
34const ENV_PGDATA: cmd_proc::EnvVariableName<'static> =
35    cmd_proc::EnvVariableName::from_static_or_panic("PGDATA");
36const ENV_PG_EPHEMERAL_SSL_DIR: cmd_proc::EnvVariableName<'static> =
37    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SSL_DIR");
38const ENV_PG_EPHEMERAL_CA_CERT_PEM: cmd_proc::EnvVariableName<'static> =
39    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_CA_CERT_PEM");
40const ENV_PG_EPHEMERAL_SERVER_CERT_PEM: cmd_proc::EnvVariableName<'static> =
41    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_CERT_PEM");
42const ENV_PG_EPHEMERAL_SERVER_KEY_PEM: cmd_proc::EnvVariableName<'static> =
43    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_KEY_PEM");
44
45const SSL_SETUP_SCRIPT: &str = r#"
46printf '%s' "$PG_EPHEMERAL_CA_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/root.crt
47printf '%s' "$PG_EPHEMERAL_SERVER_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.crt
48printf '%s' "$PG_EPHEMERAL_SERVER_KEY_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.key
49chown postgres ${PG_EPHEMERAL_SSL_DIR}/root.crt
50chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.crt
51chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.key
52chmod 600 ${PG_EPHEMERAL_SSL_DIR}/root.crt
53chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.crt
54chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.key
55exec docker-entrypoint.sh "$@"
56"#;
57
58/// Low-level container definition for running a pre-initialized PostgreSQL image.
59///
60/// All fields are assumed to represent values already stored in the referenced image.
61/// No validation is performed - the caller is responsible for ensuring the credentials
62/// and configuration match what exists in the image.
63#[derive(Debug)]
64pub struct Definition {
65    pub image: ociman::image::Reference,
66    pub password: pg_client::config::Password,
67    pub user: pg_client::User,
68    pub database: pg_client::Database,
69    pub backend: ociman::Backend,
70    pub cross_container_access: bool,
71    pub application_name: Option<pg_client::config::ApplicationName>,
72    pub ssl_config: Option<definition::SslConfig>,
73    pub wait_available_timeout: std::time::Duration,
74}
75
76#[derive(Debug)]
77pub struct Container {
78    host_port: pg_client::config::Port,
79    pub(crate) client_config: pg_client::Config,
80    container: ociman::Container,
81    backend: ociman::Backend,
82    wait_available_timeout: std::time::Duration,
83}
84
85impl Container {
86    pub(crate) async fn run_definition(definition: &crate::definition::Definition) -> Self {
87        let password = generate_password();
88
89        let ociman_definition = definition
90            .to_ociman_definition()
91            .environment_variable(ENV_POSTGRES_PASSWORD, password.as_ref())
92            .environment_variable(ENV_POSTGRES_USER, definition.superuser.as_ref());
93
94        run_container(
95            ociman_definition,
96            definition.cross_container_access,
97            &definition.ssl_config,
98            &definition.backend,
99            &definition.application_name,
100            &definition.database,
101            &password,
102            &definition.superuser,
103            definition.wait_available_timeout,
104            definition.remove,
105        )
106        .await
107    }
108
109    pub async fn run_container_definition(definition: &Definition) -> Self {
110        let ociman_definition =
111            ociman::Definition::new(definition.backend.clone(), definition.image.clone());
112
113        run_container(
114            ociman_definition,
115            definition.cross_container_access,
116            &definition.ssl_config,
117            &definition.backend,
118            &definition.application_name,
119            &definition.database,
120            &definition.password,
121            &definition.user,
122            definition.wait_available_timeout,
123            true, // Always remove containers when using low-level API
124        )
125        .await
126    }
127
128    pub async fn wait_available(&self) -> Result<(), Error> {
129        let config = self.client_config.to_sqlx_connect_options().unwrap();
130
131        let start = std::time::Instant::now();
132        let max_duration = self.wait_available_timeout;
133        let sleep_duration = std::time::Duration::from_millis(100);
134
135        let mut last_error: Option<sqlx::Error> = None;
136
137        while start.elapsed() <= max_duration {
138            log::trace!("connection attempt");
139            match sqlx::ConnectOptions::connect(&config).await {
140                Ok(connection) => {
141                    sqlx::Connection::close(connection)
142                        .await
143                        .expect("connection close failed");
144
145                    log::debug!(
146                        "pg is available on endpoint: {:#?}",
147                        self.client_config.endpoint
148                    );
149
150                    return Ok(());
151                }
152                Err(error) => {
153                    log::trace!("{error:#?}, retry in 100ms");
154                    last_error = Some(error);
155                }
156            }
157            tokio::time::sleep(sleep_duration).await;
158        }
159
160        Err(Error::ConnectionTimeout {
161            timeout: max_duration,
162            source: last_error,
163        })
164    }
165
166    pub async fn exec_schema_dump(&self, pg_schema_dump: &pg_client::PgSchemaDump) -> String {
167        let output = self
168            .container
169            .exec("pg_dump")
170            .arguments(pg_schema_dump.arguments())
171            .environment_variables(self.container_client_config().to_pg_env())
172            .build()
173            .stdout_capture()
174            .bytes()
175            .await
176            .unwrap();
177        crate::convert_schema(&output)
178    }
179
180    #[must_use]
181    pub fn client_config(&self) -> &pg_client::Config {
182        &self.client_config
183    }
184
185    pub async fn with_connection<T, F: AsyncFnMut(&mut sqlx::postgres::PgConnection) -> T>(
186        &self,
187        mut action: F,
188    ) -> T {
189        self.client_config
190            .with_sqlx_connection(async |connection| action(connection).await)
191            .await
192            .unwrap()
193    }
194
195    pub async fn apply_sql(&self, sql: &str) -> Result<(), sqlx::Error> {
196        self.with_connection(async |connection| {
197            log::debug!("Executing: {sql}");
198            sqlx::raw_sql(sqlx::AssertSqlSafe(sql))
199                .execute(connection)
200                .await
201                .map(|_| ())
202        })
203        .await
204    }
205
206    pub(crate) async fn exec_container_script(
207        &self,
208        script: &str,
209    ) -> Result<(), cmd_proc::CommandError> {
210        self.container
211            .exec("sh")
212            .arguments(["-e", "-c", script])
213            .build()
214            .status()
215            .await
216    }
217
218    pub(crate) async fn exec_container_shell(&self) {
219        self.container
220            .exec("sh")
221            .environment_variables(self.container_client_config().to_pg_env())
222            .interactive()
223            .status()
224            .await
225            .unwrap();
226    }
227
228    pub(crate) async fn exec_psql(&self) {
229        self.container
230            .exec("psql")
231            .environment_variables(self.container_client_config().to_pg_env())
232            .interactive()
233            .status()
234            .await
235            .unwrap();
236    }
237
238    fn container_client_config(&self) -> pg_client::Config {
239        let mut config = self.client_config.clone();
240        if let pg_client::config::Endpoint::Network {
241            ref host,
242            ref channel_binding,
243            ref host_addr,
244            ..
245        } = config.endpoint
246        {
247            config.endpoint = pg_client::config::Endpoint::Network {
248                host: host.clone(),
249                channel_binding: *channel_binding,
250                host_addr: host_addr.clone(),
251                port: Some(pg_client::config::Port::new(5432)),
252            };
253        }
254        config
255    }
256
257    pub async fn cross_container_client_config(&self) -> pg_client::Config {
258        // Resolve the container host from inside a container
259        // This DNS name only works from inside containers, not from the host
260        let ip_address = self
261            .backend
262            .resolve_container_host()
263            .await
264            .expect("Failed to resolve container host from container");
265
266        let channel_binding = match &self.client_config.endpoint {
267            pg_client::config::Endpoint::Network {
268                channel_binding, ..
269            } => *channel_binding,
270            pg_client::config::Endpoint::SocketPath(_) => None,
271        };
272
273        let endpoint = pg_client::config::Endpoint::Network {
274            host: pg_client::config::Host::IpAddr(ip_address),
275            channel_binding,
276            host_addr: None,
277            port: Some(self.host_port),
278        };
279
280        self.client_config.clone().endpoint(endpoint)
281    }
282
283    #[must_use]
284    pub fn pg_env(&self) -> std::collections::BTreeMap<cmd_proc::EnvVariableName<'static>, String> {
285        self.client_config.to_pg_env()
286    }
287
288    #[must_use]
289    pub fn database_url(&self) -> String {
290        self.client_config.to_url_string()
291    }
292
293    pub async fn stop(&mut self) {
294        self.container.stop().await
295    }
296
297    async fn terminate_connections(&self) -> Result<(), Error> {
298        self.apply_sql(
299            "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid()",
300        )
301        .await
302        .map_err(Error::TerminateConnections)
303    }
304
305    async fn checkpoint(&self) -> Result<(), Error> {
306        self.apply_sql("CHECKPOINT")
307            .await
308            .map_err(Error::Checkpoint)
309    }
310
311    /// Stop the container (clean PostgreSQL shutdown), commit it to an image,
312    /// and remove the stopped container.
313    pub(crate) async fn stop_commit_remove(
314        &mut self,
315        reference: &ociman::Reference,
316    ) -> Result<(), Error> {
317        self.terminate_connections().await?;
318        self.checkpoint().await?;
319        self.container.stop().await;
320        self.container.commit(reference, false).await.unwrap();
321        self.container.remove().await;
322        Ok(())
323    }
324
325    async fn wait_for_container_socket(&self) -> Result<(), Error> {
326        let start = std::time::Instant::now();
327        let max_duration = self.wait_available_timeout;
328        let sleep_duration = std::time::Duration::from_millis(100);
329
330        while start.elapsed() <= max_duration {
331            if self
332                .container
333                .exec("pg_isready")
334                .argument("--host")
335                .argument("localhost")
336                .build()
337                .stdout_capture()
338                .bytes()
339                .await
340                .is_ok()
341            {
342                return Ok(());
343            }
344            tokio::time::sleep(sleep_duration).await;
345        }
346
347        Err(Error::ConnectionTimeout {
348            timeout: max_duration,
349            source: None,
350        })
351    }
352
353    /// Set the superuser password using peer authentication via Unix domain socket.
354    ///
355    /// This is useful when resuming from a cached image where the password
356    /// doesn't match the newly generated one.
357    pub async fn set_superuser_password(
358        &self,
359        password: &pg_client::config::Password,
360    ) -> Result<(), Error> {
361        self.wait_for_container_socket().await?;
362
363        self.container
364            .exec("psql")
365            .argument("--host")
366            .argument("/var/run/postgresql")
367            .argument("--username")
368            .argument(self.client_config.session.user.as_ref())
369            .argument("--dbname")
370            .argument("postgres")
371            .argument("--variable")
372            .argument(format!(
373                "target_user={}",
374                self.client_config.session.user.as_ref()
375            ))
376            .argument("--variable")
377            .argument(format!("new_password={}", password.as_ref()))
378            .stdin("ALTER USER :target_user WITH PASSWORD :'new_password'")
379            .build()
380            .stdout_capture()
381            .bytes()
382            .await?;
383
384        Ok(())
385    }
386}
387
388fn generate_password() -> pg_client::config::Password {
389    let value: String = rand::rng()
390        .sample_iter(rand::distr::Alphanumeric)
391        .take(32)
392        .map(char::from)
393        .collect();
394
395    <pg_client::config::Password as std::str::FromStr>::from_str(&value).unwrap()
396}
397
398#[allow(clippy::too_many_arguments)]
399async fn run_container(
400    ociman_definition: ociman::Definition,
401    cross_container_access: bool,
402    ssl_config: &Option<definition::SslConfig>,
403    backend: &ociman::Backend,
404    application_name: &Option<pg_client::config::ApplicationName>,
405    database: &pg_client::Database,
406    password: &pg_client::config::Password,
407    user: &pg_client::User,
408    wait_available_timeout: std::time::Duration,
409    remove: bool,
410) -> Container {
411    let backend = backend.clone();
412    let host_ip = if cross_container_access {
413        UNSPECIFIED_IP
414    } else {
415        LOCALHOST_IP
416    };
417
418    let mut ociman_definition = ociman_definition
419        .environment_variable(ENV_PGDATA, "/var/lib/pg-ephemeral")
420        .publish(ociman::Publish::tcp(5432).host_ip(host_ip));
421
422    if remove {
423        ociman_definition = ociman_definition.remove();
424    }
425
426    let ssl_bundle = if let Some(ssl_config) = ssl_config {
427        let hostname = match ssl_config {
428            definition::SslConfig::Generated { hostname } => hostname.as_str(),
429        };
430
431        let bundle = certificate::Bundle::generate(hostname)
432            .expect("Failed to generate SSL certificate bundle");
433
434        let ssl_dir = "/var/lib/postgresql";
435
436        ociman_definition = ociman_definition
437            .entrypoint("sh")
438            .argument("-e")
439            .argument("-c")
440            .argument(SSL_SETUP_SCRIPT)
441            .argument("--")
442            .argument("postgres")
443            .argument("--ssl=on")
444            .argument(format!("--ssl_cert_file={ssl_dir}/server.crt"))
445            .argument(format!("--ssl_key_file={ssl_dir}/server.key"))
446            .argument(format!("--ssl_ca_file={ssl_dir}/root.crt"))
447            .environment_variable(ENV_PG_EPHEMERAL_SSL_DIR, ssl_dir)
448            .environment_variable(ENV_PG_EPHEMERAL_CA_CERT_PEM, &bundle.ca_cert_pem)
449            .environment_variable(ENV_PG_EPHEMERAL_SERVER_CERT_PEM, &bundle.server_cert_pem)
450            .environment_variable(ENV_PG_EPHEMERAL_SERVER_KEY_PEM, &bundle.server_key_pem);
451
452        Some(bundle)
453    } else {
454        None
455    };
456
457    let container = ociman_definition.run_detached().await;
458
459    let port: pg_client::config::Port = container
460        .read_host_tcp_port(5432)
461        .await
462        .expect("port 5432 not published")
463        .into();
464
465    let (host, host_addr, ssl_mode, ssl_root_cert) = if let Some(ssl_config) = ssl_config {
466        let hostname = match ssl_config {
467            definition::SslConfig::Generated { hostname } => hostname.clone(),
468        };
469
470        let timestamp = std::time::SystemTime::now()
471            .duration_since(std::time::UNIX_EPOCH)
472            .unwrap()
473            .as_nanos();
474        let ca_cert_path = std::env::temp_dir().join(format!("pg_ephemeral_ca_{timestamp}.crt"));
475        std::fs::write(&ca_cert_path, &ssl_bundle.as_ref().unwrap().ca_cert_pem)
476            .expect("Failed to write CA certificate to temp file");
477
478        (
479            pg_client::config::Host::HostName(hostname),
480            Some(LOCALHOST_HOST_ADDR),
481            pg_client::config::SslMode::VerifyFull,
482            Some(pg_client::config::SslRootCert::File(ca_cert_path)),
483        )
484    } else {
485        (
486            pg_client::config::Host::IpAddr(LOCALHOST_IP),
487            None,
488            pg_client::config::SslMode::Disable,
489            None,
490        )
491    };
492
493    let client_config = pg_client::Config {
494        endpoint: pg_client::config::Endpoint::Network {
495            host,
496            channel_binding: None,
497            host_addr,
498            port: Some(port),
499        },
500        session: pg_client::config::Session {
501            application_name: application_name.clone(),
502            database: database.clone(),
503            password: Some(password.clone()),
504            user: user.clone(),
505        },
506        ssl_mode,
507        ssl_root_cert,
508        sqlx: Default::default(),
509    };
510
511    Container {
512        host_port: port,
513        container,
514        backend,
515        client_config,
516        wait_available_timeout,
517    }
518}