use mqtt5::broker::auth::{AuthProvider, AuthResult, EnhancedAuthResult};
use mqtt5::broker::config::{StorageBackend, StorageConfig};
use mqtt5::broker::{BrokerConfig, MqttBroker};
use mqtt5::client::{AuthHandler, AuthResponse};
use mqtt5::error::{MqttError, Result};
use mqtt5::packet::auth::AuthPacket;
use mqtt5::packet::connect::ConnectPacket;
use mqtt5::packet::MqttPacket;
use mqtt5::protocol::v5::reason_codes::ReasonCode;
use mqtt5::types::ConnectOptions;
use mqtt5::MqttClient;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
fn test_broker_config() -> BrokerConfig {
let storage_config = StorageConfig {
backend: StorageBackend::Memory,
enable_persistence: true,
..Default::default()
};
BrokerConfig::default()
.with_bind_address("127.0.0.1:0".parse::<SocketAddr>().unwrap())
.with_storage(storage_config)
}
#[tokio::test]
async fn test_auth_packet_creation() {
let auth_packet = AuthPacket::continue_authentication(
"SCRAM-SHA-256".to_string(),
Some(b"client_nonce_data".to_vec()),
)
.unwrap();
assert_eq!(auth_packet.reason_code, ReasonCode::ContinueAuthentication);
assert_eq!(auth_packet.authentication_method(), Some("SCRAM-SHA-256"));
assert_eq!(
auth_packet.authentication_data(),
Some(b"client_nonce_data".as_ref())
);
}
#[tokio::test]
async fn test_auth_packet_re_authenticate() {
let auth_packet =
AuthPacket::re_authenticate("OAUTH2".to_string(), Some(b"refresh_token".to_vec())).unwrap();
assert_eq!(auth_packet.reason_code, ReasonCode::ReAuthenticate);
assert_eq!(auth_packet.authentication_method(), Some("OAUTH2"));
assert_eq!(
auth_packet.authentication_data(),
Some(b"refresh_token".as_ref())
);
}
#[tokio::test]
async fn test_auth_packet_success() {
let auth_packet = AuthPacket::success("PLAIN".to_string()).unwrap();
assert_eq!(auth_packet.reason_code, ReasonCode::Success);
assert_eq!(auth_packet.authentication_method(), Some("PLAIN"));
assert!(auth_packet.authentication_data().is_none());
}
#[tokio::test]
async fn test_auth_packet_failure() {
let auth_packet = AuthPacket::failure(
ReasonCode::BadAuthenticationMethod,
Some("Unsupported authentication method".to_string()),
)
.unwrap();
assert_eq!(auth_packet.reason_code, ReasonCode::BadAuthenticationMethod);
assert_eq!(
auth_packet.reason_string(),
Some("Unsupported authentication method")
);
assert!(auth_packet.authentication_method().is_none());
}
#[tokio::test]
async fn test_auth_packet_validation() {
let invalid_packet = AuthPacket::new(ReasonCode::ContinueAuthentication);
let result = invalid_packet.validate();
assert!(result.is_err());
let invalid_packet = AuthPacket::new(ReasonCode::ReAuthenticate);
let result = invalid_packet.validate();
assert!(result.is_err());
let valid_packet = AuthPacket::new(ReasonCode::Success);
let result = valid_packet.validate();
assert!(result.is_ok());
}
#[tokio::test]
async fn test_auth_packet_failure_with_success_code() {
let result = AuthPacket::failure(ReasonCode::Success, None);
assert!(result.is_err());
if let Err(MqttError::ProtocolError(msg)) = result {
assert!(msg.contains("Cannot create failure AUTH packet with success reason code"));
}
}
#[tokio::test]
async fn test_auth_packet_encode_decode_cycle() {
use bytes::BytesMut;
use mqtt5::packet::FixedHeader;
let original_packet = AuthPacket::continue_authentication(
"DIGEST-MD5".to_string(),
Some(b"challenge_response_data".to_vec()),
)
.unwrap();
let mut buf = BytesMut::new();
original_packet.encode(&mut buf).unwrap();
let fixed_header = FixedHeader::decode(&mut buf).unwrap();
let decoded_packet = <AuthPacket as MqttPacket>::decode_body(&mut buf, &fixed_header).unwrap();
assert_eq!(decoded_packet.reason_code, original_packet.reason_code);
assert_eq!(
decoded_packet.authentication_method(),
original_packet.authentication_method()
);
assert_eq!(
decoded_packet.authentication_data(),
original_packet.authentication_data()
);
}
#[tokio::test]
async fn test_auth_packet_properties_conversion() {
use mqtt5::protocol::v5::properties::{PropertyId, PropertyValue};
use mqtt5::types::WillProperties;
let mut will_props = WillProperties {
will_delay_interval: Some(30),
message_expiry_interval: Some(3600),
content_type: Some("application/json".to_string()),
..WillProperties::default()
};
will_props
.user_properties
.push(("client".to_string(), "v1.0".to_string()));
let protocol_props: mqtt5::protocol::v5::properties::Properties = will_props.into();
assert_eq!(
protocol_props.get(PropertyId::WillDelayInterval),
Some(&PropertyValue::FourByteInteger(30))
);
assert_eq!(
protocol_props.get(PropertyId::MessageExpiryInterval),
Some(&PropertyValue::FourByteInteger(3600))
);
assert_eq!(
protocol_props.get(PropertyId::ContentType),
Some(&PropertyValue::Utf8String("application/json".to_string()))
);
}
#[tokio::test]
async fn test_auth_methods_supported() {
let methods = vec![
"PLAIN",
"SCRAM-SHA-1",
"SCRAM-SHA-256",
"DIGEST-MD5",
"GSSAPI",
"OAUTH2",
"JWT",
"CUSTOM-METHOD",
];
for method in methods {
let auth_packet =
AuthPacket::continue_authentication(method.to_string(), Some(b"test_data".to_vec()))
.unwrap();
assert_eq!(auth_packet.authentication_method(), Some(method));
assert!(auth_packet.validate().is_ok());
}
}
#[tokio::test]
async fn test_auth_packet_no_data() {
let auth_packet = AuthPacket::continue_authentication(
"SCRAM-SHA-256".to_string(),
None, )
.unwrap();
assert_eq!(auth_packet.reason_code, ReasonCode::ContinueAuthentication);
assert_eq!(auth_packet.authentication_method(), Some("SCRAM-SHA-256"));
assert!(auth_packet.authentication_data().is_none());
}
#[tokio::test]
async fn test_auth_packet_large_data() {
let large_data = vec![0xAB; 10000];
let auth_packet =
AuthPacket::continue_authentication("CUSTOM".to_string(), Some(large_data.clone()))
.unwrap();
assert_eq!(
auth_packet.authentication_data(),
Some(large_data.as_slice())
);
}
struct TestChallengeResponseAuthProvider {
challenge: Vec<u8>,
expected_response: Vec<u8>,
}
impl AuthProvider for TestChallengeResponseAuthProvider {
fn authenticate<'a>(
&'a self,
_connect: &'a ConnectPacket,
_client_addr: SocketAddr,
) -> Pin<Box<dyn Future<Output = Result<AuthResult>> + Send + 'a>> {
Box::pin(async move { Ok(AuthResult::success()) })
}
fn authorize_publish<'a>(
&'a self,
_client_id: &str,
_user_id: Option<&'a str>,
_topic: &'a str,
) -> Pin<Box<dyn Future<Output = bool> + Send + 'a>> {
Box::pin(async move { true })
}
fn authorize_subscribe<'a>(
&'a self,
_client_id: &str,
_user_id: Option<&'a str>,
_topic_filter: &'a str,
) -> Pin<Box<dyn Future<Output = bool> + Send + 'a>> {
Box::pin(async move { true })
}
fn supports_enhanced_auth(&self) -> bool {
true
}
fn authenticate_enhanced<'a>(
&'a self,
auth_method: &'a str,
auth_data: Option<&'a [u8]>,
_client_id: &'a str,
) -> Pin<Box<dyn Future<Output = Result<EnhancedAuthResult>> + Send + 'a>> {
let method = auth_method.to_string();
let challenge = self.challenge.clone();
let expected = self.expected_response.clone();
Box::pin(async move {
if method != "CHALLENGE-RESPONSE" {
return Ok(EnhancedAuthResult::fail(
method,
ReasonCode::BadAuthenticationMethod,
));
}
match auth_data {
None => Ok(EnhancedAuthResult::continue_auth(method, Some(challenge))),
Some(response) if response == expected => Ok(EnhancedAuthResult::success(method)),
Some(_) => Ok(EnhancedAuthResult::fail(method, ReasonCode::NotAuthorized)),
}
})
}
fn reauthenticate<'a>(
&'a self,
auth_method: &'a str,
auth_data: Option<&'a [u8]>,
client_id: &'a str,
_user_id: Option<&'a str>,
) -> Pin<Box<dyn Future<Output = Result<EnhancedAuthResult>> + Send + 'a>> {
self.authenticate_enhanced(auth_method, auth_data, client_id)
}
}
struct TestClientAuthHandler {
expected_challenge: Vec<u8>,
response: Vec<u8>,
}
impl AuthHandler for TestClientAuthHandler {
fn handle_challenge<'a>(
&'a self,
_auth_method: &'a str,
challenge_data: Option<&'a [u8]>,
) -> Pin<Box<dyn Future<Output = Result<AuthResponse>> + Send + 'a>> {
let expected = self.expected_challenge.clone();
let response = self.response.clone();
Box::pin(async move {
if challenge_data == Some(expected.as_slice()) {
Ok(AuthResponse::Continue(response))
} else {
Ok(AuthResponse::Abort("Unexpected challenge".to_string()))
}
})
}
}
#[tokio::test]
async fn test_client_enhanced_auth_success() {
let challenge = b"server-challenge-xyz".to_vec();
let response = b"client-response-abc".to_vec();
let auth_provider = Arc::new(TestChallengeResponseAuthProvider {
challenge: challenge.clone(),
expected_response: response.clone(),
});
let mut broker = MqttBroker::with_config(test_broker_config())
.await
.unwrap()
.with_auth_provider(auth_provider);
let addr = broker.local_addr().expect("failed to get broker address");
let broker_handle = tokio::spawn(async move { broker.run().await });
tokio::time::sleep(Duration::from_millis(100)).await;
let options =
ConnectOptions::new("auth-test-client").with_authentication_method("CHALLENGE-RESPONSE");
let client = MqttClient::with_options(options);
client
.set_auth_handler(TestClientAuthHandler {
expected_challenge: challenge,
response,
})
.await;
let result = client.connect(&format!("mqtt://{addr}")).await;
assert!(
result.is_ok(),
"Client should connect with enhanced auth: {result:?}"
);
assert!(client.is_connected().await);
client.disconnect().await.unwrap();
broker_handle.abort();
}
#[tokio::test]
async fn test_client_enhanced_auth_failure() {
let challenge = b"server-challenge-xyz".to_vec();
let correct_response = b"client-response-abc".to_vec();
let wrong_response = b"wrong-response".to_vec();
let auth_provider = Arc::new(TestChallengeResponseAuthProvider {
challenge: challenge.clone(),
expected_response: correct_response,
});
let mut broker = MqttBroker::with_config(test_broker_config())
.await
.unwrap()
.with_auth_provider(auth_provider);
let addr = broker.local_addr().expect("failed to get broker address");
let broker_handle = tokio::spawn(async move { broker.run().await });
tokio::time::sleep(Duration::from_millis(100)).await;
let options =
ConnectOptions::new("auth-fail-client").with_authentication_method("CHALLENGE-RESPONSE");
let client = MqttClient::with_options(options);
client
.set_auth_handler(TestClientAuthHandler {
expected_challenge: challenge,
response: wrong_response,
})
.await;
let result = client.connect(&format!("mqtt://{addr}")).await;
assert!(
result.is_err(),
"Client should fail with wrong response, but got: {result:?}"
);
broker_handle.abort();
}
#[tokio::test]
async fn test_client_enhanced_auth_no_handler() {
let auth_provider = Arc::new(TestChallengeResponseAuthProvider {
challenge: b"challenge".to_vec(),
expected_response: b"response".to_vec(),
});
let mut broker = MqttBroker::with_config(test_broker_config())
.await
.unwrap()
.with_auth_provider(auth_provider);
let addr = broker.local_addr().expect("failed to get broker address");
let broker_handle = tokio::spawn(async move { broker.run().await });
tokio::time::sleep(Duration::from_millis(100)).await;
let options =
ConnectOptions::new("no-handler-client").with_authentication_method("CHALLENGE-RESPONSE");
let client = MqttClient::with_options(options);
let result = client.connect(&format!("mqtt://{addr}")).await;
assert!(result.is_err(), "Client should fail without auth handler");
if let Err(MqttError::AuthenticationFailed) = result {
} else {
panic!("Expected AuthenticationFailed error, got {result:?}");
}
broker_handle.abort();
}