1use std::net::IpAddr;
2use std::time::Duration;
3
4use chrono::{DateTime, Utc};
5use uuid::Uuid;
6
7use forge_core::cluster::{NodeId, NodeInfo, NodeRole, NodeStatus};
8use forge_core::{ForgeError, Result};
9
10pub struct NodeRegistry {
12 pool: sqlx::PgPool,
13 local_node: NodeInfo,
14}
15
16impl NodeRegistry {
17 pub fn new(pool: sqlx::PgPool, local_node: NodeInfo) -> Self {
19 Self { pool, local_node }
20 }
21
22 pub fn local_node(&self) -> &NodeInfo {
24 &self.local_node
25 }
26
27 pub fn local_id(&self) -> NodeId {
29 self.local_node.id
30 }
31
32 pub async fn register(&self) -> Result<()> {
34 let roles: Vec<String> = self
35 .local_node
36 .roles
37 .iter()
38 .map(|r| r.as_str().to_string())
39 .collect();
40
41 sqlx::query!(
42 r#"
43 INSERT INTO forge_nodes (
44 id, hostname, ip_address, http_port, grpc_port,
45 roles, worker_capabilities, status, version, started_at, last_heartbeat
46 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW())
47 ON CONFLICT (id) DO UPDATE SET
48 hostname = EXCLUDED.hostname,
49 ip_address = EXCLUDED.ip_address,
50 http_port = EXCLUDED.http_port,
51 grpc_port = EXCLUDED.grpc_port,
52 roles = EXCLUDED.roles,
53 worker_capabilities = EXCLUDED.worker_capabilities,
54 status = EXCLUDED.status,
55 version = EXCLUDED.version,
56 last_heartbeat = NOW()
57 "#,
58 self.local_node.id.as_uuid(),
59 &self.local_node.hostname,
60 self.local_node.ip_address.to_string(),
61 self.local_node.http_port as i32,
62 self.local_node.grpc_port as i32,
63 &roles,
64 &self.local_node.worker_capabilities,
65 self.local_node.status.as_str(),
66 &self.local_node.version,
67 self.local_node.started_at,
68 )
69 .execute(&self.pool)
70 .await
71 .map_err(|e| ForgeError::Database(e.to_string()))?;
72
73 Ok(())
74 }
75
76 pub async fn set_status(&self, status: NodeStatus) -> Result<()> {
78 sqlx::query!(
79 r#"
80 UPDATE forge_nodes
81 SET status = $2
82 WHERE id = $1
83 "#,
84 self.local_node.id.as_uuid(),
85 status.as_str(),
86 )
87 .execute(&self.pool)
88 .await
89 .map_err(|e| ForgeError::Database(e.to_string()))?;
90
91 Ok(())
92 }
93
94 pub async fn deregister(&self) -> Result<()> {
96 sqlx::query!(
97 r#"
98 DELETE FROM forge_nodes WHERE id = $1
99 "#,
100 self.local_node.id.as_uuid(),
101 )
102 .execute(&self.pool)
103 .await
104 .map_err(|e| ForgeError::Database(e.to_string()))?;
105
106 Ok(())
107 }
108
109 pub async fn get_active_nodes(&self) -> Result<Vec<NodeInfo>> {
111 self.get_nodes_by_status(NodeStatus::Active).await
112 }
113
114 pub async fn get_nodes_by_status(&self, status: NodeStatus) -> Result<Vec<NodeInfo>> {
116 let rows = sqlx::query!(
117 r#"
118 SELECT id, hostname, ip_address, http_port, grpc_port,
119 roles, worker_capabilities, status, version,
120 started_at, last_heartbeat, current_connections,
121 current_jobs, cpu_usage, memory_usage
122 FROM forge_nodes
123 WHERE status = $1
124 ORDER BY started_at
125 "#,
126 status.as_str(),
127 )
128 .fetch_all(&self.pool)
129 .await
130 .map_err(|e| ForgeError::Database(e.to_string()))?;
131
132 rows.into_iter()
133 .map(|row| {
134 parse_node_fields(
135 row.id,
136 row.hostname,
137 row.ip_address,
138 row.http_port,
139 row.grpc_port,
140 row.roles,
141 row.worker_capabilities,
142 row.status,
143 row.version,
144 row.started_at,
145 row.last_heartbeat,
146 row.current_connections,
147 row.current_jobs,
148 row.cpu_usage,
149 row.memory_usage,
150 )
151 })
152 .collect()
153 }
154
155 pub async fn get_node(&self, node_id: NodeId) -> Result<Option<NodeInfo>> {
157 let row = sqlx::query!(
158 r#"
159 SELECT id, hostname, ip_address, http_port, grpc_port,
160 roles, worker_capabilities, status, version,
161 started_at, last_heartbeat, current_connections,
162 current_jobs, cpu_usage, memory_usage
163 FROM forge_nodes
164 WHERE id = $1
165 "#,
166 node_id.as_uuid(),
167 )
168 .fetch_optional(&self.pool)
169 .await
170 .map_err(|e| ForgeError::Database(e.to_string()))?;
171
172 row.map(|row| {
173 parse_node_fields(
174 row.id,
175 row.hostname,
176 row.ip_address,
177 row.http_port,
178 row.grpc_port,
179 row.roles,
180 row.worker_capabilities,
181 row.status,
182 row.version,
183 row.started_at,
184 row.last_heartbeat,
185 row.current_connections,
186 row.current_jobs,
187 row.cpu_usage,
188 row.memory_usage,
189 )
190 })
191 .transpose()
192 }
193
194 pub async fn count_by_status(&self) -> Result<NodeCounts> {
196 let row = sqlx::query!(
197 r#"
198 SELECT
199 COUNT(*) FILTER (WHERE status = 'active') as "active!",
200 COUNT(*) FILTER (WHERE status = 'draining') as "draining!",
201 COUNT(*) FILTER (WHERE status = 'dead') as "dead!",
202 COUNT(*) FILTER (WHERE status = 'joining') as "joining!",
203 COUNT(*) as "total!"
204 FROM forge_nodes
205 "#,
206 )
207 .fetch_one(&self.pool)
208 .await
209 .map_err(|e| ForgeError::Database(e.to_string()))?;
210
211 Ok(NodeCounts {
212 active: row.active as usize,
213 draining: row.draining as usize,
214 dead: row.dead as usize,
215 joining: row.joining as usize,
216 total: row.total as usize,
217 })
218 }
219
220 pub async fn mark_dead_nodes(&self, threshold: Duration) -> Result<u64> {
222 let threshold_secs = threshold.as_secs() as i64;
223
224 let result = sqlx::query!(
225 r#"
226 UPDATE forge_nodes
227 SET status = 'dead'
228 WHERE status = 'active'
229 AND last_heartbeat < NOW() - make_interval(secs => $1)
230 "#,
231 threshold_secs as f64,
232 )
233 .execute(&self.pool)
234 .await
235 .map_err(|e| ForgeError::Database(e.to_string()))?;
236
237 Ok(result.rows_affected())
238 }
239
240 pub async fn cleanup_dead_nodes(&self, older_than: Duration) -> Result<u64> {
242 let threshold_secs = older_than.as_secs() as i64;
243
244 let result = sqlx::query!(
245 r#"
246 DELETE FROM forge_nodes
247 WHERE status = 'dead'
248 AND last_heartbeat < NOW() - make_interval(secs => $1)
249 "#,
250 threshold_secs as f64,
251 )
252 .execute(&self.pool)
253 .await
254 .map_err(|e| ForgeError::Database(e.to_string()))?;
255
256 Ok(result.rows_affected())
257 }
258}
259
260#[allow(clippy::too_many_arguments)]
261fn parse_node_fields(
262 id: Uuid,
263 hostname: String,
264 ip_address: String,
265 http_port: i32,
266 grpc_port: i32,
267 roles: Vec<String>,
268 worker_capabilities: Vec<String>,
269 status: String,
270 version: Option<String>,
271 started_at: DateTime<Utc>,
272 last_heartbeat: DateTime<Utc>,
273 current_connections: i32,
274 current_jobs: i32,
275 cpu_usage: Option<f64>,
276 memory_usage: Option<f64>,
277) -> Result<NodeInfo> {
278 let ip_addr: IpAddr = ip_address.parse().map_err(|e| {
279 ForgeError::Validation(format!("invalid IP address '{}': {}", ip_address, e))
280 })?;
281
282 let node_roles: Vec<NodeRole> = roles
283 .iter()
284 .map(|s| {
285 s.parse()
286 .map_err(|e| ForgeError::Validation(format!("invalid role '{}': {}", s, e)))
287 })
288 .collect::<Result<Vec<_>>>()?;
289
290 let node_status: NodeStatus = status
291 .parse()
292 .map_err(|e| ForgeError::Validation(format!("invalid status '{}': {}", status, e)))?;
293
294 Ok(NodeInfo {
295 id: NodeId::from_uuid(id),
296 hostname,
297 ip_address: ip_addr,
298 http_port: http_port as u16,
299 grpc_port: grpc_port as u16,
300 roles: node_roles,
301 worker_capabilities,
302 status: node_status,
303 version: version.unwrap_or_default(),
304 started_at,
305 last_heartbeat,
306 current_connections: current_connections as u32,
307 current_jobs: current_jobs as u32,
308 cpu_usage: cpu_usage.unwrap_or(0.0) as f32,
309 memory_usage: memory_usage.unwrap_or(0.0) as f32,
310 })
311}
312
313#[derive(Debug, Clone, Default)]
315pub struct NodeCounts {
316 pub active: usize,
318 pub draining: usize,
320 pub dead: usize,
322 pub joining: usize,
324 pub total: usize,
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn test_node_counts_default() {
334 let counts = NodeCounts::default();
335 assert_eq!(counts.active, 0);
336 assert_eq!(counts.total, 0);
337 }
338}