rmqtt_session_storage/
lib.rs

1#![deny(unsafe_code)]
2
3use std::convert::From as _;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use bytes::Bytes;
9use futures::channel::{mpsc, oneshot};
10use futures::{SinkExt, StreamExt};
11use serde_json::{self, json};
12
13use rmqtt::{
14    fitter::Fitter,
15    hook::{Handler, HookResult, Parameter, Register, ReturnType, Type},
16    plugin::{PackageInfo, Plugin},
17    register,
18    session::Session,
19    types::DisconnectInfo,
20    types::{ClientId, From, Publish, SessionSubMap, SessionSubs, TimestampMillis},
21    utils::timestamp_millis,
22    Result,
23};
24
25use rmqtt_storage::{init_db, DefaultStorageDB, List, Map, StorageType};
26
27use config::PluginConfig;
28use rmqtt::context::ServerContext;
29use rmqtt::inflight::OutInflightMessage;
30use rmqtt::macros::Plugin;
31use rmqtt::session::SessionState;
32use session::{Basic, StorageSessionManager, StoredSessionInfo, StoredSessionInfos};
33use session::{StoredKey, BASIC, DISCONNECT_INFO, INFLIGHT_MESSAGES, LAST_TIME, SESSION_SUB_MAP};
34
35mod config;
36mod session;
37
38enum RebuildChanType {
39    Session(Session, Duration),
40    Done(oneshot::Sender<()>),
41}
42
43type OfflineMessageOptionType = Option<(ClientId, From, Publish)>;
44
45register!(StoragePlugin::new);
46
47#[derive(Plugin)]
48struct StoragePlugin {
49    scx: ServerContext,
50    cfg: Arc<PluginConfig>,
51    storage_db: DefaultStorageDB,
52    stored_session_infos: StoredSessionInfos,
53    register: Box<dyn Register>,
54    session_mgr: &'static StorageSessionManager,
55    rebuild_tx: mpsc::Sender<RebuildChanType>,
56}
57
58impl StoragePlugin {
59    #[inline]
60    async fn new<S: Into<String>>(scx: ServerContext, name: S) -> Result<Self> {
61        let name = name.into();
62        let mut cfg = scx.plugins.read_config_default::<PluginConfig>(&name)?;
63        match cfg.storage.typ {
64            StorageType::Sled => {
65                cfg.storage.sled.path =
66                    cfg.storage.sled.path.replace("{node}", &format!("{}", scx.node.id()));
67            }
68            StorageType::Redis => {
69                cfg.storage.redis.prefix =
70                    cfg.storage.redis.prefix.replace("{node}", &format!("{}", scx.node.id()));
71            }
72            StorageType::RedisCluster => {
73                cfg.storage.redis_cluster.prefix =
74                    cfg.storage.redis_cluster.prefix.replace("{node}", &format!("{}", scx.node.id()));
75            }
76        }
77
78        log::info!("{} StoragePlugin cfg: {:?}", name, cfg);
79
80        let storage_db = init_db(&cfg.storage).await?;
81
82        let stored_session_infos = StoredSessionInfos::new();
83
84        let register = scx.extends.hook_mgr().register();
85        let session_mgr =
86            StorageSessionManager::get_or_init(storage_db.clone(), stored_session_infos.clone());
87
88        let cfg = Arc::new(cfg);
89        let rebuild_tx = Self::start_local_runtime(scx.clone());
90        Ok(Self { scx, cfg, storage_db, stored_session_infos, register, session_mgr, rebuild_tx })
91    }
92
93    async fn load_offline_session_infos(&mut self) -> Result<()> {
94        log::info!("{:?} load_offline_session_infos ...", self.name());
95        let storage_db = self.storage_db.clone();
96        let mut iter_storage_db = storage_db.clone();
97        //Load offline session information from the database
98        let mut map_iter = iter_storage_db.map_iter().await?;
99        while let Some(m) = map_iter.next().await {
100            match m {
101                Ok(m) => {
102                    let id_key = StoredKey::from(map_stored_key_to_id_bytes(m.name()).to_vec());
103                    log::debug!("map_stored_key: {:?}", id_key);
104                    let basic = match m.get::<_, Basic>(BASIC).await {
105                        Err(e) => {
106                            log::warn!("{:?} load offline session basic info error, {:?}", id_key, e);
107                            if let Err(e) = storage_db.map_remove(m.name()).await {
108                                log::warn!("{:?} remove offline session info error, {:?}", id_key, e);
109                            }
110                            continue;
111                        }
112                        Ok(None) => {
113                            log::warn!("{:?} offline session basic info is None", id_key);
114                            if let Err(e) = storage_db.map_remove(m.name()).await {
115                                log::warn!("{:?} remove offline session info error, {:?}", id_key, e);
116                            }
117                            continue;
118                        }
119                        Ok(Some(basic)) => basic,
120                    };
121
122                    log::debug!("basic: {:?}", basic);
123                    log::debug!("map key: {:?}", id_key);
124                    let mut s_info = StoredSessionInfo::from(id_key.clone(), basic);
125
126                    match m.get::<_, TimestampMillis>(LAST_TIME).await {
127                        Ok(Some(last_time)) => {
128                            log::debug!("last_time: {:?}", last_time);
129                            s_info.set_last_time(last_time);
130                        }
131                        Ok(None) => {}
132                        Err(e) => {
133                            log::warn!("{:?} load offline session last time error, {:?}", id_key, e);
134                        }
135                    }
136
137                    match m.get::<_, SessionSubMap>(SESSION_SUB_MAP).await {
138                        Ok(Some(subs)) => {
139                            log::debug!("subs: {:?}", subs);
140                            s_info.set_subs(subs);
141                        }
142                        Ok(None) => {}
143                        Err(e) => {
144                            log::warn!("{:?} load offline session subscription info error, {:?}", id_key, e);
145                        }
146                    }
147
148                    match m.get::<_, DisconnectInfo>(DISCONNECT_INFO).await {
149                        Ok(Some(disc_info)) => {
150                            log::debug!("disc_info: {:?}", disc_info);
151                            s_info.set_disconnect_info(disc_info);
152                        }
153                        Ok(None) => {}
154                        Err(e) => {
155                            log::warn!("{:?} load offline session disconnect info error, {:?}", id_key, e);
156                        }
157                    }
158
159                    match m.get::<_, Vec<OutInflightMessage>>(INFLIGHT_MESSAGES).await {
160                        Ok(Some(inflights)) => {
161                            log::debug!("inflights len: {:?}", inflights.len());
162                            s_info.inflight_messages = inflights;
163                        }
164                        Ok(None) => {}
165                        Err(e) => {
166                            log::warn!("{:?} load offline session inflight messages error, {:?}", id_key, e);
167                        }
168                    }
169
170                    self.stored_session_infos.add(s_info);
171                }
172                Err(e) => {
173                    log::warn!("load offline session info error, {:?}", e);
174                }
175            }
176        }
177        drop(map_iter);
178
179        let mut list_iter = iter_storage_db.list_iter().await?;
180        while let Some(l) = list_iter.next().await {
181            match l {
182                Ok(l) => {
183                    let id_key = StoredKey::from(list_stored_key_to_id_bytes(l.name()).to_vec());
184                    log::debug!("list_stored_key, id_key: {:?}", id_key);
185                    match l.all::<OfflineMessageOptionType>().await {
186                        Ok(offline_msgs) => {
187                            log::debug!("{:?} offline_msgs len: {}", id_key, offline_msgs.len(),);
188                            let ok =
189                                self.stored_session_infos.set_offline_messages(id_key.clone(), offline_msgs);
190                            log::debug!(
191                                "{:?} stored_session_infos, set_offline_messages res: {}",
192                                id_key,
193                                ok
194                            );
195                            if !ok {
196                                if let Err(e) = storage_db.list_remove(l.name()).await {
197                                    log::warn!("{:?} remove offline messages error, {:?}", id_key, e);
198                                }
199                            }
200                        }
201                        Err(e) => {
202                            log::warn!("{:?} load offline messages error, {:?}", id_key, e);
203                            if let Err(e) = storage_db.list_remove(l.name()).await {
204                                log::warn!("{:?} remove offline messages error, {:?}", id_key, e);
205                            }
206                        }
207                    }
208                }
209                Err(e) => {
210                    log::warn!("load offline messages error, {:?}", e);
211                }
212            }
213        }
214        drop(list_iter);
215
216        for removed_key in self.stored_session_infos.retain_latests() {
217            storage_db.map_remove(make_map_stored_key(removed_key.as_ref())).await?;
218            storage_db.list_remove(make_list_stored_key(removed_key.as_ref())).await?;
219        }
220        log::info!("stored_session_infos len: {:?}", self.stored_session_infos.len());
221
222        Ok(())
223    }
224
225    fn start_local_runtime(scx: ServerContext) -> mpsc::Sender<RebuildChanType> {
226        let (tx, mut rx) = futures::channel::mpsc::channel::<RebuildChanType>(100_000);
227        std::thread::spawn(move || {
228            let local_rt = tokio::runtime::Builder::new_current_thread()
229                .enable_all()
230                .build()
231                .expect("tokio runtime build failed");
232            let local_set = tokio::task::LocalSet::new();
233
234            local_set.block_on(&local_rt, async {
235                while let Some(msg) = rx.next().await {
236                    match msg {
237                        RebuildChanType::Session(session, session_expiry_interval)  => {
238                            match SessionState::offline_restart(session.clone(), session_expiry_interval).await {
239                                Err(e) => {
240                                    log::warn!("Rebuild offline session error, {:?}", e);
241                                },
242                                Ok(msg_tx) => {
243                                    let mut session_entry =
244                                        scx.extends.shared().await.entry(session.id.clone());
245
246                                    let id = session_entry.id().clone();
247                                    let task_fut = async move {
248                                        if let Err(e) = session_entry.set(session, msg_tx).await {
249                                            log::warn!("{:?} Rebuild offline session error, {:?}", session_entry.id(), e);
250                                        }
251                                    };
252                                    let task_exec = &scx.global_exec;
253                                    if let Err(e) = task_exec.spawn(task_fut).await {
254                                        log::warn!("{:?} Rebuild offline session error, {:?}", id, e.to_string());
255                                    }
256
257                                    let completed_count = task_exec.completed_count().await;
258                                    if completed_count > 0 && completed_count % 5000 == 0 {
259                                        log::info!(
260                                        "{:?} Rebuild offline session, completed_count: {}, active_count: {}, waiting_count: {}, rate: {:?}",
261                                        id,
262                                        task_exec.completed_count().await, task_exec.active_count(), task_exec.waiting_count(), task_exec.rate().await
263                                    );
264                                    }
265                                }
266                            }
267                        },
268                        RebuildChanType::Done(done_tx) => {
269                            let task_exec = &scx.global_exec;
270                            let _ = task_exec.flush().await;
271                            let _ = done_tx.send(());
272                            log::info!(
273                                "Rebuild offline session, completed_count: {}, active_count: {}, waiting_count: {}, rate: {:?}",
274                                task_exec.completed_count().await, task_exec.active_count(), task_exec.waiting_count(), task_exec.rate().await
275                            );
276                        }
277                    }
278                }
279            });
280            log::info!("Rebuild offline session ends");
281        });
282        tx
283    }
284}
285
286#[async_trait]
287impl Plugin for StoragePlugin {
288    #[inline]
289    async fn init(&mut self) -> Result<()> {
290        log::info!("{} init", self.name());
291        self.register
292            .add(
293                Type::BeforeStartup,
294                Box::new(StorageHandler::new(
295                    self.scx.clone(),
296                    self.storage_db.clone(),
297                    self.cfg.clone(),
298                    self.stored_session_infos.clone(),
299                    self.rebuild_tx.clone(),
300                )),
301            )
302            .await;
303        self.register
304            .add(
305                Type::OfflineMessage,
306                Box::new(OfflineMessageHandler::new(self.cfg.clone(), self.storage_db.clone())),
307            )
308            .await;
309        self.register
310            .add(
311                Type::OfflineInflightMessages,
312                Box::new(OfflineMessageHandler::new(self.cfg.clone(), self.storage_db.clone())),
313            )
314            .await;
315
316        self.load_offline_session_infos().await?;
317
318        Ok(())
319    }
320
321    #[inline]
322    async fn get_config(&self) -> Result<serde_json::Value> {
323        Ok(self.cfg.to_json())
324    }
325
326    #[inline]
327    async fn start(&mut self) -> Result<()> {
328        log::info!("{} start", self.name());
329        *self.scx.extends.session_mgr_mut().await = Box::new(self.session_mgr);
330
331        self.register.start().await;
332        Ok(())
333    }
334
335    #[inline]
336    async fn stop(&mut self) -> Result<bool> {
337        log::warn!("{} stop, if the storage plugin is started, it cannot be stopped", self.name());
338        Ok(false)
339    }
340
341    #[inline]
342    async fn attrs(&self) -> serde_json::Value {
343        let max_limit = 100;
344        let mut map_count = 0;
345        {
346            let now = std::time::Instant::now();
347            let mut storage_db = self.storage_db.clone();
348            let iter = storage_db.map_iter().await;
349            if let Ok(mut iter) = iter {
350                while let Some(m) = iter.next().await {
351                    if let Ok(m) = m {
352                        log::debug!("map: {:?}", StoredKey::from(m.name().to_vec()));
353                    }
354                    map_count += 1;
355                    if map_count >= max_limit {
356                        break;
357                    }
358                }
359            }
360            log::debug!("map_iter cost time: {:?}", now.elapsed());
361        }
362
363        let mut list_count = 0;
364        {
365            let now = std::time::Instant::now();
366            let mut storage_db = self.storage_db.clone();
367            let iter = storage_db.list_iter().await;
368            if let Ok(mut iter) = iter {
369                while let Some(l) = iter.next().await {
370                    if let Ok(l) = l {
371                        log::debug!("list: {:?}", StoredKey::from(l.name().to_vec()));
372                    }
373                    list_count += 1;
374                    if list_count >= max_limit {
375                        break;
376                    }
377                }
378            }
379            log::debug!("list_iter cost time: {:?}", now.elapsed());
380        }
381        let map_count =
382            if map_count >= max_limit { format!("{}+", map_count) } else { format!("{}", map_count) };
383        let list_count =
384            if list_count >= max_limit { format!("{}+", list_count) } else { format!("{}", list_count) };
385
386        let storage_info = self.storage_db.info().await.unwrap_or_default();
387
388        json!({
389            "session_count": map_count,
390            "offline_messages_count": list_count,
391            "storage_info": storage_info
392        })
393    }
394}
395
396struct OfflineMessageHandler {
397    cfg: Arc<PluginConfig>,
398    storage_db: DefaultStorageDB,
399}
400
401impl OfflineMessageHandler {
402    fn new(cfg: Arc<PluginConfig>, storage_db: DefaultStorageDB) -> Self {
403        Self { cfg, storage_db }
404    }
405}
406
407#[async_trait]
408impl Handler for OfflineMessageHandler {
409    async fn hook(&self, param: &Parameter, acc: Option<HookResult>) -> ReturnType {
410        match param {
411            Parameter::OfflineMessage(s, f, p) => {
412                log::debug!(
413                    "OfflineMessage storage_type: {:?}, from: {:?}, p: {:?}",
414                    self.cfg.storage.typ,
415                    f,
416                    p
417                );
418                let list_stored_key = make_list_stored_key(s.id.to_string());
419                match self.storage_db.list(list_stored_key.as_ref(), None).await {
420                    Ok(offlines_list) => {
421                        let res = offlines_list
422                            .push_limit::<OfflineMessageOptionType>(
423                                &Some((s.id.client_id.clone(), f.clone(), (*p).clone())),
424                                s.listen_cfg().max_mqueue_len,
425                                true,
426                            )
427                            .await;
428                        if let Err(e) = res {
429                            log::warn!("{:?} save offline messages error, {:?}", s.id, e)
430                        }
431                    }
432                    Err(e) => {
433                        log::warn!("{:?} save offline messages error, {:?}", s.id, e)
434                    }
435                }
436            }
437
438            Parameter::OfflineInflightMessages(s, inflight_messages) => {
439                log::debug!(
440                    "OfflineInflightMessages storage_type: {:?}, inflight_messages len: {:?}",
441                    self.cfg.storage.typ,
442                    inflight_messages.len(),
443                );
444                let map_stored_key = make_map_stored_key(s.id.to_string());
445                log::debug!("{:?} map_stored_key: {:?}", s.id, map_stored_key);
446                match self.storage_db.map(map_stored_key.as_ref(), None).await {
447                    Ok(m) => {
448                        if let Err(e) = m.insert(INFLIGHT_MESSAGES, inflight_messages).await {
449                            log::warn!("{:?} save offline inflight messages error, {:?}", s.id, e)
450                        }
451                    }
452                    Err(e) => {
453                        log::warn!("{:?} save offline inflight messages error, {:?}", s.id, e)
454                    }
455                }
456            }
457
458            _ => {
459                log::error!("unimplemented, {:?}", param)
460            }
461        }
462        (true, acc)
463    }
464}
465
466struct StorageHandler {
467    scx: ServerContext,
468    storage_db: DefaultStorageDB,
469    cfg: Arc<PluginConfig>,
470    stored_session_infos: StoredSessionInfos,
471    rebuild_tx: mpsc::Sender<RebuildChanType>,
472}
473
474impl StorageHandler {
475    fn new(
476        scx: ServerContext,
477        storage_db: DefaultStorageDB,
478        cfg: Arc<PluginConfig>,
479        stored_session_infos: StoredSessionInfos,
480        rebuild_tx: mpsc::Sender<RebuildChanType>,
481    ) -> Self {
482        Self { scx, storage_db, cfg, stored_session_infos, rebuild_tx }
483    }
484
485    //Rebuild offline session.
486    async fn rebuild_offline_sessions(&self, rebuild_done_tx: oneshot::Sender<()>) {
487        let mut offline_sessions_count = 0;
488        for mut entry in self.stored_session_infos.iter_mut() {
489            let (_, storeds) = entry.pair_mut();
490            if let Some(stored) = storeds.iter_mut().next() {
491                let id = stored.basic.id.clone();
492
493                //get listener config
494                let listen_cfg = if let Some(listen_cfg) = id
495                    .local_addr
496                    .and_then(|addr| self.scx.listen_cfgs.get(&addr.port()).map(|c| c.value().clone()))
497                {
498                    listen_cfg
499                } else {
500                    log::warn!("tcp listener config is not found, local addr is {:?}", id.local_addr);
501                    continue;
502                };
503
504                //create fitter
505                let fitter = self.scx.extends.fitter_mgr().await.create(
506                    stored.basic.conn_info.clone(),
507                    id.clone(),
508                    listen_cfg.clone(),
509                );
510
511                //check session expiry interval
512                let session_expiry_interval = session_expiry_interval(
513                    fitter.as_ref(),
514                    stored.disconnect_info.as_ref(),
515                    stored.last_time,
516                )
517                .await;
518                log::debug!("{:?} session_expiry_interval: {:?}", id, session_expiry_interval);
519                if session_expiry_interval <= 0 {
520                    log::debug!(
521                        "{:?} session is expiry, {:?}, id_key: {:?}, {:?}, {:?}",
522                        id,
523                        session_expiry_interval,
524                        stored.id_key,
525                        make_map_stored_key(stored.id_key.as_ref()),
526                        make_list_stored_key(stored.id_key.as_ref())
527                    );
528                    let storage_db = self.storage_db.clone();
529                    if let Err(e) = storage_db.map_remove(make_map_stored_key(stored.id_key.as_ref())).await {
530                        log::warn!("{:?} remove map error, {:?}", id, e);
531                    }
532                    if let Err(e) = storage_db.list_remove(make_list_stored_key(stored.id_key.as_ref())).await
533                    {
534                        log::warn!("{:?} remove list error, {:?}", id, e);
535                    }
536                    //session is expiry
537                    continue;
538                }
539                offline_sessions_count += 1;
540
541                if stored.disconnect_info.is_none() {
542                    stored.disconnect_info = Some(DisconnectInfo::new(stored.last_time));
543                }
544
545                let max_inflight = fitter.max_inflight();
546                let max_mqueue_len = fitter.max_mqueue_len();
547                let subs = stored.subs.take().map(SessionSubs::from).unwrap_or_else(SessionSubs::new);
548
549                let session = match Session::new(
550                    id.clone(),
551                    self.scx.clone(),
552                    max_mqueue_len,
553                    listen_cfg,
554                    fitter,
555                    None,
556                    max_inflight,
557                    stored.basic.created_at,
558                    stored.basic.conn_info.clone(),
559                    false,
560                    false,
561                    false,
562                    stored.basic.connected_at,
563                    subs,
564                    stored.disconnect_info.take(),
565                    None,
566                )
567                .await
568                {
569                    Ok(s) => s,
570                    Err(e) => {
571                        log::warn!("rebuild session offline message error, create session error, {:?}", e);
572                        continue;
573                    }
574                };
575
576                let deliver_queue = session.deliver_queue();
577                for item in stored.offline_messages.drain(..) {
578                    if let Err((f, p)) = deliver_queue.push(item) {
579                        log::warn!("rebuild session offline message error, deliver queue is full, from: {:?}, publish: {:?}", f, p);
580                    }
581                }
582
583                let out_inflight = session.out_inflight();
584                for item in stored.inflight_messages.drain(..) {
585                    out_inflight.write().await.push_back(item);
586                }
587
588                if let Err(e) = self
589                    .rebuild_tx
590                    .clone()
591                    .send(RebuildChanType::Session(
592                        session,
593                        Duration::from_millis(session_expiry_interval as u64),
594                    ))
595                    .await
596                {
597                    log::error!("rebuild offline sessions error, {:?}", e);
598                }
599            }
600        }
601        log::info!("offline_sessions_count: {}", offline_sessions_count);
602        let _ = self.rebuild_tx.clone().send(RebuildChanType::Done(rebuild_done_tx)).await;
603    }
604}
605
606#[async_trait]
607impl Handler for StorageHandler {
608    async fn hook(&self, param: &Parameter, acc: Option<HookResult>) -> ReturnType {
609        match param {
610            Parameter::BeforeStartup => {
611                log::info!(
612                    "BeforeStartup storage_type: {:?}, stored_session_infos len: {}",
613                    self.cfg.storage.typ,
614                    self.stored_session_infos.len()
615                );
616                let (rebuild_done_tx, rebuild_done_rx) = oneshot::channel::<()>();
617                self.rebuild_offline_sessions(rebuild_done_tx).await;
618                let _ = rebuild_done_rx.await;
619            }
620            _ => {
621                log::error!("unimplemented, {:?}", param)
622            }
623        }
624        (true, acc)
625    }
626}
627
628#[inline]
629async fn session_expiry_interval(
630    fitter: &dyn Fitter,
631    disconnect_info: Option<&DisconnectInfo>,
632    last_time: TimestampMillis,
633) -> TimestampMillis {
634    let disconnected_at = disconnect_info.map(|d| d.disconnected_at).unwrap_or_default();
635    let disconnected_at = if disconnected_at <= 0 { last_time } else { disconnected_at };
636    fitter.session_expiry_interval(disconnect_info.and_then(|d| d.mqtt_disconnect.as_ref())).as_millis()
637        as i64
638        - (timestamp_millis() - disconnected_at)
639}
640
641#[inline]
642pub(crate) fn make_map_stored_key<T: AsRef<[u8]>>(id: T) -> StoredKey {
643    let mut key = Vec::from("map-");
644    key.extend_from_slice(id.as_ref());
645    Bytes::from(key)
646}
647
648#[inline]
649pub(crate) fn map_stored_key_to_id_bytes(stored_key: &[u8]) -> &[u8] {
650    if stored_key.starts_with(b"map-") {
651        stored_key[4..].as_ref()
652    } else {
653        stored_key
654    }
655}
656
657#[inline]
658pub(crate) fn make_list_stored_key<T: AsRef<[u8]>>(id: T) -> StoredKey {
659    let mut key = Vec::from("list-");
660    key.extend_from_slice(id.as_ref());
661    Bytes::from(key)
662}
663
664#[inline]
665pub(crate) fn list_stored_key_to_id_bytes(stored_key: &[u8]) -> &[u8] {
666    if stored_key.starts_with(b"list-") {
667        stored_key[5..].as_ref()
668    } else {
669        stored_key
670    }
671}