datafusion_dist_cluster_postgres/
cluster.rs1use std::collections::HashMap;
2
3use bb8::Pool;
4use bb8_postgres::PostgresConnectionManager;
5use bb8_postgres::tokio_postgres::NoTls;
6use datafusion_dist::{
7 DistResult,
8 cluster::{DistCluster, NodeId, NodeState},
9 util::timestamp_ms,
10};
11use log::{debug, trace};
12
13use crate::PostgresClusterError;
14
15#[derive(Debug, Clone)]
16pub struct PostgresCluster {
17 pool: Pool<PostgresConnectionManager<NoTls>>,
18 heartbeat_timeout_seconds: i32,
19}
20
21impl PostgresCluster {
22 pub fn new(pool: Pool<PostgresConnectionManager<NoTls>>) -> Self {
23 Self {
24 pool,
25 heartbeat_timeout_seconds: 60,
26 }
27 }
28
29 pub fn with_heartbeat_timeout(mut self, timeout_seconds: i32) -> Self {
30 self.heartbeat_timeout_seconds = timeout_seconds;
31 self
32 }
33
34 pub async fn ensure_schema(&self) -> DistResult<()> {
35 let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
36
37 client
38 .execute(
39 r#"
40 CREATE TABLE IF NOT EXISTS cluster_nodes (
41 host TEXT NOT NULL,
42 port INTEGER NOT NULL,
43 total_memory BIGINT NOT NULL,
44 used_memory BIGINT NOT NULL,
45 free_memory BIGINT NOT NULL,
46 available_memory BIGINT NOT NULL,
47 global_cpu_usage FLOAT4 NOT NULL,
48 num_running_tasks INTEGER NOT NULL,
49 last_heartbeat BIGINT NOT NULL,
50 UNIQUE(host, port)
51 )
52 "#,
53 &[],
54 )
55 .await
56 .map_err(|e| PostgresClusterError::Query(format!("Failed to create table: {e:?}")))?;
57
58 Ok(())
59 }
60
61 fn calculate_cutoff_time(&self) -> i64 {
62 let timeout_ms = i64::from(self.heartbeat_timeout_seconds)
63 .checked_mul(1000)
64 .unwrap_or(60_000);
65 timestamp_ms() - timeout_ms
66 }
67}
68
69#[async_trait::async_trait]
70impl DistCluster for PostgresCluster {
71 async fn heartbeat(&self, node_id: NodeId, state: NodeState) -> DistResult<()> {
73 trace!("Sending heartbeat for node");
74
75 let timestamp = timestamp_ms();
77
78 let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
79
80 let query = r#"
81 INSERT INTO cluster_nodes (
82 host, port, total_memory, used_memory, free_memory,
83 available_memory, global_cpu_usage, num_running_tasks, last_heartbeat
84 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
85 ON CONFLICT (host, port)
86 DO UPDATE SET
87 total_memory = EXCLUDED.total_memory,
88 used_memory = EXCLUDED.used_memory,
89 free_memory = EXCLUDED.free_memory,
90 available_memory = EXCLUDED.available_memory,
91 global_cpu_usage = EXCLUDED.global_cpu_usage,
92 num_running_tasks = EXCLUDED.num_running_tasks,
93 last_heartbeat = $9
94 "#;
95
96 client
97 .execute(
98 query,
99 &[
100 &node_id.host,
101 &(node_id.port as i32),
102 &(state.total_memory as i64),
103 &(state.used_memory as i64),
104 &(state.free_memory as i64),
105 &(state.available_memory as i64),
106 &state.global_cpu_usage,
107 &(state.num_running_tasks as i32),
108 ×tamp,
109 ],
110 )
111 .await
112 .map_err(|e| {
113 PostgresClusterError::Query(format!("Failed to insert heartbeat: {e:?}"))
114 })?;
115
116 debug!("Heartbeat sent successfully");
117 Ok(())
118 }
119
120 async fn alive_nodes(&self) -> DistResult<HashMap<NodeId, NodeState>> {
122 trace!("Fetching alive nodes");
123
124 let cutoff_time = self.calculate_cutoff_time();
125
126 let client = self.pool.get().await.map_err(PostgresClusterError::Pool)?;
127
128 let rows = client
129 .query(
130 r#"
131 SELECT host, port, total_memory, used_memory,
132 free_memory, available_memory, global_cpu_usage, num_running_tasks
133 FROM cluster_nodes
134 WHERE last_heartbeat >= $1
135 "#,
136 &[&cutoff_time],
137 )
138 .await
139 .map_err(|e| PostgresClusterError::Query(format!("Failed to query alive nodes: {}", e)))?;
140
141 let mut result = HashMap::new();
142 for row in rows {
143 let host: String = row
144 .try_get(0)
145 .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
146 let port: i32 = row
147 .try_get(1)
148 .map_err(|e| PostgresClusterError::Query(e.to_string()))?;
149
150 let node_id = NodeId {
151 host,
152 port: port as u16,
153 };
154
155 let node_state = NodeState {
156 total_memory: row
157 .try_get::<_, i64>(2)
158 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
159 as u64,
160 used_memory: row
161 .try_get::<_, i64>(3)
162 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
163 as u64,
164 free_memory: row
165 .try_get::<_, i64>(4)
166 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
167 as u64,
168 available_memory: row
169 .try_get::<_, i64>(5)
170 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
171 as u64,
172 global_cpu_usage: row
173 .try_get(6)
174 .map_err(|e| PostgresClusterError::Query(e.to_string()))?,
175 num_running_tasks: row
176 .try_get::<_, i32>(7)
177 .map_err(|e| PostgresClusterError::Query(e.to_string()))?
178 as u32,
179 };
180
181 result.insert(node_id, node_state);
182 }
183
184 debug!("Found {} alive nodes", result.len());
185 Ok(result)
186 }
187}