use crate::error::{GcsError, GcsResult};
pub const HEADER_LEN: usize = 16;
pub const MAX_PAYLOAD_LEN: usize = 4 * 1024 * 1024;
pub const MSG_TYPE_REQUEST: u32 = 0x1000_0000;
pub const MSG_TYPE_RESPONSE: u32 = 0x2000_0000;
pub const MSG_TYPE_NOTIFY: u32 = 0x3000_0000;
pub const MSG_TYPE_MASK: u32 = 0xF000_0000;
pub const CATEGORY_COMPUTE_SYSTEM: u32 = 0x0010_0000;
pub const CATEGORY_COMPUTE_SERVICE: u32 = 0x0020_0000;
#[repr(u32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum RpcMessageType {
Create = 0x0101,
Start = 0x0201,
ShutdownGraceful = 0x0301,
ShutdownForced = 0x0401,
ExecuteProcess = 0x0501,
WaitForProcess = 0x0601,
SignalProcess = 0x0701,
ResizeConsole = 0x0801,
GetProperties = 0x0901,
ModifySettings = 0x0A01,
NegotiateProtocol = 0x0B01,
DumpStacks = 0x0C01,
DeleteContainerState = 0x0D01,
UpdateContainer = 0x0E01,
LifecycleNotification = 0x0F01,
ModifyServiceSettings = 0x1_0101,
}
impl RpcMessageType {
#[must_use]
const fn category(self) -> u32 {
match self {
Self::ModifyServiceSettings => CATEGORY_COMPUTE_SERVICE,
_ => CATEGORY_COMPUTE_SYSTEM,
}
}
#[must_use]
const fn proc_code(self) -> u32 {
(self as u32) & 0xFFFF
}
#[must_use]
pub const fn as_request_type(self) -> u32 {
MSG_TYPE_REQUEST | self.category() | self.proc_code()
}
#[must_use]
pub const fn as_response_type(self) -> u32 {
MSG_TYPE_RESPONSE | self.category() | self.proc_code()
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct FrameHeader {
pub r#type: u32,
pub size: u32,
pub message_id: u64,
}
pub fn encode_frame(r#type: u32, message_id: u64, payload: &[u8], out: &mut Vec<u8>) {
let total =
u32::try_from(HEADER_LEN + payload.len()).expect("frame total length must fit in u32");
out.clear();
out.reserve(HEADER_LEN + payload.len());
out.extend_from_slice(&r#type.to_le_bytes());
out.extend_from_slice(&total.to_le_bytes());
out.extend_from_slice(&message_id.to_le_bytes());
out.extend_from_slice(payload);
}
pub fn decode_header(bytes: &[u8; HEADER_LEN]) -> GcsResult<FrameHeader> {
let r#type = u32::from_le_bytes(
bytes[0..4]
.try_into()
.expect("static 4-byte slice of 16-byte header"),
);
let size = u32::from_le_bytes(
bytes[4..8]
.try_into()
.expect("static 4-byte slice of 16-byte header"),
);
let message_id = u64::from_le_bytes(
bytes[8..16]
.try_into()
.expect("static 8-byte slice of 16-byte header"),
);
if (size as usize) < HEADER_LEN {
return Err(GcsError::Protocol(format!(
"frame size {size} < header length {HEADER_LEN}"
)));
}
if (size as usize) > HEADER_LEN + MAX_PAYLOAD_LEN {
return Err(GcsError::Protocol(format!(
"frame size {size} exceeds MAX_PAYLOAD_LEN+header={}",
HEADER_LEN + MAX_PAYLOAD_LEN
)));
}
Ok(FrameHeader {
r#type,
size,
message_id,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_empty_payload() {
let mut buf = Vec::new();
encode_frame(0x0010_0001, 42, b"", &mut buf);
assert_eq!(buf.len(), HEADER_LEN);
let hdr_bytes: [u8; HEADER_LEN] = buf[..HEADER_LEN]
.try_into()
.expect("buf has HEADER_LEN bytes after encode_frame");
let h = decode_header(&hdr_bytes).unwrap();
assert_eq!(h.r#type, 0x0010_0001);
assert_eq!(h.size as usize, HEADER_LEN);
assert_eq!(h.message_id, 42);
}
#[test]
fn round_trip_with_payload() {
let payload = br#"{"hello":"world"}"#;
let mut buf = Vec::new();
encode_frame(0x1010_0001, 99, payload, &mut buf);
assert_eq!(buf.len(), HEADER_LEN + payload.len());
let hdr_bytes: [u8; HEADER_LEN] = buf[..HEADER_LEN]
.try_into()
.expect("buf has HEADER_LEN bytes after encode_frame");
let h = decode_header(&hdr_bytes).unwrap();
assert_eq!(h.size as usize, HEADER_LEN + payload.len());
assert_eq!(&buf[HEADER_LEN..], payload);
}
#[test]
fn decode_rejects_undersized_size_field() {
let mut bytes = [0u8; HEADER_LEN];
bytes[4..8].copy_from_slice(&8u32.to_le_bytes()); let err = decode_header(&bytes).unwrap_err();
assert!(matches!(err, GcsError::Protocol(_)));
}
#[test]
fn decode_rejects_oversized_size_field() {
let mut bytes = [0u8; HEADER_LEN];
let bad_size: u32 =
u32::try_from(HEADER_LEN + MAX_PAYLOAD_LEN + 1).expect("test constant fits in u32");
bytes[4..8].copy_from_slice(&bad_size.to_le_bytes());
let err = decode_header(&bytes).unwrap_err();
assert!(matches!(err, GcsError::Protocol(_)));
}
#[test]
fn request_vs_response_type_bit() {
let req = RpcMessageType::Create.as_request_type();
let resp = RpcMessageType::Create.as_response_type();
assert_eq!(req & MSG_TYPE_MASK, MSG_TYPE_REQUEST);
assert_eq!(resp & MSG_TYPE_MASK, MSG_TYPE_RESPONSE);
assert_eq!(req & !MSG_TYPE_MASK, resp & !MSG_TYPE_MASK);
assert_eq!(req & CATEGORY_COMPUTE_SYSTEM, CATEGORY_COMPUTE_SYSTEM);
}
#[test]
fn negotiate_protocol_wire_type_pinned() {
assert_eq!(
RpcMessageType::NegotiateProtocol.as_request_type(),
0x1010_0B01
);
assert_eq!(
RpcMessageType::NegotiateProtocol.as_response_type(),
0x2010_0B01
);
}
#[test]
fn modify_service_settings_wire_type_pinned() {
let req = RpcMessageType::ModifyServiceSettings.as_request_type();
let resp = RpcMessageType::ModifyServiceSettings.as_response_type();
assert_eq!(req, 0x1020_0101);
assert_eq!(resp, 0x2020_0101);
assert_eq!(req & CATEGORY_COMPUTE_SERVICE, CATEGORY_COMPUTE_SERVICE);
assert_eq!(req & CATEGORY_COMPUTE_SYSTEM, 0);
assert_eq!(req & 0x1_0000, 0);
}
}