use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::error::ProtocolResult;
pub const PROTOCOL_VERSION: u8 = 1;
pub const FLAG_TERMINAL: u8 = 0b0000_0001;
pub const FLAG_SESSION_START: u8 = 0b0000_0010;
pub const FLAG_SHUTDOWN: u8 = 0b0000_0100;
pub const FRAME_HEADER_SIZE: usize = 5;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub v: u8,
pub t: MessageType,
#[serde(skip)]
pub id: u32,
#[serde(skip)]
pub flags: u8,
#[serde(with = "serde_bytes")]
pub p: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum MessageType {
Ready,
Shutdown,
ExecRequest,
ExecStarted,
ExecStdin,
ExecStdout,
ExecStderr,
ExecExited,
ExecResize,
ExecSignal,
FsRequest,
FsResponse,
FsData,
}
impl Message {
pub fn new(t: MessageType, id: u32, p: Vec<u8>) -> Self {
let flags = t.flags();
Self {
v: PROTOCOL_VERSION,
t,
id,
flags,
p,
}
}
pub fn with_payload<T: Serialize>(
t: MessageType,
id: u32,
payload: &T,
) -> ProtocolResult<Self> {
let mut p = Vec::new();
ciborium::into_writer(payload, &mut p)?;
let flags = t.flags();
Ok(Self {
v: PROTOCOL_VERSION,
t,
id,
flags,
p,
})
}
pub fn payload<T: DeserializeOwned>(&self) -> ProtocolResult<T> {
Ok(ciborium::from_reader(&self.p[..])?)
}
}
impl MessageType {
pub fn flags(&self) -> u8 {
match self {
Self::ExecExited | Self::FsResponse => FLAG_TERMINAL,
Self::ExecRequest | Self::FsRequest => FLAG_SESSION_START,
Self::Shutdown => FLAG_SHUTDOWN,
_ => 0,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Ready => "core.ready",
Self::Shutdown => "core.shutdown",
Self::ExecRequest => "core.exec.request",
Self::ExecStarted => "core.exec.started",
Self::ExecStdin => "core.exec.stdin",
Self::ExecStdout => "core.exec.stdout",
Self::ExecStderr => "core.exec.stderr",
Self::ExecExited => "core.exec.exited",
Self::ExecResize => "core.exec.resize",
Self::ExecSignal => "core.exec.signal",
Self::FsRequest => "core.fs.request",
Self::FsResponse => "core.fs.response",
Self::FsData => "core.fs.data",
}
}
pub fn from_wire_str(s: &str) -> Option<Self> {
match s {
"core.ready" => Some(Self::Ready),
"core.shutdown" => Some(Self::Shutdown),
"core.exec.request" => Some(Self::ExecRequest),
"core.exec.started" => Some(Self::ExecStarted),
"core.exec.stdin" => Some(Self::ExecStdin),
"core.exec.stdout" => Some(Self::ExecStdout),
"core.exec.stderr" => Some(Self::ExecStderr),
"core.exec.exited" => Some(Self::ExecExited),
"core.exec.resize" => Some(Self::ExecResize),
"core.exec.signal" => Some(Self::ExecSignal),
"core.fs.request" => Some(Self::FsRequest),
"core.fs.response" => Some(Self::FsResponse),
"core.fs.data" => Some(Self::FsData),
_ => None,
}
}
}
impl Serialize for MessageType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.as_str())
}
}
impl<'de> Deserialize<'de> for MessageType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Self::from_wire_str(&s)
.ok_or_else(|| serde::de::Error::custom(format!("unknown message type: {s}")))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_type_roundtrip() {
let types = [
(MessageType::Ready, "core.ready"),
(MessageType::Shutdown, "core.shutdown"),
(MessageType::ExecRequest, "core.exec.request"),
(MessageType::ExecStarted, "core.exec.started"),
(MessageType::ExecStdin, "core.exec.stdin"),
(MessageType::ExecStdout, "core.exec.stdout"),
(MessageType::ExecStderr, "core.exec.stderr"),
(MessageType::ExecExited, "core.exec.exited"),
(MessageType::ExecResize, "core.exec.resize"),
(MessageType::ExecSignal, "core.exec.signal"),
(MessageType::FsRequest, "core.fs.request"),
(MessageType::FsResponse, "core.fs.response"),
(MessageType::FsData, "core.fs.data"),
];
for (mt, expected_str) in &types {
assert_eq!(mt.as_str(), *expected_str);
assert_eq!(MessageType::from_wire_str(expected_str).unwrap(), *mt);
}
}
#[test]
fn test_message_type_serde_roundtrip() {
let types = [
MessageType::Ready,
MessageType::Shutdown,
MessageType::ExecRequest,
MessageType::ExecStarted,
MessageType::ExecStdin,
MessageType::ExecStdout,
MessageType::ExecStderr,
MessageType::ExecExited,
MessageType::ExecResize,
MessageType::ExecSignal,
MessageType::FsRequest,
MessageType::FsResponse,
MessageType::FsData,
];
for mt in &types {
let mut buf = Vec::new();
ciborium::into_writer(mt, &mut buf).unwrap();
let decoded: MessageType = ciborium::from_reader(&buf[..]).unwrap();
assert_eq!(&decoded, mt);
}
}
#[test]
fn test_unknown_message_type() {
assert!(MessageType::from_wire_str("core.unknown").is_none());
}
#[test]
fn test_message_with_payload_roundtrip() {
use crate::exec::ExecExited;
let msg =
Message::with_payload(MessageType::ExecExited, 7, &ExecExited { code: 42 }).unwrap();
assert_eq!(msg.t, MessageType::ExecExited);
assert_eq!(msg.id, 7);
assert_eq!(msg.flags, FLAG_TERMINAL);
let payload: ExecExited = msg.payload().unwrap();
assert_eq!(payload.code, 42);
}
#[test]
fn test_message_type_flags() {
assert_eq!(MessageType::ExecExited.flags(), FLAG_TERMINAL);
assert_eq!(MessageType::FsResponse.flags(), FLAG_TERMINAL);
assert_eq!(MessageType::ExecRequest.flags(), FLAG_SESSION_START);
assert_eq!(MessageType::FsRequest.flags(), FLAG_SESSION_START);
assert_eq!(MessageType::Ready.flags(), 0);
assert_eq!(MessageType::Shutdown.flags(), FLAG_SHUTDOWN);
assert_eq!(MessageType::ExecStarted.flags(), 0);
assert_eq!(MessageType::ExecStdin.flags(), 0);
assert_eq!(MessageType::ExecStdout.flags(), 0);
assert_eq!(MessageType::ExecStderr.flags(), 0);
assert_eq!(MessageType::ExecResize.flags(), 0);
assert_eq!(MessageType::ExecSignal.flags(), 0);
assert_eq!(MessageType::FsData.flags(), 0);
}
#[test]
fn test_message_new_computes_flags() {
let msg = Message::new(MessageType::ExecRequest, 1, Vec::new());
assert_eq!(msg.flags, FLAG_SESSION_START);
let msg = Message::new(MessageType::ExecStdout, 1, Vec::new());
assert_eq!(msg.flags, 0);
}
}