1use std::collections::HashMap;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use std::time::Duration;
16
17use base64::Engine;
18use chitchat::transport::UdpTransport;
19use chitchat::{
20 spawn_chitchat, Chitchat, ChitchatConfig, ChitchatHandle, ChitchatId, FailureDetectorConfig,
21};
22use serde::{Deserialize, Serialize};
23use tokio::sync::{broadcast, Mutex};
24use tracing::{info, warn};
25
26use crate::error::OverlayError;
27
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
30pub struct PeerInfo {
31 pub node_id: u64,
32 pub wg_pubkey: String,
34 pub wg_endpoint: SocketAddr,
36 pub overlay_ip: String,
38 #[serde(default)]
40 pub labels: HashMap<String, String>,
41}
42
43#[derive(Debug, Clone)]
45pub enum TopologyEvent {
46 Joined(PeerInfo),
47 Updated(PeerInfo),
48 Left { node_id: u64 },
49}
50
51#[derive(Debug, Clone)]
53pub struct GossipConfig {
54 pub node_id: u64,
56 pub gossip_listen: SocketAddr,
58 pub seeds: Vec<SocketAddr>,
61 pub cluster_id: String,
64 pub self_info: PeerInfo,
66}
67
68pub struct GossipPool {
70 _handle: ChitchatHandle,
72 chitchat: Arc<Mutex<Chitchat>>,
73 cluster_id: String,
74 events_tx: broadcast::Sender<TopologyEvent>,
75}
76
77impl std::fmt::Debug for GossipPool {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("GossipPool")
80 .field("cluster_id", &self.cluster_id)
81 .finish_non_exhaustive()
82 }
83}
84
85impl GossipPool {
86 pub async fn start(config: GossipConfig) -> Result<Arc<Self>, OverlayError> {
94 let chitchat_id = ChitchatId::new(
95 format!("worker:{}", config.node_id),
96 0,
97 config.gossip_listen,
98 );
99
100 let cfg = ChitchatConfig {
101 chitchat_id,
102 cluster_id: config.cluster_id.clone(),
103 gossip_interval: Duration::from_secs(1),
104 listen_addr: config.gossip_listen,
105 seed_nodes: config
106 .seeds
107 .iter()
108 .map(std::string::ToString::to_string)
109 .collect(),
110 failure_detector_config: FailureDetectorConfig::default(),
111 marked_for_deletion_grace_period: Duration::from_secs(60),
112 catchup_callback: None,
113 extra_liveness_predicate: None,
114 };
115
116 let self_info_bytes = serde_json::to_vec(&config.self_info)
118 .map_err(|e| OverlayError::NetworkConfig(format!("encode gossip self_info: {e}")))?;
119 let self_info_b64 =
120 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&self_info_bytes);
121
122 let initial_kvs = vec![(format!("worker:{}", config.node_id), self_info_b64)];
123
124 let handle = spawn_chitchat(cfg, initial_kvs, &UdpTransport)
125 .await
126 .map_err(|e| OverlayError::NetworkConfig(format!("spawn chitchat: {e}")))?;
127
128 let chitchat = handle.chitchat();
129 let (events_tx, _events_rx) = broadcast::channel(256);
130
131 let chitchat_for_watcher = chitchat.clone();
134 let events_for_watcher = events_tx.clone();
135 let cluster_for_watcher = config.cluster_id.clone();
136 tokio::spawn(async move {
137 let mut last_snapshot: HashMap<u64, PeerInfo> = HashMap::new();
138 let mut tick = tokio::time::interval(Duration::from_secs(1));
139 loop {
140 tick.tick().await;
141 let chitchat_guard = chitchat_for_watcher.lock().await;
142 let current = collect_peers(&chitchat_guard);
143 drop(chitchat_guard);
144
145 let mut next_snapshot = HashMap::new();
146 for peer in current {
147 next_snapshot.insert(peer.node_id, peer.clone());
148 match last_snapshot.get(&peer.node_id) {
149 None => {
150 let _ = events_for_watcher.send(TopologyEvent::Joined(peer));
151 }
152 Some(prev) if prev != &peer => {
153 let _ = events_for_watcher.send(TopologyEvent::Updated(peer));
154 }
155 _ => {}
156 }
157 }
158 for id in last_snapshot.keys() {
159 if !next_snapshot.contains_key(id) {
160 let _ = events_for_watcher.send(TopologyEvent::Left { node_id: *id });
161 }
162 }
163 last_snapshot = next_snapshot;
164
165 tracing::trace!(cluster = %cluster_for_watcher, "gossip watcher tick");
166 }
167 });
168
169 info!(
170 cluster_id = %config.cluster_id,
171 node_id = config.node_id,
172 seeds = ?config.seeds,
173 "gossip pool started"
174 );
175
176 Ok(Arc::new(Self {
177 _handle: handle,
178 chitchat,
179 cluster_id: config.cluster_id,
180 events_tx,
181 }))
182 }
183
184 pub async fn peers(&self) -> Vec<PeerInfo> {
186 let chitchat = self.chitchat.lock().await;
187 collect_peers(&chitchat)
188 }
189
190 #[must_use]
193 pub fn subscribe_updates(&self) -> broadcast::Receiver<TopologyEvent> {
194 self.events_tx.subscribe()
195 }
196
197 pub async fn announce_self(&self, info: &PeerInfo) -> Result<(), OverlayError> {
203 let bytes = serde_json::to_vec(info)
204 .map_err(|e| OverlayError::NetworkConfig(format!("encode self_info: {e}")))?;
205 let b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes);
206 let key = format!("worker:{}", info.node_id);
207
208 let mut chitchat = self.chitchat.lock().await;
209 chitchat.self_node_state().set(key, b64);
210 Ok(())
211 }
212
213 #[must_use]
215 pub fn cluster_id(&self) -> &str {
216 &self.cluster_id
217 }
218}
219
220fn collect_peers(chitchat: &Chitchat) -> Vec<PeerInfo> {
223 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
224
225 let mut out = Vec::new();
226 let self_id = chitchat.self_chitchat_id().clone();
227
228 for (chitchat_id, node_state) in chitchat.node_states() {
229 if chitchat_id == &self_id {
230 continue;
231 }
232 for (key, value) in node_state.key_values() {
233 if let Some(node_id_str) = key.strip_prefix("worker:") {
234 if let Ok(node_id) = node_id_str.parse::<u64>() {
235 match URL_SAFE_NO_PAD.decode(value) {
236 Ok(bytes) => {
237 if let Ok(info) = serde_json::from_slice::<PeerInfo>(&bytes) {
238 out.push(info);
239 }
240 }
241 Err(e) => {
242 warn!(
243 ?chitchat_id,
244 key,
245 node_id,
246 error = %e,
247 "decode peer info failed"
248 );
249 }
250 }
251 }
252 }
253 }
254 }
255 out
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 fn make_self_info(node_id: u64) -> PeerInfo {
263 PeerInfo {
264 node_id,
265 wg_pubkey: "test-key".into(),
266 wg_endpoint: "127.0.0.1:51820".parse().unwrap(),
267 overlay_ip: "10.0.0.1".into(),
268 labels: HashMap::default(),
269 }
270 }
271
272 #[tokio::test]
273 async fn gossip_pool_starts_with_self_only() {
274 let config = GossipConfig {
275 node_id: 42,
276 gossip_listen: "127.0.0.1:0".parse().unwrap(),
277 seeds: vec![],
278 cluster_id: "test-cluster".into(),
279 self_info: make_self_info(42),
280 };
281 let pool = GossipPool::start(config).await.expect("start");
282 let peers = pool.peers().await;
284 assert!(peers.is_empty(), "expected no peers, got: {peers:?}");
285 assert_eq!(pool.cluster_id(), "test-cluster");
286 }
287}