1#![deny(unsafe_code)]
2
3use std::convert::From as _;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use bytes::Bytes;
9use futures::channel::{mpsc, oneshot};
10use futures::{SinkExt, StreamExt};
11use serde_json::{self, json};
12
13use rmqtt::{
14 fitter::Fitter,
15 hook::{Handler, HookResult, Parameter, Register, ReturnType, Type},
16 plugin::{PackageInfo, Plugin},
17 register,
18 session::Session,
19 types::DisconnectInfo,
20 types::{ClientId, From, Publish, SessionSubMap, SessionSubs, TimestampMillis},
21 utils::timestamp_millis,
22 Result,
23};
24
25use rmqtt_storage::{init_db, DefaultStorageDB, List, Map, StorageType};
26
27use config::PluginConfig;
28use rmqtt::context::ServerContext;
29use rmqtt::inflight::OutInflightMessage;
30use rmqtt::macros::Plugin;
31use rmqtt::session::SessionState;
32use session::{Basic, StorageSessionManager, StoredSessionInfo, StoredSessionInfos};
33use session::{StoredKey, BASIC, DISCONNECT_INFO, INFLIGHT_MESSAGES, LAST_TIME, SESSION_SUB_MAP};
34
35mod config;
36mod session;
37
38enum RebuildChanType {
39 Session(Session, Duration),
40 Done(oneshot::Sender<()>),
41}
42
43type OfflineMessageOptionType = Option<(ClientId, From, Publish)>;
44
45register!(StoragePlugin::new);
46
47#[derive(Plugin)]
48struct StoragePlugin {
49 scx: ServerContext,
50 cfg: Arc<PluginConfig>,
51 storage_db: DefaultStorageDB,
52 stored_session_infos: StoredSessionInfos,
53 register: Box<dyn Register>,
54 session_mgr: &'static StorageSessionManager,
55 rebuild_tx: mpsc::Sender<RebuildChanType>,
56}
57
58impl StoragePlugin {
59 #[inline]
60 async fn new<S: Into<String>>(scx: ServerContext, name: S) -> Result<Self> {
61 let name = name.into();
62 let mut cfg = scx.plugins.read_config_default::<PluginConfig>(&name)?;
63 match cfg.storage.typ {
64 StorageType::Sled => {
65 cfg.storage.sled.path =
66 cfg.storage.sled.path.replace("{node}", &format!("{}", scx.node.id()));
67 }
68 StorageType::Redis => {
69 cfg.storage.redis.prefix =
70 cfg.storage.redis.prefix.replace("{node}", &format!("{}", scx.node.id()));
71 }
72 StorageType::RedisCluster => {
73 cfg.storage.redis_cluster.prefix =
74 cfg.storage.redis_cluster.prefix.replace("{node}", &format!("{}", scx.node.id()));
75 }
76 }
77
78 log::info!("{} StoragePlugin cfg: {:?}", name, cfg);
79
80 let storage_db = init_db(&cfg.storage).await?;
81
82 let stored_session_infos = StoredSessionInfos::new();
83
84 let register = scx.extends.hook_mgr().register();
85 let session_mgr =
86 StorageSessionManager::get_or_init(storage_db.clone(), stored_session_infos.clone());
87
88 let cfg = Arc::new(cfg);
89 let rebuild_tx = Self::start_local_runtime(scx.clone());
90 Ok(Self { scx, cfg, storage_db, stored_session_infos, register, session_mgr, rebuild_tx })
91 }
92
93 async fn load_offline_session_infos(&mut self) -> Result<()> {
94 log::info!("{:?} load_offline_session_infos ...", self.name());
95 let storage_db = self.storage_db.clone();
96 let mut iter_storage_db = storage_db.clone();
97 let mut map_iter = iter_storage_db.map_iter().await?;
99 while let Some(m) = map_iter.next().await {
100 match m {
101 Ok(m) => {
102 let id_key = StoredKey::from(map_stored_key_to_id_bytes(m.name()).to_vec());
103 log::debug!("map_stored_key: {:?}", id_key);
104 let basic = match m.get::<_, Basic>(BASIC).await {
105 Err(e) => {
106 log::warn!("{:?} load offline session basic info error, {:?}", id_key, e);
107 if let Err(e) = storage_db.map_remove(m.name()).await {
108 log::warn!("{:?} remove offline session info error, {:?}", id_key, e);
109 }
110 continue;
111 }
112 Ok(None) => {
113 log::warn!("{:?} offline session basic info is None", id_key);
114 if let Err(e) = storage_db.map_remove(m.name()).await {
115 log::warn!("{:?} remove offline session info error, {:?}", id_key, e);
116 }
117 continue;
118 }
119 Ok(Some(basic)) => basic,
120 };
121
122 log::debug!("basic: {:?}", basic);
123 log::debug!("map key: {:?}", id_key);
124 let mut s_info = StoredSessionInfo::from(id_key.clone(), basic);
125
126 match m.get::<_, TimestampMillis>(LAST_TIME).await {
127 Ok(Some(last_time)) => {
128 log::debug!("last_time: {:?}", last_time);
129 s_info.set_last_time(last_time);
130 }
131 Ok(None) => {}
132 Err(e) => {
133 log::warn!("{:?} load offline session last time error, {:?}", id_key, e);
134 }
135 }
136
137 match m.get::<_, SessionSubMap>(SESSION_SUB_MAP).await {
138 Ok(Some(subs)) => {
139 log::debug!("subs: {:?}", subs);
140 s_info.set_subs(subs);
141 }
142 Ok(None) => {}
143 Err(e) => {
144 log::warn!("{:?} load offline session subscription info error, {:?}", id_key, e);
145 }
146 }
147
148 match m.get::<_, DisconnectInfo>(DISCONNECT_INFO).await {
149 Ok(Some(disc_info)) => {
150 log::debug!("disc_info: {:?}", disc_info);
151 s_info.set_disconnect_info(disc_info);
152 }
153 Ok(None) => {}
154 Err(e) => {
155 log::warn!("{:?} load offline session disconnect info error, {:?}", id_key, e);
156 }
157 }
158
159 match m.get::<_, Vec<OutInflightMessage>>(INFLIGHT_MESSAGES).await {
160 Ok(Some(inflights)) => {
161 log::debug!("inflights len: {:?}", inflights.len());
162 s_info.inflight_messages = inflights;
163 }
164 Ok(None) => {}
165 Err(e) => {
166 log::warn!("{:?} load offline session inflight messages error, {:?}", id_key, e);
167 }
168 }
169
170 self.stored_session_infos.add(s_info);
171 }
172 Err(e) => {
173 log::warn!("load offline session info error, {:?}", e);
174 }
175 }
176 }
177 drop(map_iter);
178
179 let mut list_iter = iter_storage_db.list_iter().await?;
180 while let Some(l) = list_iter.next().await {
181 match l {
182 Ok(l) => {
183 let id_key = StoredKey::from(list_stored_key_to_id_bytes(l.name()).to_vec());
184 log::debug!("list_stored_key, id_key: {:?}", id_key);
185 match l.all::<OfflineMessageOptionType>().await {
186 Ok(offline_msgs) => {
187 log::debug!("{:?} offline_msgs len: {}", id_key, offline_msgs.len(),);
188 let ok =
189 self.stored_session_infos.set_offline_messages(id_key.clone(), offline_msgs);
190 log::debug!(
191 "{:?} stored_session_infos, set_offline_messages res: {}",
192 id_key,
193 ok
194 );
195 if !ok {
196 if let Err(e) = storage_db.list_remove(l.name()).await {
197 log::warn!("{:?} remove offline messages error, {:?}", id_key, e);
198 }
199 }
200 }
201 Err(e) => {
202 log::warn!("{:?} load offline messages error, {:?}", id_key, e);
203 if let Err(e) = storage_db.list_remove(l.name()).await {
204 log::warn!("{:?} remove offline messages error, {:?}", id_key, e);
205 }
206 }
207 }
208 }
209 Err(e) => {
210 log::warn!("load offline messages error, {:?}", e);
211 }
212 }
213 }
214 drop(list_iter);
215
216 for removed_key in self.stored_session_infos.retain_latests() {
217 storage_db.map_remove(make_map_stored_key(removed_key.as_ref())).await?;
218 storage_db.list_remove(make_list_stored_key(removed_key.as_ref())).await?;
219 }
220 log::info!("stored_session_infos len: {:?}", self.stored_session_infos.len());
221
222 Ok(())
223 }
224
225 fn start_local_runtime(scx: ServerContext) -> mpsc::Sender<RebuildChanType> {
226 let (tx, mut rx) = futures::channel::mpsc::channel::<RebuildChanType>(100_000);
227 std::thread::spawn(move || {
228 let local_rt = tokio::runtime::Builder::new_current_thread()
229 .enable_all()
230 .build()
231 .expect("tokio runtime build failed");
232 let local_set = tokio::task::LocalSet::new();
233
234 local_set.block_on(&local_rt, async {
235 while let Some(msg) = rx.next().await {
236 match msg {
237 RebuildChanType::Session(session, session_expiry_interval) => {
238 match SessionState::offline_restart(session.clone(), session_expiry_interval).await {
239 Err(e) => {
240 log::warn!("Rebuild offline session error, {:?}", e);
241 },
242 Ok(msg_tx) => {
243 let mut session_entry =
244 scx.extends.shared().await.entry(session.id.clone());
245
246 let id = session_entry.id().clone();
247 let task_fut = async move {
248 if let Err(e) = session_entry.set(session, msg_tx).await {
249 log::warn!("{:?} Rebuild offline session error, {:?}", session_entry.id(), e);
250 }
251 };
252 let task_exec = &scx.global_exec;
253 if let Err(e) = task_exec.spawn(task_fut).await {
254 log::warn!("{:?} Rebuild offline session error, {:?}", id, e.to_string());
255 }
256
257 let completed_count = task_exec.completed_count().await;
258 if completed_count > 0 && completed_count % 5000 == 0 {
259 log::info!(
260 "{:?} Rebuild offline session, completed_count: {}, active_count: {}, waiting_count: {}, rate: {:?}",
261 id,
262 task_exec.completed_count().await, task_exec.active_count(), task_exec.waiting_count(), task_exec.rate().await
263 );
264 }
265 }
266 }
267 },
268 RebuildChanType::Done(done_tx) => {
269 let task_exec = &scx.global_exec;
270 let _ = task_exec.flush().await;
271 let _ = done_tx.send(());
272 log::info!(
273 "Rebuild offline session, completed_count: {}, active_count: {}, waiting_count: {}, rate: {:?}",
274 task_exec.completed_count().await, task_exec.active_count(), task_exec.waiting_count(), task_exec.rate().await
275 );
276 }
277 }
278 }
279 });
280 log::info!("Rebuild offline session ends");
281 });
282 tx
283 }
284}
285
286#[async_trait]
287impl Plugin for StoragePlugin {
288 #[inline]
289 async fn init(&mut self) -> Result<()> {
290 log::info!("{} init", self.name());
291 self.register
292 .add(
293 Type::BeforeStartup,
294 Box::new(StorageHandler::new(
295 self.scx.clone(),
296 self.storage_db.clone(),
297 self.cfg.clone(),
298 self.stored_session_infos.clone(),
299 self.rebuild_tx.clone(),
300 )),
301 )
302 .await;
303 self.register
304 .add(
305 Type::OfflineMessage,
306 Box::new(OfflineMessageHandler::new(self.cfg.clone(), self.storage_db.clone())),
307 )
308 .await;
309 self.register
310 .add(
311 Type::OfflineInflightMessages,
312 Box::new(OfflineMessageHandler::new(self.cfg.clone(), self.storage_db.clone())),
313 )
314 .await;
315
316 self.load_offline_session_infos().await?;
317
318 Ok(())
319 }
320
321 #[inline]
322 async fn get_config(&self) -> Result<serde_json::Value> {
323 Ok(self.cfg.to_json())
324 }
325
326 #[inline]
327 async fn start(&mut self) -> Result<()> {
328 log::info!("{} start", self.name());
329 *self.scx.extends.session_mgr_mut().await = Box::new(self.session_mgr);
330
331 self.register.start().await;
332 Ok(())
333 }
334
335 #[inline]
336 async fn stop(&mut self) -> Result<bool> {
337 log::warn!("{} stop, if the storage plugin is started, it cannot be stopped", self.name());
338 Ok(false)
339 }
340
341 #[inline]
342 async fn attrs(&self) -> serde_json::Value {
343 let max_limit = 100;
344 let mut map_count = 0;
345 {
346 let now = std::time::Instant::now();
347 let mut storage_db = self.storage_db.clone();
348 let iter = storage_db.map_iter().await;
349 if let Ok(mut iter) = iter {
350 while let Some(m) = iter.next().await {
351 if let Ok(m) = m {
352 log::debug!("map: {:?}", StoredKey::from(m.name().to_vec()));
353 }
354 map_count += 1;
355 if map_count >= max_limit {
356 break;
357 }
358 }
359 }
360 log::debug!("map_iter cost time: {:?}", now.elapsed());
361 }
362
363 let mut list_count = 0;
364 {
365 let now = std::time::Instant::now();
366 let mut storage_db = self.storage_db.clone();
367 let iter = storage_db.list_iter().await;
368 if let Ok(mut iter) = iter {
369 while let Some(l) = iter.next().await {
370 if let Ok(l) = l {
371 log::debug!("list: {:?}", StoredKey::from(l.name().to_vec()));
372 }
373 list_count += 1;
374 if list_count >= max_limit {
375 break;
376 }
377 }
378 }
379 log::debug!("list_iter cost time: {:?}", now.elapsed());
380 }
381 let map_count =
382 if map_count >= max_limit { format!("{}+", map_count) } else { format!("{}", map_count) };
383 let list_count =
384 if list_count >= max_limit { format!("{}+", list_count) } else { format!("{}", list_count) };
385
386 let storage_info = self.storage_db.info().await.unwrap_or_default();
387
388 json!({
389 "session_count": map_count,
390 "offline_messages_count": list_count,
391 "storage_info": storage_info
392 })
393 }
394}
395
396struct OfflineMessageHandler {
397 cfg: Arc<PluginConfig>,
398 storage_db: DefaultStorageDB,
399}
400
401impl OfflineMessageHandler {
402 fn new(cfg: Arc<PluginConfig>, storage_db: DefaultStorageDB) -> Self {
403 Self { cfg, storage_db }
404 }
405}
406
407#[async_trait]
408impl Handler for OfflineMessageHandler {
409 async fn hook(&self, param: &Parameter, acc: Option<HookResult>) -> ReturnType {
410 match param {
411 Parameter::OfflineMessage(s, f, p) => {
412 log::debug!(
413 "OfflineMessage storage_type: {:?}, from: {:?}, p: {:?}",
414 self.cfg.storage.typ,
415 f,
416 p
417 );
418 let list_stored_key = make_list_stored_key(s.id.to_string());
419 match self.storage_db.list(list_stored_key.as_ref(), None).await {
420 Ok(offlines_list) => {
421 let res = offlines_list
422 .push_limit::<OfflineMessageOptionType>(
423 &Some((s.id.client_id.clone(), f.clone(), (*p).clone())),
424 s.listen_cfg().max_mqueue_len,
425 true,
426 )
427 .await;
428 if let Err(e) = res {
429 log::warn!("{:?} save offline messages error, {:?}", s.id, e)
430 }
431 }
432 Err(e) => {
433 log::warn!("{:?} save offline messages error, {:?}", s.id, e)
434 }
435 }
436 }
437
438 Parameter::OfflineInflightMessages(s, inflight_messages) => {
439 log::debug!(
440 "OfflineInflightMessages storage_type: {:?}, inflight_messages len: {:?}",
441 self.cfg.storage.typ,
442 inflight_messages.len(),
443 );
444 let map_stored_key = make_map_stored_key(s.id.to_string());
445 log::debug!("{:?} map_stored_key: {:?}", s.id, map_stored_key);
446 match self.storage_db.map(map_stored_key.as_ref(), None).await {
447 Ok(m) => {
448 if let Err(e) = m.insert(INFLIGHT_MESSAGES, inflight_messages).await {
449 log::warn!("{:?} save offline inflight messages error, {:?}", s.id, e)
450 }
451 }
452 Err(e) => {
453 log::warn!("{:?} save offline inflight messages error, {:?}", s.id, e)
454 }
455 }
456 }
457
458 _ => {
459 log::error!("unimplemented, {:?}", param)
460 }
461 }
462 (true, acc)
463 }
464}
465
466struct StorageHandler {
467 scx: ServerContext,
468 storage_db: DefaultStorageDB,
469 cfg: Arc<PluginConfig>,
470 stored_session_infos: StoredSessionInfos,
471 rebuild_tx: mpsc::Sender<RebuildChanType>,
472}
473
474impl StorageHandler {
475 fn new(
476 scx: ServerContext,
477 storage_db: DefaultStorageDB,
478 cfg: Arc<PluginConfig>,
479 stored_session_infos: StoredSessionInfos,
480 rebuild_tx: mpsc::Sender<RebuildChanType>,
481 ) -> Self {
482 Self { scx, storage_db, cfg, stored_session_infos, rebuild_tx }
483 }
484
485 async fn rebuild_offline_sessions(&self, rebuild_done_tx: oneshot::Sender<()>) {
487 let mut offline_sessions_count = 0;
488 for mut entry in self.stored_session_infos.iter_mut() {
489 let (_, storeds) = entry.pair_mut();
490 if let Some(stored) = storeds.iter_mut().next() {
491 let id = stored.basic.id.clone();
492
493 let listen_cfg = if let Some(listen_cfg) = id
495 .local_addr
496 .and_then(|addr| self.scx.listen_cfgs.get(&addr.port()).map(|c| c.value().clone()))
497 {
498 listen_cfg
499 } else {
500 log::warn!("tcp listener config is not found, local addr is {:?}", id.local_addr);
501 continue;
502 };
503
504 let fitter = self.scx.extends.fitter_mgr().await.create(
506 stored.basic.conn_info.clone(),
507 id.clone(),
508 listen_cfg.clone(),
509 );
510
511 let session_expiry_interval = session_expiry_interval(
513 fitter.as_ref(),
514 stored.disconnect_info.as_ref(),
515 stored.last_time,
516 )
517 .await;
518 log::debug!("{:?} session_expiry_interval: {:?}", id, session_expiry_interval);
519 if session_expiry_interval <= 0 {
520 log::debug!(
521 "{:?} session is expiry, {:?}, id_key: {:?}, {:?}, {:?}",
522 id,
523 session_expiry_interval,
524 stored.id_key,
525 make_map_stored_key(stored.id_key.as_ref()),
526 make_list_stored_key(stored.id_key.as_ref())
527 );
528 let storage_db = self.storage_db.clone();
529 if let Err(e) = storage_db.map_remove(make_map_stored_key(stored.id_key.as_ref())).await {
530 log::warn!("{:?} remove map error, {:?}", id, e);
531 }
532 if let Err(e) = storage_db.list_remove(make_list_stored_key(stored.id_key.as_ref())).await
533 {
534 log::warn!("{:?} remove list error, {:?}", id, e);
535 }
536 continue;
538 }
539 offline_sessions_count += 1;
540
541 if stored.disconnect_info.is_none() {
542 stored.disconnect_info = Some(DisconnectInfo::new(stored.last_time));
543 }
544
545 let max_inflight = fitter.max_inflight();
546 let max_mqueue_len = fitter.max_mqueue_len();
547 let subs = stored.subs.take().map(SessionSubs::from).unwrap_or_else(SessionSubs::new);
548
549 let session = match Session::new(
550 id.clone(),
551 self.scx.clone(),
552 max_mqueue_len,
553 listen_cfg,
554 fitter,
555 None,
556 max_inflight,
557 stored.basic.created_at,
558 stored.basic.conn_info.clone(),
559 false,
560 false,
561 false,
562 stored.basic.connected_at,
563 subs,
564 stored.disconnect_info.take(),
565 None,
566 )
567 .await
568 {
569 Ok(s) => s,
570 Err(e) => {
571 log::warn!("rebuild session offline message error, create session error, {:?}", e);
572 continue;
573 }
574 };
575
576 let deliver_queue = session.deliver_queue();
577 for item in stored.offline_messages.drain(..) {
578 if let Err((f, p)) = deliver_queue.push(item) {
579 log::warn!("rebuild session offline message error, deliver queue is full, from: {:?}, publish: {:?}", f, p);
580 }
581 }
582
583 let out_inflight = session.out_inflight();
584 for item in stored.inflight_messages.drain(..) {
585 out_inflight.write().await.push_back(item);
586 }
587
588 if let Err(e) = self
589 .rebuild_tx
590 .clone()
591 .send(RebuildChanType::Session(
592 session,
593 Duration::from_millis(session_expiry_interval as u64),
594 ))
595 .await
596 {
597 log::error!("rebuild offline sessions error, {:?}", e);
598 }
599 }
600 }
601 log::info!("offline_sessions_count: {}", offline_sessions_count);
602 let _ = self.rebuild_tx.clone().send(RebuildChanType::Done(rebuild_done_tx)).await;
603 }
604}
605
606#[async_trait]
607impl Handler for StorageHandler {
608 async fn hook(&self, param: &Parameter, acc: Option<HookResult>) -> ReturnType {
609 match param {
610 Parameter::BeforeStartup => {
611 log::info!(
612 "BeforeStartup storage_type: {:?}, stored_session_infos len: {}",
613 self.cfg.storage.typ,
614 self.stored_session_infos.len()
615 );
616 let (rebuild_done_tx, rebuild_done_rx) = oneshot::channel::<()>();
617 self.rebuild_offline_sessions(rebuild_done_tx).await;
618 let _ = rebuild_done_rx.await;
619 }
620 _ => {
621 log::error!("unimplemented, {:?}", param)
622 }
623 }
624 (true, acc)
625 }
626}
627
628#[inline]
629async fn session_expiry_interval(
630 fitter: &dyn Fitter,
631 disconnect_info: Option<&DisconnectInfo>,
632 last_time: TimestampMillis,
633) -> TimestampMillis {
634 let disconnected_at = disconnect_info.map(|d| d.disconnected_at).unwrap_or_default();
635 let disconnected_at = if disconnected_at <= 0 { last_time } else { disconnected_at };
636 fitter.session_expiry_interval(disconnect_info.and_then(|d| d.mqtt_disconnect.as_ref())).as_millis()
637 as i64
638 - (timestamp_millis() - disconnected_at)
639}
640
641#[inline]
642pub(crate) fn make_map_stored_key<T: AsRef<[u8]>>(id: T) -> StoredKey {
643 let mut key = Vec::from("map-");
644 key.extend_from_slice(id.as_ref());
645 Bytes::from(key)
646}
647
648#[inline]
649pub(crate) fn map_stored_key_to_id_bytes(stored_key: &[u8]) -> &[u8] {
650 if stored_key.starts_with(b"map-") {
651 stored_key[4..].as_ref()
652 } else {
653 stored_key
654 }
655}
656
657#[inline]
658pub(crate) fn make_list_stored_key<T: AsRef<[u8]>>(id: T) -> StoredKey {
659 let mut key = Vec::from("list-");
660 key.extend_from_slice(id.as_ref());
661 Bytes::from(key)
662}
663
664#[inline]
665pub(crate) fn list_stored_key_to_id_bytes(stored_key: &[u8]) -> &[u8] {
666 if stored_key.starts_with(b"list-") {
667 stored_key[5..].as_ref()
668 } else {
669 stored_key
670 }
671}