datafusion_dist_cluster_postgres/
cluster.rs

1use std::collections::HashMap;
2
3use bb8::Pool;
4use bb8_postgres::PostgresConnectionManager;
5use bb8_postgres::tokio_postgres::NoTls;
6use datafusion_dist::{
7    DistResult,
8    cluster::{DistCluster, NodeId, NodeState},
9    util::timestamp_ms,
10};
11use log::{debug, trace};
12
13use crate::PostgresClusterError;
14
15#[derive(Debug, Clone)]
16pub struct PostgresCluster {
17    pool: Pool<PostgresConnectionManager<NoTls>>,
18    heartbeat_timeout_seconds: i32,
19}
20
21impl PostgresCluster {
22    pub fn new(pool: Pool<PostgresConnectionManager<NoTls>>) -> Self {
23        Self {
24            pool,
25            heartbeat_timeout_seconds: 60,
26        }
27    }
28
29    pub fn with_heartbeat_timeout(mut self, timeout_seconds: i32) -> Self {
30        self.heartbeat_timeout_seconds = timeout_seconds;
31        self
32    }
33
34    pub async fn ensure_schema(&self) -> DistResult<()> {
35        let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
36
37        client
38            .execute(
39                r#"
40                     CREATE TABLE IF NOT EXISTS cluster_nodes (
41                         host TEXT NOT NULL,
42                         port INTEGER NOT NULL,
43                         total_memory BIGINT NOT NULL,
44                         used_memory BIGINT NOT NULL,
45                         free_memory BIGINT NOT NULL,
46                         available_memory BIGINT NOT NULL,
47                         global_cpu_usage FLOAT4 NOT NULL,
48                         num_running_tasks INTEGER NOT NULL,
49                         last_heartbeat BIGINT NOT NULL,
50                         UNIQUE(host, port)
51                     )
52                     "#,
53                &[],
54            )
55            .await
56            .map_err(|e| PostgresClusterError::Query(format!("Failed to create table: {e:?}")))?;
57
58        Ok(())
59    }
60
61    fn calculate_cutoff_time(&self) -> i64 {
62        let timeout_ms = i64::from(self.heartbeat_timeout_seconds)
63            .checked_mul(1000)
64            .unwrap_or(60_000);
65        timestamp_ms() - timeout_ms
66    }
67}
68
69#[async_trait::async_trait]
70impl DistCluster for PostgresCluster {
71    // Send heartbeat
72    async fn heartbeat(&self, node_id: NodeId, state: NodeState) -> DistResult<()> {
73        trace!("Sending heartbeat for node");
74
75        // Get current timestamp in milliseconds as i64
76        let timestamp = timestamp_ms();
77
78        let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
79
80        let query = r#"
81                   INSERT INTO cluster_nodes (
82                       host, port, total_memory, used_memory, free_memory,
83                       available_memory, global_cpu_usage, num_running_tasks, last_heartbeat
84                   ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
85                   ON CONFLICT (host, port)
86                   DO UPDATE SET
87                       total_memory = EXCLUDED.total_memory,
88                       used_memory = EXCLUDED.used_memory,
89                       free_memory = EXCLUDED.free_memory,
90                       available_memory = EXCLUDED.available_memory,
91                       global_cpu_usage = EXCLUDED.global_cpu_usage,
92                       num_running_tasks = EXCLUDED.num_running_tasks,
93                       last_heartbeat = $9
94                   "#;
95
96        client
97            .execute(
98                query,
99                &[
100                    &node_id.host,
101                    &(node_id.port as i32),
102                    &(state.total_memory as i64),
103                    &(state.used_memory as i64),
104                    &(state.free_memory as i64),
105                    &(state.available_memory as i64),
106                    &state.global_cpu_usage,
107                    &(state.num_running_tasks as i32),
108                    &timestamp,
109                ],
110            )
111            .await
112            .map_err(|e| {
113                PostgresClusterError::Query(format!("Failed to insert heartbeat: {e:?}"))
114            })?;
115
116        debug!("Heartbeat sent successfully");
117        Ok(())
118    }
119
120    // Get alive nodes
121    async fn alive_nodes(&self) -> DistResult<HashMap<NodeId, NodeState>> {
122        trace!("Fetching alive nodes");
123
124        let cutoff_time = self.calculate_cutoff_time();
125
126        let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
127
128        let rows = client
129                        .query(
130                            r#"
131                            SELECT host, port, total_memory, used_memory,
132                                   free_memory, available_memory, global_cpu_usage, num_running_tasks
133                            FROM cluster_nodes
134                            WHERE last_heartbeat >= $1
135                            "#,
136                            &[&cutoff_time],
137                        )
138                        .await
139                        .map_err(|e| PostgresClusterError::Query(format!("Failed to query alive nodes: {}", e)))?;
140
141        let mut result = HashMap::new();
142        for row in rows {
143            let host: String = row
144                .try_get(0)
145                .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
146            let port: i32 = row
147                .try_get(1)
148                .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
149
150            let node_id = NodeId {
151                host,
152                port: port as u16,
153            };
154
155            let node_state = NodeState {
156                total_memory: row
157                    .try_get::<_, i64>(2)
158                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
159                    as u64,
160                used_memory: row
161                    .try_get::<_, i64>(3)
162                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
163                    as u64,
164                free_memory: row
165                    .try_get::<_, i64>(4)
166                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
167                    as u64,
168                available_memory: row
169                    .try_get::<_, i64>(5)
170                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
171                    as u64,
172                global_cpu_usage: row
173                    .try_get(6)
174                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?,
175                num_running_tasks: row
176                    .try_get::<_, i32>(7)
177                    .map_err(|e| PostgresClusterError::Query(e.to_string()))?
178                    as u32,
179            };
180
181            result.insert(node_id, node_state);
182        }
183
184        debug!("Found {} alive nodes", result.len());
185        Ok(result)
186    }
187}