Skip to main content

pg_ephemeral/
container.rs

1use rand::RngExt;
2
3use crate::LOCALHOST_HOST_ADDR;
4use crate::LOCALHOST_IP;
5use crate::UNSPECIFIED_IP;
6use crate::certificate;
7use crate::definition;
8
9pub const PGDATA: &str = "/var/lib/pg-ephemeral";
10
11#[derive(Debug, thiserror::Error)]
12pub enum Error {
13    #[error("PostgreSQL did not become available within {timeout:?}")]
14    ConnectionTimeout {
15        timeout: std::time::Duration,
16        #[source]
17        source: Option<sqlx::Error>,
18    },
19    #[error("Failed to execute command in container")]
20    ContainerExec(#[from] cmd_proc::CommandError),
21    #[error(transparent)]
22    SeedApply(#[from] crate::definition::SeedApplyError),
23    #[error(transparent)]
24    SeedLoad(#[from] crate::seed::LoadError),
25}
26const ENV_POSTGRES_PASSWORD: cmd_proc::EnvVariableName<'static> =
27    cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_PASSWORD");
28const ENV_POSTGRES_USER: cmd_proc::EnvVariableName<'static> =
29    cmd_proc::EnvVariableName::from_static_or_panic("POSTGRES_USER");
30const ENV_PGDATA: cmd_proc::EnvVariableName<'static> =
31    cmd_proc::EnvVariableName::from_static_or_panic("PGDATA");
32const ENV_PG_EPHEMERAL_SSL_DIR: cmd_proc::EnvVariableName<'static> =
33    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SSL_DIR");
34const ENV_PG_EPHEMERAL_CA_CERT_PEM: cmd_proc::EnvVariableName<'static> =
35    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_CA_CERT_PEM");
36const ENV_PG_EPHEMERAL_SERVER_CERT_PEM: cmd_proc::EnvVariableName<'static> =
37    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_CERT_PEM");
38const ENV_PG_EPHEMERAL_SERVER_KEY_PEM: cmd_proc::EnvVariableName<'static> =
39    cmd_proc::EnvVariableName::from_static_or_panic("PG_EPHEMERAL_SERVER_KEY_PEM");
40
41const SSL_SETUP_SCRIPT: &str = r#"
42printf '%s' "$PG_EPHEMERAL_CA_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/root.crt
43printf '%s' "$PG_EPHEMERAL_SERVER_CERT_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.crt
44printf '%s' "$PG_EPHEMERAL_SERVER_KEY_PEM" > ${PG_EPHEMERAL_SSL_DIR}/server.key
45chown postgres ${PG_EPHEMERAL_SSL_DIR}/root.crt
46chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.crt
47chown postgres ${PG_EPHEMERAL_SSL_DIR}/server.key
48chmod 600 ${PG_EPHEMERAL_SSL_DIR}/root.crt
49chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.crt
50chmod 600 ${PG_EPHEMERAL_SSL_DIR}/server.key
51exec docker-entrypoint.sh "$@"
52"#;
53
54/// Low-level container definition for running a pre-initialized PostgreSQL image.
55///
56/// All fields are assumed to represent values already stored in the referenced image.
57/// No validation is performed - the caller is responsible for ensuring the credentials
58/// and configuration match what exists in the image.
59#[derive(Debug)]
60pub struct Definition {
61    pub image: ociman::image::Reference,
62    pub password: pg_client::config::Password,
63    pub user: pg_client::User,
64    pub database: pg_client::Database,
65    pub backend: ociman::Backend,
66    pub cross_container_access: bool,
67    pub application_name: Option<pg_client::config::ApplicationName>,
68    pub ssl_config: Option<definition::SslConfig>,
69    pub wait_available_timeout: std::time::Duration,
70}
71
72#[derive(Debug)]
73pub struct Container {
74    host_port: pg_client::config::Port,
75    pub(crate) client_config: pg_client::Config,
76    container: ociman::Container,
77    backend: ociman::Backend,
78    wait_available_timeout: std::time::Duration,
79}
80
81impl Container {
82    pub(crate) async fn run_definition(definition: &crate::definition::Definition) -> Self {
83        let password = generate_password();
84
85        let ociman_definition = definition
86            .to_ociman_definition()
87            .environment_variable(ENV_POSTGRES_PASSWORD, password.as_ref())
88            .environment_variable(ENV_POSTGRES_USER, definition.superuser.as_ref());
89
90        run_container(
91            ociman_definition,
92            definition.cross_container_access,
93            &definition.ssl_config,
94            &definition.backend,
95            &definition.application_name,
96            &definition.database,
97            &password,
98            &definition.superuser,
99            definition.wait_available_timeout,
100            definition.remove,
101        )
102        .await
103    }
104
105    pub async fn run_container_definition(definition: &Definition) -> Self {
106        let ociman_definition =
107            ociman::Definition::new(definition.backend.clone(), definition.image.clone());
108
109        run_container(
110            ociman_definition,
111            definition.cross_container_access,
112            &definition.ssl_config,
113            &definition.backend,
114            &definition.application_name,
115            &definition.database,
116            &definition.password,
117            &definition.user,
118            definition.wait_available_timeout,
119            true, // Always remove containers when using low-level API
120        )
121        .await
122    }
123
124    pub async fn wait_available(&self) -> Result<(), Error> {
125        let config = self.client_config.to_sqlx_connect_options().unwrap();
126
127        let start = std::time::Instant::now();
128        let max_duration = self.wait_available_timeout;
129        let sleep_duration = std::time::Duration::from_millis(100);
130
131        let mut last_error: Option<sqlx::Error> = None;
132
133        while start.elapsed() <= max_duration {
134            log::trace!("connection attempt");
135            match sqlx::ConnectOptions::connect(&config).await {
136                Ok(connection) => {
137                    sqlx::Connection::close(connection)
138                        .await
139                        .expect("connection close failed");
140
141                    log::debug!(
142                        "pg is available on endpoint: {:#?}",
143                        self.client_config.endpoint
144                    );
145
146                    return Ok(());
147                }
148                Err(error) => {
149                    log::trace!("{error:#?}, retry in 100ms");
150                    last_error = Some(error);
151                }
152            }
153            tokio::time::sleep(sleep_duration).await;
154        }
155
156        Err(Error::ConnectionTimeout {
157            timeout: max_duration,
158            source: last_error,
159        })
160    }
161
162    pub async fn exec_schema_dump(&self, pg_schema_dump: &pg_client::PgSchemaDump) -> String {
163        let output = self
164            .container
165            .exec("pg_dump")
166            .arguments(pg_schema_dump.arguments())
167            .environment_variables(self.container_client_config().to_pg_env())
168            .build()
169            .stdout_capture()
170            .bytes()
171            .await
172            .unwrap();
173        crate::convert_schema(&output)
174    }
175
176    #[must_use]
177    pub fn client_config(&self) -> &pg_client::Config {
178        &self.client_config
179    }
180
181    pub async fn with_connection<T, F: AsyncFnMut(&mut sqlx::postgres::PgConnection) -> T>(
182        &self,
183        mut action: F,
184    ) -> T {
185        self.client_config
186            .with_sqlx_connection(async |connection| action(connection).await)
187            .await
188            .unwrap()
189    }
190
191    pub async fn apply_sql(&self, sql: &str) -> Result<(), sqlx::Error> {
192        self.with_connection(async |connection| {
193            log::debug!("Executing: {sql}");
194            sqlx::raw_sql(sqlx::AssertSqlSafe(sql))
195                .execute(connection)
196                .await
197                .map(|_| ())
198        })
199        .await
200    }
201
202    pub async fn apply_csv(
203        &self,
204        table: &pg_client::QualifiedTable,
205        content: &str,
206    ) -> Result<(), sqlx::Error> {
207        self.with_connection(async |connection| {
208            let row = sqlx::query(r#"SELECT format('%I.%I', $1, $2) AS table_identifier"#)
209                .bind(table.schema.as_ref())
210                .bind(table.table.as_ref())
211                .fetch_one(&mut *connection)
212                .await?;
213            let table_identifier: String = sqlx::Row::get(&row, "table_identifier");
214
215            let statement = format!("COPY {table_identifier} FROM STDIN WITH (FORMAT csv, HEADER)");
216            log::debug!("Executing: {statement}");
217            let mut copy = connection.copy_in_raw(&statement).await?;
218            copy.send(content.as_bytes()).await?;
219            copy.finish().await?;
220            Ok(())
221        })
222        .await
223    }
224
225    pub(crate) async fn exec_container_script(
226        &self,
227        script: &str,
228    ) -> Result<(), cmd_proc::CommandError> {
229        self.container
230            .exec("sh")
231            .arguments(["-e", "-c", script])
232            .build()
233            .status()
234            .await
235    }
236
237    pub(crate) async fn exec_container_shell(&self) {
238        self.container
239            .exec("sh")
240            .environment_variables(self.container_client_config().to_pg_env())
241            .interactive()
242            .status()
243            .await
244            .unwrap();
245    }
246
247    pub(crate) async fn exec_psql(&self) {
248        self.container
249            .exec("psql")
250            .environment_variables(self.container_client_config().to_pg_env())
251            .interactive()
252            .status()
253            .await
254            .unwrap();
255    }
256
257    fn container_client_config(&self) -> pg_client::Config {
258        let mut config = self.client_config.clone();
259        if let pg_client::config::Endpoint::Network {
260            ref host,
261            ref channel_binding,
262            ref host_addr,
263            ..
264        } = config.endpoint
265        {
266            config.endpoint = pg_client::config::Endpoint::Network {
267                host: host.clone(),
268                channel_binding: *channel_binding,
269                host_addr: host_addr.clone(),
270                port: Some(pg_client::config::Port::new(5432)),
271            };
272        }
273        config
274    }
275
276    pub async fn cross_container_client_config(&self) -> pg_client::Config {
277        // Resolve the container host from inside a container
278        // This DNS name only works from inside containers, not from the host
279        let ip_address = self
280            .backend
281            .resolve_container_host()
282            .await
283            .expect("Failed to resolve container host from container");
284
285        let channel_binding = match &self.client_config.endpoint {
286            pg_client::config::Endpoint::Network {
287                channel_binding, ..
288            } => *channel_binding,
289            pg_client::config::Endpoint::SocketPath(_) => None,
290        };
291
292        let endpoint = pg_client::config::Endpoint::Network {
293            host: pg_client::config::Host::IpAddr(ip_address),
294            channel_binding,
295            host_addr: None,
296            port: Some(self.host_port),
297        };
298
299        self.client_config.clone().endpoint(endpoint)
300    }
301
302    #[must_use]
303    pub fn pg_env(&self) -> std::collections::BTreeMap<cmd_proc::EnvVariableName<'static>, String> {
304        self.client_config.to_pg_env()
305    }
306
307    #[must_use]
308    pub fn database_url(&self) -> String {
309        self.client_config.to_url_string()
310    }
311
312    pub async fn stop(&mut self) {
313        self.container.stop().await
314    }
315
316    /// Stop the container (clean PostgreSQL shutdown), commit it to an image,
317    /// and remove the stopped container.
318    pub(crate) async fn stop_commit_remove(&mut self, reference: &ociman::Reference) {
319        self.container.stop().await;
320        self.container.commit(reference, false).await.unwrap();
321        self.container.remove().await;
322    }
323
324    async fn wait_for_container_socket(&self) -> Result<(), Error> {
325        let start = std::time::Instant::now();
326        let max_duration = self.wait_available_timeout;
327        let sleep_duration = std::time::Duration::from_millis(100);
328
329        while start.elapsed() <= max_duration {
330            if self
331                .container
332                .exec("pg_isready")
333                .argument("--host")
334                .argument("localhost")
335                .build()
336                .stdout_capture()
337                .bytes()
338                .await
339                .is_ok()
340            {
341                return Ok(());
342            }
343            tokio::time::sleep(sleep_duration).await;
344        }
345
346        Err(Error::ConnectionTimeout {
347            timeout: max_duration,
348            source: None,
349        })
350    }
351
352    /// Set the superuser password using peer authentication via Unix domain socket.
353    ///
354    /// This is useful when resuming from a cached image where the password
355    /// doesn't match the newly generated one.
356    pub async fn set_superuser_password(
357        &self,
358        password: &pg_client::config::Password,
359    ) -> Result<(), Error> {
360        self.wait_for_container_socket().await?;
361
362        self.container
363            .exec("psql")
364            .argument("--host")
365            .argument("/var/run/postgresql")
366            .argument("--username")
367            .argument(self.client_config.session.user.as_ref())
368            .argument("--dbname")
369            .argument("postgres")
370            .argument("--variable")
371            .argument(format!(
372                "target_user={}",
373                self.client_config.session.user.as_ref()
374            ))
375            .argument("--variable")
376            .argument(format!("new_password={}", password.as_ref()))
377            .stdin("ALTER USER :target_user WITH PASSWORD :'new_password'")
378            .build()
379            .stdout_capture()
380            .bytes()
381            .await?;
382
383        Ok(())
384    }
385}
386
387fn generate_password() -> pg_client::config::Password {
388    let value: String = rand::rng()
389        .sample_iter(rand::distr::Alphanumeric)
390        .take(32)
391        .map(char::from)
392        .collect();
393
394    <pg_client::config::Password as std::str::FromStr>::from_str(&value).unwrap()
395}
396
397#[allow(clippy::too_many_arguments)]
398async fn run_container(
399    ociman_definition: ociman::Definition,
400    cross_container_access: bool,
401    ssl_config: &Option<definition::SslConfig>,
402    backend: &ociman::Backend,
403    application_name: &Option<pg_client::config::ApplicationName>,
404    database: &pg_client::Database,
405    password: &pg_client::config::Password,
406    user: &pg_client::User,
407    wait_available_timeout: std::time::Duration,
408    remove: bool,
409) -> Container {
410    let backend = backend.clone();
411    let host_ip = if cross_container_access {
412        UNSPECIFIED_IP
413    } else {
414        LOCALHOST_IP
415    };
416
417    let mut ociman_definition = ociman_definition
418        .environment_variable(ENV_PGDATA, "/var/lib/pg-ephemeral")
419        .publish(ociman::Publish::tcp(5432).host_ip(host_ip));
420
421    if remove {
422        ociman_definition = ociman_definition.remove();
423    }
424
425    let ssl_bundle = if let Some(ssl_config) = ssl_config {
426        let hostname = match ssl_config {
427            definition::SslConfig::Generated { hostname } => hostname.as_str(),
428        };
429
430        let bundle = certificate::Bundle::generate(hostname)
431            .expect("Failed to generate SSL certificate bundle");
432
433        let ssl_dir = "/var/lib/postgresql";
434
435        ociman_definition = ociman_definition
436            .entrypoint("sh")
437            .argument("-e")
438            .argument("-c")
439            .argument(SSL_SETUP_SCRIPT)
440            .argument("--")
441            .argument("postgres")
442            .argument("--ssl=on")
443            .argument(format!("--ssl_cert_file={ssl_dir}/server.crt"))
444            .argument(format!("--ssl_key_file={ssl_dir}/server.key"))
445            .argument(format!("--ssl_ca_file={ssl_dir}/root.crt"))
446            .environment_variable(ENV_PG_EPHEMERAL_SSL_DIR, ssl_dir)
447            .environment_variable(ENV_PG_EPHEMERAL_CA_CERT_PEM, &bundle.ca_cert_pem)
448            .environment_variable(ENV_PG_EPHEMERAL_SERVER_CERT_PEM, &bundle.server_cert_pem)
449            .environment_variable(ENV_PG_EPHEMERAL_SERVER_KEY_PEM, &bundle.server_key_pem);
450
451        Some(bundle)
452    } else {
453        None
454    };
455
456    let container = ociman_definition.run_detached().await;
457
458    let port: pg_client::config::Port = container
459        .read_host_tcp_port(5432)
460        .await
461        .expect("port 5432 not published")
462        .into();
463
464    let (host, host_addr, ssl_mode, ssl_root_cert) = if let Some(ssl_config) = ssl_config {
465        let hostname = match ssl_config {
466            definition::SslConfig::Generated { hostname } => hostname.clone(),
467        };
468
469        let timestamp = std::time::SystemTime::now()
470            .duration_since(std::time::UNIX_EPOCH)
471            .unwrap()
472            .as_nanos();
473        let ca_cert_path = std::env::temp_dir().join(format!("pg_ephemeral_ca_{timestamp}.crt"));
474        std::fs::write(&ca_cert_path, &ssl_bundle.as_ref().unwrap().ca_cert_pem)
475            .expect("Failed to write CA certificate to temp file");
476
477        (
478            pg_client::config::Host::HostName(hostname),
479            Some(LOCALHOST_HOST_ADDR),
480            pg_client::config::SslMode::VerifyFull,
481            Some(pg_client::config::SslRootCert::File(ca_cert_path)),
482        )
483    } else {
484        (
485            pg_client::config::Host::IpAddr(LOCALHOST_IP),
486            None,
487            pg_client::config::SslMode::Disable,
488            None,
489        )
490    };
491
492    let client_config = pg_client::Config {
493        endpoint: pg_client::config::Endpoint::Network {
494            host,
495            channel_binding: None,
496            host_addr,
497            port: Some(port),
498        },
499        session: pg_client::config::Session {
500            application_name: application_name.clone(),
501            database: database.clone(),
502            password: Some(password.clone()),
503            user: user.clone(),
504        },
505        ssl_mode,
506        ssl_root_cert,
507        sqlx: Default::default(),
508    };
509
510    Container {
511        host_port: port,
512        container,
513        backend,
514        client_config,
515        wait_available_timeout,
516    }
517}