use bytes::{Buf, BufMut, Bytes, BytesMut};
use sha2::{Digest, Sha256};
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;
use crate::errors::{ProtocolError, Result};
use crate::protocol::MessageType;
const HEADER_LENGTH: u32 = 116; const MESSAGE_TYPE_LENGTH: usize = 32;
const MESSAGE_ID_LENGTH: usize = 16;
const PAYLOAD_DIGEST_LENGTH: usize = 32;
const TOTAL_HEADER_SIZE: usize = 120;
pub const MAX_PAYLOAD_SIZE: u32 = 10 * 1024 * 1024;
#[allow(dead_code)]
const HL_OFFSET: usize = 0;
#[allow(dead_code)]
const MESSAGE_TYPE_OFFSET: usize = 4;
#[allow(dead_code)]
const SCHEMA_VERSION_OFFSET: usize = 36;
#[allow(dead_code)]
const CREATED_DATE_OFFSET: usize = 40;
#[allow(dead_code)]
const SEQUENCE_NUMBER_OFFSET: usize = 48;
#[allow(dead_code)]
const FLAGS_OFFSET: usize = 56;
#[allow(dead_code)]
const MESSAGE_ID_OFFSET: usize = 64;
#[allow(dead_code)]
const PAYLOAD_DIGEST_OFFSET: usize = 80;
#[allow(dead_code)]
const PAYLOAD_TYPE_OFFSET: usize = 112;
#[allow(dead_code)]
const PAYLOAD_LENGTH_OFFSET: usize = 116;
#[allow(dead_code)]
const PAYLOAD_OFFSET: usize = 120;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum PayloadType {
Undefined = 0,
Output = 1,
Error = 2,
Size = 3,
Parameter = 4,
HandshakeRequest = 5,
HandshakeResponse = 6,
HandshakeComplete = 7,
EncChallengeRequest = 8,
EncChallengeResponse = 9,
Flag = 10,
StdErr = 11,
ExitCode = 12,
}
impl PayloadType {
pub fn from_u32(value: u32) -> Result<Self> {
match value {
0 => Ok(PayloadType::Undefined),
1 => Ok(PayloadType::Output),
2 => Ok(PayloadType::Error),
3 => Ok(PayloadType::Size),
4 => Ok(PayloadType::Parameter),
5 => Ok(PayloadType::HandshakeRequest),
6 => Ok(PayloadType::HandshakeResponse),
7 => Ok(PayloadType::HandshakeComplete),
8 => Ok(PayloadType::EncChallengeRequest),
9 => Ok(PayloadType::EncChallengeResponse),
10 => Ok(PayloadType::Flag),
11 => Ok(PayloadType::StdErr),
12 => Ok(PayloadType::ExitCode),
_ => {
Err(ProtocolError::InvalidMessage(format!("Invalid PayloadType: {}", value)).into())
}
}
}
pub fn to_u32(self) -> u32 {
self as u32
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum PayloadTypeFlag {
DisconnectToPort = 1,
TerminateSession = 2,
ConnectToPortError = 3,
}
pub mod flags {
pub const SYN: u64 = 1 << 0;
pub const FIN: u64 = 1 << 1;
}
#[derive(Debug, Clone)]
pub struct ClientMessage {
pub header_length: u32,
pub message_type: MessageType,
pub schema_version: u32,
pub created_date: u64,
pub sequence_number: i64,
pub flags: u64,
pub message_id: Uuid,
pub payload_digest: [u8; 32],
pub payload_type: PayloadType,
pub payload_length: u32,
pub payload: Bytes,
}
impl ClientMessage {
pub fn new(
message_type: MessageType,
sequence_number: i64,
payload_type: PayloadType,
payload: Bytes,
) -> Self {
let payload_length = payload.len() as u32;
let payload_digest = compute_digest(&payload);
let created_date = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
Self {
header_length: HEADER_LENGTH,
message_type,
schema_version: 1,
created_date,
sequence_number,
flags: 0,
message_id: Uuid::new_v4(),
payload_digest,
payload_type,
payload_length,
payload,
}
}
pub fn serialize(&self) -> Result<Bytes> {
let total_len = TOTAL_HEADER_SIZE + self.payload.len();
let mut buf = BytesMut::with_capacity(total_len);
buf.put_u32(self.header_length);
buf.put_slice(&message_type_to_padded(&self.message_type));
buf.put_u32(self.schema_version);
buf.put_u64(self.created_date);
buf.put_i64(self.sequence_number);
buf.put_u64(self.flags);
let uuid_bytes = self.message_id.as_bytes();
buf.put_slice(&uuid_bytes[8..16]); buf.put_slice(&uuid_bytes[0..8]);
buf.put_slice(&self.payload_digest);
buf.put_u32(self.payload_type.to_u32());
buf.put_u32(self.payload_length);
buf.put_slice(&self.payload);
Ok(buf.freeze())
}
pub fn deserialize(mut data: Bytes) -> Result<Self> {
if data.len() < TOTAL_HEADER_SIZE {
return Err(ProtocolError::InvalidMessage(format!(
"Message too short: {} bytes (need at least {})",
data.len(),
TOTAL_HEADER_SIZE
))
.into());
}
let header_length = data.get_u32();
if header_length != HEADER_LENGTH {
return Err(ProtocolError::InvalidMessage(format!(
"Invalid header length: {} (expected {})",
header_length, HEADER_LENGTH
))
.into());
}
let mut msg_type_bytes = [0u8; MESSAGE_TYPE_LENGTH];
data.copy_to_slice(&mut msg_type_bytes);
let message_type = message_type_from_padded(&msg_type_bytes)?;
let schema_version = data.get_u32();
let created_date = data.get_u64();
let sequence_number = data.get_i64();
let flags = data.get_u64();
let mut wire_uuid = [0u8; MESSAGE_ID_LENGTH];
data.copy_to_slice(&mut wire_uuid);
let mut message_id_bytes = [0u8; MESSAGE_ID_LENGTH];
message_id_bytes[0..8].copy_from_slice(&wire_uuid[8..16]); message_id_bytes[8..16].copy_from_slice(&wire_uuid[0..8]); let message_id = Uuid::from_bytes(message_id_bytes);
let mut payload_digest = [0u8; PAYLOAD_DIGEST_LENGTH];
data.copy_to_slice(&mut payload_digest);
let payload_type_u32 = data.get_u32();
let payload_type = PayloadType::from_u32(payload_type_u32)?;
let payload_length = data.get_u32();
if payload_length > MAX_PAYLOAD_SIZE {
return Err(ProtocolError::InvalidMessage(format!(
"Payload too large: {} bytes (max {})",
payload_length, MAX_PAYLOAD_SIZE
))
.into());
}
if data.remaining() < payload_length as usize {
return Err(ProtocolError::InvalidMessage(format!(
"Payload length mismatch: declared {}, have {}",
payload_length,
data.remaining()
))
.into());
}
let payload = data.copy_to_bytes(payload_length as usize);
let msg = Self {
header_length,
message_type,
schema_version,
created_date,
sequence_number,
flags,
message_id,
payload_digest,
payload_type,
payload_length,
payload,
};
msg.validate()?;
Ok(msg)
}
pub fn validate(&self) -> Result<()> {
if self.header_length != HEADER_LENGTH {
return Err(
ProtocolError::InvalidMessage("HeaderLength must be 116".to_string()).into(),
);
}
const MAX_SAFE_SEQUENCE: i64 = i64::MAX - 1000;
const MIN_SAFE_SEQUENCE: i64 = i64::MIN + 1000;
if self.sequence_number > MAX_SAFE_SEQUENCE || self.sequence_number < MIN_SAFE_SEQUENCE {
return Err(ProtocolError::InvalidMessage(format!(
"Sequence number {} outside safe bounds [{}, {}]",
self.sequence_number, MIN_SAFE_SEQUENCE, MAX_SAFE_SEQUENCE
))
.into());
}
if self.payload_length != self.payload.len() as u32 {
return Err(ProtocolError::InvalidMessage(format!(
"PayloadLength mismatch: declared {}, actual {}",
self.payload_length,
self.payload.len()
))
.into());
}
if self.payload_length > 0 {
let computed_digest = compute_digest(&self.payload);
if computed_digest != self.payload_digest {
return Err(ProtocolError::InvalidMessage(
"Payload digest validation failed (SHA-256 mismatch)".to_string(),
)
.into());
}
}
Ok(())
}
}
fn compute_digest(payload: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(payload);
hasher.finalize().into()
}
fn message_type_to_padded(msg_type: &MessageType) -> [u8; MESSAGE_TYPE_LENGTH] {
let s = match msg_type {
MessageType::InputStreamData => "input_stream_data",
MessageType::OutputStreamData => "output_stream_data",
MessageType::Acknowledge => "acknowledge",
MessageType::ChannelClosed => "channel_closed",
MessageType::StartPublication => "start_publication",
MessageType::PausePublication => "pause_publication",
};
let mut bytes = [b' '; MESSAGE_TYPE_LENGTH];
let len = s.len().min(MESSAGE_TYPE_LENGTH);
bytes[..len].copy_from_slice(&s.as_bytes()[..len]);
bytes
}
fn message_type_from_padded(bytes: &[u8; MESSAGE_TYPE_LENGTH]) -> Result<MessageType> {
let s = std::str::from_utf8(bytes)
.map_err(|e| ProtocolError::InvalidMessage(format!("Invalid UTF-8 in MessageType: {}", e)))?
.trim()
.trim_end_matches('\0');
match s {
"input_stream_data" => Ok(MessageType::InputStreamData),
"output_stream_data" => Ok(MessageType::OutputStreamData),
"acknowledge" => Ok(MessageType::Acknowledge),
"channel_closed" => Ok(MessageType::ChannelClosed),
"start_publication" => Ok(MessageType::StartPublication),
"pause_publication" => Ok(MessageType::PausePublication),
_ => Err(ProtocolError::InvalidMessage(format!("Unknown MessageType: {}", s)).into()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_payload_type_conversion() {
assert_eq!(PayloadType::from_u32(1).unwrap(), PayloadType::Output);
assert_eq!(PayloadType::from_u32(12).unwrap(), PayloadType::ExitCode);
assert!(PayloadType::from_u32(99).is_err());
assert_eq!(PayloadType::Output.to_u32(), 1);
assert_eq!(PayloadType::ExitCode.to_u32(), 12);
}
#[test]
fn test_message_type_padding() {
let msg_type = MessageType::InputStreamData;
let padded = message_type_to_padded(&msg_type);
assert_eq!(padded.len(), MESSAGE_TYPE_LENGTH);
assert_eq!(&padded[..17], b"input_stream_data");
assert_eq!(&padded[17..], &[b' '; 15]);
let parsed = message_type_from_padded(&padded).unwrap();
assert_eq!(parsed, msg_type);
}
#[test]
fn test_digest_computation() {
let payload = b"test payload";
let digest = compute_digest(payload);
let digest2 = compute_digest(payload);
assert_eq!(digest, digest2);
let digest3 = compute_digest(b"different");
assert_ne!(digest, digest3);
}
#[test]
fn test_message_serialization_roundtrip() {
let payload = Bytes::from_static(b"Hello, SSM!");
let msg = ClientMessage::new(
MessageType::OutputStreamData,
42,
PayloadType::Output,
payload.clone(),
);
let serialized = msg.serialize().unwrap();
assert_eq!(serialized.len(), TOTAL_HEADER_SIZE + payload.len());
let deserialized = ClientMessage::deserialize(serialized).unwrap();
assert_eq!(deserialized.message_type, msg.message_type);
assert_eq!(deserialized.sequence_number, msg.sequence_number);
assert_eq!(deserialized.payload_type, msg.payload_type);
assert_eq!(deserialized.payload, msg.payload);
assert_eq!(deserialized.payload_digest, msg.payload_digest);
}
#[test]
fn test_message_validation() {
let payload = Bytes::from_static(b"test");
let mut msg = ClientMessage::new(
MessageType::InputStreamData,
1,
PayloadType::Output,
payload,
);
assert!(msg.validate().is_ok());
msg.header_length = 100;
assert!(msg.validate().is_err());
msg.header_length = HEADER_LENGTH;
msg.payload_digest = [0u8; 32];
assert!(msg.validate().is_err());
}
#[test]
fn test_flags() {
assert_eq!(flags::SYN, 1);
assert_eq!(flags::FIN, 2);
assert_eq!(flags::SYN | flags::FIN, 3);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
fn payload_type_strategy() -> impl Strategy<Value = PayloadType> {
prop_oneof![
Just(PayloadType::Undefined),
Just(PayloadType::Output),
Just(PayloadType::Error),
Just(PayloadType::Size),
Just(PayloadType::Parameter),
Just(PayloadType::HandshakeRequest),
Just(PayloadType::HandshakeResponse),
Just(PayloadType::HandshakeComplete),
Just(PayloadType::EncChallengeRequest),
Just(PayloadType::EncChallengeResponse),
Just(PayloadType::Flag),
Just(PayloadType::StdErr),
Just(PayloadType::ExitCode),
]
}
fn message_type_strategy() -> impl Strategy<Value = MessageType> {
prop_oneof![
Just(MessageType::InputStreamData),
Just(MessageType::OutputStreamData),
Just(MessageType::Acknowledge),
Just(MessageType::ChannelClosed),
Just(MessageType::StartPublication),
Just(MessageType::PausePublication),
]
}
proptest! {
#[test]
fn roundtrip_preserves_data(
seq_num in any::<i64>(),
payload in prop::collection::vec(any::<u8>(), 0..4096),
payload_type in payload_type_strategy(),
message_type in message_type_strategy(),
) {
let msg = ClientMessage::new(
message_type,
seq_num,
payload_type,
Bytes::from(payload.clone()),
);
let serialized = msg.serialize().expect("serialization should succeed");
let deserialized = ClientMessage::deserialize(serialized)
.expect("deserialization should succeed");
prop_assert_eq!(deserialized.message_type, msg.message_type);
prop_assert_eq!(deserialized.sequence_number, msg.sequence_number);
prop_assert_eq!(deserialized.payload_type, msg.payload_type);
prop_assert_eq!(deserialized.payload.as_ref(), payload.as_slice());
}
#[test]
fn deserialize_never_panics(data in prop::collection::vec(any::<u8>(), 0..8192)) {
let _ = ClientMessage::deserialize(Bytes::from(data));
}
#[test]
fn serialized_length_correct(
payload in prop::collection::vec(any::<u8>(), 0..4096),
) {
let msg = ClientMessage::new(
MessageType::OutputStreamData,
0,
PayloadType::Output,
Bytes::from(payload.clone()),
);
let serialized = msg.serialize().expect("serialization should succeed");
prop_assert_eq!(serialized.len(), TOTAL_HEADER_SIZE + payload.len());
}
#[test]
fn digest_matches_payload(
payload in prop::collection::vec(any::<u8>(), 0..4096),
) {
let msg = ClientMessage::new(
MessageType::OutputStreamData,
0,
PayloadType::Output,
Bytes::from(payload.clone()),
);
let expected_digest = compute_digest(&payload);
prop_assert_eq!(msg.payload_digest, expected_digest);
}
#[test]
fn validation_passes_for_valid_messages(
seq_num in any::<i64>(),
payload in prop::collection::vec(any::<u8>(), 0..1024),
payload_type in payload_type_strategy(),
message_type in message_type_strategy(),
) {
let msg = ClientMessage::new(
message_type,
seq_num,
payload_type,
Bytes::from(payload),
);
prop_assert!(msg.validate().is_ok(), "Valid message should pass validation");
}
}
}