use bytes::{Buf, BufMut, BytesMut};
use crate::{
Metadata, ZmqError,
security::{
framer::{ISecureFramer, NullFramer},
mechanism::ProcessTokenAction,
},
};
use super::{IDataCipher, Mechanism, MechanismStatus, cipher::PassThroughDataCipher};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PlainState {
Initializing, ClientSendHello, ClientExpectWelcome, ServerExpectHello, ServerSendWelcome, Ready, Error, }
#[derive(Debug)]
pub struct PlainMechanism {
is_server: bool,
state: PlainState,
username: Option<Vec<u8>>,
password: Option<Vec<u8>>,
expected_username: Option<Vec<u8>>,
expected_password: Option<Vec<u8>>,
_zap_metadata: Option<Metadata>,
error_reason: Option<String>,
}
impl PlainMechanism {
pub const NAME_BYTES: &'static [u8; 20] = b"PLAIN\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"; pub const NAME: &'static str = "PLAIN";
const CMD_HELLO: &'static [u8] = b"HELLO";
const CMD_WELCOME: &'static [u8] = b"WELCOME";
const CMD_ERROR: &'static [u8] = b"ERROR";
pub fn new(is_server: bool) -> Self {
Self {
is_server,
state: if is_server {
PlainState::ServerExpectHello
} else {
PlainState::ClientSendHello
},
username: None,
password: None,
expected_username: None,
expected_password: None,
_zap_metadata: None,
error_reason: None,
}
}
pub fn set_client_credentials(&mut self, username: Option<Vec<u8>>, password: Option<Vec<u8>>) {
if !self.is_server {
self.username = username;
self.password = password;
} else {
tracing::warn!("set_client_credentials called on a server-side PLAIN mechanism. Ignoring.");
}
}
pub fn set_server_expected_credentials(
&mut self,
username: Option<Vec<u8>>,
password: Option<Vec<u8>>,
) {
if self.is_server {
self.expected_username = username;
self.expected_password = password;
}
}
fn parse_hello_body(body: &[u8]) -> Result<(Vec<u8>, Vec<u8>), ZmqError> {
let mut cursor = std::io::Cursor::new(body);
if body.len() < 2 {
return Err(ZmqError::SecurityError("PLAIN HELLO body too short".into()));
}
let user_len = cursor.get_u8() as usize;
if cursor.remaining() < user_len + 1 {
return Err(ZmqError::SecurityError(
"Invalid PLAIN HELLO username length".into(),
));
}
let mut username = vec![0u8; user_len];
cursor.copy_to_slice(&mut username);
let pass_len = cursor.get_u8() as usize;
if cursor.remaining() < pass_len {
return Err(ZmqError::SecurityError(
"Invalid PLAIN HELLO password length".into(),
));
}
let mut password = vec![0u8; pass_len];
cursor.copy_to_slice(&mut password);
Ok((username, password))
}
fn create_hello_body(username: &[u8], password: &[u8]) -> Vec<u8> {
let user_len = username.len().min(255) as u8;
let pass_len = password.len().min(255) as u8;
let mut body = BytesMut::with_capacity(1 + user_len as usize + 1 + pass_len as usize);
body.put_u8(user_len);
body.put_slice(&username[..user_len as usize]);
body.put_u8(pass_len);
body.put_slice(&password[..pass_len as usize]);
body.to_vec()
}
fn set_error_internal(&mut self, reason: String) {
tracing::error!(mechanism = Self::NAME, %reason, "Handshake error");
self.error_reason = Some(reason);
self.state = PlainState::Error;
}
}
impl Mechanism for PlainMechanism {
fn name(&self) -> &'static str {
Self::NAME
}
fn process_token(&mut self, token: &[u8]) -> Result<ProcessTokenAction, ZmqError> {
if token.is_empty() {
self.set_error_internal("Received empty security token".into());
return Err(ZmqError::SecurityError("Empty token".into()));
}
let command_len = token[0] as usize;
if token.len() < 1 + command_len {
self.set_error_internal("Invalid command frame received".into());
return Err(ZmqError::SecurityError("Malformed command frame".into()));
}
let command_name = &token[1..1 + command_len];
let body = &token[1 + command_len..];
if self.is_server {
match self.state {
PlainState::ServerExpectHello => {
if command_name == Self::CMD_HELLO {
tracing::debug!(mechanism = Self::NAME, "Server received HELLO");
match Self::parse_hello_body(body) {
Ok((username, password)) => {
let is_valid = self
.expected_username
.as_ref()
.map_or(false, |u| u == &username)
&& self
.expected_password
.as_ref()
.map_or(false, |p| p == &password);
if is_valid {
self.username = Some(username);
self.password = Some(password);
tracing::debug!("PLAIN Server: Credentials verified. Accepting connection.");
self.state = PlainState::ServerSendWelcome;
Ok(ProcessTokenAction::ProduceAndSend)
} else {
let msg = "Invalid username or password";
tracing::warn!("PLAIN Server: Authentication failed. {}", msg);
self.set_error_internal(msg.into());
Err(ZmqError::AuthenticationFailure(msg.into()))
}
}
Err(e) => {
self.set_error_internal(format!("Failed to parse HELLO: {}", e));
Err(e)
}
}
} else {
self.set_error_internal(format!(
"Expected HELLO, got {}",
String::from_utf8_lossy(command_name)
));
Err(ZmqError::SecurityError("Unexpected command".into()))
}
}
_ => {
self.set_error_internal(format!(
"Unexpected command received by server in state {:?}: {}",
self.state,
String::from_utf8_lossy(command_name)
));
Err(ZmqError::SecurityError(
"Unexpected command in server state".into(),
))
}
}
} else {
match self.state {
PlainState::ClientExpectWelcome => {
if command_name == Self::CMD_WELCOME {
tracing::debug!(mechanism = Self::NAME, "Client received WELCOME");
self.state = PlainState::Ready;
Ok(ProcessTokenAction::HandshakeComplete)
} else if command_name == Self::CMD_ERROR {
let reason = String::from_utf8_lossy(body).to_string();
self.set_error_internal(format!("Received ERROR from server: {}", reason));
Err(ZmqError::AuthenticationFailure(reason))
} else {
self.set_error_internal(format!(
"Expected WELCOME or ERROR, got {}",
String::from_utf8_lossy(command_name)
));
Err(ZmqError::SecurityError(
"Unexpected command from server".into(),
))
}
}
_ => {
self.set_error_internal(format!(
"Unexpected command received by client in state {:?}: {}",
self.state,
String::from_utf8_lossy(command_name)
));
Err(ZmqError::SecurityError(
"Unexpected command in client state".into(),
))
}
}
}
}
fn produce_token(&mut self) -> Result<Option<Vec<u8>>, ZmqError> {
match self.state {
PlainState::ClientSendHello => {
let username_bytes = self.username.as_deref().unwrap_or(b"");
let password_bytes = self.password.as_deref().unwrap_or(b"");
let body = Self::create_hello_body(username_bytes, password_bytes);
let mut frame = BytesMut::new();
let cmd_name = Self::CMD_HELLO;
frame.put_u8(cmd_name.len() as u8);
frame.put_slice(cmd_name);
frame.put_slice(&body);
self.state = PlainState::ClientExpectWelcome;
tracing::debug!(mechanism = Self::NAME, "Client sending HELLO");
Ok(Some(frame.to_vec()))
}
PlainState::ServerSendWelcome => {
let mut frame = BytesMut::new();
let cmd_name = Self::CMD_WELCOME;
frame.put_u8(cmd_name.len() as u8);
frame.put_slice(cmd_name);
self.state = PlainState::Ready;
tracing::debug!(mechanism = Self::NAME, "Server sending WELCOME");
Ok(Some(frame.to_vec()))
}
_ => Ok(None), }
}
fn status(&self) -> MechanismStatus {
match self.state {
PlainState::Initializing | PlainState::ClientSendHello
| PlainState::ClientExpectWelcome
| PlainState::ServerExpectHello
| PlainState::ServerSendWelcome => MechanismStatus::Handshaking,
PlainState::Ready => MechanismStatus::Ready,
PlainState::Error => MechanismStatus::Error,
}
}
fn peer_identity(&self) -> Option<Vec<u8>> {
self.username.clone()
}
fn metadata(&self) -> Option<Metadata> {
self._zap_metadata.clone()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn set_error(&mut self, reason: String) {
self.set_error_internal(reason);
}
fn error_reason(&self) -> Option<&str> {
self.error_reason.as_deref()
}
fn zap_request_needed(&mut self) -> Option<Vec<Vec<u8>>> {
None
}
fn process_zap_reply(&mut self, _reply_frames: &[Vec<u8>]) -> Result<(), ZmqError> {
Ok(())
}
fn into_framer(self: Box<Self>, max_msg_size: i64) -> Result<(Box<dyn ISecureFramer>, Option<Vec<u8>>), ZmqError> {
if self.status() != MechanismStatus::Ready {
return Err(ZmqError::InvalidState(
"PLAIN handshake not complete.".into(),
));
}
Ok((Box::new(NullFramer::new(max_msg_size)), self.username.clone()))
}
}