use crate::broker::backend_lifecycle::identity::DaemonProcess;
use crate::broker::backend_lifecycle::probe::{
endpoint_probe_request_from_frame, endpoint_probe_response_frame, EndpointProbeServerError,
};
use crate::broker::protocol::{
encode_framed, registry, try_decode_framed, Frame, FrameKind, FramingError, ENVELOPE_VERSION,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LegacyClassification {
Legacy,
NotLegacy,
NeedMoreBytes,
}
#[derive(Debug)]
pub enum MuxPoll {
NeedMoreBytes,
Legacy,
ProbeAnswered {
reply: Vec<u8>,
consumed: usize,
},
Payload {
frame: Frame,
consumed: usize,
},
}
#[derive(Debug, thiserror::Error)]
pub enum MuxError {
#[error(transparent)]
Framing(#[from] FramingError),
#[error("malformed BackendHandle probe: {0}")]
MalformedProbe(#[from] EndpointProbeServerError),
#[error(
"unexpected first-party frame on backend endpoint \
(payload_protocol {payload_protocol:#06X})"
)]
UnexpectedFirstPartyFrame {
payload_protocol: u32,
},
#[error("frame for unserved payload protocol {payload_protocol:#06X}")]
UnservedPayloadProtocol {
payload_protocol: u32,
},
}
pub struct BackendEndpointMux<F> {
daemon: DaemonProcess,
served_payload_protocols: Vec<u32>,
legacy_detector: F,
}
impl<F> BackendEndpointMux<F>
where
F: Fn(&[u8]) -> LegacyClassification,
{
pub fn new(
daemon: DaemonProcess,
served_payload_protocols: &[u32],
legacy_detector: F,
) -> Self {
Self {
daemon,
served_payload_protocols: served_payload_protocols.to_vec(),
legacy_detector,
}
}
pub fn poll(&self, buf: &[u8]) -> Result<MuxPoll, MuxError> {
if buf.is_empty() {
return Ok(MuxPoll::NeedMoreBytes);
}
match (self.legacy_detector)(buf) {
LegacyClassification::Legacy => return Ok(MuxPoll::Legacy),
LegacyClassification::NeedMoreBytes => return Ok(MuxPoll::NeedMoreBytes),
LegacyClassification::NotLegacy => {}
}
if buf[0] != ENVELOPE_VERSION {
return Ok(MuxPoll::Legacy);
}
let Some(decoded) = try_decode_framed(buf)? else {
return Ok(MuxPoll::NeedMoreBytes);
};
let frame = decoded.frame;
let consumed = decoded.consumed;
if frame.payload_protocol == registry::BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL {
let request = endpoint_probe_request_from_frame(&frame)?;
let response = endpoint_probe_response_frame(&request, &self.daemon);
let reply = encode_framed(&response)?;
return Ok(MuxPoll::ProbeAnswered { reply, consumed });
}
if registry::is_first_party(frame.payload_protocol) {
return Err(MuxError::UnexpectedFirstPartyFrame {
payload_protocol: frame.payload_protocol,
});
}
if !self
.served_payload_protocols
.contains(&frame.payload_protocol)
{
return Err(MuxError::UnservedPayloadProtocol {
payload_protocol: frame.payload_protocol,
});
}
let _ = FrameKind::try_from(frame.kind);
Ok(MuxPoll::Payload { frame, consumed })
}
pub fn daemon(&self) -> &DaemonProcess {
&self.daemon
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::broker::backend_lifecycle::probe::PROBE_NONCE_BYTES;
use crate::broker::protocol::{registry, Endpoint, PayloadEncoding};
use prost::Message;
const TEST_PROTOCOL: u32 = 0x7001;
fn test_daemon() -> DaemonProcess {
let endpoint = Endpoint::unix_socket("mux-test", "/tmp/mux-test.sock").expect("endpoint");
DaemonProcess::current_process(endpoint, Some(30)).expect("identity")
}
fn test_mux() -> BackendEndpointMux<impl Fn(&[u8]) -> LegacyClassification> {
BackendEndpointMux::new(test_daemon(), &[TEST_PROTOCOL], |buf: &[u8]| {
if buf.len() < 8 {
return LegacyClassification::NeedMoreBytes;
}
let version = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]);
if version == 15 {
LegacyClassification::Legacy
} else {
LegacyClassification::NotLegacy
}
})
}
fn probe_request_wire(nonce: [u8; PROBE_NONCE_BYTES]) -> Vec<u8> {
let frame = Frame::request(
registry::BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL,
nonce.to_vec(),
)
.with_request_id(7);
encode_framed(&frame).expect("encode probe")
}
#[test]
fn empty_and_short_buffers_need_more_bytes() {
let mux = test_mux();
assert!(matches!(mux.poll(&[]), Ok(MuxPoll::NeedMoreBytes)));
assert!(matches!(mux.poll(&[1, 0, 0]), Ok(MuxPoll::NeedMoreBytes)));
}
#[test]
fn legacy_header_wins_even_with_frame_version_first_byte() {
let mux = test_mux();
let mut legacy = 257u32.to_le_bytes().to_vec();
legacy.extend_from_slice(&15u32.to_le_bytes());
assert_eq!(legacy[0], ENVELOPE_VERSION);
assert!(matches!(mux.poll(&legacy), Ok(MuxPoll::Legacy)));
assert!(matches!(
mux.poll(&[42, 0, 0, 0, 0, 16, 0, 0, 0]),
Ok(MuxPoll::Legacy)
));
}
#[test]
fn probe_request_is_answered_with_identity_echo() {
let mux = test_mux();
let nonce = [9u8; PROBE_NONCE_BYTES];
let wire = probe_request_wire(nonce);
assert!(matches!(
mux.poll(&wire[..wire.len() - 1]),
Ok(MuxPoll::NeedMoreBytes)
));
let MuxPoll::ProbeAnswered { reply, consumed } = mux.poll(&wire).expect("poll") else {
panic!("expected ProbeAnswered");
};
assert_eq!(consumed, wire.len());
let decoded = try_decode_framed(&reply)
.expect("decode")
.expect("complete");
let frame = decoded.frame;
assert_eq!(
frame.payload_protocol,
registry::BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL
);
assert_eq!(frame.request_id, 7);
assert_eq!(&frame.payload[..PROBE_NONCE_BYTES], &nonce);
let identity =
crate::broker::protocol::DaemonProcess::decode(&frame.payload[PROBE_NONCE_BYTES..])
.expect("identity payload");
assert_eq!(identity.pid, std::process::id());
}
#[test]
fn consumer_payload_frame_passes_through() {
let mux = test_mux();
let request = Frame::request(TEST_PROTOCOL, b"ping".to_vec()).with_request_id(3);
let wire = encode_framed(&request).expect("encode");
let MuxPoll::Payload { frame, consumed } = mux.poll(&wire).expect("poll") else {
panic!("expected Payload");
};
assert_eq!(frame, request);
assert_eq!(consumed, wire.len());
}
#[test]
fn first_party_and_unserved_protocols_are_rejected() {
let mux = test_mux();
let hello = Frame::request(registry::CONTROL_PAYLOAD_PROTOCOL, b"hello".to_vec());
let wire = encode_framed(&hello).expect("encode");
assert!(matches!(
mux.poll(&wire),
Err(MuxError::UnexpectedFirstPartyFrame {
payload_protocol: 0
})
));
let other = Frame::request(0x7002, Vec::new());
let wire = encode_framed(&other).expect("encode");
assert!(matches!(
mux.poll(&wire),
Err(MuxError::UnservedPayloadProtocol {
payload_protocol: 0x7002
})
));
}
#[test]
fn malformed_probe_is_connection_fatal() {
let mux = test_mux();
let mut bad = Frame::request(
registry::BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL,
vec![0u8; PROBE_NONCE_BYTES - 1],
);
bad.payload_encoding = PayloadEncoding::None as i32;
let wire = encode_framed(&bad).expect("encode");
assert!(matches!(mux.poll(&wire), Err(MuxError::MalformedProbe(_))));
}
}