forge_runtime/cluster/
registry.rs

1use std::net::IpAddr;
2use std::time::Duration;
3
4use uuid::Uuid;
5
6use forge_core::cluster::{NodeId, NodeInfo, NodeRole, NodeStatus};
7
8/// Node registry for cluster membership.
9pub struct NodeRegistry {
10    pool: sqlx::PgPool,
11    local_node: NodeInfo,
12}
13
14impl NodeRegistry {
15    /// Create a new node registry.
16    pub fn new(pool: sqlx::PgPool, local_node: NodeInfo) -> Self {
17        Self { pool, local_node }
18    }
19
20    /// Get the local node info.
21    pub fn local_node(&self) -> &NodeInfo {
22        &self.local_node
23    }
24
25    /// Get the local node ID.
26    pub fn local_id(&self) -> NodeId {
27        self.local_node.id
28    }
29
30    /// Register the local node in the cluster.
31    pub async fn register(&self) -> forge_core::Result<()> {
32        let roles: Vec<&str> = self.local_node.roles.iter().map(|r| r.as_str()).collect();
33
34        sqlx::query(
35            r#"
36            INSERT INTO forge_nodes (
37                id, hostname, ip_address, http_port, grpc_port,
38                roles, worker_capabilities, status, version, started_at, last_heartbeat
39            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW())
40            ON CONFLICT (id) DO UPDATE SET
41                hostname = EXCLUDED.hostname,
42                ip_address = EXCLUDED.ip_address,
43                http_port = EXCLUDED.http_port,
44                grpc_port = EXCLUDED.grpc_port,
45                roles = EXCLUDED.roles,
46                worker_capabilities = EXCLUDED.worker_capabilities,
47                status = EXCLUDED.status,
48                version = EXCLUDED.version,
49                last_heartbeat = NOW()
50            "#,
51        )
52        .bind(self.local_node.id.as_uuid())
53        .bind(&self.local_node.hostname)
54        .bind(self.local_node.ip_address.to_string())
55        .bind(self.local_node.http_port as i32)
56        .bind(self.local_node.grpc_port as i32)
57        .bind(&roles)
58        .bind(&self.local_node.worker_capabilities)
59        .bind(self.local_node.status.as_str())
60        .bind(&self.local_node.version)
61        .bind(self.local_node.started_at)
62        .execute(&self.pool)
63        .await
64        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
65
66        Ok(())
67    }
68
69    /// Update node status.
70    pub async fn set_status(&self, status: NodeStatus) -> forge_core::Result<()> {
71        sqlx::query(
72            r#"
73            UPDATE forge_nodes
74            SET status = $2
75            WHERE id = $1
76            "#,
77        )
78        .bind(self.local_node.id.as_uuid())
79        .bind(status.as_str())
80        .execute(&self.pool)
81        .await
82        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
83
84        Ok(())
85    }
86
87    /// Deregister the local node from the cluster.
88    pub async fn deregister(&self) -> forge_core::Result<()> {
89        sqlx::query(
90            r#"
91            DELETE FROM forge_nodes WHERE id = $1
92            "#,
93        )
94        .bind(self.local_node.id.as_uuid())
95        .execute(&self.pool)
96        .await
97        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
98
99        Ok(())
100    }
101
102    /// Get all active nodes in the cluster.
103    pub async fn get_active_nodes(&self) -> forge_core::Result<Vec<NodeInfo>> {
104        self.get_nodes_by_status(NodeStatus::Active).await
105    }
106
107    /// Get nodes by status.
108    pub async fn get_nodes_by_status(
109        &self,
110        status: NodeStatus,
111    ) -> forge_core::Result<Vec<NodeInfo>> {
112        let rows = sqlx::query(
113            r#"
114            SELECT id, hostname, ip_address, http_port, grpc_port,
115                   roles, worker_capabilities, status, version,
116                   started_at, last_heartbeat, current_connections,
117                   current_jobs, cpu_usage, memory_usage
118            FROM forge_nodes
119            WHERE status = $1
120            ORDER BY started_at
121            "#,
122        )
123        .bind(status.as_str())
124        .fetch_all(&self.pool)
125        .await
126        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
127
128        let mut nodes = Vec::new();
129        for row in rows {
130            use sqlx::Row;
131            let id: Uuid = row.get("id");
132            let ip_str: String = row.get("ip_address");
133            let ip_address: IpAddr = ip_str
134                .parse()
135                .unwrap_or(IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)));
136            let roles_str: Vec<String> = row.get("roles");
137            let roles: Vec<NodeRole> = roles_str.iter().filter_map(|s| s.parse().ok()).collect();
138
139            nodes.push(NodeInfo {
140                id: NodeId::from_uuid(id),
141                hostname: row.get("hostname"),
142                ip_address,
143                http_port: row.get::<i32, _>("http_port") as u16,
144                grpc_port: row.get::<i32, _>("grpc_port") as u16,
145                roles,
146                worker_capabilities: row.get("worker_capabilities"),
147                status: row.get::<String, _>("status").parse().unwrap(),
148                version: row.get("version"),
149                started_at: row.get("started_at"),
150                last_heartbeat: row.get("last_heartbeat"),
151                current_connections: row.get::<i32, _>("current_connections") as u32,
152                current_jobs: row.get::<i32, _>("current_jobs") as u32,
153                cpu_usage: row.get::<f32, _>("cpu_usage"),
154                memory_usage: row.get::<f32, _>("memory_usage"),
155            });
156        }
157
158        Ok(nodes)
159    }
160
161    /// Get a specific node by ID.
162    pub async fn get_node(&self, node_id: NodeId) -> forge_core::Result<Option<NodeInfo>> {
163        let row = sqlx::query(
164            r#"
165            SELECT id, hostname, ip_address, http_port, grpc_port,
166                   roles, worker_capabilities, status, version,
167                   started_at, last_heartbeat, current_connections,
168                   current_jobs, cpu_usage, memory_usage
169            FROM forge_nodes
170            WHERE id = $1
171            "#,
172        )
173        .bind(node_id.as_uuid())
174        .fetch_optional(&self.pool)
175        .await
176        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
177
178        match row {
179            Some(row) => {
180                use sqlx::Row;
181                let id: Uuid = row.get("id");
182                let ip_str: String = row.get("ip_address");
183                let ip_address: IpAddr = ip_str
184                    .parse()
185                    .unwrap_or(IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)));
186                let roles_str: Vec<String> = row.get("roles");
187                let roles: Vec<NodeRole> =
188                    roles_str.iter().filter_map(|s| s.parse().ok()).collect();
189
190                Ok(Some(NodeInfo {
191                    id: NodeId::from_uuid(id),
192                    hostname: row.get("hostname"),
193                    ip_address,
194                    http_port: row.get::<i32, _>("http_port") as u16,
195                    grpc_port: row.get::<i32, _>("grpc_port") as u16,
196                    roles,
197                    worker_capabilities: row.get("worker_capabilities"),
198                    status: row.get::<String, _>("status").parse().unwrap(),
199                    version: row.get("version"),
200                    started_at: row.get("started_at"),
201                    last_heartbeat: row.get("last_heartbeat"),
202                    current_connections: row.get::<i32, _>("current_connections") as u32,
203                    current_jobs: row.get::<i32, _>("current_jobs") as u32,
204                    cpu_usage: row.get::<f32, _>("cpu_usage"),
205                    memory_usage: row.get::<f32, _>("memory_usage"),
206                }))
207            }
208            None => Ok(None),
209        }
210    }
211
212    /// Count nodes by status.
213    pub async fn count_by_status(&self) -> forge_core::Result<NodeCounts> {
214        let row = sqlx::query(
215            r#"
216            SELECT
217                COUNT(*) FILTER (WHERE status = 'active') as active,
218                COUNT(*) FILTER (WHERE status = 'draining') as draining,
219                COUNT(*) FILTER (WHERE status = 'dead') as dead,
220                COUNT(*) FILTER (WHERE status = 'joining') as joining,
221                COUNT(*) as total
222            FROM forge_nodes
223            "#,
224        )
225        .fetch_one(&self.pool)
226        .await
227        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
228
229        use sqlx::Row;
230        Ok(NodeCounts {
231            active: row.get::<i64, _>("active") as usize,
232            draining: row.get::<i64, _>("draining") as usize,
233            dead: row.get::<i64, _>("dead") as usize,
234            joining: row.get::<i64, _>("joining") as usize,
235            total: row.get::<i64, _>("total") as usize,
236        })
237    }
238
239    /// Mark stale nodes as dead.
240    pub async fn mark_dead_nodes(&self, threshold: Duration) -> forge_core::Result<u64> {
241        let threshold_secs = threshold.as_secs() as i64;
242
243        let result = sqlx::query(
244            r#"
245            UPDATE forge_nodes
246            SET status = 'dead'
247            WHERE status = 'active'
248              AND last_heartbeat < NOW() - make_interval(secs => $1)
249            "#,
250        )
251        .bind(threshold_secs as f64)
252        .execute(&self.pool)
253        .await
254        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
255
256        Ok(result.rows_affected())
257    }
258
259    /// Clean up old dead nodes.
260    pub async fn cleanup_dead_nodes(&self, older_than: Duration) -> forge_core::Result<u64> {
261        let threshold_secs = older_than.as_secs() as i64;
262
263        let result = sqlx::query(
264            r#"
265            DELETE FROM forge_nodes
266            WHERE status = 'dead'
267              AND last_heartbeat < NOW() - make_interval(secs => $1)
268            "#,
269        )
270        .bind(threshold_secs as f64)
271        .execute(&self.pool)
272        .await
273        .map_err(|e| forge_core::ForgeError::Database(e.to_string()))?;
274
275        Ok(result.rows_affected())
276    }
277}
278
279/// Node count statistics.
280#[derive(Debug, Clone, Default)]
281pub struct NodeCounts {
282    /// Active nodes.
283    pub active: usize,
284    /// Draining nodes.
285    pub draining: usize,
286    /// Dead nodes.
287    pub dead: usize,
288    /// Joining nodes.
289    pub joining: usize,
290    /// Total nodes.
291    pub total: usize,
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_node_counts_default() {
300        let counts = NodeCounts::default();
301        assert_eq!(counts.active, 0);
302        assert_eq!(counts.total, 0);
303    }
304}