datafusion_dist_cluster_postgres/
cluster.rs1use std::collections::HashMap;
2
3use bb8::Pool;
4use bb8_postgres::PostgresConnectionManager;
5use bb8_postgres::tokio_postgres::NoTls;
6use datafusion_dist::{
7 DistResult,
8 cluster::{DistCluster, NodeId, NodeState, NodeStatus},
9 util::timestamp_ms,
10};
11use log::{debug, trace};
12
13use crate::PostgresClusterError;
14
15#[derive(Debug, Clone)]
16pub struct PostgresCluster {
17 pub(crate) table: String,
18 pub(crate) pool: Pool<PostgresConnectionManager<NoTls>>,
19 pub(crate) heartbeat_timeout_seconds: i32,
20}
21
22impl PostgresCluster {
23 pub async fn create_table_if_not_exists(&self) -> DistResult<()> {
24 let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
25
26 let create_table_sql = format!(
27 r#"
28 CREATE TABLE IF NOT EXISTS {} (
29 host TEXT NOT NULL,
30 port INTEGER NOT NULL,
31 status TEXT NOT NULL,
32 total_memory BIGINT NOT NULL,
33 used_memory BIGINT NOT NULL,
34 free_memory BIGINT NOT NULL,
35 available_memory BIGINT NOT NULL,
36 global_cpu_usage FLOAT4 NOT NULL,
37 num_running_tasks INTEGER NOT NULL,
38 num_pending_tasks INTEGER NOT NULL,
39 last_heartbeat BIGINT NOT NULL,
40 UNIQUE(host, port)
41 )
42 "#,
43 self.table
44 );
45
46 client
47 .execute(&create_table_sql, &[])
48 .await
49 .map_err(|e| PostgresClusterError::Query(format!("Failed to create table: {e:?}")))?;
50
51 Ok(())
52 }
53
54 fn calculate_cutoff_time(&self) -> i64 {
55 let timeout_ms = i64::from(self.heartbeat_timeout_seconds)
56 .checked_mul(1000)
57 .unwrap_or(60_000);
58 timestamp_ms() - timeout_ms
59 }
60}
61
62#[async_trait::async_trait]
63impl DistCluster for PostgresCluster {
64 async fn heartbeat(&self, node_id: NodeId, state: NodeState) -> DistResult<()> {
66 let timestamp = timestamp_ms();
68
69 let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
70
71 let query = format!(
72 r#"
73 INSERT INTO {} (
74 host, port, status, total_memory, used_memory, free_memory,
75 available_memory, global_cpu_usage, num_running_tasks, num_pending_tasks, last_heartbeat
76 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
77 ON CONFLICT (host, port)
78 DO UPDATE SET
79 status = EXCLUDED.status,
80 total_memory = EXCLUDED.total_memory,
81 used_memory = EXCLUDED.used_memory,
82 free_memory = EXCLUDED.free_memory,
83 available_memory = EXCLUDED.available_memory,
84 global_cpu_usage = EXCLUDED.global_cpu_usage,
85 num_running_tasks = EXCLUDED.num_running_tasks,
86 num_pending_tasks = EXCLUDED.num_pending_tasks,
87 last_heartbeat = EXCLUDED.last_heartbeat
88 "#,
89 self.table
90 );
91
92 client
93 .execute(
94 &query,
95 &[
96 &node_id.host,
97 &(node_id.port as i32),
98 &state.status.to_string(),
99 &(state.total_memory as i64),
100 &(state.used_memory as i64),
101 &(state.free_memory as i64),
102 &(state.available_memory as i64),
103 &state.global_cpu_usage,
104 &(state.num_running_tasks as i32),
105 &(state.num_pending_tasks as i32),
106 ×tamp,
107 ],
108 )
109 .await
110 .map_err(|e| {
111 PostgresClusterError::Query(format!("Failed to insert heartbeat: {e:?}"))
112 })?;
113
114 debug!("Heartbeat sent successfully");
115 Ok(())
116 }
117
118 async fn alive_nodes(&self) -> DistResult<HashMap<NodeId, NodeState>> {
120 trace!("Fetching alive nodes");
121
122 let cutoff_time = self.calculate_cutoff_time();
123
124 let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
125
126 let query = format!(
127 r#"SELECT host, port, status, total_memory, used_memory, free_memory, available_memory, global_cpu_usage, num_running_tasks, num_pending_tasks
128 FROM {} WHERE last_heartbeat >= $1"#,
129 self.table
130 );
131
132 let rows = client.query(&query, &[&cutoff_time]).await.map_err(|e| {
133 PostgresClusterError::Query(format!("Failed to query alive nodes: {}", e))
134 })?;
135
136 let mut result = HashMap::new();
137 for row in rows {
138 let host: String = row
139 .try_get(0)
140 .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
141 let port: i32 = row
142 .try_get(1)
143 .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
144
145 let node_id = NodeId {
146 host,
147 port: port as u16,
148 };
149
150 let status_str: String = row
151 .try_get(2)
152 .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
153
154 let status = status_str.parse::<NodeStatus>()?;
155
156 let node_state = NodeState {
157 status,
158 total_memory: row
159 .try_get::<_, i64>(3)
160 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
161 as u64,
162 used_memory: row
163 .try_get::<_, i64>(4)
164 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
165 as u64,
166 free_memory: row
167 .try_get::<_, i64>(5)
168 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
169 as u64,
170 available_memory: row
171 .try_get::<_, i64>(6)
172 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
173 as u64,
174 global_cpu_usage: row
175 .try_get(7)
176 .map_err(|e| PostgresClusterError::Query(e.to_string()))?,
177 num_running_tasks: row
178 .try_get::<_, i32>(8)
179 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
180 as u32,
181 num_pending_tasks: row
182 .try_get::<_, i32>(9)
183 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
184 as u32,
185 };
186
187 result.insert(node_id, node_state);
188 }
189
190 debug!("Found {} alive nodes", result.len());
191 Ok(result)
192 }
193}