1use std::net::IpAddr;
2use std::time::Duration;
3
4use uuid::Uuid;
5
6use forge_core::cluster::{NodeId, NodeInfo, NodeRole, NodeStatus};
7
8pub struct NodeRegistry {
10 pool: sqlx::PgPool,
11 local_node: NodeInfo,
12}
13
14impl NodeRegistry {
15 pub fn new(pool: sqlx::PgPool, local_node: NodeInfo) -> Self {
17 Self { pool, local_node }
18 }
19
20 pub fn local_node(&self) -> &NodeInfo {
22 &self.local_node
23 }
24
25 pub fn local_id(&self) -> NodeId {
27 self.local_node.id
28 }
29
30 pub async fn register(&self) -> forge_core::Result<()> {
32 let roles: Vec<&str> = self.local_node.roles.iter().map(|r| r.as_str()).collect();
33
34 sqlx::query(
35 r#"
36 INSERT INTO forge_nodes (
37 id, hostname, ip_address, http_port, grpc_port,
38 roles, worker_capabilities, status, version, started_at, last_heartbeat
39 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW())
40 ON CONFLICT (id) DO UPDATE SET
41 hostname = EXCLUDED.hostname,
42 ip_address = EXCLUDED.ip_address,
43 http_port = EXCLUDED.http_port,
44 grpc_port = EXCLUDED.grpc_port,
45 roles = EXCLUDED.roles,
46 worker_capabilities = EXCLUDED.worker_capabilities,
47 status = EXCLUDED.status,
48 version = EXCLUDED.version,
49 last_heartbeat = NOW()
50 "#,
51 )
52 .bind(self.local_node.id.as_uuid())
53 .bind(&self.local_node.hostname)
54 .bind(self.local_node.ip_address.to_string())
55 .bind(self.local_node.http_port as i32)
56 .bind(self.local_node.grpc_port as i32)
57 .bind(&roles)
58 .bind(&self.local_node.worker_capabilities)
59 .bind(self.local_node.status.as_str())
60 .bind(&self.local_node.version)
61 .bind(self.local_node.started_at)
62 .execute(&self.pool)
63 .await
64 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
65
66 Ok(())
67 }
68
69 pub async fn set_status(&self, status: NodeStatus) -> forge_core::Result<()> {
71 sqlx::query(
72 r#"
73 UPDATE forge_nodes
74 SET status = $2
75 WHERE id = $1
76 "#,
77 )
78 .bind(self.local_node.id.as_uuid())
79 .bind(status.as_str())
80 .execute(&self.pool)
81 .await
82 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
83
84 Ok(())
85 }
86
87 pub async fn deregister(&self) -> forge_core::Result<()> {
89 sqlx::query(
90 r#"
91 DELETE FROM forge_nodes WHERE id = $1
92 "#,
93 )
94 .bind(self.local_node.id.as_uuid())
95 .execute(&self.pool)
96 .await
97 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
98
99 Ok(())
100 }
101
102 pub async fn get_active_nodes(&self) -> forge_core::Result<Vec<NodeInfo>> {
104 self.get_nodes_by_status(NodeStatus::Active).await
105 }
106
107 pub async fn get_nodes_by_status(
109 &self,
110 status: NodeStatus,
111 ) -> forge_core::Result<Vec<NodeInfo>> {
112 let rows = sqlx::query(
113 r#"
114 SELECT id, hostname, ip_address, http_port, grpc_port,
115 roles, worker_capabilities, status, version,
116 started_at, last_heartbeat, current_connections,
117 current_jobs, cpu_usage, memory_usage
118 FROM forge_nodes
119 WHERE status = $1
120 ORDER BY started_at
121 "#,
122 )
123 .bind(status.as_str())
124 .fetch_all(&self.pool)
125 .await
126 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
127
128 let mut nodes = Vec::new();
129 for row in rows {
130 use sqlx::Row;
131 let id: Uuid = row.get("id");
132 let ip_str: String = row.get("ip_address");
133 let ip_address: IpAddr = ip_str
134 .parse()
135 .unwrap_or(IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)));
136 let roles_str: Vec<String> = row.get("roles");
137 let roles: Vec<NodeRole> = roles_str.iter().filter_map(|s| s.parse().ok()).collect();
138
139 nodes.push(NodeInfo {
140 id: NodeId::from_uuid(id),
141 hostname: row.get("hostname"),
142 ip_address,
143 http_port: row.get::<i32, _>("http_port") as u16,
144 grpc_port: row.get::<i32, _>("grpc_port") as u16,
145 roles,
146 worker_capabilities: row.get("worker_capabilities"),
147 status: row.get::<String, _>("status").parse().unwrap(),
148 version: row.get("version"),
149 started_at: row.get("started_at"),
150 last_heartbeat: row.get("last_heartbeat"),
151 current_connections: row.get::<i32, _>("current_connections") as u32,
152 current_jobs: row.get::<i32, _>("current_jobs") as u32,
153 cpu_usage: row.get::<f32, _>("cpu_usage"),
154 memory_usage: row.get::<f32, _>("memory_usage"),
155 });
156 }
157
158 Ok(nodes)
159 }
160
161 pub async fn get_node(&self, node_id: NodeId) -> forge_core::Result<Option<NodeInfo>> {
163 let row = sqlx::query(
164 r#"
165 SELECT id, hostname, ip_address, http_port, grpc_port,
166 roles, worker_capabilities, status, version,
167 started_at, last_heartbeat, current_connections,
168 current_jobs, cpu_usage, memory_usage
169 FROM forge_nodes
170 WHERE id = $1
171 "#,
172 )
173 .bind(node_id.as_uuid())
174 .fetch_optional(&self.pool)
175 .await
176 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
177
178 match row {
179 Some(row) => {
180 use sqlx::Row;
181 let id: Uuid = row.get("id");
182 let ip_str: String = row.get("ip_address");
183 let ip_address: IpAddr = ip_str
184 .parse()
185 .unwrap_or(IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)));
186 let roles_str: Vec<String> = row.get("roles");
187 let roles: Vec<NodeRole> =
188 roles_str.iter().filter_map(|s| s.parse().ok()).collect();
189
190 Ok(Some(NodeInfo {
191 id: NodeId::from_uuid(id),
192 hostname: row.get("hostname"),
193 ip_address,
194 http_port: row.get::<i32, _>("http_port") as u16,
195 grpc_port: row.get::<i32, _>("grpc_port") as u16,
196 roles,
197 worker_capabilities: row.get("worker_capabilities"),
198 status: row.get::<String, _>("status").parse().unwrap(),
199 version: row.get("version"),
200 started_at: row.get("started_at"),
201 last_heartbeat: row.get("last_heartbeat"),
202 current_connections: row.get::<i32, _>("current_connections") as u32,
203 current_jobs: row.get::<i32, _>("current_jobs") as u32,
204 cpu_usage: row.get::<f32, _>("cpu_usage"),
205 memory_usage: row.get::<f32, _>("memory_usage"),
206 }))
207 }
208 None => Ok(None),
209 }
210 }
211
212 pub async fn count_by_status(&self) -> forge_core::Result<NodeCounts> {
214 let row = sqlx::query(
215 r#"
216 SELECT
217 COUNT(*) FILTER (WHERE status = 'active') as active,
218 COUNT(*) FILTER (WHERE status = 'draining') as draining,
219 COUNT(*) FILTER (WHERE status = 'dead') as dead,
220 COUNT(*) FILTER (WHERE status = 'joining') as joining,
221 COUNT(*) as total
222 FROM forge_nodes
223 "#,
224 )
225 .fetch_one(&self.pool)
226 .await
227 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
228
229 use sqlx::Row;
230 Ok(NodeCounts {
231 active: row.get::<i64, _>("active") as usize,
232 draining: row.get::<i64, _>("draining") as usize,
233 dead: row.get::<i64, _>("dead") as usize,
234 joining: row.get::<i64, _>("joining") as usize,
235 total: row.get::<i64, _>("total") as usize,
236 })
237 }
238
239 pub async fn mark_dead_nodes(&self, threshold: Duration) -> forge_core::Result<u64> {
241 let threshold_secs = threshold.as_secs() as i64;
242
243 let result = sqlx::query(
244 r#"
245 UPDATE forge_nodes
246 SET status = 'dead'
247 WHERE status = 'active'
248 AND last_heartbeat < NOW() - make_interval(secs => $1)
249 "#,
250 )
251 .bind(threshold_secs as f64)
252 .execute(&self.pool)
253 .await
254 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
255
256 Ok(result.rows_affected())
257 }
258
259 pub async fn cleanup_dead_nodes(&self, older_than: Duration) -> forge_core::Result<u64> {
261 let threshold_secs = older_than.as_secs() as i64;
262
263 let result = sqlx::query(
264 r#"
265 DELETE FROM forge_nodes
266 WHERE status = 'dead'
267 AND last_heartbeat < NOW() - make_interval(secs => $1)
268 "#,
269 )
270 .bind(threshold_secs as f64)
271 .execute(&self.pool)
272 .await
273 .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
274
275 Ok(result.rows_affected())
276 }
277}
278
279#[derive(Debug, Clone, Default)]
281pub struct NodeCounts {
282 pub active: usize,
284 pub draining: usize,
286 pub dead: usize,
288 pub joining: usize,
290 pub total: usize,
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_node_counts_default() {
300 let counts = NodeCounts::default();
301 assert_eq!(counts.active, 0);
302 assert_eq!(counts.total, 0);
303 }
304}