rmqtt_cluster_broadcast/
lib.rs

1#![deny(unsafe_code)]
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use serde_json::{self, json};
8use tokio::sync::RwLock;
9
10use rmqtt::{
11    grpc::{GrpcClients, Message, MessageReply, MessageType},
12    hook::{Register, Type},
13    macros::Plugin,
14    plugin::{PackageInfo, Plugin},
15    register,
16    types::{From, OfflineSession, Publish, Reason, To},
17    Result,
18};
19
20use config::PluginConfig;
21use handler::HookHandler;
22use rmqtt::context::ServerContext;
23use rmqtt::net::MqttError;
24use router::ClusterRouter;
25use shared::ClusterShared;
26
27mod config;
28mod handler;
29mod message;
30mod router;
31mod shared;
32
33type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;
34
35register!(ClusterPlugin::new);
36
37#[derive(Plugin)]
38struct ClusterPlugin {
39    scx: ServerContext,
40    register: Box<dyn Register>,
41    cfg: Arc<RwLock<PluginConfig>>,
42    grpc_clients: GrpcClients,
43    shared: ClusterShared,
44    router: ClusterRouter,
45}
46
47impl ClusterPlugin {
48    #[inline]
49    async fn new<S: Into<String>>(scx: ServerContext, name: S) -> Result<Self> {
50        let name = name.into();
51        let cfg = scx.plugins.read_config_with::<PluginConfig>(&name, &["node_grpc_addrs"])?;
52        log::debug!("{name} ClusterPlugin cfg: {cfg:?}");
53
54        let register = scx.extends.hook_mgr().register();
55        let mut grpc_clients = HashMap::default();
56        let node_grpc_addrs = cfg.node_grpc_addrs.clone();
57        for node_addr in &node_grpc_addrs {
58            if node_addr.id != scx.node.id() {
59                let batch_size = cfg.node_grpc_batch_size;
60                let client_concurrency_limit = cfg.node_grpc_client_concurrency_limit;
61                let client_timeout = cfg.node_grpc_client_timeout;
62                grpc_clients.insert(
63                    node_addr.id,
64                    (
65                        node_addr.addr.clone(),
66                        scx.node
67                            .new_grpc_client(
68                                &node_addr.addr,
69                                client_timeout,
70                                client_concurrency_limit,
71                                batch_size,
72                            )
73                            .await?,
74                    ),
75                );
76            }
77        }
78        let grpc_clients = Arc::new(grpc_clients);
79        let message_type = cfg.message_type;
80        let router = ClusterRouter::new(scx.clone(), grpc_clients.clone(), message_type);
81        let shared = ClusterShared::new(scx.clone(), grpc_clients.clone(), message_type);
82        let cfg = Arc::new(RwLock::new(cfg));
83        Ok(Self { scx, register, cfg, grpc_clients, shared, router })
84    }
85}
86
87#[async_trait]
88impl Plugin for ClusterPlugin {
89    #[inline]
90    async fn init(&mut self) -> Result<()> {
91        log::info!("{} init", self.name());
92        self.register
93            .add(
94                Type::GrpcMessageReceived,
95                Box::new(HookHandler::new(self.scx.clone(), self.shared.clone(), self.router.clone())),
96            )
97            .await;
98        Ok(())
99    }
100
101    #[inline]
102    async fn get_config(&self) -> Result<serde_json::Value> {
103        self.cfg.read().await.to_json()
104    }
105
106    #[inline]
107    async fn start(&mut self) -> Result<()> {
108        log::info!("{} start", self.name());
109        self.register.start().await;
110        *self.scx.extends.shared_mut().await = Box::new(self.shared.clone());
111        *self.scx.extends.router_mut().await = Box::new(self.router.clone());
112        Ok(())
113    }
114
115    #[inline]
116    async fn stop(&mut self) -> Result<bool> {
117        log::warn!("{} stop, once the cluster is started, it cannot be stopped", self.name());
118        Ok(false)
119    }
120
121    #[inline]
122    async fn attrs(&self) -> serde_json::Value {
123        let mut nodes = HashMap::default();
124        for (id, (addr, c)) in self.grpc_clients.iter() {
125            let stats = json!({
126                "transfer_queue_len": c.transfer_queue_len(),
127                "active_tasks_count": c.active_tasks().count(),
128                "active_tasks_max": c.active_tasks().max(),
129            });
130            nodes.insert(format!("{id}-{addr}"), stats);
131        }
132        json!({
133            "grpc_clients": nodes,
134        })
135    }
136}
137
138#[inline]
139pub(crate) async fn kick(
140    grpc_clients: GrpcClients,
141    msg_type: MessageType,
142    msg: Message,
143) -> Result<OfflineSession> {
144    let reply =
145        rmqtt::grpc::MessageBroadcaster::new(grpc_clients, msg_type, msg, Some(Duration::from_secs(15)))
146            .select_ok(|reply: MessageReply| -> Result<MessageReply> {
147                log::debug!("reply: {reply:?}");
148                if let MessageReply::Kick(o) = reply {
149                    Ok(MessageReply::Kick(o))
150                } else {
151                    Err(MqttError::None.into())
152                }
153            })
154            .await?;
155    if let MessageReply::Kick(kicked) = reply {
156        Ok(kicked)
157    } else {
158        Err(MqttError::None.into())
159    }
160}
161
162pub(crate) async fn hook_message_dropped(scx: &ServerContext, droppeds: Vec<(To, From, Publish, Reason)>) {
163    for (to, from, publish, reason) in droppeds {
164        //hook, message_dropped
165        scx.extends.hook_mgr().message_dropped(Some(to), from, publish, reason).await;
166    }
167}