Skip to main content

asteroid_mq/protocol/node/raft/
cluster.rs

1use std::{
2    borrow::Cow,
3    collections::{BTreeMap, BTreeSet},
4    future::Future,
5};
6
7use crate::prelude::NodeId;
8#[cfg(feature = "cluster-k8s")]
9pub(crate) mod k8s;
10#[cfg(feature = "cluster-k8s")]
11pub use k8s::{this_pod_id, K8sClusterProvider};
12pub(crate) mod r#static;
13pub use r#static::StaticClusterProvider;
14use tokio_util::sync::CancellationToken;
15use tracing::instrument;
16
17use super::{network_factory::TcpNetworkService, raft_node::TcpNode};
18pub trait ClusterProvider: Send + 'static {
19    fn pristine_nodes(
20        &mut self,
21    ) -> impl Future<Output = crate::Result<BTreeMap<NodeId, String>>> + Send;
22    fn next_update(
23        &mut self,
24    ) -> impl Future<Output = crate::Result<BTreeMap<NodeId, String>>> + Send;
25    fn name(&self) -> Cow<'static, str> {
26        std::any::type_name::<Self>().into()
27    }
28}
29
30pub struct DynClusterProvider {
31    inner: Box<dyn sealed::ClusterProviderObjectTrait + Send>,
32    name: Cow<'static, str>,
33}
34
35impl std::fmt::Debug for DynClusterProvider {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        f.debug_struct("DynClusterProvider")
38            .field("name", &self.name)
39            .finish()
40    }
41}
42
43impl DynClusterProvider {
44    pub fn new<T>(inner: T) -> Self
45    where
46        T: ClusterProvider,
47    {
48        let provider_name = inner.name();
49        Self {
50            inner: Box::new(inner),
51            name: provider_name,
52        }
53    }
54    pub fn with_name(self, name: impl Into<Cow<'static, str>>) -> Self {
55        Self {
56            name: name.into(),
57            ..self
58        }
59    }
60    pub fn name(&self) -> &str {
61        &self.name
62    }
63    pub async fn pristine_nodes(&mut self) -> crate::Result<BTreeMap<NodeId, String>> {
64        self.inner.pristine_nodes().await
65    }
66    pub async fn next_update(&mut self) -> crate::Result<BTreeMap<NodeId, String>> {
67        self.inner.next_update().await
68    }
69}
70
71mod sealed {
72    use super::ClusterProvider;
73    use crate::prelude::NodeId;
74    use std::{collections::BTreeMap, future::Future, pin::Pin};
75    type DynUpdate<'a> = dyn Future<Output = crate::Result<BTreeMap<NodeId, String>>> + Send + 'a;
76    pub trait ClusterProviderObjectTrait {
77        fn pristine_nodes(&mut self) -> Pin<Box<DynUpdate<'_>>>;
78        fn next_update(&mut self) -> Pin<Box<DynUpdate<'_>>>;
79    }
80
81    impl<T> ClusterProviderObjectTrait for T
82    where
83        T: ClusterProvider,
84    {
85        fn pristine_nodes(&mut self) -> Pin<Box<DynUpdate<'_>>> {
86            Box::pin(self.pristine_nodes())
87        }
88        fn next_update(&mut self) -> Pin<Box<DynUpdate<'_>>> {
89            Box::pin(self.next_update())
90        }
91    }
92}
93
94pub struct ClusterService {
95    provider: DynClusterProvider,
96    tcp_network_service: TcpNetworkService,
97    ct: CancellationToken,
98}
99
100impl ClusterService {
101    pub fn new(
102        provider: impl ClusterProvider,
103        tcp_network_service: TcpNetworkService,
104        ct: CancellationToken,
105    ) -> Self {
106        Self {
107            provider: DynClusterProvider::new(provider),
108            tcp_network_service,
109            ct,
110        }
111    }
112    #[instrument(name="cluster_service", skip(self), fields(cluster_provider_name=%self.provider.name(), local_id=%self.tcp_network_service.info.id))]
113    pub async fn run(self) -> Result<(), crate::Error> {
114        tracing::info!("cluster service started");
115        let Self {
116            mut provider,
117            tcp_network_service,
118            ct,
119        } = self;
120        let local_id = tcp_network_service.info.id;
121        // 3. listen cluster update
122        loop {
123            let nodes = tokio::select! {
124                _ = ct.cancelled() => break,
125                nodes = provider.next_update() => {
126                    nodes?
127                }
128            };
129            tracing::trace!(?nodes, "nodes update received");
130
131            // ensure connections to all nodes
132            let mut ensured_nodes = BTreeMap::new();
133            for (peer_id, peer_addr) in nodes.clone() {
134                if local_id == peer_id {
135                    ensured_nodes.insert(peer_id, TcpNode::new(peer_addr));
136                } else {
137                    tracing::trace!("ensuring connection to {}", peer_id);
138                    let ensure_result = tcp_network_service
139                        .ensure_connection(peer_id, peer_addr.clone())
140                        .await;
141                    if let Err(e) = ensure_result {
142                        tracing::warn!("failed to ensure connection to {}: {}", peer_id, e);
143                    } else {
144                        tracing::trace!("connection to {} ensured", peer_id);
145                        ensured_nodes.insert(peer_id, TcpNode::new(peer_addr.clone()));
146                    }
147                }
148            }
149            // raft update members
150            let raft = tcp_network_service.raft.get().await;
151            let Ok(current_members) = raft.with_raft_state(|r| r.membership_state.clone()).await
152            else {
153                continue;
154            };
155            let current_nodes = current_members
156                .committed()
157                .nodes()
158                .map(|(k, v)| (*k, v.clone()))
159                .collect::<BTreeMap<_, _>>();
160            let to_remove = current_nodes
161                .keys()
162                .filter(|k| !ensured_nodes.contains_key(k))
163                .cloned()
164                .collect::<BTreeSet<_>>();
165            let to_add = ensured_nodes
166                .iter()
167                .filter_map(|(k, v)| {
168                    if !current_nodes.contains_key(k) {
169                        Some((*k, v.clone()))
170                    } else {
171                        None
172                    }
173                })
174                .collect::<BTreeMap<_, _>>();
175            let leader_node = raft.current_leader().await;
176            if to_remove.is_empty() && to_add.is_empty() {
177                tracing::trace!(leader=?leader_node, "no change in nodes");
178            } else {
179                tracing::info!(ensured = ?ensured_nodes, remove = ?to_remove, add = ?to_add, leader=?leader_node, "updating raft members");
180            }
181            if let Some(leader_node) = leader_node {
182                if to_remove.contains(&leader_node) && local_id != leader_node {
183                    tracing::warn!("leader {} is removed from cluster", leader_node);
184                    let trigger_elect_result = raft.trigger().elect().await;
185                    match trigger_elect_result {
186                        Ok(_) => {
187                            tracing::info!("leader removed, trigger election");
188                        }
189                        Err(e) => {
190                            tracing::warn!("failed to trigger election: {}", e);
191                        }
192                    }
193                }
194            }
195            if to_remove.contains(&local_id) {
196                tracing::warn!("local node {} is removed from cluster", local_id);
197                break;
198            }
199
200            if Some(local_id) == leader_node && (!to_add.is_empty() || !to_remove.is_empty()) {
201                let raft = raft.clone();
202                for (id, node) in to_add {
203                    let add_result = raft.add_learner(id, node, true).await;
204                    match add_result {
205                        Ok(resp) => {
206                            tracing::debug!(?resp, "learner {} added", id);
207                        }
208                        Err(e) => {
209                            tracing::warn!("failed to add learner {}: {}", id, e);
210                        }
211                    }
212                }
213                let add_voters_result = raft
214                    .change_membership(
215                        ensured_nodes.keys().cloned().collect::<BTreeSet<_>>(),
216                        false,
217                    )
218                    .await;
219                match add_voters_result {
220                    Ok(resp) => {
221                        tracing::debug!(?resp, "voters added");
222                    }
223                    Err(e) => {
224                        tracing::warn!("failed to add voters: {}", e);
225                    }
226                }
227            }
228        }
229        Ok(())
230    }
231    pub fn spawn(self) {
232        tokio::spawn(self.run());
233    }
234}