1use std::collections::HashMap;
28use std::net::SocketAddr;
29use std::sync::atomic::{AtomicU64, Ordering};
30use std::sync::Arc;
31use std::time::Duration;
32
33use parking_lot::RwLock;
34use tokio::net::{TcpListener, TcpStream};
35use tokio::time::interval;
36use tokio_util::sync::CancellationToken;
37
38use crate::gossip::fanout::GossipFanout;
39
40use super::codec::{ClusterMessage, MessageCodec};
41
42#[derive(Default, Clone)]
52pub struct GossipState {
53 pub entries: HashMap<String, (u64, u64)>,
55}
56
57impl GossipState {
58 pub fn set(&mut self, key: &str, value: u64, version: u64) -> bool {
63 let entry = self.entries.entry(key.to_owned()).or_insert((0, 0));
64 if version > entry.1 {
65 *entry = (value, version);
66 true
67 } else {
68 false
69 }
70 }
71
72 pub fn get(&self, key: &str) -> Option<(u64, u64)> {
74 self.entries.get(key).copied()
75 }
76
77 pub fn len(&self) -> usize {
79 self.entries.len()
80 }
81
82 pub fn is_empty(&self) -> bool {
84 self.entries.is_empty()
85 }
86}
87
88#[derive(Debug, Clone)]
94pub struct TcpNodeConfig {
95 pub node_id: String,
97 pub bind_addr: SocketAddr,
100 pub fanout: GossipFanout,
102 pub gossip_interval_ms: u64,
104}
105
106impl TcpNodeConfig {
107 pub fn localhost(node_id: &str, port: u16) -> Self {
111 Self {
112 node_id: node_id.to_owned(),
113 bind_addr: SocketAddr::from(([127, 0, 0, 1], port)),
114 fanout: GossipFanout::Unbounded,
115 gossip_interval_ms: 50,
116 }
117 }
118}
119
120#[derive(Debug)]
126pub enum TcpNodeError {
127 BindError(std::io::Error),
129 SendError(String),
131 Shutdown,
133}
134
135impl std::fmt::Display for TcpNodeError {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 match self {
138 TcpNodeError::BindError(e) => write!(f, "TCP bind failed: {e}"),
139 TcpNodeError::SendError(s) => write!(f, "send error: {s}"),
140 TcpNodeError::Shutdown => write!(f, "node has been shut down"),
141 }
142 }
143}
144
145impl std::error::Error for TcpNodeError {
146 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
147 match self {
148 TcpNodeError::BindError(e) => Some(e),
149 TcpNodeError::SendError(_) | TcpNodeError::Shutdown => None,
150 }
151 }
152}
153
154pub struct TcpClusterNode {
164 config: TcpNodeConfig,
165 bound_addr: SocketAddr,
167 state: Arc<RwLock<GossipState>>,
168 peers: Arc<RwLock<Vec<SocketAddr>>>,
169 cancel: CancellationToken,
170 version: Arc<AtomicU64>,
172}
173
174impl TcpClusterNode {
175 pub async fn start(config: TcpNodeConfig) -> Result<Self, TcpNodeError> {
183 let listener = TcpListener::bind(config.bind_addr)
184 .await
185 .map_err(TcpNodeError::BindError)?;
186 let bound_addr = listener.local_addr().map_err(TcpNodeError::BindError)?;
187
188 let state: Arc<RwLock<GossipState>> = Arc::default();
189 let peers: Arc<RwLock<Vec<SocketAddr>>> = Arc::default();
190 let cancel = CancellationToken::new();
191 let version = Arc::new(AtomicU64::new(1));
192
193 let state_clone = Arc::clone(&state);
195 let cancel_clone = cancel.clone();
196 let node_id_clone = config.node_id.clone();
197 tokio::spawn(async move {
198 run_listener(listener, state_clone, node_id_clone, cancel_clone).await;
199 });
200
201 let state_gossip = Arc::clone(&state);
203 let peers_gossip = Arc::clone(&peers);
204 let cancel_gossip = cancel.clone();
205 let gossip_interval = config.gossip_interval_ms;
206 let fanout = config.fanout;
207 let node_id_gossip = config.node_id.clone();
208 tokio::spawn(async move {
209 run_gossip_loop(
210 node_id_gossip,
211 fanout,
212 gossip_interval,
213 state_gossip,
214 peers_gossip,
215 cancel_gossip,
216 )
217 .await;
218 });
219
220 Ok(Self {
221 config,
222 bound_addr,
223 state,
224 peers,
225 cancel,
226 version,
227 })
228 }
229
230 pub fn add_peer(&self, addr: SocketAddr) {
232 self.peers.write().push(addr);
233 }
234
235 pub fn set(&self, key: &str, value: u64) {
241 let ver = self.version.fetch_add(1, Ordering::Relaxed) + 1;
242 self.state.write().set(key, value, ver);
243 }
244
245 pub fn set_with_version(&self, key: &str, value: u64, version: u64) {
255 self.state.write().set(key, value, version);
256 let mut current = self.version.load(Ordering::Relaxed);
259 loop {
260 if current >= version {
261 break;
262 }
263 match self.version.compare_exchange_weak(
264 current,
265 version,
266 Ordering::Relaxed,
267 Ordering::Relaxed,
268 ) {
269 Ok(_) => break,
270 Err(v) => current = v,
271 }
272 }
273 }
274
275 pub fn get(&self, key: &str) -> Option<u64> {
279 self.state.read().get(key).map(|(v, _ver)| v)
280 }
281
282 pub fn state_len(&self) -> usize {
284 self.state.read().len()
285 }
286
287 pub fn shutdown(&self) {
292 self.cancel.cancel();
293 }
294
295 pub fn node_id(&self) -> &str {
297 &self.config.node_id
298 }
299
300 pub fn addr(&self) -> SocketAddr {
302 self.bound_addr
303 }
304}
305
306async fn run_listener(
312 listener: TcpListener,
313 state: Arc<RwLock<GossipState>>,
314 node_id: String,
315 cancel: CancellationToken,
316) {
317 loop {
318 tokio::select! {
319 biased;
320 _ = cancel.cancelled() => break,
321 result = listener.accept() => {
322 match result {
323 Ok((stream, _peer)) => {
324 let state_clone = Arc::clone(&state);
325 let node_id_clone = node_id.clone();
326 let cancel_clone = cancel.clone();
327 tokio::spawn(async move {
328 handle_connection(stream, state_clone, node_id_clone, cancel_clone).await;
329 });
330 }
331 Err(_) => break,
332 }
333 }
334 }
335 }
336}
337
338async fn handle_connection(
342 mut stream: TcpStream,
343 state: Arc<RwLock<GossipState>>,
344 node_id: String,
345 cancel: CancellationToken,
346) {
347 let (mut reader, mut writer) = stream.split();
348
349 loop {
350 let msg = tokio::select! {
351 biased;
352 _ = cancel.cancelled() => break,
353 result = MessageCodec::read(&mut reader) => {
354 match result {
355 Ok(m) => m,
356 Err(_) => break, }
358 }
359 };
360
361 match msg {
362 ClusterMessage::Gossip {
363 key,
364 value,
365 version,
366 ..
367 } => {
368 state.write().set(&key, value, version);
369 }
370 ClusterMessage::Ping { sender_id: _, seq } => {
371 let pong = ClusterMessage::Pong {
372 sender_id: node_id.clone(),
373 seq,
374 };
375 if MessageCodec::write(&mut writer, &pong).await.is_err() {
376 break;
377 }
378 }
379 ClusterMessage::Replicate { index, .. } => {
380 let ack = ClusterMessage::ReplicateAck {
381 follower_id: node_id.clone(),
382 index,
383 success: true,
384 };
385 if MessageCodec::write(&mut writer, &ack).await.is_err() {
386 break;
387 }
388 }
389 ClusterMessage::Pong { .. } | ClusterMessage::ReplicateAck { .. } => {}
392 }
393 }
394}
395
396async fn run_gossip_loop(
398 node_id: String,
399 fanout: GossipFanout,
400 interval_ms: u64,
401 state: Arc<RwLock<GossipState>>,
402 peers: Arc<RwLock<Vec<SocketAddr>>>,
403 cancel: CancellationToken,
404) {
405 let mut ticker = interval(Duration::from_millis(interval_ms));
406
407 loop {
408 tokio::select! {
409 biased;
410 _ = cancel.cancelled() => break,
411 _ = ticker.tick() => {}
412 }
413
414 let snapshot: Vec<(String, u64, u64)> = {
416 let st = state.read();
417 st.entries
418 .iter()
419 .map(|(k, (v, ver))| (k.clone(), *v, *ver))
420 .collect()
421 };
422
423 if snapshot.is_empty() {
424 continue;
425 }
426
427 let selected = {
428 let all_peers: Vec<SocketAddr> = peers.read().clone();
429 let count = fanout.resolve(all_peers.len());
430 select_random_peers(&all_peers, count)
431 };
432
433 for peer_addr in selected {
434 gossip_to_peer(&node_id, peer_addr, &snapshot).await;
435 }
436 }
437}
438
439fn select_random_peers(peers: &[SocketAddr], count: usize) -> Vec<SocketAddr> {
444 if count == 0 || peers.is_empty() {
445 return Vec::new();
446 }
447 let count = count.min(peers.len());
448 let mut indices: Vec<usize> = (0..peers.len()).collect();
449
450 let seed = std::time::SystemTime::now()
452 .duration_since(std::time::UNIX_EPOCH)
453 .unwrap_or_default()
454 .subsec_nanos() as u64;
455
456 let mut state = seed.wrapping_add(0x9e37_79b9_7f4a_7c15);
458 if state == 0 {
459 state = 1;
460 }
461
462 for i in 0..count {
463 state ^= state << 13;
465 state ^= state >> 7;
466 state ^= state << 17;
467 let j = i + (state as usize % (peers.len() - i));
468 indices.swap(i, j);
469 }
470
471 indices[..count].iter().map(|&i| peers[i]).collect()
472}
473
474async fn gossip_to_peer(node_id: &str, peer_addr: SocketAddr, snapshot: &[(String, u64, u64)]) {
476 let Ok(mut stream) = TcpStream::connect(peer_addr).await else {
477 return; };
479
480 for (key, value, version) in snapshot {
481 let msg = ClusterMessage::Gossip {
482 sender_id: node_id.to_owned(),
483 key: key.clone(),
484 value: *value,
485 version: *version,
486 };
487 if MessageCodec::write(&mut stream, &msg).await.is_err() {
488 break; }
490 }
491}
492
493#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn test_gossip_state_lww() {
503 let mut state = GossipState::default();
504 assert!(state.set("k", 10, 1));
505 assert!(!state.set("k", 99, 1)); assert!(state.set("k", 42, 2)); assert_eq!(state.get("k"), Some((42, 2)));
508 }
509
510 #[test]
511 fn test_gossip_state_len() {
512 let mut state = GossipState::default();
513 state.set("a", 1, 1);
514 state.set("b", 2, 1);
515 assert_eq!(state.len(), 2);
516 assert!(!state.is_empty());
517 }
518
519 #[test]
520 fn test_node_config_localhost() {
521 let cfg = TcpNodeConfig::localhost("n1", 0);
522 assert_eq!(cfg.node_id, "n1");
523 assert_eq!(cfg.bind_addr.port(), 0);
524 }
525
526 #[test]
527 fn test_select_random_peers_empty() {
528 let peers: Vec<SocketAddr> = vec![];
529 let result = select_random_peers(&peers, 3);
530 assert!(result.is_empty());
531 }
532
533 #[test]
534 fn test_select_random_peers_count_capped() {
535 let peers: Vec<SocketAddr> = (0..5)
536 .map(|i| SocketAddr::from(([127, 0, 0, 1], 10000 + i)))
537 .collect();
538 let result = select_random_peers(&peers, 100);
539 assert_eq!(result.len(), 5);
540 }
541
542 #[tokio::test]
543 async fn test_start_and_addr() {
544 let cfg = TcpNodeConfig::localhost("test-node", 0);
545 let node = TcpClusterNode::start(cfg).await.expect("start");
546 assert_eq!(node.node_id(), "test-node");
547 assert_ne!(
548 node.addr().port(),
549 0,
550 "OS should have assigned a non-zero port"
551 );
552 node.shutdown();
553 }
554
555 #[tokio::test]
556 async fn test_set_and_get() {
557 let cfg = TcpNodeConfig::localhost("n1", 0);
558 let node = TcpClusterNode::start(cfg).await.expect("start");
559 node.set("foo", 42);
560 assert_eq!(node.get("foo"), Some(42));
561 node.shutdown();
562 }
563}