use std::net::IpAddr;
use std::time::Duration;
use chrono::{DateTime, Utc};
use uuid::Uuid;
use forge_core::cluster::{NodeId, NodeInfo, NodeRole, NodeStatus};
use forge_core::{ForgeError, Result};
pub struct NodeRegistry {
pool: sqlx::PgPool,
local_node: NodeInfo,
}
impl NodeRegistry {
pub fn new(pool: sqlx::PgPool, local_node: NodeInfo) -> Self {
Self { pool, local_node }
}
pub fn local_node(&self) -> &NodeInfo {
&self.local_node
}
pub fn local_id(&self) -> NodeId {
self.local_node.id
}
pub async fn register(&self) -> Result<()> {
let roles: Vec<String> = self
.local_node
.roles
.iter()
.map(|r| r.as_str().to_string())
.collect();
sqlx::query!(
r#"
INSERT INTO forge_nodes (
id, hostname, ip_address, http_port, grpc_port,
roles, worker_capabilities, status, version, started_at, last_heartbeat
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW())
ON CONFLICT (id) DO UPDATE SET
hostname = EXCLUDED.hostname,
ip_address = EXCLUDED.ip_address,
http_port = EXCLUDED.http_port,
grpc_port = EXCLUDED.grpc_port,
roles = EXCLUDED.roles,
worker_capabilities = EXCLUDED.worker_capabilities,
status = EXCLUDED.status,
version = EXCLUDED.version,
last_heartbeat = NOW()
"#,
self.local_node.id.as_uuid(),
&self.local_node.hostname,
self.local_node.ip_address.to_string(),
self.local_node.http_port as i32,
self.local_node.grpc_port as i32,
&roles,
&self.local_node.worker_capabilities,
self.local_node.status.as_str(),
&self.local_node.version,
self.local_node.started_at,
)
.execute(&self.pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
Ok(())
}
pub async fn set_status(&self, status: NodeStatus) -> Result<()> {
sqlx::query!(
r#"
UPDATE forge_nodes
SET status = $2
WHERE id = $1
"#,
self.local_node.id.as_uuid(),
status.as_str(),
)
.execute(&self.pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
Ok(())
}
pub async fn deregister(&self) -> Result<()> {
sqlx::query!(
r#"
DELETE FROM forge_nodes WHERE id = $1
"#,
self.local_node.id.as_uuid(),
)
.execute(&self.pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
Ok(())
}
pub async fn get_active_nodes(&self) -> Result<Vec<NodeInfo>> {
self.get_nodes_by_status(NodeStatus::Active).await
}
pub async fn get_nodes_by_status(&self, status: NodeStatus) -> Result<Vec<NodeInfo>> {
let rows = sqlx::query!(
r#"
SELECT id, hostname, ip_address, http_port, grpc_port,
roles, worker_capabilities, status, version,
started_at, last_heartbeat, current_connections,
current_jobs, cpu_usage, memory_usage
FROM forge_nodes
WHERE status = $1
ORDER BY started_at
"#,
status.as_str(),
)
.fetch_all(&self.pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
rows.into_iter()
.map(|row| {
parse_node_fields(
row.id,
row.hostname,
row.ip_address,
row.http_port,
row.grpc_port,
row.roles,
row.worker_capabilities,
row.status,
row.version,
row.started_at,
row.last_heartbeat,
row.current_connections,
row.current_jobs,
row.cpu_usage,
row.memory_usage,
)
})
.collect()
}
pub async fn get_node(&self, node_id: NodeId) -> Result<Option<NodeInfo>> {
let row = sqlx::query!(
r#"
SELECT id, hostname, ip_address, http_port, grpc_port,
roles, worker_capabilities, status, version,
started_at, last_heartbeat, current_connections,
current_jobs, cpu_usage, memory_usage
FROM forge_nodes
WHERE id = $1
"#,
node_id.as_uuid(),
)
.fetch_optional(&self.pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
row.map(|row| {
parse_node_fields(
row.id,
row.hostname,
row.ip_address,
row.http_port,
row.grpc_port,
row.roles,
row.worker_capabilities,
row.status,
row.version,
row.started_at,
row.last_heartbeat,
row.current_connections,
row.current_jobs,
row.cpu_usage,
row.memory_usage,
)
})
.transpose()
}
pub async fn count_by_status(&self) -> Result<NodeCounts> {
let row = sqlx::query!(
r#"
SELECT
COUNT(*) FILTER (WHERE status = 'active') as "active!",
COUNT(*) FILTER (WHERE status = 'draining') as "draining!",
COUNT(*) FILTER (WHERE status = 'dead') as "dead!",
COUNT(*) FILTER (WHERE status = 'joining') as "joining!",
COUNT(*) as "total!"
FROM forge_nodes
"#,
)
.fetch_one(&self.pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
Ok(NodeCounts {
active: row.active as usize,
draining: row.draining as usize,
dead: row.dead as usize,
joining: row.joining as usize,
total: row.total as usize,
})
}
pub async fn mark_dead_nodes(&self, threshold: Duration) -> Result<u64> {
let threshold_secs = threshold.as_secs() as i64;
let result = sqlx::query!(
r#"
UPDATE forge_nodes
SET status = 'dead'
WHERE status = 'active'
AND last_heartbeat < NOW() - make_interval(secs => $1)
"#,
threshold_secs as f64,
)
.execute(&self.pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
Ok(result.rows_affected())
}
pub async fn cleanup_dead_nodes(&self, older_than: Duration) -> Result<u64> {
let threshold_secs = older_than.as_secs() as i64;
let result = sqlx::query!(
r#"
DELETE FROM forge_nodes
WHERE status = 'dead'
AND last_heartbeat < NOW() - make_interval(secs => $1)
"#,
threshold_secs as f64,
)
.execute(&self.pool)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
Ok(result.rows_affected())
}
}
#[allow(clippy::too_many_arguments)]
fn parse_node_fields(
id: Uuid,
hostname: String,
ip_address: String,
http_port: i32,
grpc_port: i32,
roles: Vec<String>,
worker_capabilities: Vec<String>,
status: String,
version: Option<String>,
started_at: DateTime<Utc>,
last_heartbeat: DateTime<Utc>,
current_connections: i32,
current_jobs: i32,
cpu_usage: Option<f64>,
memory_usage: Option<f64>,
) -> Result<NodeInfo> {
let ip_addr: IpAddr = ip_address.parse().map_err(|e| {
ForgeError::Validation(format!("invalid IP address '{}': {}", ip_address, e))
})?;
let node_roles: Vec<NodeRole> = roles
.iter()
.map(|s| {
s.parse()
.map_err(|e| ForgeError::Validation(format!("invalid role '{}': {}", s, e)))
})
.collect::<Result<Vec<_>>>()?;
let node_status: NodeStatus = status
.parse()
.map_err(|e| ForgeError::Validation(format!("invalid status '{}': {}", status, e)))?;
Ok(NodeInfo {
id: NodeId::from_uuid(id),
hostname,
ip_address: ip_addr,
http_port: http_port as u16,
grpc_port: grpc_port as u16,
roles: node_roles,
worker_capabilities,
status: node_status,
version: version.unwrap_or_default(),
started_at,
last_heartbeat,
current_connections: current_connections as u32,
current_jobs: current_jobs as u32,
cpu_usage: cpu_usage.unwrap_or(0.0) as f32,
memory_usage: memory_usage.unwrap_or(0.0) as f32,
})
}
#[derive(Debug, Clone, Default)]
pub struct NodeCounts {
pub active: usize,
pub draining: usize,
pub dead: usize,
pub joining: usize,
pub total: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_node_counts_default() {
let counts = NodeCounts::default();
assert_eq!(counts.active, 0);
assert_eq!(counts.total, 0);
}
}