1use std::net::IpAddr;
2use std::time::Duration;
3
4use sqlx::Row;
5use uuid::Uuid;
6
7use forge_core::cluster::{NodeId, NodeInfo, NodeRole, NodeStatus};
8use forge_core::{ForgeError, Result};
9
10pub struct NodeRegistry {
12 pool: sqlx::PgPool,
13 local_node: NodeInfo,
14}
15
16impl NodeRegistry {
17 pub fn new(pool: sqlx::PgPool, local_node: NodeInfo) -> Self {
19 Self { pool, local_node }
20 }
21
22 pub fn local_node(&self) -> &NodeInfo {
24 &self.local_node
25 }
26
27 pub fn local_id(&self) -> NodeId {
29 self.local_node.id
30 }
31
32 pub async fn register(&self) -> Result<()> {
34 let roles: Vec<&str> = self.local_node.roles.iter().map(|r| r.as_str()).collect();
35
36 sqlx::query(
37 r#"
38 INSERT INTO forge_nodes (
39 id, hostname, ip_address, http_port, grpc_port,
40 roles, worker_capabilities, status, version, started_at, last_heartbeat
41 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW())
42 ON CONFLICT (id) DO UPDATE SET
43 hostname = EXCLUDED.hostname,
44 ip_address = EXCLUDED.ip_address,
45 http_port = EXCLUDED.http_port,
46 grpc_port = EXCLUDED.grpc_port,
47 roles = EXCLUDED.roles,
48 worker_capabilities = EXCLUDED.worker_capabilities,
49 status = EXCLUDED.status,
50 version = EXCLUDED.version,
51 last_heartbeat = NOW()
52 "#,
53 )
54 .bind(self.local_node.id.as_uuid())
55 .bind(&self.local_node.hostname)
56 .bind(self.local_node.ip_address.to_string())
57 .bind(self.local_node.http_port as i32)
58 .bind(self.local_node.grpc_port as i32)
59 .bind(&roles)
60 .bind(&self.local_node.worker_capabilities)
61 .bind(self.local_node.status.as_str())
62 .bind(&self.local_node.version)
63 .bind(self.local_node.started_at)
64 .execute(&self.pool)
65 .await
66 .map_err(|e| ForgeError::Database(e.to_string()))?;
67
68 Ok(())
69 }
70
71 pub async fn set_status(&self, status: NodeStatus) -> Result<()> {
73 sqlx::query(
74 r#"
75 UPDATE forge_nodes
76 SET status = $2
77 WHERE id = $1
78 "#,
79 )
80 .bind(self.local_node.id.as_uuid())
81 .bind(status.as_str())
82 .execute(&self.pool)
83 .await
84 .map_err(|e| ForgeError::Database(e.to_string()))?;
85
86 Ok(())
87 }
88
89 pub async fn deregister(&self) -> Result<()> {
91 sqlx::query(
92 r#"
93 DELETE FROM forge_nodes WHERE id = $1
94 "#,
95 )
96 .bind(self.local_node.id.as_uuid())
97 .execute(&self.pool)
98 .await
99 .map_err(|e| ForgeError::Database(e.to_string()))?;
100
101 Ok(())
102 }
103
104 pub async fn get_active_nodes(&self) -> Result<Vec<NodeInfo>> {
106 self.get_nodes_by_status(NodeStatus::Active).await
107 }
108
109 pub async fn get_nodes_by_status(&self, status: NodeStatus) -> Result<Vec<NodeInfo>> {
111 let rows = sqlx::query(
112 r#"
113 SELECT id, hostname, ip_address, http_port, grpc_port,
114 roles, worker_capabilities, status, version,
115 started_at, last_heartbeat, current_connections,
116 current_jobs, cpu_usage, memory_usage
117 FROM forge_nodes
118 WHERE status = $1
119 ORDER BY started_at
120 "#,
121 )
122 .bind(status.as_str())
123 .fetch_all(&self.pool)
124 .await
125 .map_err(|e| ForgeError::Database(e.to_string()))?;
126
127 rows.into_iter().map(|row| parse_node_row(&row)).collect()
128 }
129
130 pub async fn get_node(&self, node_id: NodeId) -> Result<Option<NodeInfo>> {
132 let row = sqlx::query(
133 r#"
134 SELECT id, hostname, ip_address, http_port, grpc_port,
135 roles, worker_capabilities, status, version,
136 started_at, last_heartbeat, current_connections,
137 current_jobs, cpu_usage, memory_usage
138 FROM forge_nodes
139 WHERE id = $1
140 "#,
141 )
142 .bind(node_id.as_uuid())
143 .fetch_optional(&self.pool)
144 .await
145 .map_err(|e| ForgeError::Database(e.to_string()))?;
146
147 row.map(|r| parse_node_row(&r)).transpose()
148 }
149
150 pub async fn count_by_status(&self) -> Result<NodeCounts> {
152 let row = sqlx::query(
153 r#"
154 SELECT
155 COUNT(*) FILTER (WHERE status = 'active') as active,
156 COUNT(*) FILTER (WHERE status = 'draining') as draining,
157 COUNT(*) FILTER (WHERE status = 'dead') as dead,
158 COUNT(*) FILTER (WHERE status = 'joining') as joining,
159 COUNT(*) as total
160 FROM forge_nodes
161 "#,
162 )
163 .fetch_one(&self.pool)
164 .await
165 .map_err(|e| ForgeError::Database(e.to_string()))?;
166
167 Ok(NodeCounts {
168 active: row.get::<i64, _>("active") as usize,
169 draining: row.get::<i64, _>("draining") as usize,
170 dead: row.get::<i64, _>("dead") as usize,
171 joining: row.get::<i64, _>("joining") as usize,
172 total: row.get::<i64, _>("total") as usize,
173 })
174 }
175
176 pub async fn mark_dead_nodes(&self, threshold: Duration) -> Result<u64> {
178 let threshold_secs = threshold.as_secs() as i64;
179
180 let result = sqlx::query(
181 r#"
182 UPDATE forge_nodes
183 SET status = 'dead'
184 WHERE status = 'active'
185 AND last_heartbeat < NOW() - make_interval(secs => $1)
186 "#,
187 )
188 .bind(threshold_secs as f64)
189 .execute(&self.pool)
190 .await
191 .map_err(|e| ForgeError::Database(e.to_string()))?;
192
193 Ok(result.rows_affected())
194 }
195
196 pub async fn cleanup_dead_nodes(&self, older_than: Duration) -> Result<u64> {
198 let threshold_secs = older_than.as_secs() as i64;
199
200 let result = sqlx::query(
201 r#"
202 DELETE FROM forge_nodes
203 WHERE status = 'dead'
204 AND last_heartbeat < NOW() - make_interval(secs => $1)
205 "#,
206 )
207 .bind(threshold_secs as f64)
208 .execute(&self.pool)
209 .await
210 .map_err(|e| ForgeError::Database(e.to_string()))?;
211
212 Ok(result.rows_affected())
213 }
214}
215
216fn parse_node_row(row: &sqlx::postgres::PgRow) -> Result<NodeInfo> {
217 let id: Uuid = row.get("id");
218
219 let ip_str: String = row.get("ip_address");
220 let ip_address: IpAddr = ip_str
221 .parse()
222 .map_err(|e| ForgeError::Validation(format!("invalid IP address '{}': {}", ip_str, e)))?;
223
224 let roles_str: Vec<String> = row.get("roles");
225 let roles: Vec<NodeRole> = roles_str
226 .iter()
227 .map(|s| {
228 s.parse()
229 .map_err(|e| ForgeError::Validation(format!("invalid role '{}': {}", s, e)))
230 })
231 .collect::<Result<Vec<_>>>()?;
232
233 let status_str: String = row.get("status");
234 let status: NodeStatus = status_str
235 .parse()
236 .map_err(|e| ForgeError::Validation(format!("invalid status '{}': {}", status_str, e)))?;
237
238 Ok(NodeInfo {
239 id: NodeId::from_uuid(id),
240 hostname: row.get("hostname"),
241 ip_address,
242 http_port: row.get::<i32, _>("http_port") as u16,
243 grpc_port: row.get::<i32, _>("grpc_port") as u16,
244 roles,
245 worker_capabilities: row.get("worker_capabilities"),
246 status,
247 version: row.get("version"),
248 started_at: row.get("started_at"),
249 last_heartbeat: row.get("last_heartbeat"),
250 current_connections: row.get::<i32, _>("current_connections") as u32,
251 current_jobs: row.get::<i32, _>("current_jobs") as u32,
252 cpu_usage: row.get::<f32, _>("cpu_usage"),
253 memory_usage: row.get::<f32, _>("memory_usage"),
254 })
255}
256
257#[derive(Debug, Clone, Default)]
259pub struct NodeCounts {
260 pub active: usize,
262 pub draining: usize,
264 pub dead: usize,
266 pub joining: usize,
268 pub total: usize,
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_node_counts_default() {
278 let counts = NodeCounts::default();
279 assert_eq!(counts.active, 0);
280 assert_eq!(counts.total, 0);
281 }
282}