Skip to main content

forge_runtime/cluster/
registry.rs

1use std::net::IpAddr;
2use std::time::Duration;
3
4use chrono::{DateTime, Utc};
5use uuid::Uuid;
6
7use forge_core::cluster::{NodeId, NodeInfo, NodeRole, NodeStatus};
8use forge_core::{ForgeError, Result};
9
10/// Node registry for cluster membership.
11pub struct NodeRegistry {
12    pool: sqlx::PgPool,
13    local_node: NodeInfo,
14}
15
16impl NodeRegistry {
17    /// Create a new node registry.
18    pub fn new(pool: sqlx::PgPool, local_node: NodeInfo) -> Self {
19        Self { pool, local_node }
20    }
21
22    /// Get the local node info.
23    pub fn local_node(&self) -> &NodeInfo {
24        &self.local_node
25    }
26
27    /// Get the local node ID.
28    pub fn local_id(&self) -> NodeId {
29        self.local_node.id
30    }
31
32    /// Register the local node in the cluster.
33    pub async fn register(&self) -> Result<()> {
34        let roles: Vec<String> = self
35            .local_node
36            .roles
37            .iter()
38            .map(|r| r.as_str().to_string())
39            .collect();
40
41        sqlx::query!(
42            r#"
43            INSERT INTO forge_nodes (
44                id, hostname, ip_address, http_port, grpc_port,
45                roles, worker_capabilities, status, version, started_at, last_heartbeat
46            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW())
47            ON CONFLICT (id) DO UPDATE SET
48                hostname = EXCLUDED.hostname,
49                ip_address = EXCLUDED.ip_address,
50                http_port = EXCLUDED.http_port,
51                grpc_port = EXCLUDED.grpc_port,
52                roles = EXCLUDED.roles,
53                worker_capabilities = EXCLUDED.worker_capabilities,
54                status = EXCLUDED.status,
55                version = EXCLUDED.version,
56                last_heartbeat = NOW()
57            "#,
58            self.local_node.id.as_uuid(),
59            &self.local_node.hostname,
60            self.local_node.ip_address.to_string(),
61            self.local_node.http_port as i32,
62            self.local_node.grpc_port as i32,
63            &roles,
64            &self.local_node.worker_capabilities,
65            self.local_node.status.as_str(),
66            &self.local_node.version,
67            self.local_node.started_at,
68        )
69        .execute(&self.pool)
70        .await
71        .map_err(|e| ForgeError::Database(e.to_string()))?;
72
73        Ok(())
74    }
75
76    /// Update node status.
77    pub async fn set_status(&self, status: NodeStatus) -> Result<()> {
78        sqlx::query!(
79            r#"
80            UPDATE forge_nodes
81            SET status = $2
82            WHERE id = $1
83            "#,
84            self.local_node.id.as_uuid(),
85            status.as_str(),
86        )
87        .execute(&self.pool)
88        .await
89        .map_err(|e| ForgeError::Database(e.to_string()))?;
90
91        Ok(())
92    }
93
94    /// Deregister the local node from the cluster.
95    pub async fn deregister(&self) -> Result<()> {
96        sqlx::query!(
97            r#"
98            DELETE FROM forge_nodes WHERE id = $1
99            "#,
100            self.local_node.id.as_uuid(),
101        )
102        .execute(&self.pool)
103        .await
104        .map_err(|e| ForgeError::Database(e.to_string()))?;
105
106        Ok(())
107    }
108
109    /// Get all active nodes in the cluster.
110    pub async fn get_active_nodes(&self) -> Result<Vec<NodeInfo>> {
111        self.get_nodes_by_status(NodeStatus::Active).await
112    }
113
114    /// Get nodes by status.
115    pub async fn get_nodes_by_status(&self, status: NodeStatus) -> Result<Vec<NodeInfo>> {
116        let rows = sqlx::query!(
117            r#"
118            SELECT id, hostname, ip_address, http_port, grpc_port,
119                   roles, worker_capabilities, status, version,
120                   started_at, last_heartbeat, current_connections,
121                   current_jobs, cpu_usage, memory_usage
122            FROM forge_nodes
123            WHERE status = $1
124            ORDER BY started_at
125            "#,
126            status.as_str(),
127        )
128        .fetch_all(&self.pool)
129        .await
130        .map_err(|e| ForgeError::Database(e.to_string()))?;
131
132        rows.into_iter()
133            .map(|row| {
134                parse_node_fields(
135                    row.id,
136                    row.hostname,
137                    row.ip_address,
138                    row.http_port,
139                    row.grpc_port,
140                    row.roles,
141                    row.worker_capabilities,
142                    row.status,
143                    row.version,
144                    row.started_at,
145                    row.last_heartbeat,
146                    row.current_connections,
147                    row.current_jobs,
148                    row.cpu_usage,
149                    row.memory_usage,
150                )
151            })
152            .collect()
153    }
154
155    /// Get a specific node by ID.
156    pub async fn get_node(&self, node_id: NodeId) -> Result<Option<NodeInfo>> {
157        let row = sqlx::query!(
158            r#"
159            SELECT id, hostname, ip_address, http_port, grpc_port,
160                   roles, worker_capabilities, status, version,
161                   started_at, last_heartbeat, current_connections,
162                   current_jobs, cpu_usage, memory_usage
163            FROM forge_nodes
164            WHERE id = $1
165            "#,
166            node_id.as_uuid(),
167        )
168        .fetch_optional(&self.pool)
169        .await
170        .map_err(|e| ForgeError::Database(e.to_string()))?;
171
172        row.map(|row| {
173            parse_node_fields(
174                row.id,
175                row.hostname,
176                row.ip_address,
177                row.http_port,
178                row.grpc_port,
179                row.roles,
180                row.worker_capabilities,
181                row.status,
182                row.version,
183                row.started_at,
184                row.last_heartbeat,
185                row.current_connections,
186                row.current_jobs,
187                row.cpu_usage,
188                row.memory_usage,
189            )
190        })
191        .transpose()
192    }
193
194    /// Count nodes by status.
195    pub async fn count_by_status(&self) -> Result<NodeCounts> {
196        let row = sqlx::query!(
197            r#"
198            SELECT
199                COUNT(*) FILTER (WHERE status = 'active') as "active!",
200                COUNT(*) FILTER (WHERE status = 'draining') as "draining!",
201                COUNT(*) FILTER (WHERE status = 'dead') as "dead!",
202                COUNT(*) FILTER (WHERE status = 'joining') as "joining!",
203                COUNT(*) as "total!"
204            FROM forge_nodes
205            "#,
206        )
207        .fetch_one(&self.pool)
208        .await
209        .map_err(|e| ForgeError::Database(e.to_string()))?;
210
211        Ok(NodeCounts {
212            active: row.active as usize,
213            draining: row.draining as usize,
214            dead: row.dead as usize,
215            joining: row.joining as usize,
216            total: row.total as usize,
217        })
218    }
219
220    /// Mark stale nodes as dead.
221    pub async fn mark_dead_nodes(&self, threshold: Duration) -> Result<u64> {
222        let threshold_secs = threshold.as_secs() as i64;
223
224        let result = sqlx::query!(
225            r#"
226            UPDATE forge_nodes
227            SET status = 'dead'
228            WHERE status = 'active'
229              AND last_heartbeat < NOW() - make_interval(secs => $1)
230            "#,
231            threshold_secs as f64,
232        )
233        .execute(&self.pool)
234        .await
235        .map_err(|e| ForgeError::Database(e.to_string()))?;
236
237        Ok(result.rows_affected())
238    }
239
240    /// Clean up old dead nodes.
241    pub async fn cleanup_dead_nodes(&self, older_than: Duration) -> Result<u64> {
242        let threshold_secs = older_than.as_secs() as i64;
243
244        let result = sqlx::query!(
245            r#"
246            DELETE FROM forge_nodes
247            WHERE status = 'dead'
248              AND last_heartbeat < NOW() - make_interval(secs => $1)
249            "#,
250            threshold_secs as f64,
251        )
252        .execute(&self.pool)
253        .await
254        .map_err(|e| ForgeError::Database(e.to_string()))?;
255
256        Ok(result.rows_affected())
257    }
258}
259
260#[allow(clippy::too_many_arguments)]
261fn parse_node_fields(
262    id: Uuid,
263    hostname: String,
264    ip_address: String,
265    http_port: i32,
266    grpc_port: i32,
267    roles: Vec<String>,
268    worker_capabilities: Vec<String>,
269    status: String,
270    version: Option<String>,
271    started_at: DateTime<Utc>,
272    last_heartbeat: DateTime<Utc>,
273    current_connections: i32,
274    current_jobs: i32,
275    cpu_usage: Option<f64>,
276    memory_usage: Option<f64>,
277) -> Result<NodeInfo> {
278    let ip_addr: IpAddr = ip_address.parse().map_err(|e| {
279        ForgeError::Validation(format!("invalid IP address '{}': {}", ip_address, e))
280    })?;
281
282    let node_roles: Vec<NodeRole> = roles
283        .iter()
284        .map(|s| {
285            s.parse()
286                .map_err(|e| ForgeError::Validation(format!("invalid role '{}': {}", s, e)))
287        })
288        .collect::<Result<Vec<_>>>()?;
289
290    let node_status: NodeStatus = status
291        .parse()
292        .map_err(|e| ForgeError::Validation(format!("invalid status '{}': {}", status, e)))?;
293
294    Ok(NodeInfo {
295        id: NodeId::from_uuid(id),
296        hostname,
297        ip_address: ip_addr,
298        http_port: http_port as u16,
299        grpc_port: grpc_port as u16,
300        roles: node_roles,
301        worker_capabilities,
302        status: node_status,
303        version: version.unwrap_or_default(),
304        started_at,
305        last_heartbeat,
306        current_connections: current_connections as u32,
307        current_jobs: current_jobs as u32,
308        cpu_usage: cpu_usage.unwrap_or(0.0) as f32,
309        memory_usage: memory_usage.unwrap_or(0.0) as f32,
310    })
311}
312
313/// Node count statistics.
314#[derive(Debug, Clone, Default)]
315pub struct NodeCounts {
316    /// Active nodes.
317    pub active: usize,
318    /// Draining nodes.
319    pub draining: usize,
320    /// Dead nodes.
321    pub dead: usize,
322    /// Joining nodes.
323    pub joining: usize,
324    /// Total nodes.
325    pub total: usize,
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn test_node_counts_default() {
334        let counts = NodeCounts::default();
335        assert_eq!(counts.active, 0);
336        assert_eq!(counts.total, 0);
337    }
338}