#![cfg(feature = "noise_xx")]
use super::cipher::IDataCipher;
use crate::error::ZmqError;
use crate::security::framer::{ISecureFramer, LengthPrefixedFramer};
use crate::security::mechanism::ProcessTokenAction;
use crate::security::{Mechanism, MechanismStatus, Metadata};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use snow::error::{Prerequisite, StateProblem};
use snow::params::NoiseParams;
use snow::{Error as SnowError, TransportState};
pub struct NoiseXxMechanism {
is_server_role: bool,
state: Option<snow::HandshakeState>,
pub(crate) transport_state: Option<TransportState>,
configured_remote_static_pk_bytes: Option<[u8; 32]>,
verified_peer_static_pk: Option<Vec<u8>>,
current_status: MechanismStatus,
error_reason_str: Option<String>,
pending_outgoing_handshake_msg: Option<Vec<u8>>,
}
impl std::fmt::Debug for NoiseXxMechanism {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NoiseXxMechanism")
.field("is_server_role", &self.is_server_role)
.field(
"state_is_handshake_finished",
&self
.state
.as_ref()
.map(|s| s.is_handshake_finished())
.unwrap_or(false),
)
.field("transport_state_is_some", &self.transport_state.is_some())
.field(
"verified_peer_static_pk_len",
&self.verified_peer_static_pk.as_ref().map(|v| v.len()),
)
.field("current_status", &self.current_status)
.field("error_reason_str", &self.error_reason_str)
.field(
"pending_outgoing_handshake_msg_len",
&self
.pending_outgoing_handshake_msg
.as_ref()
.map(|v| v.len()),
)
.finish()
}
}
impl NoiseXxMechanism {
pub const NAME: &'static str = "NOISE_XX";
pub const NAME_BYTES: &'static [u8; 20] = b"NOISE_XX\0\0\0\0\0\0\0\0\0\0\0\0";
const NOISE_PARAMS_STR: &'static str = "Noise_XX_25519_ChaChaPoly_BLAKE2s";
pub fn new(
is_server: bool,
local_static_sk_bytes: &[u8; 32], initial_remote_static_pk_bytes: Option<[u8; 32]>,
) -> Result<Self, ZmqError> {
let params: NoiseParams = Self::NOISE_PARAMS_STR.parse().map_err(|e| {
ZmqError::Internal(format!(
"Failed to parse Noise params string '{}': {:?}",
Self::NOISE_PARAMS_STR,
e
))
})?;
let mut builder = snow::Builder::new(params);
builder = builder.local_private_key(local_static_sk_bytes);
if !is_server {
if let Some(ref pk_bytes) = initial_remote_static_pk_bytes {
builder = builder.remote_public_key(pk_bytes);
} else {
return Err(ZmqError::SecurityError(
"NOISE_XX Client: Server's static public key is required for configuration.".into(),
));
}
}
let noise_handshake_state = if is_server {
builder.build_responder()?
} else {
builder.build_initiator()?
};
Ok(Self {
is_server_role: is_server,
state: Some(noise_handshake_state),
transport_state: None,
configured_remote_static_pk_bytes: initial_remote_static_pk_bytes,
verified_peer_static_pk: None,
current_status: MechanismStatus::Initializing, error_reason_str: None,
pending_outgoing_handshake_msg: None,
})
}
fn transition_to_final_state(&mut self) -> Result<(), SnowError> {
let current_handshake_state = self
.state
.as_ref()
.ok_or(SnowError::State(StateProblem::HandshakeAlreadyFinished))?;
if !current_handshake_state.is_handshake_finished() {
tracing::warn!("NOISE_XX: transition_to_final_state called but handshake not finished.");
return Err(SnowError::State(StateProblem::HandshakeNotFinished));
}
let handshake_derived_peer_pk = match current_handshake_state.get_remote_static() {
Some(pk_bytes_slice) => pk_bytes_slice.to_vec(),
None => {
let err_msg =
"NOISE_XX handshake purportedly finished but no remote static key material was obtained."
.to_string();
tracing::error!("{}", err_msg);
self.current_status = MechanismStatus::Error;
self.error_reason_str = Some(err_msg);
return Err(SnowError::State(StateProblem::MissingKeyMaterial));
}
};
self.verified_peer_static_pk = Some(handshake_derived_peer_pk.clone());
if !self.is_server_role {
if let Some(expected_server_pk_array) = self.configured_remote_static_pk_bytes {
if handshake_derived_peer_pk.as_slice() != expected_server_pk_array.as_slice() {
tracing::error!(
mechanism = Self::NAME,
role = "Client",
"Server public key validation FAILED. Expected PK: {:?}, Actual PK from handshake: {:?}",
expected_server_pk_array.as_slice(),
handshake_derived_peer_pk.as_slice()
);
self.current_status = MechanismStatus::Error;
self.error_reason_str =
Some("NOISE_XX: Server public key mismatch. Connection rejected.".into());
return Err(SnowError::Decrypt);
}
tracing::debug!(
mechanism = Self::NAME,
role = "Client",
"Server static public key successfully verified."
);
} else {
let err_msg =
"NOISE_XX Client: Configuration error - missing expected remote server public key for verification."
.to_string();
tracing::error!("{}", err_msg);
self.current_status = MechanismStatus::Error;
self.error_reason_str = Some(err_msg);
return Err(SnowError::Prereq(Prerequisite::RemotePublicKey));
}
} else {
tracing::debug!(
mechanism = Self::NAME,
role = "Server",
"Learned and cryptographically verified client's static public key: {:?}",
handshake_derived_peer_pk.as_slice()
);
}
if !self.is_server_role {
if let Some(expected_server_pk_array) = self.configured_remote_static_pk_bytes {
if handshake_derived_peer_pk.as_slice() != expected_server_pk_array.as_slice() {
tracing::error!(
mechanism = Self::NAME,
role = "Client",
"Server public key validation FAILED during NOISE_XX handshake.
Expected PK (configured via socket option): {:?},
Actual PK received from peer during handshake: {:?}",
expected_server_pk_array.as_slice(),
handshake_derived_peer_pk.as_slice()
);
self.current_status = MechanismStatus::Error;
self.error_reason_str =
Some("NOISE_XX: Server public key mismatch. Connection rejected.".into());
return Err(SnowError::Decrypt);
}
tracing::debug!(
mechanism = Self::NAME,
role = "Client",
"Server static public key successfully verified against configuration."
);
} else {
let err_msg = "NOISE_XX Client: Configuration error - missing expected remote server public key for verification post-handshake.".to_string();
tracing::error!("{}", err_msg);
self.current_status = MechanismStatus::Error;
self.error_reason_str = Some(err_msg);
return Err(SnowError::Prereq(Prerequisite::RemotePublicKey)); }
} else {
tracing::debug!(
mechanism = Self::NAME,
role = "Server",
"Learned and cryptographically verified client's static public key: {:?}",
handshake_derived_peer_pk.as_slice()
);
}
let handshake_state_to_consume = self
.state
.take()
.expect("HandshakeState was None unexpectedly during transition");
match handshake_state_to_consume.into_transport_mode() {
Ok(ts) => {
self.transport_state = Some(ts); self.current_status = MechanismStatus::Ready;
tracing::info!(
mechanism = Self::NAME,
"Handshake successful. Transitioned to transport mode. Status: Ready."
);
Ok(())
}
Err(e) => {
let err_msg = format!(
"NOISE_XX: Failed to transition to transport mode after handshake: {:?}",
e
);
tracing::error!("{}", err_msg);
self.current_status = MechanismStatus::Error;
self.error_reason_str = Some(err_msg);
Err(e)
}
}
}
pub fn into_data_cipher(mut self) -> Result<(Box<dyn IDataCipher>, Option<Vec<u8>>), ZmqError> {
if self.current_status != MechanismStatus::Ready {
return Err(ZmqError::InvalidState(
"Noise handshake not complete, cannot create data cipher.".into(),
));
}
let ts = self.transport_state.take().ok_or_else(|| {
ZmqError::Internal("NoiseXxMechanism is Ready but TransportState is missing.".into())
})?;
let peer_id = self.verified_peer_static_pk.clone();
Ok((
Box::new(NoiseDataCipher {
transport_state: ts,
}),
peer_id,
))
}
}
#[async_trait::async_trait]
impl Mechanism for NoiseXxMechanism {
fn name(&self) -> &'static str {
Self::NAME
}
fn produce_token(&mut self) -> Result<Option<Vec<u8>>, ZmqError> {
if let Some(msg_to_send) = self.pending_outgoing_handshake_msg.take() {
tracing::debug!(
"NOISE_XX produce_token: Sending pending msg (len {}) from previous process_token.",
msg_to_send.len()
);
if !self.is_server_role {
if let Some(state) = self.state.as_ref() {
if state.is_handshake_finished() {
tracing::debug!(
"NOISE_XX produce_token (Initiator): Final message produced. Transitioning state (verifies peer PK)."
);
self.transition_to_final_state()?;
}
}
}
return Ok(Some(msg_to_send));
}
if self.current_status == MechanismStatus::Ready
|| self.current_status == MechanismStatus::Error
{
return Ok(None);
}
if self.current_status == MechanismStatus::Initializing {
self.current_status = MechanismStatus::Handshaking;
if self.is_server_role {
return Ok(None);
}
}
let handshake_state = self.state.as_mut().ok_or_else(|| {
ZmqError::InvalidState(
"Noise HandshakeState missing when trying to produce new token.".into(),
)
})?;
if handshake_state.is_my_turn() {
let mut msg_buf = vec![0u8; 1024];
let len = handshake_state.write_message(&[], &mut msg_buf)?;
msg_buf.truncate(len);
tracing::debug!(
"NOISE_XX produce_token (is_server={}): Generated NEW handshake message (len {}).",
self.is_server_role,
len
);
if handshake_state.is_handshake_finished() {
tracing::debug!(
"NOISE_XX produce_token: snow::HandshakeState became finished *after this write*. Transitioning state."
);
if self.is_server_role {
self.transition_to_final_state()?;
}
}
Ok(Some(msg_buf))
} else {
tracing::trace!(
"NOISE_XX produce_token: Not my turn and no pending message. Waiting for peer."
);
Ok(None) }
}
fn process_token(&mut self, token: &[u8]) -> Result<ProcessTokenAction, ZmqError> {
if self.current_status == MechanismStatus::Error
|| self.current_status == MechanismStatus::Ready
{
return Err(ZmqError::InvalidState(
"NOISE_XX: Processing token in Error/Ready state".into(),
));
}
if self.current_status == MechanismStatus::Initializing {
self.current_status = MechanismStatus::Handshaking;
}
tracing::debug!(
"NOISE_XX process_token (is_server={}): Received handshake message (len {})",
self.is_server_role,
token.len()
);
let handshake_state = self.state.as_mut().ok_or_else(|| {
ZmqError::InvalidState("Noise HandshakeState missing before process_token".into())
})?;
let mut read_payload_buf = vec![0u8; 1024];
let _payload_len = handshake_state.read_message(token, &mut read_payload_buf)?;
if handshake_state.is_handshake_finished() {
tracing::debug!(
"NOISE_XX process_token (is_server={}): snow::HandshakeState finished after read_message. Transitioning state.",
self.is_server_role
);
self.transition_to_final_state()?; Ok(ProcessTokenAction::HandshakeComplete)
} else if handshake_state.is_my_turn() {
let mut msg_buf = vec![0u8; 1024];
let len = handshake_state.write_message(&[], &mut msg_buf)?;
msg_buf.truncate(len);
tracing::debug!(
"NOISE_XX process_token (is_server={}): Generated next handshake message (len {}) for pending send.",
self.is_server_role,
len
);
self.pending_outgoing_handshake_msg = Some(msg_buf);
Ok(ProcessTokenAction::ProduceAndSend)
} else {
tracing::trace!(
"NOISE_XX process_token (is_server={}): Processed token, awaiting peer's next. No pending message generated.",
self.is_server_role
);
self.pending_outgoing_handshake_msg = None;
Ok(ProcessTokenAction::ContinueWaiting)
}
}
fn status(&self) -> MechanismStatus {
self.current_status
}
fn peer_identity(&self) -> Option<Vec<u8>> {
self.verified_peer_static_pk.clone()
}
fn metadata(&self) -> Option<Metadata> {
None
}
fn set_error(&mut self, reason: String) {
tracing::error!("NOISE_XX Mechanism error set: {}", reason);
self.current_status = MechanismStatus::Error;
self.error_reason_str = Some(reason);
self.state = None; self.transport_state = None; self.pending_outgoing_handshake_msg = None;
}
fn error_reason(&self) -> Option<&str> {
self.error_reason_str.as_deref()
}
fn zap_request_needed(&mut self) -> Option<Vec<Vec<u8>>> {
None
}
fn process_zap_reply(&mut self, _reply_frames: &[Vec<u8>]) -> Result<(), ZmqError> {
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn into_framer(
mut self: Box<Self>,
max_msg_size: i64,
) -> Result<(Box<dyn ISecureFramer>, Option<Vec<u8>>), ZmqError> {
if self.current_status != MechanismStatus::Ready {
return Err(ZmqError::InvalidState(
"Noise handshake not complete, cannot create framer.".into(),
));
}
let (cipher, peer_id) = self.into_data_cipher()?;
let framer = Box::new(LengthPrefixedFramer::new(cipher, max_msg_size));
Ok((framer, peer_id))
}
}
#[derive(Debug)] struct NoiseDataCipher {
transport_state: TransportState,
}
impl IDataCipher for NoiseDataCipher {
fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, ZmqError> {
const NOISE_TAG_LEN: usize = 16;
if plaintext.len() > (u16::MAX as usize - NOISE_TAG_LEN) {
return Err(ZmqError::InvalidMessage(
"Payload too large for Noise frame.".into(),
));
}
let mut output_buffer = vec![0u8; plaintext.len() + NOISE_TAG_LEN];
let bytes_written = self
.transport_state
.write_message(plaintext, &mut output_buffer)?;
output_buffer.truncate(bytes_written);
Ok(output_buffer)
}
fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, ZmqError> {
const MIN_NOISE_MSG_LEN: usize = 16; if ciphertext.len() < MIN_NOISE_MSG_LEN {
return Err(ZmqError::ProtocolViolation(
"Noise message too short for tag.".into(),
));
}
let max_plaintext_len = ciphertext.len() - MIN_NOISE_MSG_LEN;
let mut decrypted_buffer = vec![0u8; max_plaintext_len];
let len_decrypted = self
.transport_state
.read_message(ciphertext, &mut decrypted_buffer)?;
decrypted_buffer.truncate(len_decrypted);
Ok(decrypted_buffer)
}
}
impl From<SnowError> for ZmqError {
fn from(e: SnowError) -> Self {
tracing::warn!("Snow protocol error occurred: {}", e); match e {
SnowError::Pattern(problem) => {
ZmqError::SecurityError(format!("Noise pattern configuration error: {:?}", problem))
}
SnowError::Init(stage) => {
ZmqError::SecurityError(format!("Noise initialization error at stage: {:?}", stage))
}
SnowError::Prereq(prereq) => match prereq {
Prerequisite::LocalPrivateKey => ZmqError::SecurityError(
"Noise prerequisite error: Local private key missing or invalid.".into(),
),
Prerequisite::RemotePublicKey => ZmqError::SecurityError(
"Noise prerequisite error: Remote public key missing or invalid for current operation."
.into(),
),
},
SnowError::State(problem) => match problem {
StateProblem::MissingKeyMaterial => ZmqError::SecurityError(
"Noise state error: Missing required key material for operation.".into(),
),
StateProblem::HandshakeNotFinished => ZmqError::InvalidState(
"Noise state error: Handshake is not yet finished for requested operation.".into(),
),
StateProblem::HandshakeAlreadyFinished => {
ZmqError::InvalidState("Noise state error: Handshake is already finished.".into())
}
_ => ZmqError::SecurityError(format!("Noise state machine error: {:?}", problem)),
},
SnowError::Input => {
ZmqError::SecurityError(
"Noise input error: Invalid message format or size for current state.".into(),
)
}
SnowError::Dh => ZmqError::SecurityError("Noise Diffie-Hellman operation failed.".into()),
SnowError::Decrypt => {
ZmqError::SecurityError(
"Noise decrypt/authentication failed (e.g., bad MAC or ciphertext).".into(),
)
}
#[cfg(feature = "hfs")] SnowError::Kem => {
ZmqError::SecurityError("Noise Key Encapsulation Mechanism (KEM) failed.".into())
}
_ => {
tracing::error!("Unhandled snow::Error variant: {}", e);
ZmqError::SecurityError(format!("Unhandled or new Noise protocol error: {}", e))
}
}
}
}