use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use forge_core::config::cluster::{ClusterConfig, DiscoveryMethod};
use forge_core::{ForgeError, Result};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PeerAddress {
pub ip: IpAddr,
pub port: u16,
}
impl std::fmt::Display for PeerAddress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.ip, self.port)
}
}
pub async fn discover_peers(config: &ClusterConfig, default_port: u16) -> Result<Vec<PeerAddress>> {
match config.discovery {
DiscoveryMethod::Postgres => {
Ok(Vec::new())
}
DiscoveryMethod::Dns => discover_dns(config, default_port).await,
DiscoveryMethod::Kubernetes => discover_kubernetes(config, default_port).await,
DiscoveryMethod::Static => discover_static(config, default_port),
}
}
async fn discover_dns(config: &ClusterConfig, default_port: u16) -> Result<Vec<PeerAddress>> {
let dns_name = config.dns_name.as_deref().ok_or_else(|| {
ForgeError::Config(
"DNS discovery requires 'dns_name' to be set in [cluster] config".to_string(),
)
})?;
let lookup_name = if dns_name.contains(':') {
dns_name.to_string()
} else {
format!("{}:{}", dns_name, default_port)
};
let addrs: Vec<SocketAddr> = tokio::task::spawn_blocking(move || {
lookup_name
.to_socket_addrs()
.map(|iter| iter.collect::<Vec<_>>())
})
.await
.map_err(|e| ForgeError::Cluster(format!("DNS lookup task failed: {}", e)))?
.map_err(|e| ForgeError::Cluster(format!("DNS resolution failed for '{}': {}", dns_name, e)))?;
let peers: Vec<PeerAddress> = addrs
.into_iter()
.map(|addr| PeerAddress {
ip: addr.ip(),
port: addr.port(),
})
.collect();
tracing::debug!(
dns_name,
peer_count = peers.len(),
"DNS discovery completed"
);
Ok(peers)
}
async fn discover_kubernetes(
config: &ClusterConfig,
default_port: u16,
) -> Result<Vec<PeerAddress>> {
if config.dns_name.is_some() {
return discover_dns(config, default_port).await;
}
let namespace = std::env::var("POD_NAMESPACE")
.or_else(|_| std::env::var("KUBERNETES_NAMESPACE"))
.unwrap_or_else(|_| "default".to_string());
let service_name = std::env::var("SERVICE_NAME").or_else(|_| {
std::env::var("HOSTNAME").map(|h| {
h.rsplit_once('-')
.map(|(prefix, _)| prefix.to_string())
.unwrap_or(h)
})
});
match service_name {
Ok(svc) => {
let dns_name = format!("{}.{}.svc.cluster.local", svc, namespace);
tracing::info!(
dns_name = %dns_name,
"Kubernetes discovery: constructed service DNS from environment"
);
let k8s_config = ClusterConfig {
dns_name: Some(dns_name),
..config.clone()
};
discover_dns(&k8s_config, default_port).await
}
Err(_) => Err(ForgeError::Config(
"Kubernetes discovery requires either 'dns_name' in [cluster] config, \
or SERVICE_NAME/HOSTNAME and POD_NAMESPACE environment variables"
.to_string(),
)),
}
}
fn discover_static(config: &ClusterConfig, default_port: u16) -> Result<Vec<PeerAddress>> {
if config.seed_nodes.is_empty() {
return Err(ForgeError::Config(
"Static discovery requires 'seed_nodes' to be set in [cluster] config".to_string(),
));
}
let mut peers = Vec::with_capacity(config.seed_nodes.len());
for node in &config.seed_nodes {
let (host, port) = if let Some((h, p)) = node.rsplit_once(':') {
let port = p.parse::<u16>().map_err(|_| {
ForgeError::Config(format!("Invalid port in seed node '{}': '{}'", node, p))
})?;
(h, port)
} else {
(node.as_str(), default_port)
};
let ip: IpAddr = host.parse().map_err(|e| {
ForgeError::Config(format!("Invalid IP address in seed node '{}': {}", node, e))
})?;
peers.push(PeerAddress { ip, port });
}
tracing::debug!(
seed_count = peers.len(),
"Static discovery loaded seed nodes"
);
Ok(peers)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_static_discovery_parses_addresses() {
let config = ClusterConfig {
seed_nodes: vec![
"10.0.0.1:9081".to_string(),
"10.0.0.2:9081".to_string(),
"10.0.0.3".to_string(),
],
..Default::default()
};
let peers = discover_static(&config, 9081).unwrap();
assert_eq!(peers.len(), 3);
let first = peers.first().expect("expected at least one peer");
assert_eq!(first.ip, "10.0.0.1".parse::<IpAddr>().unwrap());
assert_eq!(first.port, 9081);
let third = peers.get(2).expect("expected three peers");
assert_eq!(third.port, 9081); }
#[test]
fn test_static_discovery_empty_seed_nodes_errors() {
let config = ClusterConfig::default();
let result = discover_static(&config, 9081);
assert!(result.is_err());
}
#[test]
fn test_peer_address_display() {
let peer = PeerAddress {
ip: "10.0.0.1".parse().unwrap(),
port: 9081,
};
assert_eq!(peer.to_string(), "10.0.0.1:9081");
}
}