datafusion_dist_cluster_postgres/
cluster.rs1use std::collections::HashMap;
2
3use bb8::Pool;
4use bb8_postgres::PostgresConnectionManager;
5use bb8_postgres::tokio_postgres::NoTls;
6use datafusion_dist::{
7 DistResult,
8 cluster::{DistCluster, NodeId, NodeState, NodeStatus},
9 util::timestamp_ms,
10};
11use log::{debug, trace};
12
13use crate::PostgresClusterError;
14
15#[derive(Debug, Clone)]
16pub struct PostgresCluster {
17 pub(crate) table: String,
18 pub(crate) pool: Pool<PostgresConnectionManager<NoTls>>,
19 pub(crate) heartbeat_timeout_seconds: i32,
20}
21
22impl PostgresCluster {
23 pub async fn create_table_if_not_exists(&self) -> DistResult<()> {
24 let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
25
26 let create_table_sql = format!(
27 r#"
28 CREATE TABLE IF NOT EXISTS {} (
29 host TEXT NOT NULL,
30 port INTEGER NOT NULL,
31 status TEXT NOT NULL,
32 total_memory BIGINT NOT NULL,
33 used_memory BIGINT NOT NULL,
34 free_memory BIGINT NOT NULL,
35 available_memory BIGINT NOT NULL,
36 global_cpu_usage FLOAT4 NOT NULL,
37 num_running_tasks INTEGER NOT NULL,
38 last_heartbeat BIGINT NOT NULL,
39 UNIQUE(host, port)
40 )
41 "#,
42 self.table
43 );
44
45 client
46 .execute(&create_table_sql, &[])
47 .await
48 .map_err(|e| PostgresClusterError::Query(format!("Failed to create table: {e:?}")))?;
49
50 Ok(())
51 }
52
53 fn calculate_cutoff_time(&self) -> i64 {
54 let timeout_ms = i64::from(self.heartbeat_timeout_seconds)
55 .checked_mul(1000)
56 .unwrap_or(60_000);
57 timestamp_ms() - timeout_ms
58 }
59}
60
61#[async_trait::async_trait]
62impl DistCluster for PostgresCluster {
63 async fn heartbeat(&self, node_id: NodeId, state: NodeState) -> DistResult<()> {
65 let timestamp = timestamp_ms();
67
68 let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
69
70 let query = format!(
71 r#"
72 INSERT INTO {} (
73 host, port, status, total_memory, used_memory, free_memory,
74 available_memory, global_cpu_usage, num_running_tasks, last_heartbeat
75 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
76 ON CONFLICT (host, port)
77 DO UPDATE SET
78 status = EXCLUDED.status,
79 total_memory = EXCLUDED.total_memory,
80 used_memory = EXCLUDED.used_memory,
81 free_memory = EXCLUDED.free_memory,
82 available_memory = EXCLUDED.available_memory,
83 global_cpu_usage = EXCLUDED.global_cpu_usage,
84 num_running_tasks = EXCLUDED.num_running_tasks,
85 last_heartbeat = EXCLUDED.last_heartbeat
86 "#,
87 self.table
88 );
89
90 client
91 .execute(
92 &query,
93 &[
94 &node_id.host,
95 &(node_id.port as i32),
96 &state.status.to_string(),
97 &(state.total_memory as i64),
98 &(state.used_memory as i64),
99 &(state.free_memory as i64),
100 &(state.available_memory as i64),
101 &state.global_cpu_usage,
102 &(state.num_running_tasks as i32),
103 ×tamp,
104 ],
105 )
106 .await
107 .map_err(|e| {
108 PostgresClusterError::Query(format!("Failed to insert heartbeat: {e:?}"))
109 })?;
110
111 debug!("Heartbeat sent successfully");
112 Ok(())
113 }
114
115 async fn alive_nodes(&self) -> DistResult<HashMap<NodeId, NodeState>> {
117 trace!("Fetching alive nodes");
118
119 let cutoff_time = self.calculate_cutoff_time();
120
121 let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
122
123 let query = format!(
124 r#"SELECT host, port, status, total_memory, used_memory, free_memory, available_memory, global_cpu_usage, num_running_tasks
125 FROM {} WHERE last_heartbeat >= $1"#,
126 self.table
127 );
128
129 let rows = client.query(&query, &[&cutoff_time]).await.map_err(|e| {
130 PostgresClusterError::Query(format!("Failed to query alive nodes: {}", e))
131 })?;
132
133 let mut result = HashMap::new();
134 for row in rows {
135 let host: String = row
136 .try_get(0)
137 .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
138 let port: i32 = row
139 .try_get(1)
140 .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
141
142 let node_id = NodeId {
143 host,
144 port: port as u16,
145 };
146
147 let status_str: String = row
148 .try_get(2)
149 .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
150
151 let status = status_str.parse::<NodeStatus>()?;
152
153 let node_state = NodeState {
154 status,
155 total_memory: row
156 .try_get::<_, i64>(3)
157 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
158 as u64,
159 used_memory: row
160 .try_get::<_, i64>(4)
161 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
162 as u64,
163 free_memory: row
164 .try_get::<_, i64>(5)
165 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
166 as u64,
167 available_memory: row
168 .try_get::<_, i64>(6)
169 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
170 as u64,
171 global_cpu_usage: row
172 .try_get(7)
173 .map_err(|e| PostgresClusterError::Query(e.to_string()))?,
174 num_running_tasks: row
175 .try_get::<_, i32>(8)
176 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
177 as u32,
178 };
179
180 result.insert(node_id, node_state);
181 }
182
183 debug!("Found {} alive nodes", result.len());
184 Ok(result)
185 }
186}