datafusion_dist_cluster_postgres/
cluster.rs

1use std::collections::HashMap;
2
3use bb8::Pool;
4use bb8_postgres::PostgresConnectionManager;
5use bb8_postgres::tokio_postgres::NoTls;
6use datafusion_dist::{
7    DistResult,
8    cluster::{DistCluster, NodeId, NodeState, NodeStatus},
9    util::timestamp_ms,
10};
11use log::{debug, trace};
12
13use crate::PostgresClusterError;
14
15#[derive(Debug, Clone)]
16pub struct PostgresCluster {
17    pub(crate) table: String,
18    pub(crate) pool: Pool<PostgresConnectionManager<NoTls>>,
19    pub(crate) heartbeat_timeout_seconds: i32,
20}
21
22impl PostgresCluster {
23    pub async fn create_table_if_not_exists(&self) -> DistResult<()> {
24        let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
25
26        let create_table_sql = format!(
27            r#"
28             CREATE TABLE IF NOT EXISTS {} (
29                 host TEXT NOT NULL,
30                 port INTEGER NOT NULL,
31                 status TEXT NOT NULL,
32                 total_memory BIGINT NOT NULL,
33                 used_memory BIGINT NOT NULL,
34                 free_memory BIGINT NOT NULL,
35                 available_memory BIGINT NOT NULL,
36                 global_cpu_usage FLOAT4 NOT NULL,
37                 num_running_tasks INTEGER NOT NULL,
38                 last_heartbeat BIGINT NOT NULL,
39                 UNIQUE(host, port)
40             )
41             "#,
42            self.table
43        );
44
45        client
46            .execute(&create_table_sql, &[])
47            .await
48            .map_err(|e| PostgresClusterError::Query(format!("Failed to create table: {e:?}")))?;
49
50        Ok(())
51    }
52
53    fn calculate_cutoff_time(&self) -> i64 {
54        let timeout_ms = i64::from(self.heartbeat_timeout_seconds)
55            .checked_mul(1000)
56            .unwrap_or(60_000);
57        timestamp_ms() - timeout_ms
58    }
59}
60
61#[async_trait::async_trait]
62impl DistCluster for PostgresCluster {
63    // Send heartbeat
64    async fn heartbeat(&self, node_id: NodeId, state: NodeState) -> DistResult<()> {
65        // Get current timestamp in milliseconds as i64
66        let timestamp = timestamp_ms();
67
68        let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
69
70        let query = format!(
71            r#"
72                   INSERT INTO {} (
73                       host, port, status, total_memory, used_memory, free_memory,
74                       available_memory, global_cpu_usage, num_running_tasks, last_heartbeat
75                   ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
76                   ON CONFLICT (host, port)
77                   DO UPDATE SET
78                       status = EXCLUDED.status,
79                       total_memory = EXCLUDED.total_memory,
80                       used_memory = EXCLUDED.used_memory,
81                       free_memory = EXCLUDED.free_memory,
82                       available_memory = EXCLUDED.available_memory,
83                       global_cpu_usage = EXCLUDED.global_cpu_usage,
84                       num_running_tasks = EXCLUDED.num_running_tasks,
85                       last_heartbeat = EXCLUDED.last_heartbeat
86                   "#,
87            self.table
88        );
89
90        client
91            .execute(
92                &query,
93                &[
94                    &node_id.host,
95                    &(node_id.port as i32),
96                    &state.status.to_string(),
97                    &(state.total_memory as i64),
98                    &(state.used_memory as i64),
99                    &(state.free_memory as i64),
100                    &(state.available_memory as i64),
101                    &state.global_cpu_usage,
102                    &(state.num_running_tasks as i32),
103                    &timestamp,
104                ],
105            )
106            .await
107            .map_err(|e| {
108                PostgresClusterError::Query(format!("Failed to insert heartbeat: {e:?}"))
109            })?;
110
111        debug!("Heartbeat sent successfully");
112        Ok(())
113    }
114
115    // Get alive nodes
116    async fn alive_nodes(&self) -> DistResult<HashMap<NodeId, NodeState>> {
117        trace!("Fetching alive nodes");
118
119        let cutoff_time = self.calculate_cutoff_time();
120
121        let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
122
123        let query = format!(
124            r#"SELECT host, port, status, total_memory, used_memory, free_memory, available_memory, global_cpu_usage, num_running_tasks
125                FROM {} WHERE last_heartbeat >= $1"#,
126            self.table
127        );
128
129        let rows = client.query(&query, &[&cutoff_time]).await.map_err(|e| {
130            PostgresClusterError::Query(format!("Failed to query alive nodes: {}", e))
131        })?;
132
133        let mut result = HashMap::new();
134        for row in rows {
135            let host: String = row
136                .try_get(0)
137                .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
138            let port: i32 = row
139                .try_get(1)
140                .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
141
142            let node_id = NodeId {
143                host,
144                port: port as u16,
145            };
146
147            let status_str: String = row
148                .try_get(2)
149                .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
150
151            let status = status_str.parse::<NodeStatus>()?;
152
153            let node_state = NodeState {
154                status,
155                total_memory: row
156                    .try_get::<_, i64>(3)
157                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
158                    as u64,
159                used_memory: row
160                    .try_get::<_, i64>(4)
161                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
162                    as u64,
163                free_memory: row
164                    .try_get::<_, i64>(5)
165                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
166                    as u64,
167                available_memory: row
168                    .try_get::<_, i64>(6)
169                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
170                    as u64,
171                global_cpu_usage: row
172                    .try_get(7)
173                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?,
174                num_running_tasks: row
175                    .try_get::<_, i32>(8)
176                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
177                    as u32,
178            };
179
180            result.insert(node_id, node_state);
181        }
182
183        debug!("Found {} alive nodes", result.len());
184        Ok(result)
185    }
186}