Skip to main content

forge_runtime/cluster/
discovery.rs

1//! Cluster discovery implementations.
2//!
3//! Supports multiple discovery methods for finding peer nodes:
4//! - **Postgres**: Default. Nodes register in the `forge_nodes` table.
5//! - **DNS**: Resolve a DNS name to find peer node IPs.
6//! - **Kubernetes**: Use the Kubernetes API to discover pods in a headless service.
7//! - **Static**: Use a fixed list of seed node addresses.
8
9use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
10
11use forge_core::config::cluster::{ClusterConfig, DiscoveryMethod};
12use forge_core::{ForgeError, Result};
13
14/// Discovered peer node address.
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct PeerAddress {
17    /// IP address of the peer.
18    pub ip: IpAddr,
19    /// HTTP port of the peer (defaults to gateway port).
20    pub port: u16,
21}
22
23impl std::fmt::Display for PeerAddress {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        write!(f, "{}:{}", self.ip, self.port)
26    }
27}
28
29/// Discover peer nodes using the configured discovery method.
30///
31/// For `Postgres` discovery, peers are found via the `forge_nodes` table
32/// (handled by `NodeRegistry`), so this function returns an empty list.
33/// For other methods, this resolves peers from the configured source.
34pub async fn discover_peers(config: &ClusterConfig, default_port: u16) -> Result<Vec<PeerAddress>> {
35    match config.discovery {
36        DiscoveryMethod::Postgres => {
37            // Postgres discovery is handled by NodeRegistry directly.
38            Ok(Vec::new())
39        }
40        DiscoveryMethod::Dns => discover_dns(config, default_port).await,
41        DiscoveryMethod::Kubernetes => discover_kubernetes(config, default_port).await,
42        DiscoveryMethod::Static => discover_static(config, default_port),
43    }
44}
45
46/// DNS-based discovery.
47///
48/// Resolves the configured `dns_name` to a set of IP addresses. This is
49/// commonly used with Kubernetes headless services, where the DNS name
50/// resolves to all pod IPs in the service.
51async fn discover_dns(config: &ClusterConfig, default_port: u16) -> Result<Vec<PeerAddress>> {
52    let dns_name = config.dns_name.as_deref().ok_or_else(|| {
53        ForgeError::Config(
54            "DNS discovery requires 'dns_name' to be set in [cluster] config".to_string(),
55        )
56    })?;
57
58    let lookup_name = if dns_name.contains(':') {
59        dns_name.to_string()
60    } else {
61        format!("{}:{}", dns_name, default_port)
62    };
63
64    // Perform DNS resolution (blocking, but fast for local DNS)
65    let addrs: Vec<SocketAddr> = tokio::task::spawn_blocking(move || {
66        lookup_name
67            .to_socket_addrs()
68            .map(|iter| iter.collect::<Vec<_>>())
69    })
70    .await
71    .map_err(|e| ForgeError::Cluster(format!("DNS lookup task failed: {}", e)))?
72    .map_err(|e| ForgeError::Cluster(format!("DNS resolution failed for '{}': {}", dns_name, e)))?;
73
74    let peers: Vec<PeerAddress> = addrs
75        .into_iter()
76        .map(|addr| PeerAddress {
77            ip: addr.ip(),
78            port: addr.port(),
79        })
80        .collect();
81
82    tracing::debug!(
83        dns_name,
84        peer_count = peers.len(),
85        "DNS discovery completed"
86    );
87    Ok(peers)
88}
89
90/// Kubernetes-based discovery.
91///
92/// Uses the Kubernetes downward API and DNS to discover peer pods.
93/// This expects the service to be a headless service (ClusterIP: None)
94/// so that DNS returns individual pod IPs.
95///
96/// Falls back to DNS discovery using the `dns_name` config, which should
97/// be set to the headless service FQDN (e.g., `my-app.default.svc.cluster.local`).
98async fn discover_kubernetes(
99    config: &ClusterConfig,
100    default_port: u16,
101) -> Result<Vec<PeerAddress>> {
102    // Kubernetes discovery uses DNS under the hood via headless services.
103    // The dns_name should point to the headless service FQDN.
104    if config.dns_name.is_some() {
105        return discover_dns(config, default_port).await;
106    }
107
108    // Attempt to construct the service DNS name from environment variables
109    // set by the Kubernetes downward API.
110    let namespace = std::env::var("POD_NAMESPACE")
111        .or_else(|_| std::env::var("KUBERNETES_NAMESPACE"))
112        .unwrap_or_else(|_| "default".to_string());
113
114    let service_name = std::env::var("SERVICE_NAME").or_else(|_| {
115        std::env::var("HOSTNAME").map(|h| {
116            // Extract service name from pod hostname (e.g., "my-app-0" -> "my-app")
117            h.rsplit_once('-')
118                .map(|(prefix, _)| prefix.to_string())
119                .unwrap_or(h)
120        })
121    });
122
123    match service_name {
124        Ok(svc) => {
125            let dns_name = format!("{}.{}.svc.cluster.local", svc, namespace);
126            tracing::info!(
127                dns_name = %dns_name,
128                "Kubernetes discovery: constructed service DNS from environment"
129            );
130
131            let k8s_config = ClusterConfig {
132                dns_name: Some(dns_name),
133                ..config.clone()
134            };
135            discover_dns(&k8s_config, default_port).await
136        }
137        Err(_) => Err(ForgeError::Config(
138            "Kubernetes discovery requires either 'dns_name' in [cluster] config, \
139             or SERVICE_NAME/HOSTNAME and POD_NAMESPACE environment variables"
140                .to_string(),
141        )),
142    }
143}
144
145/// Static discovery from configured seed nodes.
146///
147/// Parses the `seed_nodes` list from config. Each entry should be
148/// in the format `host:port` or just `host` (uses default port).
149fn discover_static(config: &ClusterConfig, default_port: u16) -> Result<Vec<PeerAddress>> {
150    if config.seed_nodes.is_empty() {
151        return Err(ForgeError::Config(
152            "Static discovery requires 'seed_nodes' to be set in [cluster] config".to_string(),
153        ));
154    }
155
156    let mut peers = Vec::with_capacity(config.seed_nodes.len());
157
158    for node in &config.seed_nodes {
159        let (host, port) = if let Some((h, p)) = node.rsplit_once(':') {
160            let port = p.parse::<u16>().map_err(|_| {
161                ForgeError::Config(format!("Invalid port in seed node '{}': '{}'", node, p))
162            })?;
163            (h, port)
164        } else {
165            (node.as_str(), default_port)
166        };
167
168        let ip: IpAddr = host.parse().map_err(|e| {
169            ForgeError::Config(format!("Invalid IP address in seed node '{}': {}", node, e))
170        })?;
171
172        peers.push(PeerAddress { ip, port });
173    }
174
175    tracing::debug!(
176        seed_count = peers.len(),
177        "Static discovery loaded seed nodes"
178    );
179    Ok(peers)
180}
181
182#[cfg(test)]
183#[allow(clippy::unwrap_used)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn test_static_discovery_parses_addresses() {
189        let config = ClusterConfig {
190            seed_nodes: vec![
191                "10.0.0.1:9081".to_string(),
192                "10.0.0.2:9081".to_string(),
193                "10.0.0.3".to_string(),
194            ],
195            ..Default::default()
196        };
197
198        let peers = discover_static(&config, 9081).unwrap();
199        assert_eq!(peers.len(), 3);
200        let first = peers.first().expect("expected at least one peer");
201        assert_eq!(first.ip, "10.0.0.1".parse::<IpAddr>().unwrap());
202        assert_eq!(first.port, 9081);
203        let third = peers.get(2).expect("expected three peers");
204        assert_eq!(third.port, 9081); // default port used
205    }
206
207    #[test]
208    fn test_static_discovery_empty_seed_nodes_errors() {
209        let config = ClusterConfig::default();
210        let result = discover_static(&config, 9081);
211        assert!(result.is_err());
212    }
213
214    #[test]
215    fn test_peer_address_display() {
216        let peer = PeerAddress {
217            ip: "10.0.0.1".parse().unwrap(),
218            port: 9081,
219        };
220        assert_eq!(peer.to_string(), "10.0.0.1:9081");
221    }
222}