use super::frame::{Flags, Frame, MessageKind, FRAME_HEADER_SIZE, MAX_FRAME_SIZE};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BuildError {
KindMissing,
PayloadTooLarge { encoded_len: usize, max: u32 },
FlagsNotAllowedForKind { kind: MessageKind, flags: u8 },
}
impl std::fmt::Display for BuildError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::KindMissing => write!(f, "FrameBuilder: kind() must be called before build()"),
Self::PayloadTooLarge { encoded_len, max } => write!(
f,
"FrameBuilder: encoded frame size {encoded_len} exceeds MAX_FRAME_SIZE ({max})"
),
Self::FlagsNotAllowedForKind { kind, flags } => write!(
f,
"FrameBuilder: flag bits 0x{flags:02x} not allowed on kind {kind:?}"
),
}
}
}
impl std::error::Error for BuildError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Compress {
No,
Yes,
}
#[derive(Debug, Clone)]
pub struct FrameBuilder {
kind: Option<MessageKind>,
correlation_id: u64,
stream_id: u16,
payload: Vec<u8>,
flags: Flags,
compress: Compress,
more_frames: bool,
}
impl FrameBuilder {
pub fn reply_to(correlation_id: u64) -> Self {
Self::with_correlation(correlation_id)
}
pub fn unsolicited() -> Self {
Self::with_correlation(0)
}
fn with_correlation(correlation_id: u64) -> Self {
Self {
kind: None,
correlation_id,
stream_id: 0,
payload: Vec::new(),
flags: Flags::empty(),
compress: Compress::No,
more_frames: false,
}
}
pub fn kind(mut self, kind: MessageKind) -> Self {
self.kind = Some(kind);
self
}
pub fn payload(mut self, payload: Vec<u8>) -> Self {
self.payload = payload;
self
}
pub fn stream_id(mut self, stream_id: u16) -> Self {
self.stream_id = stream_id;
self
}
pub fn flags(mut self, flags: Flags) -> Self {
self.flags = flags;
self
}
pub fn more_frames(mut self, more: bool) -> Self {
self.more_frames = more;
self
}
pub fn compress(mut self, yes: bool) -> Self {
self.compress = if yes { Compress::Yes } else { Compress::No };
self
}
pub fn build(self) -> Result<Frame, BuildError> {
let kind = self.kind.ok_or(BuildError::KindMissing)?;
let encoded_len = FRAME_HEADER_SIZE + self.payload.len();
if encoded_len > MAX_FRAME_SIZE as usize {
return Err(BuildError::PayloadTooLarge {
encoded_len,
max: MAX_FRAME_SIZE,
});
}
let mut flags = self.flags;
if self.more_frames {
flags = flags.insert(Flags::MORE_FRAMES);
} else {
flags = Flags::from_bits(flags.bits() & !Flags::MORE_FRAMES.bits());
}
let compressed = match self.compress {
Compress::No => false,
Compress::Yes => is_payload_compressible(&self.payload),
};
if compressed {
flags = flags.insert(Flags::COMPRESSED);
} else {
flags = Flags::from_bits(flags.bits() & !Flags::COMPRESSED.bits());
}
if !kind.permits_flags(flags) {
return Err(BuildError::FlagsNotAllowedForKind {
kind,
flags: flags.bits(),
});
}
Ok(Frame {
kind,
flags,
stream_id: self.stream_id,
correlation_id: self.correlation_id,
payload: self.payload,
})
}
}
fn is_payload_compressible(payload: &[u8]) -> bool {
payload.len() > 32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reply_to_propagates_correlation_id() {
let frame = FrameBuilder::reply_to(0xABCD)
.kind(MessageKind::Result)
.payload(b"ok".to_vec())
.build()
.expect("build");
assert_eq!(frame.correlation_id, 0xABCD);
assert_eq!(frame.kind, MessageKind::Result);
assert_eq!(frame.payload, b"ok");
}
#[test]
fn unsolicited_uses_zero_correlation() {
let frame = FrameBuilder::unsolicited()
.kind(MessageKind::Notice)
.payload(b"server-side notice".to_vec())
.build()
.expect("build");
assert_eq!(frame.correlation_id, 0);
}
#[test]
fn missing_kind_rejected() {
let err = FrameBuilder::reply_to(1).build().unwrap_err();
assert_eq!(err, BuildError::KindMissing);
}
#[test]
fn more_frames_last_frame_clears_the_flag() {
let middle = FrameBuilder::reply_to(7)
.kind(MessageKind::Result)
.payload(vec![0; 8])
.more_frames(true)
.build()
.expect("build middle");
assert!(
middle.flags.contains(Flags::MORE_FRAMES),
"middle frame must carry MORE_FRAMES"
);
let last = FrameBuilder::reply_to(7)
.kind(MessageKind::Result)
.payload(vec![0; 8])
.more_frames(false)
.build()
.expect("build last");
assert!(
!last.flags.contains(Flags::MORE_FRAMES),
"last frame must clear MORE_FRAMES"
);
}
#[test]
fn more_frames_default_is_last_frame() {
let frame = FrameBuilder::reply_to(1)
.kind(MessageKind::Pong)
.build()
.expect("build");
assert!(!frame.flags.contains(Flags::MORE_FRAMES));
}
#[test]
fn payload_at_max_size_accepted() {
let payload = vec![0u8; (MAX_FRAME_SIZE as usize) - FRAME_HEADER_SIZE];
let frame = FrameBuilder::reply_to(1)
.kind(MessageKind::Result)
.payload(payload)
.build()
.expect("build at limit");
assert_eq!(frame.encoded_len(), MAX_FRAME_SIZE);
}
#[test]
fn payload_over_max_size_rejected() {
let oversize = (MAX_FRAME_SIZE as usize) - FRAME_HEADER_SIZE + 1;
let payload = vec![0u8; oversize];
let err = FrameBuilder::reply_to(1)
.kind(MessageKind::Result)
.payload(payload)
.build()
.unwrap_err();
match err {
BuildError::PayloadTooLarge { encoded_len, max } => {
assert_eq!(max, MAX_FRAME_SIZE);
assert_eq!(encoded_len, MAX_FRAME_SIZE as usize + 1);
}
other => panic!("expected PayloadTooLarge, got {other:?}"),
}
}
#[test]
fn compression_fallback_drops_flag_for_incompressible_payload() {
let frame = FrameBuilder::reply_to(1)
.kind(MessageKind::Result)
.payload(b"tiny".to_vec())
.compress(true)
.build()
.expect("build");
assert!(
!frame.flags.contains(Flags::COMPRESSED),
"incompressible payload must not carry COMPRESSED"
);
}
#[test]
fn compression_kept_for_compressible_payload() {
let payload = b"abcabcabc".repeat(16);
let frame = FrameBuilder::reply_to(1)
.kind(MessageKind::Result)
.payload(payload)
.compress(true)
.build()
.expect("build");
assert!(frame.flags.contains(Flags::COMPRESSED));
}
#[test]
fn flags_not_allowed_for_kind_rejected_at_build() {
let err = FrameBuilder::reply_to(1)
.kind(MessageKind::Hello)
.flags(Flags::COMPRESSED)
.build()
.unwrap_err();
match err {
BuildError::FlagsNotAllowedForKind { kind, flags } => {
assert_eq!(kind, MessageKind::Hello);
assert_eq!(flags, Flags::COMPRESSED.bits());
}
other => panic!("expected FlagsNotAllowedForKind, got {other:?}"),
}
}
#[test]
fn stream_id_propagates() {
let frame = FrameBuilder::reply_to(1)
.kind(MessageKind::Result)
.stream_id(0xBEEF)
.build()
.expect("build");
assert_eq!(frame.stream_id, 0xBEEF);
}
}