network_protocol/transport/
cluster.rs

1use std::collections::HashMap;
2// No need for Arc in this module
3use std::time::{Duration, Instant};
4use tokio::select;
5use tokio::sync::mpsc;
6use tokio::time::{sleep, interval};
7use tracing::{info, warn, debug, instrument};
8
9use crate::service::client::Client;
10use crate::protocol::message::Message;
11//use crate::error::Result;
12
13#[derive(Debug, Clone)]
14pub struct ClusterNode {
15    pub id: String,
16    pub addr: String,
17    pub last_seen: Option<Instant>,
18}
19
20pub struct Cluster {
21    peers: HashMap<String, ClusterNode>,
22    shutdown_tx: Option<mpsc::Sender<()>>,
23}
24
25impl Cluster {
26    pub fn new(peers: Vec<(String, String)>) -> Self {
27        let peers = peers.into_iter().map(|(id, addr)| {
28            (id.clone(), ClusterNode { id, addr, last_seen: None })
29        }).collect();
30
31        Self { 
32            peers,
33            shutdown_tx: None
34        }
35    }
36
37    #[instrument(skip(self), fields(interval_ms = %heartbeat_interval.as_millis()))]
38    pub async fn start_heartbeat(&mut self, heartbeat_interval: Duration) -> mpsc::Sender<()> {
39        // Create shutdown channel
40        let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
41        
42        // Clone necessary data for the heartbeat task
43        let peers = self.peers.clone();
44        
45        // Store the sender for shutdown
46        self.shutdown_tx = Some(shutdown_tx.clone());
47        
48        // Spawn the heartbeat task
49        tokio::spawn(async move {
50            let mut interval_timer = interval(heartbeat_interval);
51            
52            loop {
53                select! {
54                    // Check for shutdown signal
55                    _ = shutdown_rx.recv() => {
56                        info!("Received shutdown signal, stopping heartbeat");
57                        break;
58                    }
59                    
60                    // Run heartbeat on interval
61                    _ = interval_timer.tick() => {
62                        for (id, node) in peers.iter() {
63                            match Client::connect(&node.addr).await {
64                                Ok(mut client) => {
65                                    match client.send_and_wait(Message::Ping).await {
66                                        Ok(Message::Pong) => {
67                                            debug!(node_id = %id, "Peer alive");
68                                        }
69                                        _ => {
70                                            warn!(node_id = %id, "Peer timeout");
71                                        }
72                                    }
73                                }
74                                Err(e) => {
75                                    warn!(node_id = %id, error = ?e, "Peer unreachable");
76                                }
77                            }
78                        }
79                    }
80                }
81            }
82            
83            info!("Heartbeat task shut down gracefully");
84        });
85        
86        // Return the shutdown sender so the caller can trigger shutdown
87        shutdown_tx
88    }
89
90    pub fn get_peers(&self) -> Vec<&ClusterNode> {
91        self.peers.values().collect()
92    }
93    
94    /// Gracefully shut down the cluster's heartbeat task
95    #[instrument(skip(self))]
96    pub async fn shutdown(&mut self) {
97        if let Some(tx) = self.shutdown_tx.take() {
98            if tx.send(()).await.is_err() {
99                info!("Heartbeat task already stopped");
100            } else {
101                info!("Shutdown signal sent to heartbeat task");
102                // Give heartbeat task time to finish
103                sleep(Duration::from_millis(100)).await;
104            }
105        } else {
106            info!("No active heartbeat to shut down");
107        }
108    }
109}