mqtt5 0.31.2

Complete MQTT v5.0 platform with high-performance async client and full-featured broker supporting TCP, TLS, WebSocket, authentication, bridging, and resource monitoring
Documentation
use crate::client::auth_handler::{AuthFuture, AuthHandler, AuthResponse};
use base64::prelude::*;
use ring::{hmac, pbkdf2};
use sha2::{Digest, Sha256};
use std::num::NonZeroU32;
use std::sync::Mutex;
use tracing::debug;
use zeroize::Zeroizing;

enum ScramClientState {
    Initial,
    WaitingForServerFirst {
        client_nonce: String,
        client_first_bare: String,
    },
    WaitingForServerFinal {
        expected_server_signature: [u8; 32],
    },
}

pub struct ScramSha256AuthHandler {
    username: String,
    password: Zeroizing<String>,
    state: Mutex<ScramClientState>,
}

impl ScramSha256AuthHandler {
    #[must_use]
    pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
        Self {
            username: username.into(),
            password: Zeroizing::new(password.into()),
            state: Mutex::new(ScramClientState::Initial),
        }
    }

    fn generate_nonce() -> String {
        let mut bytes = [0u8; 24];
        getrandom::fill(&mut bytes).expect("failed to generate random nonce");
        BASE64_STANDARD.encode(bytes)
    }
}

impl AuthHandler for ScramSha256AuthHandler {
    fn initial_response<'a>(&'a self, _auth_method: &'a str) -> AuthFuture<'a, Option<Vec<u8>>> {
        let username = self.username.clone();

        Box::pin(async move {
            let client_nonce = Self::generate_nonce();
            let client_first_bare = format!("n={username},r={client_nonce}");
            let client_first = format!("n,,{client_first_bare}");

            *self.state.lock().unwrap() = ScramClientState::WaitingForServerFirst {
                client_nonce,
                client_first_bare,
            };

            debug!("SCRAM client: sending client-first message");
            Ok(Some(client_first.into_bytes()))
        })
    }

    fn handle_challenge<'a>(
        &'a self,
        _auth_method: &'a str,
        challenge_data: Option<&'a [u8]>,
    ) -> AuthFuture<'a, AuthResponse> {
        let password = self.password.clone();

        Box::pin(async move {
            let Some(data) = challenge_data else {
                return Ok(AuthResponse::Abort(
                    "No challenge data received".to_string(),
                ));
            };

            let server_message = String::from_utf8_lossy(data);

            let state =
                std::mem::replace(&mut *self.state.lock().unwrap(), ScramClientState::Initial);

            match state {
                ScramClientState::WaitingForServerFirst {
                    client_nonce,
                    client_first_bare,
                } => {
                    let attrs = parse_scram_attributes(&server_message);

                    let Some(combined_nonce) = attrs.get("r") else {
                        return Ok(AuthResponse::Abort("Missing nonce in server-first".into()));
                    };

                    if !combined_nonce.starts_with(&client_nonce) {
                        return Ok(AuthResponse::Abort(
                            "Server nonce doesn't contain client nonce".into(),
                        ));
                    }

                    let Some(salt_b64) = attrs.get("s") else {
                        return Ok(AuthResponse::Abort("Missing salt in server-first".into()));
                    };

                    let Ok(salt) = BASE64_STANDARD.decode(salt_b64) else {
                        return Ok(AuthResponse::Abort("Invalid salt encoding".into()));
                    };

                    let Some(iterations_str) = attrs.get("i") else {
                        return Ok(AuthResponse::Abort(
                            "Missing iteration count in server-first".into(),
                        ));
                    };

                    let Ok(iterations) = iterations_str.parse::<u32>() else {
                        return Ok(AuthResponse::Abort("Invalid iteration count".into()));
                    };

                    if iterations < 4096 {
                        return Ok(AuthResponse::Abort("Iteration count too low".into()));
                    }

                    let salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
                    let client_key = hmac_sha256(&salted_password, b"Client Key");
                    let stored_key = sha256(&client_key);
                    let server_key = hmac_sha256(&salted_password, b"Server Key");

                    let client_final_without_proof = format!("c=biws,r={combined_nonce}");
                    let auth_message = format!(
                        "{client_first_bare},{server_message},{client_final_without_proof}"
                    );

                    let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
                    let client_proof: Vec<u8> = client_key
                        .iter()
                        .zip(client_signature.iter())
                        .map(|(a, b)| a ^ b)
                        .collect();

                    let expected_server_signature =
                        hmac_sha256(&server_key, auth_message.as_bytes());

                    *self.state.lock().unwrap() = ScramClientState::WaitingForServerFinal {
                        expected_server_signature,
                    };

                    let proof_b64 = BASE64_STANDARD.encode(&client_proof);
                    let client_final = format!("{client_final_without_proof},p={proof_b64}");

                    debug!("SCRAM client: sending client-final message");
                    Ok(AuthResponse::Continue(client_final.into_bytes()))
                }
                ScramClientState::WaitingForServerFinal {
                    expected_server_signature,
                } => {
                    let attrs = parse_scram_attributes(&server_message);

                    let Some(verifier_b64) = attrs.get("v") else {
                        return Ok(AuthResponse::Abort(
                            "Missing verifier in server-final".into(),
                        ));
                    };

                    let Ok(received_signature) = BASE64_STANDARD.decode(verifier_b64) else {
                        return Ok(AuthResponse::Abort("Invalid verifier encoding".into()));
                    };

                    if received_signature.len() != 32 {
                        return Ok(AuthResponse::Abort("Invalid verifier length".into()));
                    }

                    if !constant_time_compare(&received_signature, &expected_server_signature) {
                        return Ok(AuthResponse::Abort(
                            "Server signature verification failed".into(),
                        ));
                    }

                    debug!("SCRAM client: server signature verified, authentication complete");
                    Ok(AuthResponse::Success)
                }
                ScramClientState::Initial => {
                    Ok(AuthResponse::Abort("Invalid state: Initial".into()))
                }
            }
        })
    }
}

fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
    if a.len() != b.len() {
        return false;
    }
    let mut result = 0u8;
    for (x, y) in a.iter().zip(b.iter()) {
        result |= x ^ y;
    }
    result == 0
}

fn parse_scram_attributes(message: &str) -> std::collections::HashMap<String, String> {
    let mut attrs = std::collections::HashMap::new();
    for part in message.split(',') {
        if let Some((key, value)) = part.split_once('=') {
            attrs.insert(key.to_string(), value.to_string());
        }
    }
    attrs
}

fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
    let mut out = [0u8; 32];
    pbkdf2::derive(
        pbkdf2::PBKDF2_HMAC_SHA256,
        NonZeroU32::new(iterations).expect("iterations must be non-zero"),
        salt,
        password,
        &mut out,
    );
    out
}

fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; 32] {
    let key = hmac::Key::new(hmac::HMAC_SHA256, key);
    let tag = hmac::sign(&key, data);
    let mut result = [0u8; 32];
    result.copy_from_slice(tag.as_ref());
    result
}

fn sha256(data: &[u8]) -> [u8; 32] {
    let mut hasher = Sha256::new();
    hasher.update(data);
    hasher.finalize().into()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_scram_initial_response() {
        let handler = ScramSha256AuthHandler::new("testuser", "testpass");
        let response = handler.initial_response("SCRAM-SHA-256").await.unwrap();

        assert!(response.is_some());
        let msg = String::from_utf8_lossy(response.as_ref().unwrap());
        assert!(msg.starts_with("n,,n=testuser,r="));
    }

    #[tokio::test]
    async fn test_scram_handle_server_first() {
        let handler = ScramSha256AuthHandler::new("testuser", "testpass");

        let _ = handler.initial_response("SCRAM-SHA-256").await.unwrap();

        let client_nonce = match &*handler.state.lock().unwrap() {
            ScramClientState::WaitingForServerFirst { client_nonce, .. } => client_nonce.clone(),
            _ => panic!("Wrong state"),
        };

        let server_first = format!(
            "r={client_nonce}servernonce,s={},i=4096",
            BASE64_STANDARD.encode(b"testsalt")
        );

        let response = handler
            .handle_challenge("SCRAM-SHA-256", Some(server_first.as_bytes()))
            .await
            .unwrap();

        match response {
            AuthResponse::Continue(data) => {
                let msg = String::from_utf8_lossy(&data);
                assert!(msg.starts_with("c=biws,r="));
                assert!(msg.contains(",p="));
            }
            other => panic!("Expected Continue, got {other:?}",),
        }
    }

    #[tokio::test]
    async fn test_scram_handle_server_final_valid() {
        let handler = ScramSha256AuthHandler::new("testuser", "testpass");

        let expected_sig = [0u8; 32];
        *handler.state.lock().unwrap() = ScramClientState::WaitingForServerFinal {
            expected_server_signature: expected_sig,
        };

        let server_final = format!("v={}", BASE64_STANDARD.encode(expected_sig));
        let response = handler
            .handle_challenge("SCRAM-SHA-256", Some(server_final.as_bytes()))
            .await
            .unwrap();

        assert!(matches!(response, AuthResponse::Success));
    }

    #[tokio::test]
    async fn test_scram_handle_server_final_invalid() {
        let handler = ScramSha256AuthHandler::new("testuser", "testpass");

        let expected_sig = [0u8; 32];
        let wrong_sig = [1u8; 32];
        *handler.state.lock().unwrap() = ScramClientState::WaitingForServerFinal {
            expected_server_signature: expected_sig,
        };

        let server_final = format!("v={}", BASE64_STANDARD.encode(wrong_sig));
        let response = handler
            .handle_challenge("SCRAM-SHA-256", Some(server_final.as_bytes()))
            .await
            .unwrap();

        match response {
            AuthResponse::Abort(msg) => {
                assert!(msg.contains("Server signature verification failed"));
            }
            other => panic!("Expected Abort, got {other:?}",),
        }
    }

    #[tokio::test]
    async fn test_scram_nonce_mismatch() {
        let handler = ScramSha256AuthHandler::new("testuser", "testpass");

        let _ = handler.initial_response("SCRAM-SHA-256").await.unwrap();

        let server_first = format!(
            "r=differentnonce,s={},i=4096",
            BASE64_STANDARD.encode(b"testsalt")
        );

        let response = handler
            .handle_challenge("SCRAM-SHA-256", Some(server_first.as_bytes()))
            .await
            .unwrap();

        match response {
            AuthResponse::Abort(msg) => {
                assert!(msg.contains("nonce"));
            }
            other => panic!("Expected Abort, got {other:?}",),
        }
    }
}