rmqtt_auth_http/
lib.rs

1#![deny(unsafe_code)]
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use anyhow::anyhow;
7use async_trait::async_trait;
8use bytestring::ByteString;
9use reqwest::{
10    header::{HeaderMap, CONTENT_TYPE},
11    Method, Response, Url,
12};
13use serde::ser::Serialize;
14use tokio::sync::RwLock;
15
16use rmqtt::{
17    acl::{AuthInfo, Rule},
18    codec::v5::SubscribeAckReason,
19    context::ServerContext,
20    hook::{Handler, HookResult, Parameter, Register, ReturnType, Type},
21    macros::Plugin,
22    plugin::{PackageInfo, Plugin},
23    register,
24    types::{
25        AuthResult, ConnectInfo, DashMap, Disconnect, Id, Message, Password, PublishAclResult, Reason,
26        SubscribeAclResult, Superuser, TimestampMillis, TopicName,
27    },
28    utils::timestamp_millis,
29    Error, Result,
30};
31
32use config::PluginConfig;
33
34mod config;
35
36type HashMap<K, V> = std::collections::HashMap<K, V, ahash::RandomState>;
37
38const CACHEABLE: &str = "X-Cache";
39const SUPERUSER: &str = "X-Superuser";
40// const CACHE_KEY: &str = "ACL-CACHE-MAP";
41
42#[derive(Clone, Debug)]
43struct ResponseResult {
44    permission: Permission,
45    superuser: Superuser,
46    cacheable: Cacheable,
47    expire_at: Option<Duration>,
48    acl_data: Option<serde_json::Value>,
49}
50
51impl ResponseResult {
52    #[inline]
53    fn new(permission: Permission, superuser: Superuser, cacheable: Cacheable) -> ResponseResult {
54        ResponseResult { permission, superuser, cacheable, expire_at: None, acl_data: None }
55    }
56}
57
58#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
59enum Permission {
60    Allow(Superuser),
61    Deny,
62    Ignore,
63}
64
65impl TryFrom<(&str, Superuser)> for Permission {
66    type Error = Error;
67
68    #[inline]
69    fn try_from((s, superuser): (&str, Superuser)) -> std::result::Result<Self, Self::Error> {
70        match s {
71            "allow" => Ok(Permission::Allow(superuser)),
72            "deny" => Ok(Permission::Deny),
73            "ignore" => Ok(Permission::Ignore),
74            _ => Err(anyhow!(
75                "The authentication result is incorrect; only 'allow,' 'deny,' or 'ignore' are permitted.",
76            )),
77        }
78    }
79}
80
81impl Permission {
82    #[inline]
83    fn from(s: &str, superuser: Superuser) -> Self {
84        match s {
85            "allow" => Permission::Allow(superuser),
86            "deny" => Permission::Deny,
87            "ignore" => Permission::Ignore,
88            _ => Permission::Allow(superuser),
89        }
90    }
91}
92
93type Cacheable = Option<i64>;
94
95#[derive(Clone, Debug, PartialEq, Eq, Hash, Copy)]
96enum ACLType {
97    Sub = 1,
98    Pub = 2,
99}
100
101impl ACLType {
102    fn as_str(&self) -> &str {
103        match self {
104            Self::Sub => "1",
105            Self::Pub => "2",
106        }
107    }
108}
109
110register!(AuthHttpPlugin::new);
111
112type Caches = Arc<DashMap<Id, std::collections::BTreeMap<TopicName, (Permission, TimestampMillis)>>>;
113
114#[derive(Plugin)]
115struct AuthHttpPlugin {
116    scx: ServerContext,
117    httpc: reqwest::Client,
118    register: Box<dyn Register>,
119    cfg: Arc<RwLock<PluginConfig>>,
120    caches: Caches,
121}
122
123impl AuthHttpPlugin {
124    #[inline]
125    async fn new<S: Into<String>>(scx: ServerContext, name: S) -> Result<Self> {
126        let name = name.into();
127        let cfg = Arc::new(RwLock::new(scx.plugins.read_config::<PluginConfig>(&name)?));
128        log::debug!("{} AuthHttpPlugin cfg: {:?}", name, cfg.read().await);
129        let register = scx.extends.hook_mgr().register();
130        let caches = Arc::new(DashMap::default());
131        let httpc = new_reqwest_client()?;
132        Ok(Self { scx, httpc, register, cfg, caches })
133    }
134}
135
136#[async_trait]
137impl Plugin for AuthHttpPlugin {
138    #[inline]
139    async fn init(&mut self) -> Result<()> {
140        log::info!("{} init", self.name());
141        let cfg = &self.cfg;
142
143        let priority = cfg.read().await.priority;
144        self.register
145            .add_priority(
146                Type::ClientAuthenticate,
147                priority,
148                Box::new(AuthHandler::new(&self.scx, self.httpc.clone(), cfg, &self.caches)),
149            )
150            .await;
151        self.register
152            .add_priority(
153                Type::ClientSubscribeCheckAcl,
154                priority,
155                Box::new(AuthHandler::new(&self.scx, self.httpc.clone(), cfg, &self.caches)),
156            )
157            .await;
158        self.register
159            .add_priority(
160                Type::MessagePublishCheckAcl,
161                priority,
162                Box::new(AuthHandler::new(&self.scx, self.httpc.clone(), cfg, &self.caches)),
163            )
164            .await;
165        self.register
166            .add(
167                Type::ClientKeepalive,
168                Box::new(AuthHandler::new(&self.scx, self.httpc.clone(), cfg, &self.caches)),
169            )
170            .await;
171
172        self.register
173            .add(
174                Type::ClientDisconnected,
175                Box::new(AuthHandler::new(&self.scx, self.httpc.clone(), cfg, &self.caches)),
176            )
177            .await;
178        Ok(())
179    }
180
181    #[inline]
182    async fn get_config(&self) -> Result<serde_json::Value> {
183        self.cfg.read().await.to_json()
184    }
185
186    #[inline]
187    async fn load_config(&mut self) -> Result<()> {
188        let new_cfg = self.scx.plugins.read_config::<PluginConfig>(self.name())?;
189        *self.cfg.write().await = new_cfg;
190        log::debug!("load_config ok,  {:?}", self.cfg);
191        Ok(())
192    }
193
194    #[inline]
195    async fn start(&mut self) -> Result<()> {
196        log::info!("{} start", self.name());
197        self.register.start().await;
198        Ok(())
199    }
200
201    #[inline]
202    async fn stop(&mut self) -> Result<bool> {
203        log::info!("{} stop", self.name());
204        self.register.stop().await;
205        Ok(true)
206    }
207
208    #[inline]
209    async fn attrs(&self) -> serde_json::Value {
210        let mut stats = HashMap::default();
211        for (i, c) in self.caches.iter().enumerate() {
212            if i < 1000 {
213                stats.insert(c.key().to_string(), c.value().len());
214            }
215        }
216
217        serde_json::json!({
218            "caches": self.caches.len(),
219            "stats": stats,
220        })
221    }
222}
223
224struct AuthHandler {
225    scx: ServerContext,
226    httpc: reqwest::Client,
227    cfg: Arc<RwLock<PluginConfig>>,
228    caches: Caches,
229}
230
231impl AuthHandler {
232    fn new(
233        scx: &ServerContext,
234        httpc: reqwest::Client,
235        cfg: &Arc<RwLock<PluginConfig>>,
236        caches: &Caches,
237    ) -> Self {
238        Self { scx: scx.clone(), httpc, cfg: cfg.clone(), caches: caches.clone() }
239    }
240
241    async fn response_result(resp: Response) -> Result<ResponseResult> {
242        if resp.status().is_success() {
243            let content_type = resp.headers().get(CONTENT_TYPE);
244            let is_json_content_type =
245                content_type.map(|hv| hv.as_bytes().starts_with(b"application/json")).unwrap_or_default();
246            log::debug!("content_type: {content_type:?}");
247            log::debug!("is_json_content_type: {is_json_content_type}");
248            let superuser = resp.headers().contains_key(SUPERUSER);
249            // let acl = resp.headers().contains_key(ACL);
250            let cache_timeout = if let Some(tm) = resp.headers().get(CACHEABLE).and_then(|v| v.to_str().ok())
251            {
252                match tm.parse::<i64>() {
253                    Ok(tm) => Some(tm),
254                    Err(e) => {
255                        log::warn!("Parse X-Cache error, {e:?}");
256                        None
257                    }
258                }
259            } else {
260                None
261            };
262            log::debug!("Cache timeout is {cache_timeout:?}");
263            let resp = if is_json_content_type {
264                let mut body: serde_json::Value = resp.json().await?;
265                log::debug!("body: {body:?}");
266                if let Some(obj) = body.as_object_mut() {
267                    let result = obj
268                        .get("result")
269                        .and_then(|res| res.as_str())
270                        .ok_or_else(|| anyhow!("Authentication result does not exist"))?;
271                    let superuser = obj.get("superuser").and_then(|res| res.as_bool()).unwrap_or(superuser);
272                    let expire_at =
273                        obj.get("expire_at").and_then(|res| res.as_u64().map(Duration::from_secs));
274                    let permission = Permission::try_from((result, superuser))?;
275                    let acl_data = obj.remove("acl");
276
277                    ResponseResult { permission, superuser, cacheable: cache_timeout, expire_at, acl_data }
278                } else if let Some(body) = body.as_str() {
279                    log::debug!("body: {body:?}");
280                    ResponseResult::new(Permission::try_from((body, superuser))?, superuser, cache_timeout)
281                } else {
282                    return Err(anyhow!(format!("The response result is incorrect, {}", body)));
283                }
284            } else {
285                let body = resp.text().await?;
286                log::debug!("body: {body:?}");
287                ResponseResult::new(Permission::from(body.as_str(), superuser), superuser, cache_timeout)
288            };
289            Ok(resp)
290        } else {
291            Ok(ResponseResult::new(Permission::Ignore, false, None))
292        }
293    }
294
295    async fn http_get_request<T: Serialize + ?Sized>(
296        httpc: &reqwest::Client,
297        url: Url,
298        body: &T,
299        headers: HeaderMap,
300        timeout: Duration,
301    ) -> Result<ResponseResult> {
302        log::debug!("http_get_request, timeout: {timeout:?}, url: {url}");
303        match httpc.get(url).headers(headers).timeout(timeout).query(body).send().await {
304            Err(e) => {
305                log::warn!("{e:?}");
306                Err(anyhow!(e))
307            }
308            Ok(resp) => Self::response_result(resp).await,
309        }
310    }
311
312    async fn http_form_request<T: Serialize + ?Sized>(
313        httpc: &reqwest::Client,
314        url: Url,
315        method: Method,
316        body: &T,
317        headers: HeaderMap,
318        timeout: Duration,
319    ) -> Result<ResponseResult> {
320        log::debug!("http_form_request, method: {method:?}, timeout: {timeout:?}, url: {url}");
321        match httpc.request(method, url).headers(headers).timeout(timeout).form(body).send().await {
322            Err(e) => {
323                log::warn!("{e:?}");
324                Err(anyhow!(e))
325            }
326            Ok(resp) => Self::response_result(resp).await,
327        }
328    }
329
330    async fn http_json_request<T: Serialize + ?Sized>(
331        httpc: &reqwest::Client,
332        url: Url,
333        method: Method,
334        body: &T,
335        headers: HeaderMap,
336        timeout: Duration,
337    ) -> Result<ResponseResult> {
338        log::debug!("http_json_request, method: {method:?}, timeout: {timeout:?}, url: {url}");
339        match httpc.request(method, url).headers(headers).timeout(timeout).json(body).send().await {
340            Err(e) => {
341                log::warn!("{e:?}");
342                Err(anyhow!(e))
343            }
344            Ok(resp) => Self::response_result(resp).await,
345        }
346    }
347
348    fn replaces(
349        params: &mut HashMap<String, String>,
350        id: &Id,
351        password: Option<&Password>,
352        protocol: Option<u8>,
353        sub_or_pub: Option<(ACLType, &TopicName)>,
354    ) -> Result<()> {
355        let password =
356            if let Some(p) = password { ByteString::try_from(p.clone())? } else { ByteString::default() };
357        let client_id = id.client_id.as_ref();
358        let username = id.username.as_ref().map(|n| n.as_ref()).unwrap_or("");
359        let remote_addr = id.remote_addr.map(|addr| addr.ip().to_string()).unwrap_or_default();
360        for v in params.values_mut() {
361            *v = v.replace("%u", username);
362            *v = v.replace("%c", client_id);
363            *v = v.replace("%a", &remote_addr);
364            *v = v.replace("%P", &password);
365            if let Some(protocol) = protocol {
366                let mut buffer = itoa::Buffer::new();
367                *v = v.replace("%r", buffer.format(protocol));
368            }
369            if let Some((ref acl_type, topic)) = sub_or_pub {
370                *v = v.replace("%A", acl_type.as_str());
371                *v = v.replace("%t", topic);
372            } else {
373                *v = v.replace("%A", "");
374                *v = v.replace("%t", "");
375            }
376        }
377        Ok(())
378    }
379
380    async fn request(
381        &self,
382        id: &Id,
383        mut req_cfg: config::Req,
384        password: Option<&Password>,
385        protocol: Option<u8>,
386        sub_or_pub: Option<(ACLType, &TopicName)>,
387    ) -> Result<ResponseResult> {
388        log::debug!("{:?} req_cfg.url.path(): {:?}", id, req_cfg.url.path());
389        let (headers, timeout) = {
390            let cfg = self.cfg.read().await;
391            let headers = match (cfg.headers(), req_cfg.headers()) {
392                (Some(def_headers), Some(req_headers)) => {
393                    let mut headers = def_headers.clone();
394                    headers.extend(req_headers.clone());
395                    headers
396                }
397                (Some(def_headers), None) => def_headers.clone(),
398                (None, Some(req_headers)) => req_headers.clone(),
399                (None, None) => HeaderMap::new(),
400            };
401            (headers, cfg.http_timeout)
402        };
403
404        let auth_result = if req_cfg.is_get() {
405            let body = &mut req_cfg.params;
406            Self::replaces(body, id, password, protocol, sub_or_pub)?;
407            Self::http_get_request(&self.httpc, req_cfg.url, body, headers, timeout).await?
408        } else if req_cfg.json_body() {
409            let body = &mut req_cfg.params;
410            Self::replaces(body, id, password, protocol, sub_or_pub)?;
411            Self::http_json_request(&self.httpc, req_cfg.url, req_cfg.method, body, headers, timeout).await?
412        } else {
413            //form body
414            let body = &mut req_cfg.params;
415            Self::replaces(body, id, password, protocol, sub_or_pub)?;
416            Self::http_form_request(&self.httpc, req_cfg.url, req_cfg.method, body, headers, timeout).await?
417        };
418        log::debug!("auth_result: {auth_result:?}");
419        Ok(auth_result)
420    }
421
422    #[inline]
423    async fn auth(&self, connect_info: &ConnectInfo) -> (Permission, Option<AuthInfo>) {
424        if let Some(req) = { self.cfg.read().await.http_auth_req.clone() } {
425            match self
426                .request(
427                    connect_info.id(),
428                    req,
429                    connect_info.password(),
430                    Some(connect_info.proto_ver()),
431                    None,
432                )
433                .await
434            {
435                Ok(auth_res) => {
436                    log::debug!("auth result: {auth_res:?}");
437                    let auth_info = if matches!(auth_res.permission, Permission::Allow(_)) {
438                        if let Some(acl_data) =
439                            auth_res.acl_data.as_ref().and_then(|acl_data| acl_data.as_array())
440                        {
441                            match acl_data
442                                .iter()
443                                .map(|acl| Rule::try_from((acl, connect_info)))
444                                .collect::<Result<Vec<Rule>>>()
445                            {
446                                Ok(rules) => {
447                                    let auth_info = AuthInfo {
448                                        superuser: auth_res.superuser,
449                                        expire_at: auth_res.expire_at,
450                                        rules,
451                                    };
452                                    log::debug!("auth_info: {auth_info:?}");
453                                    Some(auth_info)
454                                }
455                                Err(e) => {
456                                    log::warn!("{} {}", connect_info.id(), e);
457                                    None
458                                }
459                            }
460                        } else {
461                            None
462                        }
463                    } else {
464                        None
465                    };
466                    (auth_res.permission, auth_info)
467                }
468                Err(e) => {
469                    log::warn!("{:?} auth error, {:?}", connect_info.id(), e);
470                    if self.cfg.read().await.deny_if_error {
471                        (Permission::Deny, None)
472                    } else {
473                        (Permission::Ignore, None)
474                    }
475                }
476            }
477        } else {
478            (Permission::Ignore, None)
479        }
480    }
481
482    #[inline]
483    async fn acl(
484        &self,
485        id: &Id,
486        protocol: Option<u8>,
487        sub_or_pub: Option<(ACLType, &TopicName)>,
488    ) -> (Permission, Cacheable) {
489        if let Some(req) = { self.cfg.read().await.http_acl_req.clone() } {
490            match self.request(id, req, None, protocol, sub_or_pub).await {
491                Ok(acl_res) => {
492                    log::debug!("acl result: {acl_res:?}");
493                    (acl_res.permission, acl_res.cacheable)
494                }
495                Err(e) => {
496                    log::warn!("{id:?} acl error, {e:?}");
497                    if self.cfg.read().await.deny_if_error {
498                        (Permission::Deny, None)
499                    } else {
500                        (Permission::Ignore, None)
501                    }
502                }
503            }
504        } else {
505            (Permission::Ignore, None)
506        }
507    }
508
509    #[inline]
510    fn cache_set(&self, id: Id, topic: TopicName, perm: Permission, expire: TimestampMillis) {
511        self.caches.entry(id).or_default().insert(topic, (perm, expire));
512    }
513
514    #[inline]
515    fn cache_get(&self, id: &Id, topic: &TopicName) -> Option<(Permission, TimestampMillis)> {
516        self.caches.get(id).and_then(|c| c.get(topic).map(|(perm, expire)| (*perm, *expire)))
517    }
518
519    #[inline]
520    fn cache_remove(&self, id: &Id) {
521        self.caches.remove(id);
522    }
523}
524
525#[async_trait]
526impl Handler for AuthHandler {
527    async fn hook(&self, param: &Parameter, acc: Option<HookResult>) -> ReturnType {
528        match param {
529            Parameter::ClientAuthenticate(connect_info) => {
530                log::debug!("ClientAuthenticate auth-http");
531                if matches!(
532                    acc,
533                    Some(HookResult::AuthResult(AuthResult::BadUsernameOrPassword))
534                        | Some(HookResult::AuthResult(AuthResult::NotAuthorized))
535                ) {
536                    return (false, acc);
537                }
538
539                return match self.auth(connect_info).await {
540                    (Permission::Allow(superuser), auth_info) => {
541                        if auth_info.as_ref().map(|ai| ai.is_expired()).unwrap_or_default() {
542                            log::warn!("{} authentication information has expired.", connect_info.id());
543                            (false, Some(HookResult::AuthResult(AuthResult::NotAuthorized)))
544                        } else {
545                            (false, Some(HookResult::AuthResult(AuthResult::Allow(superuser, auth_info))))
546                        }
547                    }
548                    (Permission::Deny, _) => {
549                        (false, Some(HookResult::AuthResult(AuthResult::BadUsernameOrPassword)))
550                    }
551                    (Permission::Ignore, _) => (true, None),
552                };
553            }
554
555            Parameter::ClientSubscribeCheckAcl(session, subscribe) => {
556                if let Some(HookResult::SubscribeAclResult(acl_result)) = &acc {
557                    if acl_result.failure() {
558                        return (false, acc);
559                    }
560                }
561
562                if let Some(auth_info) = &session.auth_info {
563                    if let Some(acl_res) = auth_info.subscribe_acl(subscribe).await {
564                        return acl_res;
565                    }
566                }
567
568                //Permission, Cacheable
569                let (acl_res, _) = self
570                    .acl(
571                        &session.id,
572                        session.protocol().await.ok(),
573                        Some((ACLType::Sub, &subscribe.topic_filter)),
574                    )
575                    .await;
576                return match acl_res {
577                    Permission::Allow(_) => (
578                        false,
579                        Some(HookResult::SubscribeAclResult(SubscribeAclResult::new_success(
580                            subscribe.opts.qos(),
581                            None,
582                        ))),
583                    ),
584                    Permission::Deny => (
585                        false,
586                        Some(HookResult::SubscribeAclResult(SubscribeAclResult::new_failure(
587                            SubscribeAckReason::NotAuthorized,
588                        ))),
589                    ),
590                    Permission::Ignore => (true, None),
591                };
592            }
593
594            Parameter::MessagePublishCheckAcl(session, publish) => {
595                log::debug!("MessagePublishCheckAcl");
596                if let Some(HookResult::PublishAclResult(PublishAclResult::Rejected(_))) = &acc {
597                    return (false, acc);
598                }
599
600                if let Some(auth_info) = &session.auth_info {
601                    if let Some(acl_res) =
602                        auth_info.publish_acl(publish, self.cfg.read().await.disconnect_if_pub_rejected).await
603                    {
604                        return acl_res;
605                    }
606                }
607
608                let acl_res = if let Some((acl_res, expire)) = self.cache_get(&session.id, &publish.topic) {
609                    if expire < 0 || timestamp_millis() < expire {
610                        Some(acl_res)
611                    } else {
612                        None
613                    }
614                } else {
615                    None
616                };
617
618                let acl_res = if let Some(acl_res) = acl_res {
619                    acl_res
620                } else {
621                    //Permission, Cacheable
622                    let (acl_res, cacheable) = self
623                        .acl(&session.id, session.protocol().await.ok(), Some((ACLType::Pub, &publish.topic)))
624                        .await;
625                    if let Some(tm) = cacheable {
626                        let expire = if tm < 0 { tm } else { timestamp_millis() + tm };
627
628                        self.cache_set(session.id.clone(), publish.topic.clone(), acl_res, expire);
629                    }
630                    acl_res
631                };
632
633                return match acl_res {
634                    Permission::Allow(_) => {
635                        (false, Some(HookResult::PublishAclResult(PublishAclResult::Allow)))
636                    }
637                    Permission::Deny => (
638                        false,
639                        Some(HookResult::PublishAclResult(PublishAclResult::Rejected(
640                            self.cfg.read().await.disconnect_if_pub_rejected,
641                        ))),
642                    ),
643                    Permission::Ignore => (true, None),
644                };
645            }
646
647            Parameter::ClientKeepalive(s, _) => {
648                if let Some(auth) = &s.auth_info {
649                    log::debug!("Keepalive auth-http, is_expired: {:?}", auth.is_expired());
650                    if auth.is_expired() && self.cfg.read().await.disconnect_if_expiry {
651                        if let Some(tx) = self.scx.extends.shared().await.entry(s.id().clone()).tx() {
652                            if let Err(e) = tx.unbounded_send(Message::Closed(Reason::ConnectDisconnect(
653                                Some(Disconnect::Other("Http Auth expired".into())),
654                            ))) {
655                                log::warn!("{} {}", s.id(), e);
656                            }
657                        }
658                    }
659                }
660            }
661
662            Parameter::ClientDisconnected(s, _) => {
663                self.cache_remove(&s.id);
664            }
665
666            _ => {
667                log::error!("unimplemented, {param:?}")
668            }
669        }
670        (true, acc)
671    }
672}
673
674fn new_reqwest_client() -> Result<reqwest::Client> {
675    reqwest::Client::builder()
676        .connect_timeout(Duration::from_secs(10))
677        .timeout(Duration::from_secs(10))
678        .build()
679        .map_err(|e| anyhow!(e))
680}