Skip to main content

pg_ephemeral/
container.rs

1use rand::RngExt;
2
3use crate::LOCALHOST_HOST_ADDR;
4use crate::LOCALHOST_IP;
5use crate::UNSPECIFIED_IP;
6use crate::certificate;
7use crate::definition;
8
9pub const PGDATA: &str = "/var/lib/pg-ephemeral";
10
11#[derive(Debug, thiserror::Error)]
12pub enum Error {
13    #[error("PostgreSQL did not become available within {timeout:?}")]
14    ConnectionTimeout {
15        timeout: std::time::Duration,
16        #[source]
17        source: Option<sqlx::Error>,
18    },
19    #[error("Failed to execute command in container")]
20    ContainerExec(#[from] cmd_proc::CommandError),
21    #[error(transparent)]
22    SeedApply(#[from] crate::definition::SeedApplyError),
23    #[error(transparent)]
24    SeedLoad(#[from] crate::seed::LoadError),
25    #[error("Failed to terminate backend connections")]
26    TerminateConnections(#[source] sqlx::Error),
27    #[error("Failed to checkpoint")]
28    Checkpoint(#[source] sqlx::Error),
29}
30const ENV_POSTGRES_PASSWORD: cmd_proc::EnvVariableName<'static> =
31    cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_PASSWORD");
32const ENV_POSTGRES_USER: cmd_proc::EnvVariableName<'static> =
33    cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_USER");
34const ENV_PGDATA: cmd_proc::EnvVariableName<'static> =
35    cmd_proc::EnvVariableName::from_static_or_panic("PGDATA");
36const ENV_PG_EPHEMERAL_SSL_DIR: cmd_proc::EnvVariableName<'static> =
37    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SSL_DIR");
38const ENV_PG_EPHEMERAL_CA_CERT_PEM: cmd_proc::EnvVariableName<'static> =
39    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_CA_CERT_PEM");
40const ENV_PG_EPHEMERAL_SERVER_CERT_PEM: cmd_proc::EnvVariableName<'static> =
41    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_CERT_PEM");
42const ENV_PG_EPHEMERAL_SERVER_KEY_PEM: cmd_proc::EnvVariableName<'static> =
43    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_KEY_PEM");
44
45const SSL_SETUP_SCRIPT: &str = r#"
46printf '%s' "$PG_EPHEMERAL_CA_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/root.crt
47printf '%s' "$PG_EPHEMERAL_SERVER_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.crt
48printf '%s' "$PG_EPHEMERAL_SERVER_KEY_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.key
49chown postgres ${PG_EPHEMERAL_SSL_DIR}/root.crt
50chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.crt
51chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.key
52chmod 600 ${PG_EPHEMERAL_SSL_DIR}/root.crt
53chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.crt
54chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.key
55exec docker-entrypoint.sh "$@"
56"#;
57
58/// Low-level container definition for running a pre-initialized PostgreSQL image.
59///
60/// All fields are assumed to represent values already stored in the referenced image.
61/// No validation is performed - the caller is responsible for ensuring the credentials
62/// and configuration match what exists in the image.
63#[derive(Debug)]
64pub struct Definition {
65    pub image: ociman::image::Reference,
66    pub password: pg_client::config::Password,
67    pub user: pg_client::User,
68    pub database: pg_client::Database,
69    pub backend: ociman::Backend,
70    pub cross_container_access: bool,
71    pub application_name: Option<pg_client::config::ApplicationName>,
72    pub ssl_config: Option<definition::SslConfig>,
73    pub wait_available_timeout: std::time::Duration,
74}
75
76#[derive(Debug)]
77pub struct Container {
78    host_port: pg_client::config::Port,
79    pub(crate) client_config: pg_client::Config,
80    container: ociman::Container,
81    backend: ociman::Backend,
82    wait_available_timeout: std::time::Duration,
83}
84
85impl Container {
86    pub(crate) async fn run_definition(definition: &crate::definition::Definition) -> Self {
87        let password = generate_password();
88
89        let ociman_definition = definition
90            .to_ociman_definition()
91            .environment_variable(ENV_POSTGRES_PASSWORD, password.as_ref())
92            .environment_variable(ENV_POSTGRES_USER, definition.superuser.as_ref());
93
94        run_container(
95            ociman_definition,
96            definition.cross_container_access,
97            &definition.ssl_config,
98            &definition.backend,
99            &definition.application_name,
100            &definition.database,
101            &password,
102            &definition.superuser,
103            definition.wait_available_timeout,
104            definition.remove,
105        )
106        .await
107    }
108
109    pub async fn run_container_definition(definition: &Definition) -> Self {
110        let ociman_definition =
111            ociman::Definition::new(definition.backend.clone(), definition.image.clone());
112
113        run_container(
114            ociman_definition,
115            definition.cross_container_access,
116            &definition.ssl_config,
117            &definition.backend,
118            &definition.application_name,
119            &definition.database,
120            &definition.password,
121            &definition.user,
122            definition.wait_available_timeout,
123            true, // Always remove containers when using low-level API
124        )
125        .await
126    }
127
128    pub async fn wait_available(&self) -> Result<(), Error> {
129        let config = self.client_config.to_sqlx_connect_options().unwrap();
130
131        let start = std::time::Instant::now();
132        let max_duration = self.wait_available_timeout;
133        let sleep_duration = std::time::Duration::from_millis(100);
134
135        let mut last_error: Option<sqlx::Error> = None;
136
137        while start.elapsed() <= max_duration {
138            log::trace!("connection attempt");
139            match sqlx::ConnectOptions::connect(&config).await {
140                Ok(connection) => {
141                    sqlx::Connection::close(connection)
142                        .await
143                        .expect("connection close failed");
144
145                    log::debug!(
146                        "pg is available on endpoint: {:#?}",
147                        self.client_config.endpoint
148                    );
149
150                    return Ok(());
151                }
152                Err(error) => {
153                    log::trace!("{error:#?}, retry in 100ms");
154                    last_error = Some(error);
155                }
156            }
157            tokio::time::sleep(sleep_duration).await;
158        }
159
160        Err(Error::ConnectionTimeout {
161            timeout: max_duration,
162            source: last_error,
163        })
164    }
165
166    pub async fn exec_schema_dump(&self, pg_schema_dump: &pg_client::PgSchemaDump) -> String {
167        let output = self
168            .container
169            .exec("pg_dump")
170            .arguments(pg_schema_dump.arguments())
171            .environment_variables(self.container_client_config().to_pg_env())
172            .build()
173            .stdout_capture()
174            .bytes()
175            .await
176            .unwrap();
177        crate::convert_schema(&output)
178    }
179
180    #[must_use]
181    pub fn client_config(&self) -> &pg_client::Config {
182        &self.client_config
183    }
184
185    pub async fn with_connection<T, F: AsyncFnMut(&mut sqlx::postgres::PgConnection) -> T>(
186        &self,
187        mut action: F,
188    ) -> T {
189        self.client_config
190            .with_sqlx_connection(async |connection| action(connection).await)
191            .await
192            .unwrap()
193    }
194
195    pub async fn apply_sql(&self, sql: &str) -> Result<(), sqlx::Error> {
196        self.with_connection(async |connection| {
197            log::debug!("Executing: {sql}");
198            sqlx::raw_sql(sqlx::AssertSqlSafe(sql))
199                .execute(connection)
200                .await
201                .map(|_| ())
202        })
203        .await
204    }
205
206    /// Apply CSV content to a table using PostgreSQL's COPY protocol.
207    ///
208    /// The line delimiter is hardcoded to `\n`.
209    pub async fn apply_csv(
210        &self,
211        table: &pg_client::QualifiedTable,
212        delimiter: char,
213        content: &str,
214    ) -> Result<(), sqlx::Error> {
215        self.with_connection(async |connection| {
216            let header_line = content.lines().next().unwrap_or_default();
217
218            let columns: Vec<&str> = header_line.split(delimiter).map(str::trim).collect();
219
220            let row = sqlx::query(
221                r#"SELECT 'COPY ' || format('%I.%I', $1, $2)
222                    || '(' || (SELECT string_agg(format('%I', "column"), ', ') FROM unnest($3::text[]) AS "column") || ')'
223                    || ' FROM STDIN WITH (FORMAT csv, HEADER MATCH, DELIMITER ' || quote_literal($4) || ')'
224                    AS statement"#,
225            )
226            .bind(table.schema.as_ref())
227            .bind(table.table.as_ref())
228            .bind(&columns)
229            .bind(delimiter.to_string())
230            .fetch_one(&mut *connection)
231            .await?;
232            let statement: String = sqlx::Row::get(&row, "statement");
233
234            log::debug!("Executing: {statement}");
235            let mut copy = connection.copy_in_raw(&statement).await?;
236            copy.send(content.as_bytes()).await?;
237            copy.finish().await?;
238            Ok(())
239        })
240        .await
241    }
242
243    pub(crate) async fn exec_container_script(
244        &self,
245        script: &str,
246    ) -> Result<(), cmd_proc::CommandError> {
247        self.container
248            .exec("sh")
249            .arguments(["-e", "-c", script])
250            .build()
251            .status()
252            .await
253    }
254
255    pub(crate) async fn exec_container_shell(&self) {
256        self.container
257            .exec("sh")
258            .environment_variables(self.container_client_config().to_pg_env())
259            .interactive()
260            .status()
261            .await
262            .unwrap();
263    }
264
265    pub(crate) async fn exec_psql(&self) {
266        self.container
267            .exec("psql")
268            .environment_variables(self.container_client_config().to_pg_env())
269            .interactive()
270            .status()
271            .await
272            .unwrap();
273    }
274
275    fn container_client_config(&self) -> pg_client::Config {
276        let mut config = self.client_config.clone();
277        if let pg_client::config::Endpoint::Network {
278            ref host,
279            ref channel_binding,
280            ref host_addr,
281            ..
282        } = config.endpoint
283        {
284            config.endpoint = pg_client::config::Endpoint::Network {
285                host: host.clone(),
286                channel_binding: *channel_binding,
287                host_addr: host_addr.clone(),
288                port: Some(pg_client::config::Port::new(5432)),
289            };
290        }
291        config
292    }
293
294    pub async fn cross_container_client_config(&self) -> pg_client::Config {
295        // Resolve the container host from inside a container
296        // This DNS name only works from inside containers, not from the host
297        let ip_address = self
298            .backend
299            .resolve_container_host()
300            .await
301            .expect("Failed to resolve container host from container");
302
303        let channel_binding = match &self.client_config.endpoint {
304            pg_client::config::Endpoint::Network {
305                channel_binding, ..
306            } => *channel_binding,
307            pg_client::config::Endpoint::SocketPath(_) => None,
308        };
309
310        let endpoint = pg_client::config::Endpoint::Network {
311            host: pg_client::config::Host::IpAddr(ip_address),
312            channel_binding,
313            host_addr: None,
314            port: Some(self.host_port),
315        };
316
317        self.client_config.clone().endpoint(endpoint)
318    }
319
320    #[must_use]
321    pub fn pg_env(&self) -> std::collections::BTreeMap<cmd_proc::EnvVariableName<'static>, String> {
322        self.client_config.to_pg_env()
323    }
324
325    #[must_use]
326    pub fn database_url(&self) -> String {
327        self.client_config.to_url_string()
328    }
329
330    pub async fn stop(&mut self) {
331        self.container.stop().await
332    }
333
334    async fn terminate_connections(&self) -> Result<(), Error> {
335        self.apply_sql(
336            "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE pid <> pg_backend_pid()",
337        )
338        .await
339        .map_err(Error::TerminateConnections)
340    }
341
342    async fn checkpoint(&self) -> Result<(), Error> {
343        self.apply_sql("CHECKPOINT")
344            .await
345            .map_err(Error::Checkpoint)
346    }
347
348    /// Stop the container (clean PostgreSQL shutdown), commit it to an image,
349    /// and remove the stopped container.
350    pub(crate) async fn stop_commit_remove(
351        &mut self,
352        reference: &ociman::Reference,
353    ) -> Result<(), Error> {
354        self.terminate_connections().await?;
355        self.checkpoint().await?;
356        self.container.stop().await;
357        self.container.commit(reference, false).await.unwrap();
358        self.container.remove().await;
359        Ok(())
360    }
361
362    async fn wait_for_container_socket(&self) -> Result<(), Error> {
363        let start = std::time::Instant::now();
364        let max_duration = self.wait_available_timeout;
365        let sleep_duration = std::time::Duration::from_millis(100);
366
367        while start.elapsed() <= max_duration {
368            if self
369                .container
370                .exec("pg_isready")
371                .argument("--host")
372                .argument("localhost")
373                .build()
374                .stdout_capture()
375                .bytes()
376                .await
377                .is_ok()
378            {
379                return Ok(());
380            }
381            tokio::time::sleep(sleep_duration).await;
382        }
383
384        Err(Error::ConnectionTimeout {
385            timeout: max_duration,
386            source: None,
387        })
388    }
389
390    /// Set the superuser password using peer authentication via Unix domain socket.
391    ///
392    /// This is useful when resuming from a cached image where the password
393    /// doesn't match the newly generated one.
394    pub async fn set_superuser_password(
395        &self,
396        password: &pg_client::config::Password,
397    ) -> Result<(), Error> {
398        self.wait_for_container_socket().await?;
399
400        self.container
401            .exec("psql")
402            .argument("--host")
403            .argument("/var/run/postgresql")
404            .argument("--username")
405            .argument(self.client_config.session.user.as_ref())
406            .argument("--dbname")
407            .argument("postgres")
408            .argument("--variable")
409            .argument(format!(
410                "target_user={}",
411                self.client_config.session.user.as_ref()
412            ))
413            .argument("--variable")
414            .argument(format!("new_password={}", password.as_ref()))
415            .stdin("ALTER USER :target_user WITH PASSWORD :'new_password'")
416            .build()
417            .stdout_capture()
418            .bytes()
419            .await?;
420
421        Ok(())
422    }
423}
424
425fn generate_password() -> pg_client::config::Password {
426    let value: String = rand::rng()
427        .sample_iter(rand::distr::Alphanumeric)
428        .take(32)
429        .map(char::from)
430        .collect();
431
432    <pg_client::config::Password as std::str::FromStr>::from_str(&value).unwrap()
433}
434
435#[allow(clippy::too_many_arguments)]
436async fn run_container(
437    ociman_definition: ociman::Definition,
438    cross_container_access: bool,
439    ssl_config: &Option<definition::SslConfig>,
440    backend: &ociman::Backend,
441    application_name: &Option<pg_client::config::ApplicationName>,
442    database: &pg_client::Database,
443    password: &pg_client::config::Password,
444    user: &pg_client::User,
445    wait_available_timeout: std::time::Duration,
446    remove: bool,
447) -> Container {
448    let backend = backend.clone();
449    let host_ip = if cross_container_access {
450        UNSPECIFIED_IP
451    } else {
452        LOCALHOST_IP
453    };
454
455    let mut ociman_definition = ociman_definition
456        .environment_variable(ENV_PGDATA, "/var/lib/pg-ephemeral")
457        .publish(ociman::Publish::tcp(5432).host_ip(host_ip));
458
459    if remove {
460        ociman_definition = ociman_definition.remove();
461    }
462
463    let ssl_bundle = if let Some(ssl_config) = ssl_config {
464        let hostname = match ssl_config {
465            definition::SslConfig::Generated { hostname } => hostname.as_str(),
466        };
467
468        let bundle = certificate::Bundle::generate(hostname)
469            .expect("Failed to generate SSL certificate bundle");
470
471        let ssl_dir = "/var/lib/postgresql";
472
473        ociman_definition = ociman_definition
474            .entrypoint("sh")
475            .argument("-e")
476            .argument("-c")
477            .argument(SSL_SETUP_SCRIPT)
478            .argument("--")
479            .argument("postgres")
480            .argument("--ssl=on")
481            .argument(format!("--ssl_cert_file={ssl_dir}/server.crt"))
482            .argument(format!("--ssl_key_file={ssl_dir}/server.key"))
483            .argument(format!("--ssl_ca_file={ssl_dir}/root.crt"))
484            .environment_variable(ENV_PG_EPHEMERAL_SSL_DIR, ssl_dir)
485            .environment_variable(ENV_PG_EPHEMERAL_CA_CERT_PEM, &bundle.ca_cert_pem)
486            .environment_variable(ENV_PG_EPHEMERAL_SERVER_CERT_PEM, &bundle.server_cert_pem)
487            .environment_variable(ENV_PG_EPHEMERAL_SERVER_KEY_PEM, &bundle.server_key_pem);
488
489        Some(bundle)
490    } else {
491        None
492    };
493
494    let container = ociman_definition.run_detached().await;
495
496    let port: pg_client::config::Port = container
497        .read_host_tcp_port(5432)
498        .await
499        .expect("port 5432 not published")
500        .into();
501
502    let (host, host_addr, ssl_mode, ssl_root_cert) = if let Some(ssl_config) = ssl_config {
503        let hostname = match ssl_config {
504            definition::SslConfig::Generated { hostname } => hostname.clone(),
505        };
506
507        let timestamp = std::time::SystemTime::now()
508            .duration_since(std::time::UNIX_EPOCH)
509            .unwrap()
510            .as_nanos();
511        let ca_cert_path = std::env::temp_dir().join(format!("pg_ephemeral_ca_{timestamp}.crt"));
512        std::fs::write(&ca_cert_path, &ssl_bundle.as_ref().unwrap().ca_cert_pem)
513            .expect("Failed to write CA certificate to temp file");
514
515        (
516            pg_client::config::Host::HostName(hostname),
517            Some(LOCALHOST_HOST_ADDR),
518            pg_client::config::SslMode::VerifyFull,
519            Some(pg_client::config::SslRootCert::File(ca_cert_path)),
520        )
521    } else {
522        (
523            pg_client::config::Host::IpAddr(LOCALHOST_IP),
524            None,
525            pg_client::config::SslMode::Disable,
526            None,
527        )
528    };
529
530    let client_config = pg_client::Config {
531        endpoint: pg_client::config::Endpoint::Network {
532            host,
533            channel_binding: None,
534            host_addr,
535            port: Some(port),
536        },
537        session: pg_client::config::Session {
538            application_name: application_name.clone(),
539            database: database.clone(),
540            password: Some(password.clone()),
541            user: user.clone(),
542        },
543        ssl_mode,
544        ssl_root_cert,
545        sqlx: Default::default(),
546    };
547
548    Container {
549        host_port: port,
550        container,
551        backend,
552        client_config,
553        wait_available_timeout,
554    }
555}