Skip to main content

forge_runtime/cluster/
registry.rs

1use std::net::IpAddr;
2use std::time::Duration;
3
4use sqlx::Row;
5use uuid::Uuid;
6
7use forge_core::cluster::{NodeId, NodeInfo, NodeRole, NodeStatus};
8use forge_core::{ForgeError, Result};
9
10/// Node registry for cluster membership.
11pub struct NodeRegistry {
12    pool: sqlx::PgPool,
13    local_node: NodeInfo,
14}
15
16impl NodeRegistry {
17    /// Create a new node registry.
18    pub fn new(pool: sqlx::PgPool, local_node: NodeInfo) -> Self {
19        Self { pool, local_node }
20    }
21
22    /// Get the local node info.
23    pub fn local_node(&self) -> &NodeInfo {
24        &self.local_node
25    }
26
27    /// Get the local node ID.
28    pub fn local_id(&self) -> NodeId {
29        self.local_node.id
30    }
31
32    /// Register the local node in the cluster.
33    pub async fn register(&self) -> Result<()> {
34        let roles: Vec<&str> = self.local_node.roles.iter().map(|r| r.as_str()).collect();
35
36        sqlx::query(
37            r#"
38            INSERT INTO forge_nodes (
39                id, hostname, ip_address, http_port, grpc_port,
40                roles, worker_capabilities, status, version, started_at, last_heartbeat
41            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW())
42            ON CONFLICT (id) DO UPDATE SET
43                hostname = EXCLUDED.hostname,
44                ip_address = EXCLUDED.ip_address,
45                http_port = EXCLUDED.http_port,
46                grpc_port = EXCLUDED.grpc_port,
47                roles = EXCLUDED.roles,
48                worker_capabilities = EXCLUDED.worker_capabilities,
49                status = EXCLUDED.status,
50                version = EXCLUDED.version,
51                last_heartbeat = NOW()
52            "#,
53        )
54        .bind(self.local_node.id.as_uuid())
55        .bind(&self.local_node.hostname)
56        .bind(self.local_node.ip_address.to_string())
57        .bind(self.local_node.http_port as i32)
58        .bind(self.local_node.grpc_port as i32)
59        .bind(&roles)
60        .bind(&self.local_node.worker_capabilities)
61        .bind(self.local_node.status.as_str())
62        .bind(&self.local_node.version)
63        .bind(self.local_node.started_at)
64        .execute(&self.pool)
65        .await
66        .map_err(|e| ForgeError::Database(e.to_string()))?;
67
68        Ok(())
69    }
70
71    /// Update node status.
72    pub async fn set_status(&self, status: NodeStatus) -> Result<()> {
73        sqlx::query(
74            r#"
75            UPDATE forge_nodes
76            SET status = $2
77            WHERE id = $1
78            "#,
79        )
80        .bind(self.local_node.id.as_uuid())
81        .bind(status.as_str())
82        .execute(&self.pool)
83        .await
84        .map_err(|e| ForgeError::Database(e.to_string()))?;
85
86        Ok(())
87    }
88
89    /// Deregister the local node from the cluster.
90    pub async fn deregister(&self) -> Result<()> {
91        sqlx::query(
92            r#"
93            DELETE FROM forge_nodes WHERE id = $1
94            "#,
95        )
96        .bind(self.local_node.id.as_uuid())
97        .execute(&self.pool)
98        .await
99        .map_err(|e| ForgeError::Database(e.to_string()))?;
100
101        Ok(())
102    }
103
104    /// Get all active nodes in the cluster.
105    pub async fn get_active_nodes(&self) -> Result<Vec<NodeInfo>> {
106        self.get_nodes_by_status(NodeStatus::Active).await
107    }
108
109    /// Get nodes by status.
110    pub async fn get_nodes_by_status(&self, status: NodeStatus) -> Result<Vec<NodeInfo>> {
111        let rows = sqlx::query(
112            r#"
113            SELECT id, hostname, ip_address, http_port, grpc_port,
114                   roles, worker_capabilities, status, version,
115                   started_at, last_heartbeat, current_connections,
116                   current_jobs, cpu_usage, memory_usage
117            FROM forge_nodes
118            WHERE status = $1
119            ORDER BY started_at
120            "#,
121        )
122        .bind(status.as_str())
123        .fetch_all(&self.pool)
124        .await
125        .map_err(|e| ForgeError::Database(e.to_string()))?;
126
127        rows.into_iter().map(|row| parse_node_row(&row)).collect()
128    }
129
130    /// Get a specific node by ID.
131    pub async fn get_node(&self, node_id: NodeId) -> Result<Option<NodeInfo>> {
132        let row = sqlx::query(
133            r#"
134            SELECT id, hostname, ip_address, http_port, grpc_port,
135                   roles, worker_capabilities, status, version,
136                   started_at, last_heartbeat, current_connections,
137                   current_jobs, cpu_usage, memory_usage
138            FROM forge_nodes
139            WHERE id = $1
140            "#,
141        )
142        .bind(node_id.as_uuid())
143        .fetch_optional(&self.pool)
144        .await
145        .map_err(|e| ForgeError::Database(e.to_string()))?;
146
147        row.map(|r| parse_node_row(&r)).transpose()
148    }
149
150    /// Count nodes by status.
151    pub async fn count_by_status(&self) -> Result<NodeCounts> {
152        let row = sqlx::query(
153            r#"
154            SELECT
155                COUNT(*) FILTER (WHERE status = 'active') as active,
156                COUNT(*) FILTER (WHERE status = 'draining') as draining,
157                COUNT(*) FILTER (WHERE status = 'dead') as dead,
158                COUNT(*) FILTER (WHERE status = 'joining') as joining,
159                COUNT(*) as total
160            FROM forge_nodes
161            "#,
162        )
163        .fetch_one(&self.pool)
164        .await
165        .map_err(|e| ForgeError::Database(e.to_string()))?;
166
167        Ok(NodeCounts {
168            active: row.get::<i64, _>("active") as usize,
169            draining: row.get::<i64, _>("draining") as usize,
170            dead: row.get::<i64, _>("dead") as usize,
171            joining: row.get::<i64, _>("joining") as usize,
172            total: row.get::<i64, _>("total") as usize,
173        })
174    }
175
176    /// Mark stale nodes as dead.
177    pub async fn mark_dead_nodes(&self, threshold: Duration) -> Result<u64> {
178        let threshold_secs = threshold.as_secs() as i64;
179
180        let result = sqlx::query(
181            r#"
182            UPDATE forge_nodes
183            SET status = 'dead'
184            WHERE status = 'active'
185              AND last_heartbeat < NOW() - make_interval(secs => $1)
186            "#,
187        )
188        .bind(threshold_secs as f64)
189        .execute(&self.pool)
190        .await
191        .map_err(|e| ForgeError::Database(e.to_string()))?;
192
193        Ok(result.rows_affected())
194    }
195
196    /// Clean up old dead nodes.
197    pub async fn cleanup_dead_nodes(&self, older_than: Duration) -> Result<u64> {
198        let threshold_secs = older_than.as_secs() as i64;
199
200        let result = sqlx::query(
201            r#"
202            DELETE FROM forge_nodes
203            WHERE status = 'dead'
204              AND last_heartbeat < NOW() - make_interval(secs => $1)
205            "#,
206        )
207        .bind(threshold_secs as f64)
208        .execute(&self.pool)
209        .await
210        .map_err(|e| ForgeError::Database(e.to_string()))?;
211
212        Ok(result.rows_affected())
213    }
214}
215
216fn parse_node_row(row: &sqlx::postgres::PgRow) -> Result<NodeInfo> {
217    let id: Uuid = row.get("id");
218
219    let ip_str: String = row.get("ip_address");
220    let ip_address: IpAddr = ip_str
221        .parse()
222        .map_err(|e| ForgeError::Validation(format!("invalid IP address '{}': {}", ip_str, e)))?;
223
224    let roles_str: Vec<String> = row.get("roles");
225    let roles: Vec<NodeRole> = roles_str
226        .iter()
227        .map(|s| {
228            s.parse()
229                .map_err(|e| ForgeError::Validation(format!("invalid role '{}': {}", s, e)))
230        })
231        .collect::<Result<Vec<_>>>()?;
232
233    let status_str: String = row.get("status");
234    let status: NodeStatus = status_str
235        .parse()
236        .map_err(|e| ForgeError::Validation(format!("invalid status '{}': {}", status_str, e)))?;
237
238    Ok(NodeInfo {
239        id: NodeId::from_uuid(id),
240        hostname: row.get("hostname"),
241        ip_address,
242        http_port: row.get::<i32, _>("http_port") as u16,
243        grpc_port: row.get::<i32, _>("grpc_port") as u16,
244        roles,
245        worker_capabilities: row.get("worker_capabilities"),
246        status,
247        version: row.get("version"),
248        started_at: row.get("started_at"),
249        last_heartbeat: row.get("last_heartbeat"),
250        current_connections: row.get::<i32, _>("current_connections") as u32,
251        current_jobs: row.get::<i32, _>("current_jobs") as u32,
252        cpu_usage: row.get::<f32, _>("cpu_usage"),
253        memory_usage: row.get::<f32, _>("memory_usage"),
254    })
255}
256
257/// Node count statistics.
258#[derive(Debug, Clone, Default)]
259pub struct NodeCounts {
260    /// Active nodes.
261    pub active: usize,
262    /// Draining nodes.
263    pub draining: usize,
264    /// Dead nodes.
265    pub dead: usize,
266    /// Joining nodes.
267    pub joining: usize,
268    /// Total nodes.
269    pub total: usize,
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_node_counts_default() {
278        let counts = NodeCounts::default();
279        assert_eq!(counts.active, 0);
280        assert_eq!(counts.total, 0);
281    }
282}