datafusion-dist-cluster-postgres 0.1.0

A PostgreSQL cluster implementation for datafusion-dist
Documentation
use std::collections::HashMap;

use bb8::Pool;
use bb8_postgres::PostgresConnectionManager;
use bb8_postgres::tokio_postgres::NoTls;
use datafusion_dist::{
    DistResult,
    cluster::{DistCluster, NodeId, NodeState},
    util::timestamp_ms,
};
use log::{debug, trace};

use crate::PostgresClusterError;

#[derive(Debug, Clone)]
pub struct PostgresCluster {
    pool: Pool<PostgresConnectionManager<NoTls>>,
    heartbeat_timeout_seconds: i32,
}

impl PostgresCluster {
    pub fn new(pool: Pool<PostgresConnectionManager<NoTls>>) -> Self {
        Self {
            pool,
            heartbeat_timeout_seconds: 60,
        }
    }

    pub fn with_heartbeat_timeout(mut self, timeout_seconds: i32) -> Self {
        self.heartbeat_timeout_seconds = timeout_seconds;
        self
    }

    pub async fn ensure_schema(&self) -> DistResult<()> {
        let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;

        client
            .execute(
                r#"
                     CREATE TABLE IF NOT EXISTS cluster_nodes (
                         host TEXT NOT NULL,
                         port INTEGER NOT NULL,
                         total_memory BIGINT NOT NULL,
                         used_memory BIGINT NOT NULL,
                         free_memory BIGINT NOT NULL,
                         available_memory BIGINT NOT NULL,
                         global_cpu_usage FLOAT4 NOT NULL,
                         num_running_tasks INTEGER NOT NULL,
                         last_heartbeat BIGINT NOT NULL,
                         UNIQUE(host, port)
                     )
                     "#,
                &[],
            )
            .await
            .map_err(|e| PostgresClusterError::Query(format!("Failed to create table: {e:?}")))?;

        Ok(())
    }

    fn calculate_cutoff_time(&self) -> i64 {
        let timeout_ms = i64::from(self.heartbeat_timeout_seconds)
            .checked_mul(1000)
            .unwrap_or(60_000);
        timestamp_ms() - timeout_ms
    }
}

#[async_trait::async_trait]
impl DistCluster for PostgresCluster {
    // Send heartbeat
    async fn heartbeat(&self, node_id: NodeId, state: NodeState) -> DistResult<()> {
        trace!("Sending heartbeat for node");

        // Get current timestamp in milliseconds as i64
        let timestamp = timestamp_ms();

        let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;

        let query = r#"
                   INSERT INTO cluster_nodes (
                       host, port, total_memory, used_memory, free_memory,
                       available_memory, global_cpu_usage, num_running_tasks, last_heartbeat
                   ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
                   ON CONFLICT (host, port)
                   DO UPDATE SET
                       total_memory = EXCLUDED.total_memory,
                       used_memory = EXCLUDED.used_memory,
                       free_memory = EXCLUDED.free_memory,
                       available_memory = EXCLUDED.available_memory,
                       global_cpu_usage = EXCLUDED.global_cpu_usage,
                       num_running_tasks = EXCLUDED.num_running_tasks,
                       last_heartbeat = $9
                   "#;

        client
            .execute(
                query,
                &[
                    &node_id.host,
                    &(node_id.port as i32),
                    &(state.total_memory as i64),
                    &(state.used_memory as i64),
                    &(state.free_memory as i64),
                    &(state.available_memory as i64),
                    &state.global_cpu_usage,
                    &(state.num_running_tasks as i32),
                    &timestamp,
                ],
            )
            .await
            .map_err(|e| {
                PostgresClusterError::Query(format!("Failed to insert heartbeat: {e:?}"))
            })?;

        debug!("Heartbeat sent successfully");
        Ok(())
    }

    // Get alive nodes
    async fn alive_nodes(&self) -> DistResult<HashMap<NodeId, NodeState>> {
        trace!("Fetching alive nodes");

        let cutoff_time = self.calculate_cutoff_time();

        let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;

        let rows = client
                        .query(
                            r#"
                            SELECT host, port, total_memory, used_memory,
                                   free_memory, available_memory, global_cpu_usage, num_running_tasks
                            FROM cluster_nodes
                            WHERE last_heartbeat >= $1
                            "#,
                            &[&cutoff_time],
                        )
                        .await
                        .map_err(|e| PostgresClusterError::Query(format!("Failed to query alive nodes: {}", e)))?;

        let mut result = HashMap::new();
        for row in rows {
            let host: String = row
                .try_get(0)
                .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
            let port: i32 = row
                .try_get(1)
                .map_err(|e| PostgresClusterError::Query(e.to_string()))?;

            let node_id = NodeId {
                host,
                port: port as u16,
            };

            let node_state = NodeState {
                total_memory: row
                    .try_get::<_, i64>(2)
                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
                    as u64,
                used_memory: row
                    .try_get::<_, i64>(3)
                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
                    as u64,
                free_memory: row
                    .try_get::<_, i64>(4)
                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
                    as u64,
                available_memory: row
                    .try_get::<_, i64>(5)
                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
                    as u64,
                global_cpu_usage: row
                    .try_get(6)
                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?,
                num_running_tasks: row
                    .try_get::<_, i32>(7)
                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
                    as u32,
            };

            result.insert(node_id, node_state);
        }

        debug!("Found {} alive nodes", result.len());
        Ok(result)
    }
}