tokio-postgres-rustls 0.14.0

Rustls integration for tokio-postgres
Documentation
use bollard::models::{ContainerCreateBody, HostConfig};
use bollard::query_parameters::{
    CreateContainerOptionsBuilder, CreateImageOptionsBuilder, InspectContainerOptionsBuilder,
    RemoveContainerOptionsBuilder, StartContainerOptionsBuilder, StopContainerOptionsBuilder,
};
use bollard::{API_DEFAULT_VERSION, Docker};
use futures_util::TryStreamExt;
use std::collections::HashMap;
use std::env::{self, current_dir};
use std::error::Error;
use std::path::{Path, PathBuf};
use tokio::time::{Duration, Instant, sleep, timeout};
use tokio_postgres::{Config, NoTls};

const POSTGRES_IMAGE: &str = "postgres:17";

pub(crate) struct PostgresContainer {
    docker: Docker,
    id: String,
    pub port: u16,
    closed: bool,
}

impl PostgresContainer {
    pub(crate) async fn new(
        test_name: &str,
        setup_script: impl AsRef<Path>,
        ca_cert: impl AsRef<Path>,
        server_cert: impl AsRef<Path>,
        server_key: impl AsRef<Path>,
    ) -> Result<Self, Box<dyn Error + Send + Sync>> {
        let docker = connect_docker()?;
        ensure_postgres_image(&docker).await?;
        let container_name = format!("test-pg-{}", test_name);
        let pwd = current_dir()?;

        let binds = vec![
            bind(
                &pwd,
                setup_script,
                "/docker-entrypoint-initdb.d/sql_setup.sh",
            ),
            bind(&pwd, ca_cert, "/etc/postgresql/certs/ca.crt"),
            bind(&pwd, server_cert, "/etc/postgresql/certs/server.crt"),
            bind(&pwd, server_key, "/etc/postgresql/certs/server.key"),
        ];

        let mut exposed: HashMap<String, HashMap<(), ()>> = HashMap::new();
        exposed.insert("5433/tcp".into(), HashMap::new());

        let host_config = HostConfig {
            binds: Some(binds),
            publish_all_ports: Some(true),
            ..Default::default()
        };

        let env = vec![
            "POSTGRES_PASSWORD=postgres".to_string(),
            "POSTGRES_USER=postgres".to_string(),
            "POSTGRES_DB=postgres".to_string(),
        ];

        let body = ContainerCreateBody {
            image: Some(POSTGRES_IMAGE.to_string()),
            env: Some(env),
            exposed_ports: Some(exposed),
            host_config: Some(host_config),
            ..Default::default()
        };

        let created = docker
            .create_container(
                Some(
                    CreateContainerOptionsBuilder::default()
                        .name(&container_name)
                        .build(),
                ),
                body,
            )
            .await?;
        let id = created.id;

        docker
            .start_container(&id, Some(StartContainerOptionsBuilder::default().build()))
            .await?;

        let inspect = docker
            .inspect_container(&id, Some(InspectContainerOptionsBuilder::default().build()))
            .await?;

        let host_port = inspect
            .network_settings
            .as_ref()
            .and_then(|ns| ns.ports.as_ref())
            .and_then(|ports| ports.get("5433/tcp"))
            .and_then(|opt| opt.as_ref())
            .and_then(|vec| vec.first())
            .and_then(|pb| pb.host_port.as_ref())
            .and_then(|hp| hp.parse::<u16>().ok())
            .ok_or("failed to resolve host port for 5433/tcp")?;

        if wait_for_pg(host_port, Duration::from_secs(30))
            .await
            .is_err()
        {
            cleanup(docker, id).await;
            panic!("postgres container startup probe failed");
        }

        Ok(Self {
            docker,
            id,
            port: host_port,
            closed: false,
        })
    }

    pub(crate) async fn cleanup(&mut self) {
        cleanup(self.docker.clone(), self.id.clone()).await;
        self.closed = true;
    }
}

impl Drop for PostgresContainer {
    fn drop(&mut self) {
        if !self.closed {
            eprintln!("postgres test container {} was not cleaned up", self.id);
        }
    }
}

fn connect_docker() -> Result<Docker, bollard::errors::Error> {
    match Docker::connect_with_defaults() {
        Ok(docker) => Ok(docker),
        Err(error) if env::var_os("DOCKER_HOST").is_none() => {
            let Some(socket) = colima_socket() else {
                return Err(error);
            };
            Docker::connect_with_unix(
                socket
                    .to_str()
                    .expect("Docker socket path should be valid UTF-8"),
                120,
                API_DEFAULT_VERSION,
            )
            .or(Err(error))
        }
        Err(error) => Err(error),
    }
}

fn colima_socket() -> Option<PathBuf> {
    let socket = PathBuf::from(env::var_os("HOME")?).join(".colima/default/docker.sock");
    socket.exists().then_some(socket)
}

async fn ensure_postgres_image(docker: &Docker) -> Result<(), Box<dyn Error + Send + Sync>> {
    if docker.inspect_image(POSTGRES_IMAGE).await.is_ok() {
        return Ok(());
    }

    docker
        .create_image(
            Some(
                CreateImageOptionsBuilder::default()
                    .from_image(POSTGRES_IMAGE)
                    .build(),
            ),
            None,
            None,
        )
        .try_collect::<Vec<_>>()
        .await?;

    Ok(())
}

fn bind(root: &Path, source: impl AsRef<Path>, destination: &str) -> String {
    format!(
        "{}:{destination}:ro",
        root.join(source)
            .to_str()
            .expect("test fixture path should be valid UTF-8")
    )
}

async fn cleanup(docker: Docker, id: String) {
    let _ = docker
        .stop_container(
            &id,
            Some(StopContainerOptionsBuilder::default().t(5).build()),
        )
        .await;
    let _ = docker
        .remove_container(
            &id,
            Some(
                RemoveContainerOptionsBuilder::default()
                    .v(true)
                    .force(true)
                    .build(),
            ),
        )
        .await;
}

async fn wait_for_pg(host_port: u16, max_wait: Duration) -> Result<(), &'static str> {
    let mut cfg = Config::new();
    cfg.host("localhost")
        .port(host_port)
        .user("startup_probe")
        .dbname("postgres")
        .ssl_mode(tokio_postgres::config::SslMode::Disable)
        .connect_timeout(Duration::from_secs(2));

    let deadline = Instant::now() + max_wait;

    loop {
        let Ok((client, conn)) = cfg.connect(NoTls).await else {
            if Instant::now() >= deadline {
                return Err("postgres not ready in time");
            }
            sleep(Duration::from_millis(500)).await;
            continue;
        };
        let conn_task = tokio::spawn(async move {
            let _ = conn.await;
        });

        let ok = timeout(Duration::from_secs(2), client.simple_query("SELECT 1"))
            .await
            .ok()
            .and_then(|r| r.ok())
            .is_some();

        conn_task.abort();

        if ok {
            return Ok(());
        }
        if Instant::now() >= deadline {
            return Err("postgres not ready in time");
        }
        sleep(Duration::from_millis(500)).await;
    }
}