rmqtt_bridge_origin/
lib.rs1#![deny(unsafe_code)]
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use tokio::sync::RwLock;
7
8use rmqtt::{
9 context::ServerContext,
10 hook::{Handler, HookResult, Parameter, Register, ReturnType, Type},
11 macros::Plugin,
12 plugin::{PackageInfo, Plugin as _PluginTrait},
13 register, Result,
14};
15
16use config::PluginConfig;
17
18mod config;
19
20register!(BridgeOriginPlugin::new);
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum BridgeDirection {
24 Ingress,
25 Egress,
26}
27
28#[derive(Debug, Clone, Copy)]
29pub struct BridgeOrigin {
30 pub direction: BridgeDirection,
31}
32
33impl BridgeOrigin {
34 #[inline]
35 pub fn is_ingress(&self) -> bool {
36 matches!(self.direction, BridgeDirection::Ingress)
37 }
38
39 #[inline]
40 pub fn is_egress(&self) -> bool {
41 matches!(self.direction, BridgeDirection::Egress)
42 }
43}
44
45#[derive(Plugin)]
46struct BridgeOriginPlugin {
47 scx: ServerContext,
48 register: Box<dyn Register>,
49 cfg: Arc<RwLock<PluginConfig>>,
50}
51
52impl BridgeOriginPlugin {
53 #[inline]
54 async fn new<N: Into<String>>(scx: ServerContext, name: N) -> Result<Self> {
55 let name: String = name.into();
56 let cfg = scx.plugins.read_config_default::<PluginConfig>(&name);
57 let cfg = Arc::new(RwLock::new(cfg?));
58 let register = scx.extends.hook_mgr().register();
59 log::info!("{name} BridgeOriginPlugin cfg: {cfg:?}", cfg = cfg.read().await);
60 Ok(Self { scx, register, cfg })
61 }
62}
63
64#[async_trait]
65impl _PluginTrait for BridgeOriginPlugin {
66 #[inline]
67 async fn init(&mut self) -> Result<()> {
68 log::info!("{} init", self.name());
69 self.register.add(Type::ClientConnected, Box::new(BridgeOriginHandler::new(self.cfg.clone()))).await;
70 log::info!("{} registered ClientConnected hook", self.name());
71 Ok(())
72 }
73
74 #[inline]
75 async fn get_config(&self) -> Result<serde_json::Value> {
76 self.cfg.read().await.to_json()
77 }
78
79 #[inline]
80 async fn load_config(&mut self) -> Result<()> {
81 let new_cfg = self.scx.plugins.read_config::<PluginConfig>(self.name())?;
82 *self.cfg.write().await = new_cfg;
83 log::info!("{} load_config ok", self.name());
84 Ok(())
85 }
86
87 #[inline]
88 async fn start(&mut self) -> Result<()> {
89 log::info!("{} start", self.name());
90 self.register.start().await;
91 Ok(())
92 }
93
94 #[inline]
95 async fn stop(&mut self) -> Result<bool> {
96 log::info!("{} stop", self.name());
97 self.register.stop().await;
98 Ok(false)
99 }
100}
101
102struct BridgeOriginHandler {
103 cfg: Arc<RwLock<PluginConfig>>,
104}
105
106impl BridgeOriginHandler {
107 fn new(cfg: Arc<RwLock<PluginConfig>>) -> Self {
108 Self { cfg }
109 }
110}
111
112#[async_trait]
113impl Handler for BridgeOriginHandler {
114 async fn hook(&self, param: &Parameter, acc: Option<HookResult>) -> ReturnType {
115 match param {
116 Parameter::ClientConnected(session) => {
117 let client_id = &session.id.client_id;
118 let cfg = self.cfg.read().await;
119
120 let is_ingress = client_id.contains(&cfg.ingress_marker);
121 let is_egress = client_id.contains(&cfg.egress_marker);
122
123 let direction = if is_ingress {
124 log::debug!("bridge-origin: detected ingress bridge client, client_id={}", client_id,);
125 Some(BridgeDirection::Ingress)
126 } else if is_egress {
127 log::debug!("bridge-origin: detected egress bridge client, client_id={}", client_id,);
128 Some(BridgeDirection::Egress)
129 } else {
130 None
131 };
132
133 if let Some(direction) = direction {
134 let extra_attrs = session.extra_attrs.clone();
135 let key = cfg.attr_key.clone();
136 extra_attrs.write().await.insert::<BridgeOrigin>(key, BridgeOrigin { direction });
137 }
138 }
139 _ => {
140 log::error!("unimplemented, {param:?}")
141 }
142 }
143 (true, acc)
144 }
145}