1#![deny(unsafe_code)]
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use anyhow::anyhow;
7use async_trait::async_trait;
8use bytestring::ByteString;
9use reqwest::{
10 header::{HeaderMap, CONTENT_TYPE},
11 Method, Response, Url,
12};
13use serde::ser::Serialize;
14use tokio::sync::RwLock;
15
16use rmqtt::{
17 acl::{AuthInfo, Rule},
18 codec::v5::SubscribeAckReason,
19 context::ServerContext,
20 hook::{Handler, HookResult, Parameter, Register, ReturnType, Type},
21 macros::Plugin,
22 plugin::{PackageInfo, Plugin},
23 register,
24 types::{
25 AuthResult, ConnectInfo, DashMap, Disconnect, Id, Message, Password, PublishAclResult, Reason,
26 SubscribeAclResult, Superuser, TimestampMillis, TopicName,
27 },
28 utils::timestamp_millis,
29 Error, Result,
30};
31
32use config::PluginConfig;
33
34mod config;
35
36type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;
37
38const CACHEABLE: &str = "X-Cache";
39const SUPERUSER: &str = "X-Superuser";
40#[derive(Clone, Debug)]
43struct ResponseResult {
44 permission: Permission,
45 superuser: Superuser,
46 cacheable: Cacheable,
47 expire_at: Option<Duration>,
48 acl_data: Option<serde_json::Value>,
49}
50
51impl ResponseResult {
52 #[inline]
53 fn new(permission: Permission, superuser: Superuser, cacheable: Cacheable) -> ResponseResult {
54 ResponseResult { permission, superuser, cacheable, expire_at: None, acl_data: None }
55 }
56}
57
58#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
59enum Permission {
60 Allow(Superuser),
61 Deny,
62 Ignore,
63}
64
65impl TryFrom<(&str, Superuser)> for Permission {
66 type Error = Error;
67
68 #[inline]
69 fn try_from((s, superuser): (&str, Superuser)) -> std::result::Result<Self, Self::Error> {
70 match s {
71 "allow" => Ok(Permission::Allow(superuser)),
72 "deny" => Ok(Permission::Deny),
73 "ignore" => Ok(Permission::Ignore),
74 _ => Err(anyhow!(
75 "The authentication result is incorrect; only 'allow,' 'deny,' or 'ignore' are permitted.",
76 )),
77 }
78 }
79}
80
81impl Permission {
82 #[inline]
83 fn from(s: &str, superuser: Superuser) -> Self {
84 match s {
85 "allow" => Permission::Allow(superuser),
86 "deny" => Permission::Deny,
87 "ignore" => Permission::Ignore,
88 _ => Permission::Allow(superuser),
89 }
90 }
91}
92
93type Cacheable = Option<i64>;
94
95#[derive(Clone, Debug, PartialEq, Eq, Hash, Copy)]
96enum ACLType {
97 Sub = 1,
98 Pub = 2,
99}
100
101impl ACLType {
102 fn as_str(&self) -> &str {
103 match self {
104 Self::Sub => "1",
105 Self::Pub => "2",
106 }
107 }
108}
109
110register!(AuthHttpPlugin::new);
111
112type Caches = Arc<DashMap<Id, std::collections::BTreeMap<TopicName, (Permission, TimestampMillis)>>>;
113
114#[derive(Plugin)]
115struct AuthHttpPlugin {
116 scx: ServerContext,
117 httpc: reqwest::Client,
118 register: Box<dyn Register>,
119 cfg: Arc<RwLock<PluginConfig>>,
120 caches: Caches,
121}
122
123impl AuthHttpPlugin {
124 #[inline]
125 async fn new<S: Into<String>>(scx: ServerContext, name: S) -> Result<Self> {
126 let name = name.into();
127 let cfg = Arc::new(RwLock::new(scx.plugins.read_config::<PluginConfig>(&name)?));
128 log::debug!("{} AuthHttpPlugin cfg: {:?}", name, cfg.read().await);
129 let register = scx.extends.hook_mgr().register();
130 let caches = Arc::new(DashMap::default());
131 let httpc = new_reqwest_client()?;
132 Ok(Self { scx, httpc, register, cfg, caches })
133 }
134}
135
136#[async_trait]
137impl Plugin for AuthHttpPlugin {
138 #[inline]
139 async fn init(&mut self) -> Result<()> {
140 log::info!("{} init", self.name());
141 let cfg = &self.cfg;
142
143 let priority = cfg.read().await.priority;
144 self.register
145 .add_priority(
146 Type::ClientAuthenticate,
147 priority,
148 Box::new(AuthHandler::new(&self.scx, self.httpc.clone(), cfg, &self.caches)),
149 )
150 .await;
151 self.register
152 .add_priority(
153 Type::ClientSubscribeCheckAcl,
154 priority,
155 Box::new(AuthHandler::new(&self.scx, self.httpc.clone(), cfg, &self.caches)),
156 )
157 .await;
158 self.register
159 .add_priority(
160 Type::MessagePublishCheckAcl,
161 priority,
162 Box::new(AuthHandler::new(&self.scx, self.httpc.clone(), cfg, &self.caches)),
163 )
164 .await;
165 self.register
166 .add(
167 Type::ClientKeepalive,
168 Box::new(AuthHandler::new(&self.scx, self.httpc.clone(), cfg, &self.caches)),
169 )
170 .await;
171
172 self.register
173 .add(
174 Type::ClientDisconnected,
175 Box::new(AuthHandler::new(&self.scx, self.httpc.clone(), cfg, &self.caches)),
176 )
177 .await;
178 Ok(())
179 }
180
181 #[inline]
182 async fn get_config(&self) -> Result<serde_json::Value> {
183 self.cfg.read().await.to_json()
184 }
185
186 #[inline]
187 async fn load_config(&mut self) -> Result<()> {
188 let new_cfg = self.scx.plugins.read_config::<PluginConfig>(self.name())?;
189 *self.cfg.write().await = new_cfg;
190 log::debug!("load_config ok, {:?}", self.cfg);
191 Ok(())
192 }
193
194 #[inline]
195 async fn start(&mut self) -> Result<()> {
196 log::info!("{} start", self.name());
197 self.register.start().await;
198 Ok(())
199 }
200
201 #[inline]
202 async fn stop(&mut self) -> Result<bool> {
203 log::info!("{} stop", self.name());
204 self.register.stop().await;
205 Ok(true)
206 }
207
208 #[inline]
209 async fn attrs(&self) -> serde_json::Value {
210 let mut stats = HashMap::default();
211 for (i, c) in self.caches.iter().enumerate() {
212 if i < 1000 {
213 stats.insert(c.key().to_string(), c.value().len());
214 }
215 }
216
217 serde_json::json!({
218 "caches": self.caches.len(),
219 "stats": stats,
220 })
221 }
222}
223
224struct AuthHandler {
225 scx: ServerContext,
226 httpc: reqwest::Client,
227 cfg: Arc<RwLock<PluginConfig>>,
228 caches: Caches,
229}
230
231impl AuthHandler {
232 fn new(
233 scx: &ServerContext,
234 httpc: reqwest::Client,
235 cfg: &Arc<RwLock<PluginConfig>>,
236 caches: &Caches,
237 ) -> Self {
238 Self { scx: scx.clone(), httpc, cfg: cfg.clone(), caches: caches.clone() }
239 }
240
241 async fn response_result(resp: Response) -> Result<ResponseResult> {
242 if resp.status().is_success() {
243 let content_type = resp.headers().get(CONTENT_TYPE);
244 let is_json_content_type =
245 content_type.map(|hv| hv.as_bytes().starts_with(b"application/json")).unwrap_or_default();
246 log::debug!("content_type: {content_type:?}");
247 log::debug!("is_json_content_type: {is_json_content_type}");
248 let superuser = resp.headers().contains_key(SUPERUSER);
249 let cache_timeout = if let Some(tm) = resp.headers().get(CACHEABLE).and_then(|v| v.to_str().ok())
251 {
252 match tm.parse::<i64>() {
253 Ok(tm) => Some(tm),
254 Err(e) => {
255 log::warn!("Parse X-Cache error, {e:?}");
256 None
257 }
258 }
259 } else {
260 None
261 };
262 log::debug!("Cache timeout is {cache_timeout:?}");
263 let resp = if is_json_content_type {
264 let mut body: serde_json::Value = resp.json().await?;
265 log::debug!("body: {body:?}");
266 if let Some(obj) = body.as_object_mut() {
267 let result = obj
268 .get("result")
269 .and_then(|res| res.as_str())
270 .ok_or_else(|| anyhow!("Authentication result does not exist"))?;
271 let superuser = obj.get("superuser").and_then(|res| res.as_bool()).unwrap_or(superuser);
272 let expire_at =
273 obj.get("expire_at").and_then(|res| res.as_u64().map(Duration::from_secs));
274 let permission = Permission::try_from((result, superuser))?;
275 let acl_data = obj.remove("acl");
276
277 ResponseResult { permission, superuser, cacheable: cache_timeout, expire_at, acl_data }
278 } else if let Some(body) = body.as_str() {
279 log::debug!("body: {body:?}");
280 ResponseResult::new(Permission::try_from((body, superuser))?, superuser, cache_timeout)
281 } else {
282 return Err(anyhow!(format!("The response result is incorrect, {}", body)));
283 }
284 } else {
285 let body = resp.text().await?;
286 log::debug!("body: {body:?}");
287 ResponseResult::new(Permission::from(body.as_str(), superuser), superuser, cache_timeout)
288 };
289 Ok(resp)
290 } else {
291 Ok(ResponseResult::new(Permission::Ignore, false, None))
292 }
293 }
294
295 async fn http_get_request<T: Serialize + ?Sized>(
296 httpc: &reqwest::Client,
297 url: Url,
298 body: &T,
299 headers: HeaderMap,
300 timeout: Duration,
301 ) -> Result<ResponseResult> {
302 log::debug!("http_get_request, timeout: {timeout:?}, url: {url}");
303 match httpc.get(url).headers(headers).timeout(timeout).query(body).send().await {
304 Err(e) => {
305 log::warn!("{e:?}");
306 Err(anyhow!(e))
307 }
308 Ok(resp) => Self::response_result(resp).await,
309 }
310 }
311
312 async fn http_form_request<T: Serialize + ?Sized>(
313 httpc: &reqwest::Client,
314 url: Url,
315 method: Method,
316 body: &T,
317 headers: HeaderMap,
318 timeout: Duration,
319 ) -> Result<ResponseResult> {
320 log::debug!("http_form_request, method: {method:?}, timeout: {timeout:?}, url: {url}");
321 match httpc.request(method, url).headers(headers).timeout(timeout).form(body).send().await {
322 Err(e) => {
323 log::warn!("{e:?}");
324 Err(anyhow!(e))
325 }
326 Ok(resp) => Self::response_result(resp).await,
327 }
328 }
329
330 async fn http_json_request<T: Serialize + ?Sized>(
331 httpc: &reqwest::Client,
332 url: Url,
333 method: Method,
334 body: &T,
335 headers: HeaderMap,
336 timeout: Duration,
337 ) -> Result<ResponseResult> {
338 log::debug!("http_json_request, method: {method:?}, timeout: {timeout:?}, url: {url}");
339 match httpc.request(method, url).headers(headers).timeout(timeout).json(body).send().await {
340 Err(e) => {
341 log::warn!("{e:?}");
342 Err(anyhow!(e))
343 }
344 Ok(resp) => Self::response_result(resp).await,
345 }
346 }
347
348 fn replaces(
349 params: &mut HashMap<String, String>,
350 id: &Id,
351 password: Option<&Password>,
352 protocol: Option<u8>,
353 sub_or_pub: Option<(ACLType, &TopicName)>,
354 ) -> Result<()> {
355 let password =
356 if let Some(p) = password { ByteString::try_from(p.clone())? } else { ByteString::default() };
357 let client_id = id.client_id.as_ref();
358 let username = id.username.as_ref().map(|n| n.as_ref()).unwrap_or("");
359 let remote_addr = id.remote_addr.map(|addr| addr.ip().to_string()).unwrap_or_default();
360 for v in params.values_mut() {
361 *v = v.replace("%u", username);
362 *v = v.replace("%c", client_id);
363 *v = v.replace("%a", &remote_addr);
364 *v = v.replace("%P", &password);
365 if let Some(protocol) = protocol {
366 let mut buffer = itoa::Buffer::new();
367 *v = v.replace("%r", buffer.format(protocol));
368 }
369 if let Some((ref acl_type, topic)) = sub_or_pub {
370 *v = v.replace("%A", acl_type.as_str());
371 *v = v.replace("%t", topic);
372 } else {
373 *v = v.replace("%A", "");
374 *v = v.replace("%t", "");
375 }
376 }
377 Ok(())
378 }
379
380 async fn request(
381 &self,
382 id: &Id,
383 mut req_cfg: config::Req,
384 password: Option<&Password>,
385 protocol: Option<u8>,
386 sub_or_pub: Option<(ACLType, &TopicName)>,
387 ) -> Result<ResponseResult> {
388 log::debug!("{:?} req_cfg.url.path(): {:?}", id, req_cfg.url.path());
389 let (headers, timeout) = {
390 let cfg = self.cfg.read().await;
391 let headers = match (cfg.headers(), req_cfg.headers()) {
392 (Some(def_headers), Some(req_headers)) => {
393 let mut headers = def_headers.clone();
394 headers.extend(req_headers.clone());
395 headers
396 }
397 (Some(def_headers), None) => def_headers.clone(),
398 (None, Some(req_headers)) => req_headers.clone(),
399 (None, None) => HeaderMap::new(),
400 };
401 (headers, cfg.http_timeout)
402 };
403
404 let auth_result = if req_cfg.is_get() {
405 let body = &mut req_cfg.params;
406 Self::replaces(body, id, password, protocol, sub_or_pub)?;
407 Self::http_get_request(&self.httpc, req_cfg.url, body, headers, timeout).await?
408 } else if req_cfg.json_body() {
409 let body = &mut req_cfg.params;
410 Self::replaces(body, id, password, protocol, sub_or_pub)?;
411 Self::http_json_request(&self.httpc, req_cfg.url, req_cfg.method, body, headers, timeout).await?
412 } else {
413 let body = &mut req_cfg.params;
415 Self::replaces(body, id, password, protocol, sub_or_pub)?;
416 Self::http_form_request(&self.httpc, req_cfg.url, req_cfg.method, body, headers, timeout).await?
417 };
418 log::debug!("auth_result: {auth_result:?}");
419 Ok(auth_result)
420 }
421
422 #[inline]
423 async fn auth(&self, connect_info: &ConnectInfo) -> (Permission, Option<AuthInfo>) {
424 if let Some(req) = { self.cfg.read().await.http_auth_req.clone() } {
425 match self
426 .request(
427 connect_info.id(),
428 req,
429 connect_info.password(),
430 Some(connect_info.proto_ver()),
431 None,
432 )
433 .await
434 {
435 Ok(auth_res) => {
436 log::debug!("auth result: {auth_res:?}");
437 let auth_info = if matches!(auth_res.permission, Permission::Allow(_)) {
438 if let Some(acl_data) =
439 auth_res.acl_data.as_ref().and_then(|acl_data| acl_data.as_array())
440 {
441 match acl_data
442 .iter()
443 .map(|acl| Rule::try_from((acl, connect_info)))
444 .collect::<Result<Vec<Rule>>>()
445 {
446 Ok(rules) => {
447 let auth_info = AuthInfo {
448 superuser: auth_res.superuser,
449 expire_at: auth_res.expire_at,
450 rules,
451 };
452 log::debug!("auth_info: {auth_info:?}");
453 Some(auth_info)
454 }
455 Err(e) => {
456 log::warn!("{} {}", connect_info.id(), e);
457 None
458 }
459 }
460 } else {
461 None
462 }
463 } else {
464 None
465 };
466 (auth_res.permission, auth_info)
467 }
468 Err(e) => {
469 log::warn!("{:?} auth error, {:?}", connect_info.id(), e);
470 if self.cfg.read().await.deny_if_error {
471 (Permission::Deny, None)
472 } else {
473 (Permission::Ignore, None)
474 }
475 }
476 }
477 } else {
478 (Permission::Ignore, None)
479 }
480 }
481
482 #[inline]
483 async fn acl(
484 &self,
485 id: &Id,
486 protocol: Option<u8>,
487 sub_or_pub: Option<(ACLType, &TopicName)>,
488 ) -> (Permission, Cacheable) {
489 if let Some(req) = { self.cfg.read().await.http_acl_req.clone() } {
490 match self.request(id, req, None, protocol, sub_or_pub).await {
491 Ok(acl_res) => {
492 log::debug!("acl result: {acl_res:?}");
493 (acl_res.permission, acl_res.cacheable)
494 }
495 Err(e) => {
496 log::warn!("{id:?} acl error, {e:?}");
497 if self.cfg.read().await.deny_if_error {
498 (Permission::Deny, None)
499 } else {
500 (Permission::Ignore, None)
501 }
502 }
503 }
504 } else {
505 (Permission::Ignore, None)
506 }
507 }
508
509 #[inline]
510 fn cache_set(&self, id: Id, topic: TopicName, perm: Permission, expire: TimestampMillis) {
511 self.caches.entry(id).or_default().insert(topic, (perm, expire));
512 }
513
514 #[inline]
515 fn cache_get(&self, id: &Id, topic: &TopicName) -> Option<(Permission, TimestampMillis)> {
516 self.caches.get(id).and_then(|c| c.get(topic).map(|(perm, expire)| (*perm, *expire)))
517 }
518
519 #[inline]
520 fn cache_remove(&self, id: &Id) {
521 self.caches.remove(id);
522 }
523}
524
525#[async_trait]
526impl Handler for AuthHandler {
527 async fn hook(&self, param: &Parameter, acc: Option<HookResult>) -> ReturnType {
528 match param {
529 Parameter::ClientAuthenticate(connect_info) => {
530 log::debug!("ClientAuthenticate auth-http");
531 if matches!(
532 acc,
533 Some(HookResult::AuthResult(AuthResult::BadUsernameOrPassword))
534 | Some(HookResult::AuthResult(AuthResult::NotAuthorized))
535 ) {
536 return (false, acc);
537 }
538
539 return match self.auth(connect_info).await {
540 (Permission::Allow(superuser), auth_info) => {
541 if auth_info.as_ref().map(|ai| ai.is_expired()).unwrap_or_default() {
542 log::warn!("{} authentication information has expired.", connect_info.id());
543 (false, Some(HookResult::AuthResult(AuthResult::NotAuthorized)))
544 } else {
545 (false, Some(HookResult::AuthResult(AuthResult::Allow(superuser, auth_info))))
546 }
547 }
548 (Permission::Deny, _) => {
549 (false, Some(HookResult::AuthResult(AuthResult::BadUsernameOrPassword)))
550 }
551 (Permission::Ignore, _) => (true, None),
552 };
553 }
554
555 Parameter::ClientSubscribeCheckAcl(session, subscribe) => {
556 if let Some(HookResult::SubscribeAclResult(acl_result)) = &acc {
557 if acl_result.failure() {
558 return (false, acc);
559 }
560 }
561
562 if let Some(auth_info) = &session.auth_info {
563 if let Some(acl_res) = auth_info.subscribe_acl(subscribe).await {
564 return acl_res;
565 }
566 }
567
568 let (acl_res, _) = self
570 .acl(
571 &session.id,
572 session.protocol().await.ok(),
573 Some((ACLType::Sub, &subscribe.topic_filter)),
574 )
575 .await;
576 return match acl_res {
577 Permission::Allow(_) => (
578 false,
579 Some(HookResult::SubscribeAclResult(SubscribeAclResult::new_success(
580 subscribe.opts.qos(),
581 None,
582 ))),
583 ),
584 Permission::Deny => (
585 false,
586 Some(HookResult::SubscribeAclResult(SubscribeAclResult::new_failure(
587 SubscribeAckReason::NotAuthorized,
588 ))),
589 ),
590 Permission::Ignore => (true, None),
591 };
592 }
593
594 Parameter::MessagePublishCheckAcl(session, publish) => {
595 log::debug!("MessagePublishCheckAcl");
596 if let Some(HookResult::PublishAclResult(PublishAclResult::Rejected(_))) = &acc {
597 return (false, acc);
598 }
599
600 if let Some(auth_info) = &session.auth_info {
601 if let Some(acl_res) =
602 auth_info.publish_acl(publish, self.cfg.read().await.disconnect_if_pub_rejected).await
603 {
604 return acl_res;
605 }
606 }
607
608 let acl_res = if let Some((acl_res, expire)) = self.cache_get(&session.id, &publish.topic) {
609 if expire < 0 || timestamp_millis() < expire {
610 Some(acl_res)
611 } else {
612 None
613 }
614 } else {
615 None
616 };
617
618 let acl_res = if let Some(acl_res) = acl_res {
619 acl_res
620 } else {
621 let (acl_res, cacheable) = self
623 .acl(&session.id, session.protocol().await.ok(), Some((ACLType::Pub, &publish.topic)))
624 .await;
625 if let Some(tm) = cacheable {
626 let expire = if tm < 0 { tm } else { timestamp_millis() + tm };
627
628 self.cache_set(session.id.clone(), publish.topic.clone(), acl_res, expire);
629 }
630 acl_res
631 };
632
633 return match acl_res {
634 Permission::Allow(_) => {
635 (false, Some(HookResult::PublishAclResult(PublishAclResult::Allow)))
636 }
637 Permission::Deny => (
638 false,
639 Some(HookResult::PublishAclResult(PublishAclResult::Rejected(
640 self.cfg.read().await.disconnect_if_pub_rejected,
641 ))),
642 ),
643 Permission::Ignore => (true, None),
644 };
645 }
646
647 Parameter::ClientKeepalive(s, _) => {
648 if let Some(auth) = &s.auth_info {
649 log::debug!("Keepalive auth-http, is_expired: {:?}", auth.is_expired());
650 if auth.is_expired() && self.cfg.read().await.disconnect_if_expiry {
651 if let Some(tx) = self.scx.extends.shared().await.entry(s.id().clone()).tx() {
652 if let Err(e) = tx.unbounded_send(Message::Closed(Reason::ConnectDisconnect(
653 Some(Disconnect::Other("Http Auth expired".into())),
654 ))) {
655 log::warn!("{} {}", s.id(), e);
656 }
657 }
658 }
659 }
660 }
661
662 Parameter::ClientDisconnected(s, _) => {
663 self.cache_remove(&s.id);
664 }
665
666 _ => {
667 log::error!("unimplemented, {param:?}")
668 }
669 }
670 (true, acc)
671 }
672}
673
674fn new_reqwest_client() -> Result<reqwest::Client> {
675 reqwest::Client::builder()
676 .connect_timeout(Duration::from_secs(10))
677 .timeout(Duration::from_secs(10))
678 .build()
679 .map_err(|e| anyhow!(e))
680}