use prost::Message;
use crate::broker::protocol::{
Endpoint, Frame, FrameKind, FramingError, PayloadEncoding, ENVELOPE_VERSION, MAX_FRAME_BYTES,
PROTOCOL_VERSION,
};
pub const FRAME_HEADER_BYTES: usize = 5;
impl Frame {
pub fn request(payload_protocol: u32, payload: Vec<u8>) -> Self {
Self {
envelope_version: PROTOCOL_VERSION,
kind: FrameKind::Request as i32,
payload_protocol,
payload,
request_id: 0,
payload_encoding: PayloadEncoding::None as i32,
deadline_unix_ms: 0,
traceparent: String::new(),
tracestate: String::new(),
}
}
pub fn response_to(request: &Self, payload: Vec<u8>) -> Self {
Self {
envelope_version: PROTOCOL_VERSION,
kind: FrameKind::Response as i32,
payload_protocol: request.payload_protocol,
payload,
request_id: request.request_id,
payload_encoding: PayloadEncoding::None as i32,
deadline_unix_ms: 0,
traceparent: request.traceparent.clone(),
tracestate: request.tracestate.clone(),
}
}
#[must_use]
pub fn with_request_id(mut self, request_id: u64) -> Self {
self.request_id = request_id;
self
}
}
pub fn encode_framed(frame: &Frame) -> Result<Vec<u8>, FramingError> {
let body_len = frame.encoded_len();
if body_len > MAX_FRAME_BYTES {
return Err(FramingError::FrameTooLarge {
body_length: body_len,
cap: MAX_FRAME_BYTES,
});
}
let mut wire = Vec::with_capacity(FRAME_HEADER_BYTES + body_len);
wire.push(ENVELOPE_VERSION);
wire.extend_from_slice(&(body_len as u32).to_le_bytes());
frame
.encode(&mut wire)
.expect("prost encoding into Vec cannot fail because Vec writes are infallible");
Ok(wire)
}
#[derive(Debug, Clone, PartialEq)]
pub struct DecodedFramed {
pub frame: Frame,
pub consumed: usize,
}
pub fn try_decode_framed(buf: &[u8]) -> Result<Option<DecodedFramed>, FramingError> {
if buf.is_empty() {
return Ok(None);
}
if buf[0] != ENVELOPE_VERSION {
return Err(FramingError::UnsupportedFramingVersion {
got: buf[0],
expected: ENVELOPE_VERSION,
});
}
if buf.len() < FRAME_HEADER_BYTES {
return Ok(None);
}
let body_len = u32::from_le_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
if body_len > MAX_FRAME_BYTES {
return Err(FramingError::FrameTooLarge {
body_length: body_len,
cap: MAX_FRAME_BYTES,
});
}
let total = FRAME_HEADER_BYTES + body_len;
if buf.len() < total {
return Ok(None);
}
let frame = Frame::decode(&buf[FRAME_HEADER_BYTES..total]).map_err(FramingError::Decode)?;
Ok(Some(DecodedFramed {
frame,
consumed: total,
}))
}
impl Endpoint {
pub fn windows_pipe(
namespace_id: impl Into<String>,
pipe_name: impl Into<String>,
) -> Result<Self, EndpointNameError> {
let pipe_name = pipe_name.into();
if pipe_name.is_empty() {
return Err(EndpointNameError::Empty);
}
let lowered = pipe_name.to_ascii_lowercase().replace('/', "\\");
if lowered.starts_with("\\\\.\\pipe\\") {
return Err(EndpointNameError::PrefixedPipeName { got: pipe_name });
}
Ok(Self {
namespace_id: namespace_id.into(),
path: pipe_name,
})
}
pub fn unix_socket(
namespace_id: impl Into<String>,
socket_path: impl Into<String>,
) -> Result<Self, EndpointNameError> {
let socket_path = socket_path.into();
if socket_path.is_empty() {
return Err(EndpointNameError::Empty);
}
Ok(Self {
namespace_id: namespace_id.into(),
path: socket_path,
})
}
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum EndpointNameError {
#[error("endpoint name must not be empty")]
Empty,
#[error(
"windows pipe name must be bare (no \\\\.\\pipe\\ prefix), got {got:?}: \
running-process prepends the prefix when resolving the endpoint"
)]
PrefixedPipeName {
got: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_and_response_round_trip_through_buffer_codecs() {
let request = Frame::request(0x7A63, b"ping".to_vec()).with_request_id(42);
assert_eq!(request.envelope_version, PROTOCOL_VERSION);
assert_eq!(FrameKind::try_from(request.kind), Ok(FrameKind::Request));
assert_eq!(
PayloadEncoding::try_from(request.payload_encoding),
Ok(PayloadEncoding::None)
);
let response = Frame::response_to(&request, b"pong".to_vec());
assert_eq!(FrameKind::try_from(response.kind), Ok(FrameKind::Response));
assert_eq!(response.request_id, 42);
assert_eq!(response.payload_protocol, 0x7A63);
let wire = encode_framed(&request).expect("encode");
assert_eq!(wire[0], ENVELOPE_VERSION);
let decoded = try_decode_framed(&wire)
.expect("decode")
.expect("complete frame");
assert_eq!(decoded.frame, request);
assert_eq!(decoded.consumed, wire.len());
}
#[test]
fn response_echoes_trace_context() {
let mut request = Frame::request(0x7A63, Vec::new()).with_request_id(1);
request.traceparent = "00-abc-def-01".to_owned();
request.tracestate = "vendor=1".to_owned();
let response = Frame::response_to(&request, Vec::new());
assert_eq!(response.traceparent, request.traceparent);
assert_eq!(response.tracestate, request.tracestate);
}
#[test]
fn try_decode_framed_waits_for_complete_frames() {
let wire = encode_framed(&Frame::request(0x7001, b"abc".to_vec())).expect("encode");
assert!(try_decode_framed(&[]).expect("empty").is_none());
for cut in 1..wire.len() {
assert!(
try_decode_framed(&wire[..cut]).expect("partial").is_none(),
"partial frame of {cut} bytes must not decode"
);
}
let mut two = wire.clone();
two.extend_from_slice(&wire);
let first = try_decode_framed(&two).expect("decode").expect("complete");
assert_eq!(first.consumed, wire.len());
}
#[test]
fn try_decode_framed_rejects_foreign_version_and_oversize() {
assert!(matches!(
try_decode_framed(&[2, 0, 0, 0, 0]),
Err(FramingError::UnsupportedFramingVersion { got: 2, .. })
));
let mut oversize = vec![ENVELOPE_VERSION];
oversize.extend_from_slice(&(MAX_FRAME_BYTES as u32 + 1).to_le_bytes());
assert!(matches!(
try_decode_framed(&oversize),
Err(FramingError::FrameTooLarge { .. })
));
}
#[test]
fn endpoint_constructors_enforce_naming_rules() {
let pipe = Endpoint::windows_pipe("svc", "svc-pipe").expect("bare name");
assert_eq!(pipe.namespace_id, "svc");
assert_eq!(pipe.path, "svc-pipe");
assert_eq!(
Endpoint::windows_pipe("svc", r"\\.\pipe\svc-pipe"),
Err(EndpointNameError::PrefixedPipeName {
got: r"\\.\pipe\svc-pipe".to_owned()
})
);
assert_eq!(
Endpoint::windows_pipe("svc", "//./pipe/svc-pipe"),
Err(EndpointNameError::PrefixedPipeName {
got: "//./pipe/svc-pipe".to_owned()
})
);
assert_eq!(
Endpoint::windows_pipe("svc", ""),
Err(EndpointNameError::Empty)
);
let sock = Endpoint::unix_socket("svc", "/tmp/svc.sock").expect("path");
assert_eq!(sock.path, "/tmp/svc.sock");
assert_eq!(
Endpoint::unix_socket("svc", ""),
Err(EndpointNameError::Empty)
);
}
}