Skip to main content

datafusion_dist_cluster_postgres/
cluster.rs

1use std::collections::HashMap;
2
3use bb8::Pool;
4use bb8_postgres::PostgresConnectionManager;
5use bb8_postgres::tokio_postgres::NoTls;
6use datafusion_dist::{
7    DistResult,
8    cluster::{DistCluster, NodeId, NodeState, NodeStatus},
9    util::timestamp_ms,
10};
11use log::{debug, trace};
12
13use crate::PostgresClusterError;
14
15#[derive(Debug, Clone)]
16pub struct PostgresCluster {
17    pub(crate) table: String,
18    pub(crate) pool: Pool<PostgresConnectionManager<NoTls>>,
19    pub(crate) heartbeat_timeout_seconds: i32,
20}
21
22impl PostgresCluster {
23    pub async fn create_table_if_not_exists(&self) -> DistResult<()> {
24        let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
25
26        let create_table_sql = format!(
27            r#"
28             CREATE TABLE IF NOT EXISTS {} (
29                 host TEXT NOT NULL,
30                 port INTEGER NOT NULL,
31                 status TEXT NOT NULL,
32                 total_memory BIGINT NOT NULL,
33                 used_memory BIGINT NOT NULL,
34                 free_memory BIGINT NOT NULL,
35                 available_memory BIGINT NOT NULL,
36                 global_cpu_usage FLOAT4 NOT NULL,
37                 num_running_tasks INTEGER NOT NULL,
38                 num_pending_tasks INTEGER NOT NULL,
39                 last_heartbeat BIGINT NOT NULL,
40                 UNIQUE(host, port)
41             )
42             "#,
43            self.table
44        );
45
46        client
47            .execute(&create_table_sql, &[])
48            .await
49            .map_err(|e| PostgresClusterError::Query(format!("Failed to create table: {e:?}")))?;
50
51        Ok(())
52    }
53
54    fn calculate_cutoff_time(&self) -> i64 {
55        let timeout_ms = i64::from(self.heartbeat_timeout_seconds)
56            .checked_mul(1000)
57            .unwrap_or(60_000);
58        timestamp_ms() - timeout_ms
59    }
60}
61
62#[async_trait::async_trait]
63impl DistCluster for PostgresCluster {
64    // Send heartbeat
65    async fn heartbeat(&self, node_id: NodeId, state: NodeState) -> DistResult<()> {
66        // Get current timestamp in milliseconds as i64
67        let timestamp = timestamp_ms();
68
69        let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
70
71        let query = format!(
72            r#"
73                   INSERT INTO {} (
74                       host, port, status, total_memory, used_memory, free_memory,
75                       available_memory, global_cpu_usage, num_running_tasks, num_pending_tasks, last_heartbeat
76                   ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
77                   ON CONFLICT (host, port)
78                   DO UPDATE SET
79                       status = EXCLUDED.status,
80                       total_memory = EXCLUDED.total_memory,
81                       used_memory = EXCLUDED.used_memory,
82                       free_memory = EXCLUDED.free_memory,
83                       available_memory = EXCLUDED.available_memory,
84                       global_cpu_usage = EXCLUDED.global_cpu_usage,
85                       num_running_tasks = EXCLUDED.num_running_tasks,
86                       num_pending_tasks = EXCLUDED.num_pending_tasks,
87                       last_heartbeat = EXCLUDED.last_heartbeat
88                   "#,
89            self.table
90        );
91
92        client
93            .execute(
94                &query,
95                &[
96                    &node_id.host,
97                    &(node_id.port as i32),
98                    &state.status.to_string(),
99                    &(state.total_memory as i64),
100                    &(state.used_memory as i64),
101                    &(state.free_memory as i64),
102                    &(state.available_memory as i64),
103                    &state.global_cpu_usage,
104                    &(state.num_running_tasks as i32),
105                    &(state.num_pending_tasks as i32),
106                    &timestamp,
107                ],
108            )
109            .await
110            .map_err(|e| {
111                PostgresClusterError::Query(format!("Failed to insert heartbeat: {e:?}"))
112            })?;
113
114        debug!("Heartbeat sent successfully");
115        Ok(())
116    }
117
118    // Get alive nodes
119    async fn alive_nodes(&self) -> DistResult<HashMap<NodeId, NodeState>> {
120        trace!("Fetching alive nodes");
121
122        let cutoff_time = self.calculate_cutoff_time();
123
124        let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
125
126        let query = format!(
127            r#"SELECT host, port, status, total_memory, used_memory, free_memory, available_memory, global_cpu_usage, num_running_tasks, num_pending_tasks
128                FROM {} WHERE last_heartbeat >= $1"#,
129            self.table
130        );
131
132        let rows = client.query(&query, &[&cutoff_time]).await.map_err(|e| {
133            PostgresClusterError::Query(format!("Failed to query alive nodes: {}", e))
134        })?;
135
136        let mut result = HashMap::new();
137        for row in rows {
138            let host: String = row
139                .try_get(0)
140                .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
141            let port: i32 = row
142                .try_get(1)
143                .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
144
145            let node_id = NodeId {
146                host,
147                port: port as u16,
148            };
149
150            let status_str: String = row
151                .try_get(2)
152                .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
153
154            let status = status_str.parse::<NodeStatus>()?;
155
156            let node_state = NodeState {
157                status,
158                total_memory: row
159                    .try_get::<_, i64>(3)
160                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
161                    as u64,
162                used_memory: row
163                    .try_get::<_, i64>(4)
164                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
165                    as u64,
166                free_memory: row
167                    .try_get::<_, i64>(5)
168                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
169                    as u64,
170                available_memory: row
171                    .try_get::<_, i64>(6)
172                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
173                    as u64,
174                global_cpu_usage: row
175                    .try_get(7)
176                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?,
177                num_running_tasks: row
178                    .try_get::<_, i32>(8)
179                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
180                    as u32,
181                num_pending_tasks: row
182                    .try_get::<_, i32>(9)
183                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
184                    as u32,
185            };
186
187            result.insert(node_id, node_state);
188        }
189
190        debug!("Found {} alive nodes", result.len());
191        Ok(result)
192    }
193}