athena_rs 3.3.0

Database gateway API
Documentation
#[cfg(feature = "deadpool_experimental")]
use deadpool_postgres::{ManagerConfig, Pool, RecyclingMethod, Runtime};

#[cfg(feature = "deadpool_experimental")]
use std::collections::HashMap;
#[cfg(feature = "deadpool_experimental")]
use std::sync::RwLock;
#[cfg(feature = "deadpool_experimental")]
use std::time::Duration;

#[cfg(feature = "deadpool_experimental")]
use tokio::time::timeout;
#[cfg(feature = "deadpool_experimental")]
use tokio_postgres::NoTls;

#[cfg(feature = "deadpool_experimental")]
#[derive(Debug, Clone)]
pub struct DeadpoolRegisteredClient {
    pub client_name: String,
    pub pool_connected: bool,
}

/// Deadpool-backed registry keyed by Athena client name.
#[cfg(feature = "deadpool_experimental")]
pub struct DeadpoolPostgresRegistry {
    pools: RwLock<HashMap<String, Pool>>,
    clients: RwLock<HashMap<String, DeadpoolRegisteredClient>>,
}

#[cfg(feature = "deadpool_experimental")]
impl DeadpoolPostgresRegistry {
    pub fn empty() -> Self {
        Self {
            pools: RwLock::new(HashMap::new()),
            clients: RwLock::new(HashMap::new()),
        }
    }

    pub fn get_pool(&self, key: &str) -> Option<Pool> {
        self.pools
            .read()
            .ok()
            .and_then(|pools| pools.get(&normalize_client_key(key)).cloned())
    }

    pub fn list_clients(&self) -> Vec<String> {
        let Ok(clients) = self.clients.read() else {
            return Vec::new();
        };
        let mut names = clients
            .values()
            .filter(|client| client.pool_connected)
            .map(|client| client.client_name.clone())
            .collect::<Vec<_>>();
        names.sort_by_cached_key(|value| value.to_lowercase());
        names
    }

    pub async fn from_entries(
        entries: Vec<(String, String)>,
        max_size: usize,
        warmup_timeout: Duration,
    ) -> Self {
        let mut pools: HashMap<String, Pool> = HashMap::new();
        let mut clients: HashMap<String, DeadpoolRegisteredClient> = HashMap::new();

        for (client_name, uri) in entries {
            let normalized = normalize_client_key(&client_name);

            let mut cfg = deadpool_postgres::Config::new();
            cfg.url = Some(uri);
            cfg.manager = Some(ManagerConfig {
                recycling_method: RecyclingMethod::Fast,
            });
            cfg.pool = Some(deadpool_postgres::PoolConfig {
                max_size,
                ..Default::default()
            });

            let pool = match cfg.create_pool(Some(Runtime::Tokio1), NoTls) {
                Ok(pool) => pool,
                Err(err) => {
                    tracing::warn!(
                        client = %client_name,
                        error = %err,
                        "Failed to create deadpool Postgres pool config"
                    );
                    clients.insert(
                        normalized,
                        DeadpoolRegisteredClient {
                            client_name,
                            pool_connected: false,
                        },
                    );
                    continue;
                }
            };

            let warmup_ok = timeout(warmup_timeout, pool.get())
                .await
                .ok()
                .and_then(Result::ok)
                .is_some();
            if !warmup_ok {
                tracing::warn!(
                    client = %client_name,
                    "Deadpool warmup failed; will keep pool but mark as offline initially"
                );
            }

            clients.insert(
                normalized.clone(),
                DeadpoolRegisteredClient {
                    client_name: client_name.clone(),
                    pool_connected: warmup_ok,
                },
            );
            pools.insert(normalized, pool);
        }

        Self {
            pools: RwLock::new(pools),
            clients: RwLock::new(clients),
        }
    }
}

#[cfg(feature = "deadpool_experimental")]
fn normalize_client_key(value: &str) -> String {
    value.trim().to_lowercase()
}

#[cfg(all(test, feature = "deadpool_experimental"))]
mod tests {
    use super::*;

    #[tokio::test]
    async fn from_entries_handles_empty_list() {
        let reg =
            DeadpoolPostgresRegistry::from_entries(vec![], 4, Duration::from_millis(50)).await;
        assert!(reg.list_clients().is_empty());
    }
}