mqtt5 0.31.2

Complete MQTT v5.0 platform with high-performance async client and full-featured broker supporting TCP, TLS, WebSocket, authentication, bridging, and resource monitoring
Documentation
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;

use crate::error::Result;
use crate::packet::connect::ConnectPacket;
use crate::protocol::v5::reason_codes::ReasonCode;

use super::{AuthProvider, AuthResult, EnhancedAuthResult};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AuthorizationMode {
    #[default]
    PrimaryOnly,
    Or,
    And,
}

pub struct CompositeAuthProvider {
    primary: Arc<dyn AuthProvider>,
    fallback: Arc<dyn AuthProvider>,
    authorization_mode: AuthorizationMode,
}

impl CompositeAuthProvider {
    #[must_use]
    pub fn new(primary: Arc<dyn AuthProvider>, fallback: Arc<dyn AuthProvider>) -> Self {
        Self {
            primary,
            fallback,
            authorization_mode: AuthorizationMode::default(),
        }
    }

    #[must_use]
    pub fn with_authorization_mode(mut self, mode: AuthorizationMode) -> Self {
        self.authorization_mode = mode;
        self
    }
}

impl AuthProvider for CompositeAuthProvider {
    fn authenticate<'a>(
        &'a self,
        connect: &'a ConnectPacket,
        client_addr: SocketAddr,
    ) -> Pin<Box<dyn Future<Output = Result<AuthResult>> + Send + 'a>> {
        Box::pin(async move {
            let primary_result = self.primary.authenticate(connect, client_addr).await?;
            if primary_result.authenticated {
                return Ok(primary_result);
            }
            if primary_result.reason_code == ReasonCode::BadAuthenticationMethod {
                return self.fallback.authenticate(connect, client_addr).await;
            }
            Ok(primary_result)
        })
    }

    fn authorize_publish<'a>(
        &'a self,
        client_id: &str,
        user_id: Option<&'a str>,
        topic: &'a str,
    ) -> Pin<Box<dyn Future<Output = bool> + Send + 'a>> {
        let client_id = client_id.to_string();
        let mode = self.authorization_mode;
        Box::pin(async move {
            let primary = self
                .primary
                .authorize_publish(&client_id, user_id, topic)
                .await;
            match mode {
                AuthorizationMode::PrimaryOnly => primary,
                AuthorizationMode::Or => {
                    primary
                        || self
                            .fallback
                            .authorize_publish(&client_id, user_id, topic)
                            .await
                }
                AuthorizationMode::And => {
                    primary
                        && self
                            .fallback
                            .authorize_publish(&client_id, user_id, topic)
                            .await
                }
            }
        })
    }

    fn authorize_subscribe<'a>(
        &'a self,
        client_id: &str,
        user_id: Option<&'a str>,
        topic_filter: &'a str,
    ) -> Pin<Box<dyn Future<Output = bool> + Send + 'a>> {
        let client_id = client_id.to_string();
        let mode = self.authorization_mode;
        Box::pin(async move {
            let primary = self
                .primary
                .authorize_subscribe(&client_id, user_id, topic_filter)
                .await;
            match mode {
                AuthorizationMode::PrimaryOnly => primary,
                AuthorizationMode::Or => {
                    primary
                        || self
                            .fallback
                            .authorize_subscribe(&client_id, user_id, topic_filter)
                            .await
                }
                AuthorizationMode::And => {
                    primary
                        && self
                            .fallback
                            .authorize_subscribe(&client_id, user_id, topic_filter)
                            .await
                }
            }
        })
    }

    fn supports_enhanced_auth(&self) -> bool {
        self.primary.supports_enhanced_auth()
    }

    fn authenticate_enhanced<'a>(
        &'a self,
        auth_method: &'a str,
        auth_data: Option<&'a [u8]>,
        client_id: &'a str,
    ) -> Pin<Box<dyn Future<Output = Result<EnhancedAuthResult>> + Send + 'a>> {
        self.primary
            .authenticate_enhanced(auth_method, auth_data, client_id)
    }

    fn reauthenticate<'a>(
        &'a self,
        auth_method: &'a str,
        auth_data: Option<&'a [u8]>,
        client_id: &'a str,
        user_id: Option<&'a str>,
    ) -> Pin<Box<dyn Future<Output = Result<EnhancedAuthResult>> + Send + 'a>> {
        self.primary
            .reauthenticate(auth_method, auth_data, client_id, user_id)
    }

    fn cleanup_session<'a>(
        &'a self,
        user_id: &'a str,
    ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
        Box::pin(async move {
            self.primary.cleanup_session(user_id).await;
            self.fallback.cleanup_session(user_id).await;
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::broker::auth::AllowAllAuthProvider;

    struct RejectWithBadAuthMethod;

    impl AuthProvider for RejectWithBadAuthMethod {
        fn authenticate<'a>(
            &'a self,
            _connect: &'a ConnectPacket,
            _client_addr: SocketAddr,
        ) -> Pin<Box<dyn Future<Output = Result<AuthResult>> + Send + 'a>> {
            Box::pin(async { Ok(AuthResult::fail(ReasonCode::BadAuthenticationMethod)) })
        }

        fn authorize_publish<'a>(
            &'a self,
            _client_id: &str,
            _user_id: Option<&'a str>,
            _topic: &'a str,
        ) -> Pin<Box<dyn Future<Output = bool> + Send + 'a>> {
            Box::pin(async { false })
        }

        fn authorize_subscribe<'a>(
            &'a self,
            _client_id: &str,
            _user_id: Option<&'a str>,
            _topic_filter: &'a str,
        ) -> Pin<Box<dyn Future<Output = bool> + Send + 'a>> {
            Box::pin(async { false })
        }
    }

    struct RejectWithNotAuthorized;

    impl AuthProvider for RejectWithNotAuthorized {
        fn authenticate<'a>(
            &'a self,
            _connect: &'a ConnectPacket,
            _client_addr: SocketAddr,
        ) -> Pin<Box<dyn Future<Output = Result<AuthResult>> + Send + 'a>> {
            Box::pin(async { Ok(AuthResult::fail(ReasonCode::NotAuthorized)) })
        }

        fn authorize_publish<'a>(
            &'a self,
            _client_id: &str,
            _user_id: Option<&'a str>,
            _topic: &'a str,
        ) -> Pin<Box<dyn Future<Output = bool> + Send + 'a>> {
            Box::pin(async { false })
        }

        fn authorize_subscribe<'a>(
            &'a self,
            _client_id: &str,
            _user_id: Option<&'a str>,
            _topic_filter: &'a str,
        ) -> Pin<Box<dyn Future<Output = bool> + Send + 'a>> {
            Box::pin(async { false })
        }
    }

    #[tokio::test]
    async fn primary_success_skips_fallback() {
        let composite = CompositeAuthProvider::new(
            Arc::new(AllowAllAuthProvider),
            Arc::new(RejectWithNotAuthorized),
        );
        let connect = ConnectPacket::new(mqtt5_protocol::ConnectOptions::new("test"));
        let addr = "127.0.0.1:1234".parse().unwrap();

        let result = composite.authenticate(&connect, addr).await.unwrap();
        assert!(result.authenticated);
    }

    #[tokio::test]
    async fn bad_auth_method_falls_through_to_fallback() {
        let composite = CompositeAuthProvider::new(
            Arc::new(RejectWithBadAuthMethod),
            Arc::new(AllowAllAuthProvider),
        );
        let connect = ConnectPacket::new(mqtt5_protocol::ConnectOptions::new("test"));
        let addr = "127.0.0.1:1234".parse().unwrap();

        let result = composite.authenticate(&connect, addr).await.unwrap();
        assert!(result.authenticated);
    }

    #[tokio::test]
    async fn non_bad_auth_rejection_does_not_fall_through() {
        let composite = CompositeAuthProvider::new(
            Arc::new(RejectWithNotAuthorized),
            Arc::new(AllowAllAuthProvider),
        );
        let connect = ConnectPacket::new(mqtt5_protocol::ConnectOptions::new("test"));
        let addr = "127.0.0.1:1234".parse().unwrap();

        let result = composite.authenticate(&connect, addr).await.unwrap();
        assert!(!result.authenticated);
        assert_eq!(result.reason_code, ReasonCode::NotAuthorized);
    }

    #[tokio::test]
    async fn authorize_publish_or_mode_falls_through() {
        let composite = CompositeAuthProvider::new(
            Arc::new(RejectWithBadAuthMethod),
            Arc::new(AllowAllAuthProvider),
        )
        .with_authorization_mode(AuthorizationMode::Or);

        let result = composite
            .authorize_publish("client", Some("user"), "topic")
            .await;
        assert!(result);
    }

    #[tokio::test]
    async fn authorize_subscribe_or_mode_falls_through() {
        let composite = CompositeAuthProvider::new(
            Arc::new(RejectWithBadAuthMethod),
            Arc::new(AllowAllAuthProvider),
        )
        .with_authorization_mode(AuthorizationMode::Or);

        let result = composite
            .authorize_subscribe("client", Some("user"), "topic/#")
            .await;
        assert!(result);
    }

    #[tokio::test]
    async fn authorize_publish_primary_only_ignores_fallback() {
        let composite = CompositeAuthProvider::new(
            Arc::new(RejectWithBadAuthMethod),
            Arc::new(AllowAllAuthProvider),
        );

        let result = composite
            .authorize_publish("client", Some("user"), "topic")
            .await;
        assert!(!result);
    }

    #[tokio::test]
    async fn authorize_subscribe_primary_only_ignores_fallback() {
        let composite = CompositeAuthProvider::new(
            Arc::new(RejectWithBadAuthMethod),
            Arc::new(AllowAllAuthProvider),
        );

        let result = composite
            .authorize_subscribe("client", Some("user"), "topic/#")
            .await;
        assert!(!result);
    }

    #[tokio::test]
    async fn authorize_publish_and_mode_requires_both() {
        let both_allow = CompositeAuthProvider::new(
            Arc::new(AllowAllAuthProvider),
            Arc::new(AllowAllAuthProvider),
        )
        .with_authorization_mode(AuthorizationMode::And);
        assert!(
            both_allow
                .authorize_publish("client", Some("user"), "topic")
                .await
        );

        let primary_rejects = CompositeAuthProvider::new(
            Arc::new(RejectWithBadAuthMethod),
            Arc::new(AllowAllAuthProvider),
        )
        .with_authorization_mode(AuthorizationMode::And);
        assert!(
            !primary_rejects
                .authorize_publish("client", Some("user"), "topic")
                .await
        );
    }

    #[tokio::test]
    async fn authorize_subscribe_and_mode_requires_both() {
        let both_allow = CompositeAuthProvider::new(
            Arc::new(AllowAllAuthProvider),
            Arc::new(AllowAllAuthProvider),
        )
        .with_authorization_mode(AuthorizationMode::And);
        assert!(
            both_allow
                .authorize_subscribe("client", Some("user"), "topic/#")
                .await
        );

        let fallback_rejects = CompositeAuthProvider::new(
            Arc::new(AllowAllAuthProvider),
            Arc::new(RejectWithNotAuthorized),
        )
        .with_authorization_mode(AuthorizationMode::And);
        assert!(
            !fallback_rejects
                .authorize_subscribe("client", Some("user"), "topic/#")
                .await
        );
    }

    #[tokio::test]
    async fn cleanup_calls_both() {
        let composite = CompositeAuthProvider::new(
            Arc::new(AllowAllAuthProvider),
            Arc::new(AllowAllAuthProvider),
        );
        composite.cleanup_session("user").await;
    }

    #[test]
    fn supports_enhanced_auth_delegates_to_primary() {
        let composite = CompositeAuthProvider::new(
            Arc::new(AllowAllAuthProvider),
            Arc::new(AllowAllAuthProvider),
        );
        assert!(!composite.supports_enhanced_auth());
    }

    #[test]
    fn default_authorization_mode_is_primary_only() {
        let composite = CompositeAuthProvider::new(
            Arc::new(AllowAllAuthProvider),
            Arc::new(AllowAllAuthProvider),
        );
        assert_eq!(composite.authorization_mode, AuthorizationMode::PrimaryOnly);
    }
}