rmqtt_cluster_broadcast/
lib.rs1#![deny(unsafe_code)]
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use serde_json::{self, json};
8use tokio::sync::RwLock;
9
10use rmqtt::{
11 grpc::{GrpcClients, Message, MessageReply, MessageType},
12 hook::{Register, Type},
13 macros::Plugin,
14 plugin::{PackageInfo, Plugin},
15 register,
16 types::{From, OfflineSession, Publish, Reason, To},
17 Result,
18};
19
20use config::PluginConfig;
21use handler::HookHandler;
22use rmqtt::context::ServerContext;
23use rmqtt::net::MqttError;
24use router::ClusterRouter;
25use shared::ClusterShared;
26
27mod config;
28mod handler;
29mod message;
30mod router;
31mod shared;
32
33type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;
34
35register!(ClusterPlugin::new);
36
37#[derive(Plugin)]
38struct ClusterPlugin {
39 scx: ServerContext,
40 register: Box<dyn Register>,
41 cfg: Arc<RwLock<PluginConfig>>,
42 grpc_clients: GrpcClients,
43 shared: ClusterShared,
44 router: ClusterRouter,
45}
46
47impl ClusterPlugin {
48 #[inline]
49 async fn new<S: Into<String>>(scx: ServerContext, name: S) -> Result<Self> {
50 let name = name.into();
51 let cfg = scx.plugins.read_config_with::<PluginConfig>(&name, &["node_grpc_addrs"])?;
52 log::debug!("{name} ClusterPlugin cfg: {cfg:?}");
53
54 let register = scx.extends.hook_mgr().register();
55 let mut grpc_clients = HashMap::default();
56 let node_grpc_addrs = cfg.node_grpc_addrs.clone();
57 for node_addr in &node_grpc_addrs {
58 if node_addr.id != scx.node.id() {
59 let batch_size = cfg.node_grpc_batch_size;
60 let client_concurrency_limit = cfg.node_grpc_client_concurrency_limit;
61 let client_timeout = cfg.node_grpc_client_timeout;
62 grpc_clients.insert(
63 node_addr.id,
64 (
65 node_addr.addr.clone(),
66 scx.node
67 .new_grpc_client(
68 &node_addr.addr,
69 client_timeout,
70 client_concurrency_limit,
71 batch_size,
72 )
73 .await?,
74 ),
75 );
76 }
77 }
78 let grpc_clients = Arc::new(grpc_clients);
79 let message_type = cfg.message_type;
80 let router = ClusterRouter::new(scx.clone(), grpc_clients.clone(), message_type);
81 let shared = ClusterShared::new(scx.clone(), grpc_clients.clone(), message_type);
82 let cfg = Arc::new(RwLock::new(cfg));
83 Ok(Self { scx, register, cfg, grpc_clients, shared, router })
84 }
85}
86
87#[async_trait]
88impl Plugin for ClusterPlugin {
89 #[inline]
90 async fn init(&mut self) -> Result<()> {
91 log::info!("{} init", self.name());
92 self.register
93 .add(
94 Type::GrpcMessageReceived,
95 Box::new(HookHandler::new(self.scx.clone(), self.shared.clone(), self.router.clone())),
96 )
97 .await;
98 Ok(())
99 }
100
101 #[inline]
102 async fn get_config(&self) -> Result<serde_json::Value> {
103 self.cfg.read().await.to_json()
104 }
105
106 #[inline]
107 async fn start(&mut self) -> Result<()> {
108 log::info!("{} start", self.name());
109 self.register.start().await;
110 *self.scx.extends.shared_mut().await = Box::new(self.shared.clone());
111 *self.scx.extends.router_mut().await = Box::new(self.router.clone());
112 Ok(())
113 }
114
115 #[inline]
116 async fn stop(&mut self) -> Result<bool> {
117 log::warn!("{} stop, once the cluster is started, it cannot be stopped", self.name());
118 Ok(false)
119 }
120
121 #[inline]
122 async fn attrs(&self) -> serde_json::Value {
123 let mut nodes = HashMap::default();
124 for (id, (addr, c)) in self.grpc_clients.iter() {
125 let stats = json!({
126 "transfer_queue_len": c.transfer_queue_len(),
127 "active_tasks_count": c.active_tasks().count(),
128 "active_tasks_max": c.active_tasks().max(),
129 });
130 nodes.insert(format!("{id}-{addr}"), stats);
131 }
132 json!({
133 "grpc_clients": nodes,
134 })
135 }
136}
137
138#[inline]
139pub(crate) async fn kick(
140 grpc_clients: GrpcClients,
141 msg_type: MessageType,
142 msg: Message,
143) -> Result<OfflineSession> {
144 let reply =
145 rmqtt::grpc::MessageBroadcaster::new(grpc_clients, msg_type, msg, Some(Duration::from_secs(15)))
146 .select_ok(|reply: MessageReply| -> Result<MessageReply> {
147 log::debug!("reply: {reply:?}");
148 if let MessageReply::Kick(o) = reply {
149 Ok(MessageReply::Kick(o))
150 } else {
151 Err(MqttError::None.into())
152 }
153 })
154 .await?;
155 if let MessageReply::Kick(kicked) = reply {
156 Ok(kicked)
157 } else {
158 Err(MqttError::None.into())
159 }
160}
161
162pub(crate) async fn hook_message_dropped(scx: &ServerContext, droppeds: Vec<(To, From, Publish, Reason)>) {
163 for (to, from, publish, reason) in droppeds {
164 scx.extends.hook_mgr().message_dropped(Some(to), from, publish, reason).await;
166 }
167}