Skip to main content

rmqtt_bridge_origin/
lib.rs

1#![deny(unsafe_code)]
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use tokio::sync::RwLock;
7
8use rmqtt::{
9    context::ServerContext,
10    hook::{Handler, HookResult, Parameter, Register, ReturnType, Type},
11    macros::Plugin,
12    plugin::{PackageInfo, Plugin as _PluginTrait},
13    register, Result,
14};
15
16use config::PluginConfig;
17
18mod config;
19
20register!(BridgeOriginPlugin::new);
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum BridgeDirection {
24    Ingress,
25    Egress,
26}
27
28#[derive(Debug, Clone, Copy)]
29pub struct BridgeOrigin {
30    pub direction: BridgeDirection,
31}
32
33impl BridgeOrigin {
34    #[inline]
35    pub fn is_ingress(&self) -> bool {
36        matches!(self.direction, BridgeDirection::Ingress)
37    }
38
39    #[inline]
40    pub fn is_egress(&self) -> bool {
41        matches!(self.direction, BridgeDirection::Egress)
42    }
43}
44
45#[derive(Plugin)]
46struct BridgeOriginPlugin {
47    scx: ServerContext,
48    register: Box<dyn Register>,
49    cfg: Arc<RwLock<PluginConfig>>,
50}
51
52impl BridgeOriginPlugin {
53    #[inline]
54    async fn new<N: Into<String>>(scx: ServerContext, name: N) -> Result<Self> {
55        let name: String = name.into();
56        let cfg = scx.plugins.read_config_default::<PluginConfig>(&name);
57        let cfg = Arc::new(RwLock::new(cfg?));
58        let register = scx.extends.hook_mgr().register();
59        log::info!("{name} BridgeOriginPlugin cfg: {cfg:?}", cfg = cfg.read().await);
60        Ok(Self { scx, register, cfg })
61    }
62}
63
64#[async_trait]
65impl _PluginTrait for BridgeOriginPlugin {
66    #[inline]
67    async fn init(&mut self) -> Result<()> {
68        log::info!("{} init", self.name());
69        self.register.add(Type::ClientConnected, Box::new(BridgeOriginHandler::new(self.cfg.clone()))).await;
70        log::info!("{} registered ClientConnected hook", self.name());
71        Ok(())
72    }
73
74    #[inline]
75    async fn get_config(&self) -> Result<serde_json::Value> {
76        self.cfg.read().await.to_json()
77    }
78
79    #[inline]
80    async fn load_config(&mut self) -> Result<()> {
81        let new_cfg = self.scx.plugins.read_config::<PluginConfig>(self.name())?;
82        *self.cfg.write().await = new_cfg;
83        log::info!("{} load_config ok", self.name());
84        Ok(())
85    }
86
87    #[inline]
88    async fn start(&mut self) -> Result<()> {
89        log::info!("{} start", self.name());
90        self.register.start().await;
91        Ok(())
92    }
93
94    #[inline]
95    async fn stop(&mut self) -> Result<bool> {
96        log::info!("{} stop", self.name());
97        self.register.stop().await;
98        Ok(false)
99    }
100}
101
102struct BridgeOriginHandler {
103    cfg: Arc<RwLock<PluginConfig>>,
104}
105
106impl BridgeOriginHandler {
107    fn new(cfg: Arc<RwLock<PluginConfig>>) -> Self {
108        Self { cfg }
109    }
110}
111
112#[async_trait]
113impl Handler for BridgeOriginHandler {
114    async fn hook(&self, param: &Parameter, acc: Option<HookResult>) -> ReturnType {
115        match param {
116            Parameter::ClientConnected(session) => {
117                let client_id = &session.id.client_id;
118                let cfg = self.cfg.read().await;
119
120                let is_ingress = client_id.contains(&cfg.ingress_marker);
121                let is_egress = client_id.contains(&cfg.egress_marker);
122
123                let direction = if is_ingress {
124                    log::debug!("bridge-origin: detected ingress bridge client, client_id={}", client_id,);
125                    Some(BridgeDirection::Ingress)
126                } else if is_egress {
127                    log::debug!("bridge-origin: detected egress bridge client, client_id={}", client_id,);
128                    Some(BridgeDirection::Egress)
129                } else {
130                    None
131                };
132
133                if let Some(direction) = direction {
134                    let extra_attrs = session.extra_attrs.clone();
135                    let key = cfg.attr_key.clone();
136                    extra_attrs.write().await.insert::<BridgeOrigin>(key, BridgeOrigin { direction });
137                }
138            }
139            _ => {
140                log::error!("unimplemented, {param:?}")
141            }
142        }
143        (true, acc)
144    }
145}