1#![warn(missing_docs)]
31#[macro_export]
32#[doc(hidden)]
33macro_rules! log_error {
34 ($result:expr) => {
35 if let Err(e) = $result {
36 error!("{}", e.to_string());
37 }
38 };
39}
40
41use almost_raft::election::{raft_election, RaftElectionState};
42use almost_raft::{Message, Node};
43
44use async_trait::async_trait;
45use bytes::Bytes;
46use futures_util::stream::FuturesUnordered;
47use http::{Error, Request};
48use hyper::client::{Client, HttpConnector};
49use hyper::Body;
50use log::{debug, error, info, trace};
51
52use rust_cloud_discovery::{DiscoveryClient, DiscoveryService, ServiceInstance};
53use serde::{Deserialize, Serialize};
54
55use hyper_tls::HttpsConnector;
56use native_tls::TlsConnector;
57use std::collections::{HashMap, HashSet};
58use std::hash::{Hash, Hasher};
59use std::num::NonZeroUsize;
60use std::result::Result::Err;
61use std::sync::Arc;
62use std::time::{Duration, Instant};
63use tokio::sync::mpsc::Sender;
64use tokio::sync::{mpsc, RwLock};
65use tokio_stream::StreamExt;
66
67pub enum InstanceMode {
69 Inactive,
71 Primary,
73 Secondary,
75}
76
77#[allow(dead_code)]
78pub struct Cluster {
80 self_id: String,
82 mode: RwLock<InstanceMode>,
84 self_: RwLock<Option<ServiceInstance>>,
86 primaries: RwLock<HashSet<RestClusterNode>>,
88 secondaries: RwLock<Arc<HashSet<RestClusterNode>>>,
90 n_primary: usize,
92 raft_tx: RwLock<Option<Sender<Message<RestClusterNode>>>>,
94}
95
96impl Cluster {
97 pub fn new() -> Self {
101 Cluster {
102 ..Default::default()
103 }
104 }
105
106 #[doc(hidden)]
107 pub fn _new(mode: InstanceMode, secondaries: HashSet<RestClusterNode>) -> Self {
109 Cluster {
110 mode: RwLock::new(mode),
111 secondaries: RwLock::new(Arc::new(secondaries)),
112 ..Default::default()
113 }
114 }
115
116 pub async fn secondaries(&self) -> Option<Arc<HashSet<RestClusterNode>>> {
119 if self.is_primary().await {
120 let guard = self.secondaries.read().await;
121 Some(guard.clone())
122 } else {
123 info!("[node: {}] not a primary node", &self.self_id);
124 None
125 }
126 }
127
128 pub async fn primaries(&self) -> Option<HashSet<RestClusterNode>> {
131 if self.is_secondary().await {
132 let guard = self.primaries.read().await;
133 Some(guard.clone())
134 } else {
135 info!("[node: {}] not a secondary node", &self.self_id);
136 None
137 }
138 }
139
140 #[inline]
142 pub async fn is_primary(&self) -> bool {
143 let guard = self.mode.read().await;
144 matches!(*guard, InstanceMode::Primary)
145 }
146
147 #[inline]
149 pub async fn is_secondary(&self) -> bool {
150 let guard = self.mode.read().await;
151 matches!(*guard, InstanceMode::Secondary)
152 }
153
154 #[inline]
156 pub async fn is_active(&self) -> bool {
157 let guard = self.mode.read().await;
158 !matches!(*guard, InstanceMode::Inactive)
159 }
160
161 pub async fn accept_raft_request_vote(&self, requester_node_id: String, term: usize) {
163 self.send_message_to_raft(Message::RequestVote {
164 term,
165 node_id: requester_node_id,
166 })
167 .await;
168 }
169
170 pub async fn accept_raft_request_vote_resp(&self, term: usize, vote: bool) {
172 self.send_message_to_raft(Message::RequestVoteResponse { term, vote })
173 .await;
174 }
175
176 pub async fn accept_raft_heartbeat(&self, leader_node_id: String, term: usize) {
178 self.send_message_to_raft(Message::HeartBeat {
179 leader_node_id,
180 term,
181 })
182 .await;
183 }
184
185 async fn send_message_to_raft(&self, msg: Message<RestClusterNode>) {
186 trace!(
187 "[node: {}] sending messages to raft: {:?}",
188 &self.self_id,
189 &msg
190 );
191 let guard = self.raft_tx.read().await;
192 if let Some(tx) = guard.as_ref() {
193 let result = tx.send(msg).await;
194 log_error!(result);
195 }
196 }
197
198 pub async fn get_service_instance(&self) -> Option<ServiceInstance> {
200 self.self_.read().await.clone()
201 }
202}
203
204impl Default for Cluster {
205 fn default() -> Self {
206 Cluster {
207 self_id: uuid::Uuid::new_v4().to_string(),
208 mode: RwLock::from(InstanceMode::Inactive),
209 self_: Default::default(),
210 primaries: Default::default(),
211 secondaries: Default::default(),
212 n_primary: 1,
213 raft_tx: Default::default(),
214 }
215 }
216}
217
218pub struct ClusterConfig {
220 pub connection_timeout: u64,
222 pub election_timeout: u64,
226 pub update_interval: u64,
228 pub max_node: NonZeroUsize,
230 pub min_node: NonZeroUsize,
232}
233
234impl Default for ClusterConfig {
235 fn default() -> Self {
236 ClusterConfig {
237 connection_timeout: 10 * 1000,
238 election_timeout: 30 * 1000,
239 update_interval: 10 * 1000,
240 max_node: NonZeroUsize::new(20).unwrap(),
241 min_node: NonZeroUsize::new(4).unwrap(),
242 }
243 }
244}
245
246pub async fn start_cluster<T: DiscoveryService>(
248 cluster: Arc<Cluster>,
249 discovery_service: DiscoveryClient<T>,
250 config: ClusterConfig,
251) {
252 info!("[node: {}] starting cluster...", &cluster.self_id);
253 let raft_tx_timeout = 15;
254
255 let (tx, mut raft_rx) = mpsc::channel::<Message<RestClusterNode>>(20);
256
257 let (raft, raft_tx) = RaftElectionState::init(
258 cluster.self_id.clone(),
259 config.election_timeout,
260 config.connection_timeout,
261 500,
262 vec![],
263 tx.clone(),
264 config.max_node.get(),
265 config.min_node.get() - 1,
268 );
269
270 {
271 let mut write_guard = cluster.raft_tx.write().await;
272 *write_guard = Some(raft_tx.clone());
273 }
274
275 info!("[node: {}] spawning raft election...", &cluster.self_id);
276 tokio::spawn(raft_election(raft));
277
278 let mut remaining_update_interval = config.update_interval;
279
280 let client = build_client();
281
282 let mut discovered: HashMap<String, RestClusterNode> = HashMap::new();
285
286 loop {
287 trace!(
288 "[node: {}] update timeout: {}",
289 &cluster.self_id,
290 &remaining_update_interval
291 );
292 let start_time = Instant::now();
294 let raft_msg = tokio::time::timeout(
295 Duration::from_millis(remaining_update_interval),
296 raft_rx.recv(),
297 )
298 .await;
299
300 if let Ok(msg) = raft_msg {
301 handle_control_message_from_raft(&cluster, &discovered, msg).await;
303 remaining_update_interval = unsigned_subtract(
304 remaining_update_interval,
305 start_time.elapsed().as_millis() as u64,
306 );
307 continue;
308 }
309 remaining_update_interval = config.update_interval;
310
311 trace!("[node: {}] calling discovery service.", &cluster.self_id);
312 let instances = if let Ok(instance) = discovery_service.get_instances().await {
313 instance
314 } else {
315 vec![]
316 };
317
318 debug!("discovered instances: {:?}", instances);
319
320 let mut requests = FuturesUnordered::new();
322 let mut current_instances = HashSet::new();
323 for instance in instances {
324 let id = if instance.instance_id().is_some() {
325 instance.instance_id().clone().unwrap()
326 } else {
327 continue;
329 };
330 if discovered.contains_key(&id) || instance.uri().is_none()
332 {
333 current_instances.insert(id);
334 continue;
335 }
336 current_instances.insert(id);
337
338 let request = Request::builder()
339 .uri(format!("{}{}", instance.uri().clone().unwrap(), PATH_INFO))
340 .body(Body::empty());
341 requests.push(send_request(&client, request, instance));
343 }
344
345 let mut new_nodes = HashSet::new();
346 while let Some(result) = requests.next().await {
347 match result {
348 Ok((resp, instance)) => {
349 let info = serde_json::from_slice::<ClusterInfo>(resp.as_ref());
350 trace!(
351 "[node: {}] cluster info {:?} from {:?}",
352 &cluster.self_id,
353 &info,
354 &instance
355 );
356 if let Ok(info) = info {
357 if info.node_id == cluster.self_id {
358 {
359 let mut guard = cluster.self_.write().await;
360 guard.replace(instance.clone());
361 }
362 }
365 let node = RestClusterNode::new(info.node_id, instance);
366 if cluster.self_id != node.node_id {
367 new_nodes.insert(node.inner.instance_id().clone().unwrap());
368 debug!("[node: {}] new node found: {:?}", &cluster.self_id, &node);
370 let result = raft_tx
371 .send_timeout(
372 Message::ControlAddNode(node.clone()),
373 Duration::from_millis(raft_tx_timeout),
374 )
375 .await;
376 log_error!(result);
377 }
378 discovered.insert(node.inner.instance_id().clone().unwrap(), node);
379 }
380 }
381 Err(err) => {
382 error!(
383 "[node: {}] error getting cluster info: {}",
384 &cluster.self_id,
385 err.to_string()
386 );
387 }
388 }
389 }
390
391 let mut removed_nodes = HashSet::new();
392 for (key, val) in discovered.iter() {
394 if !current_instances.contains(val.service_instance().instance_id().as_ref().unwrap()) {
395 removed_nodes.insert(key.clone());
396 }
397 }
398
399 if !new_nodes.is_empty() || !removed_nodes.is_empty() {
400 let mut current = {
402 let guard = cluster.secondaries.read().await;
403 guard.clone().as_ref().clone()
404 };
405 for node in removed_nodes {
406 let removed = discovered.remove(&node);
407 if let Some(removed) = removed {
408 debug!("removing node: {:?}", &removed);
409 current.remove(&removed);
410 let result = raft_tx
411 .send_timeout(
412 Message::ControlRemoveNode(removed),
413 Duration::from_millis(raft_tx_timeout),
414 )
415 .await;
416 log_error!(result);
417 }
418 }
419 for node in new_nodes {
420 if let Some(node) = discovered.get(&node) {
421 current.insert(node.clone());
422 }
423 }
424 {
425 trace!(
426 "[node: {}] updating secondaries to: {:?}",
427 &cluster.self_id,
428 ¤t
429 );
430 let mut write_guard = cluster.secondaries.write().await;
431 *write_guard = Arc::new(current);
432 }
433 }
434 }
435}
436
437fn build_client() -> Client<HttpsConnector<HttpConnector>> {
438 let tls = TlsConnector::builder()
439 .danger_accept_invalid_hostnames(true)
440 .danger_accept_invalid_certs(true)
441 .build()
442 .unwrap();
443 let mut http_connector = HttpConnector::new();
444 http_connector.enforce_http(false);
445 let connector = HttpsConnector::from((http_connector, tls.into()));
446 Client::builder().build(connector)
447}
448
449async fn send_request(
450 client: &Client<HttpsConnector<HttpConnector>>,
451 request: Result<Request<Body>, Error>,
452 instance: ServiceInstance,
453) -> anyhow::Result<(Bytes, ServiceInstance)> {
454 let request = request?;
455 let resp = client.request(request).await?;
456 let resp = hyper::body::to_bytes(resp).await?;
457 Ok((resp, instance))
458}
459
460#[inline]
461async fn handle_control_message_from_raft(
462 cluster: &Arc<Cluster>,
463 discovered: &HashMap<String, RestClusterNode>,
464 msg: Option<Message<RestClusterNode>>,
465) {
466 info!(
467 "[node: {}] control message from raft: {:?}",
468 cluster.self_id, &msg
469 );
470 if let Some(Message::ControlLeaderChanged(node_id)) = msg {
471 let mut node = None;
472 for discovered_node in discovered.values() {
473 if discovered_node.node_id == node_id {
474 node = Some(discovered_node);
475 }
476 }
477 if let Some(node) = node {
478 info!("new primary: {:?}", node);
479 let mode = if cluster.self_id == node_id {
480 InstanceMode::Primary
481 } else {
482 InstanceMode::Secondary
483 };
484 {
485 let mut write_guard = cluster.mode.write().await;
486 *write_guard = mode;
487 }
488 let node = node.clone();
489 let mut write_guard = cluster.primaries.write().await;
490 write_guard.insert(node);
491 } else {
492 error!("Node not found in discovered list");
493 }
494 }
495}
496
497pub async fn get_cluster_info(cluster: Arc<Cluster>) -> ClusterInfo {
499 let node = {
500 let guard = cluster.self_.read().await;
501 guard.as_ref().map(|x| x.to_owned())
502 };
503 ClusterInfo {
504 instance: node,
505 node_id: cluster.self_id.clone(),
506 }
507}
508
509#[derive(Debug, Serialize, Deserialize)]
511pub struct ClusterInfo {
512 pub node_id: String,
514 pub instance: Option<ServiceInstance>,
516}
517
518#[derive(Debug, Clone)]
526pub struct RestClusterNode {
527 pub(crate) node_id: String,
528 pub(crate) inner: ServiceInstance,
529}
530
531impl RestClusterNode {
532 pub fn new(node_id: String, instance: ServiceInstance) -> Self {
537 Self {
538 node_id,
539 inner: instance,
540 }
541 }
542
543 pub fn service_instance(&self) -> &ServiceInstance {
545 &self.inner
546 }
547
548 async fn send_request_vote(&self, node_id: String, term: usize) -> anyhow::Result<()> {
549 self.send_raft_request(format!(
550 "{}{}/{}/{}",
551 self.inner.uri().clone().unwrap(),
552 PATH_RAFT_REQUEST_VOTE,
553 node_id,
554 term
555 ))
556 .await
557 }
558
559 async fn send_request_vote_response(&self, vote: bool, term: usize) -> anyhow::Result<()> {
560 self.send_raft_request(format!(
561 "{}{}/{}/{}",
562 self.inner.uri().clone().unwrap(),
563 PATH_RAFT_VOTE,
564 term,
565 vote
566 ))
567 .await
568 }
569
570 async fn send_heartbeat(&self, leader_node_id: String, term: usize) -> anyhow::Result<()> {
571 self.send_raft_request(format!(
572 "{}{}/{}/{}",
573 self.inner.uri().clone().unwrap(),
574 PATH_RAFT_HEARTBEAT,
575 leader_node_id,
576 term
577 ))
578 .await
579 }
580
581 async fn send_raft_request(&self, uri: String) -> anyhow::Result<()> {
582 trace!(
583 "sending raft request to node: {}, path: {}",
584 &self.node_id,
585 &uri
586 );
587 let request = Request::builder().uri(uri).body(Body::empty())?;
588 let client = Client::new();
590 let resp = client.request(request).await?;
591 let resp = hyper::body::to_bytes(resp).await?;
592 trace!(
593 "raft request response: {:?}",
594 std::str::from_utf8(resp.as_ref())
595 );
596 Ok(())
597 }
598}
599
600#[async_trait]
601impl Node for RestClusterNode {
602 type NodeType = RestClusterNode;
603
604 async fn send_message(&self, msg: Message<Self::NodeType>) {
605 debug!(
606 "[RestClusterNode: {}] message from raft: {:?}",
607 &self.node_id, &msg
608 );
609 match msg {
610 Message::RequestVote { node_id, term } => {
611 let result = self.send_request_vote(node_id, term).await;
612 log_error!(result);
613 }
614 Message::RequestVoteResponse { vote, term } => {
615 let result = self.send_request_vote_response(vote, term).await;
616 log_error!(result);
617 }
618 Message::HeartBeat {
619 leader_node_id,
620 term,
621 } => {
622 let result = self.send_heartbeat(leader_node_id, term).await;
623 log_error!(result);
624 }
625 _ => {}
626 }
627 }
628
629 fn node_id(&self) -> &String {
630 &self.node_id
631 }
632}
633
634impl PartialEq for RestClusterNode {
635 fn eq(&self, other: &Self) -> bool {
636 self.node_id.eq(&other.node_id)
637 }
638}
639
640impl Eq for RestClusterNode {}
641
642impl Hash for RestClusterNode {
643 fn hash<H: Hasher>(&self, state: &mut H) {
644 self.node_id.hash(state);
645 }
646}
647
648const PATH_INFO: &str = "/cluster/info";
649const PATH_RAFT_REQUEST_VOTE: &str = "/cluster/raft/request-vote";
651const PATH_RAFT_VOTE: &str = "/cluster/raft/vote";
653const PATH_RAFT_HEARTBEAT: &str = "/cluster/raft/beat";
655
656#[inline(always)]
658fn unsigned_subtract<T>(lhs: T, rhs: T) -> T
659where
660 T: PartialEq + PartialOrd + std::ops::Sub<Output = T> + From<u64>,
661{
662 if lhs < rhs {
663 0.into()
664 } else {
665 lhs - rhs
666 }
667}
668
669#[cfg(test)]
670mod test {
671 use crate::{build_client, start_cluster, Cluster, ClusterConfig};
672 use cloud_discovery_kubernetes::KubernetesDiscoverService;
673 use hyper::{Body, Request};
674 use rust_cloud_discovery::DiscoveryClient;
675 use std::sync::Arc;
676
677 #[tokio::test]
678 async fn test_cluster_impl() {
679 let result =
680 KubernetesDiscoverService::init("overload".to_string(), "default".to_string()).await;
681 if let Ok(k8s) = result {
682 let cluster = Arc::new(Cluster::default());
683 let client = DiscoveryClient::new(k8s);
684 let config = ClusterConfig::default();
685 tokio::spawn(start_cluster(cluster, client, config));
686 }
687 }
688
689 #[tokio::test]
690 async fn client_test_http() {
691 let client = build_client();
692 let req = Request::builder()
693 .uri("http://httpbin.org/get")
694 .method("GET")
695 .body(Body::empty())
696 .unwrap();
697 let resp = client.request(req).await;
698 assert!(resp.is_ok());
699 }
700
701 #[tokio::test]
702 async fn client_test_https() {
703 let client = build_client();
704 let req = Request::builder()
705 .uri("https://httpbin.org/get")
706 .method("GET")
707 .body(Body::empty())
708 .unwrap();
709 let resp = client.request(req).await;
710 assert!(resp.is_ok());
711 }
712
713 #[tokio::test]
714 async fn client_test_self_signed() {
715 let client = build_client();
716 let req = Request::builder()
717 .uri("https://self-signed.badssl.com/")
718 .method("GET")
719 .body(Body::empty())
720 .unwrap();
721 let resp = client.request(req).await;
722 assert!(resp.is_ok());
723 }
724}