use uuid::Uuid;
use crate::error::{PsrpError, Result};
pub const HEADER_LEN: usize = 40;
const BOM: &[u8] = &[0xEF, 0xBB, 0xBF];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum Destination {
Client = 1,
Server = 2,
}
impl Destination {
fn from_u32(v: u32) -> Result<Self> {
match v {
1 => Ok(Self::Client),
2 => Ok(Self::Server),
other => Err(PsrpError::protocol(format!("unknown destination {other}"))),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageType {
SessionCapability,
InitRunspacePool,
PublicKey,
EncryptedSessionKey,
PublicKeyRequest,
SetMaxRunspaces,
SetMinRunspaces,
RunspaceAvailability,
RunspacePoolState,
CreatePipeline,
GetAvailableRunspaces,
UserEvent,
ApplicationPrivateData,
GetCommandMetadata,
RunspacePoolHostCall,
RunspacePoolHostResponse,
PipelineInput,
EndOfPipelineInput,
PipelineOutput,
ErrorRecord,
PipelineState,
DebugRecord,
VerboseRecord,
WarningRecord,
ProgressRecord,
InformationRecord,
PipelineHostCall,
PipelineHostResponse,
ConnectRunspacePool,
RunspacePoolInitData,
ResetRunspaceState,
CloseRunspacePool,
Unknown(u32),
}
impl MessageType {
#[must_use]
pub fn to_u32(self) -> u32 {
match self {
Self::SessionCapability => 0x0001_0002,
Self::InitRunspacePool => 0x0001_0004,
Self::PublicKey => 0x0001_0005,
Self::EncryptedSessionKey => 0x0001_0006,
Self::PublicKeyRequest => 0x0001_0007,
Self::ConnectRunspacePool => 0x0002_100B,
Self::SetMaxRunspaces => 0x0002_1002,
Self::SetMinRunspaces => 0x0002_1003,
Self::RunspaceAvailability => 0x0002_1004,
Self::RunspacePoolState => 0x0002_1005,
Self::CreatePipeline => 0x0002_1006,
Self::GetAvailableRunspaces => 0x0002_1007,
Self::UserEvent => 0x0002_1008,
Self::ApplicationPrivateData => 0x0002_1009,
Self::GetCommandMetadata => 0x0002_100A,
Self::RunspacePoolInitData => 0x0002_100C,
Self::ResetRunspaceState => 0x0002_100D,
Self::RunspacePoolHostCall => 0x0002_1100,
Self::RunspacePoolHostResponse => 0x0002_1101,
Self::PipelineInput => 0x0004_1002,
Self::EndOfPipelineInput => 0x0004_1003,
Self::PipelineOutput => 0x0004_1004,
Self::ErrorRecord => 0x0004_1005,
Self::PipelineState => 0x0004_1006,
Self::DebugRecord => 0x0004_1007,
Self::VerboseRecord => 0x0004_1008,
Self::WarningRecord => 0x0004_1009,
Self::ProgressRecord => 0x0004_1010,
Self::InformationRecord => 0x0004_1011,
Self::PipelineHostCall => 0x0004_1100,
Self::PipelineHostResponse => 0x0004_1101,
Self::CloseRunspacePool => 0x0002_100E,
Self::Unknown(v) => v,
}
}
#[must_use]
pub fn from_u32(v: u32) -> Self {
match v {
0x0001_0002 => Self::SessionCapability,
0x0001_0004 => Self::InitRunspacePool,
0x0001_0005 => Self::PublicKey,
0x0001_0006 => Self::EncryptedSessionKey,
0x0001_0007 => Self::PublicKeyRequest,
0x0002_1002 => Self::SetMaxRunspaces,
0x0002_1003 => Self::SetMinRunspaces,
0x0002_1004 => Self::RunspaceAvailability,
0x0002_1005 => Self::RunspacePoolState,
0x0002_1006 => Self::CreatePipeline,
0x0002_1007 => Self::GetAvailableRunspaces,
0x0002_1008 => Self::UserEvent,
0x0002_1009 => Self::ApplicationPrivateData,
0x0002_100A => Self::GetCommandMetadata,
0x0002_100B => Self::ConnectRunspacePool,
0x0002_100C => Self::RunspacePoolInitData,
0x0002_100D => Self::ResetRunspaceState,
0x0002_100E => Self::CloseRunspacePool,
0x0002_1100 => Self::RunspacePoolHostCall,
0x0002_1101 => Self::RunspacePoolHostResponse,
0x0004_1002 => Self::PipelineInput,
0x0004_1003 => Self::EndOfPipelineInput,
0x0004_1004 => Self::PipelineOutput,
0x0004_1005 => Self::ErrorRecord,
0x0004_1006 => Self::PipelineState,
0x0004_1007 => Self::DebugRecord,
0x0004_1008 => Self::VerboseRecord,
0x0004_1009 => Self::WarningRecord,
0x0004_1010 => Self::ProgressRecord,
0x0004_1011 => Self::InformationRecord,
0x0004_1100 => Self::PipelineHostCall,
0x0004_1101 => Self::PipelineHostResponse,
other => Self::Unknown(other),
}
}
}
#[derive(Debug, Clone)]
pub struct PsrpMessage {
pub destination: Destination,
pub message_type: MessageType,
pub rpid: Uuid,
pub pid: Uuid,
pub data: String,
}
impl PsrpMessage {
pub fn to_server(message_type: MessageType, rpid: Uuid, pid: Uuid, data: String) -> Self {
Self {
destination: Destination::Server,
message_type,
rpid,
pid,
data,
}
}
#[must_use]
pub fn encode(&self) -> Vec<u8> {
let body = self.data.as_bytes();
let mut out = Vec::with_capacity(HEADER_LEN + body.len());
out.extend_from_slice(&(self.destination as u32).to_le_bytes());
out.extend_from_slice(&self.message_type.to_u32().to_le_bytes());
out.extend_from_slice(&self.rpid.to_bytes_le());
out.extend_from_slice(&self.pid.to_bytes_le());
out.extend_from_slice(body);
out
}
pub fn decode(bytes: &[u8]) -> Result<Self> {
if bytes.len() < HEADER_LEN {
return Err(PsrpError::protocol(format!(
"message too short: {} bytes",
bytes.len()
)));
}
let dest = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
let mt = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
let rpid_bytes: [u8; 16] = bytes[8..24].try_into().unwrap();
let pid_bytes: [u8; 16] = bytes[24..40].try_into().unwrap();
let destination = Destination::from_u32(dest)?;
let message_type = MessageType::from_u32(mt);
let rpid = Uuid::from_bytes_le(rpid_bytes);
let pid = Uuid::from_bytes_le(pid_bytes);
let mut body = &bytes[HEADER_LEN..];
if body.starts_with(BOM) {
body = &body[BOM.len()..];
}
let data = String::from_utf8(body.to_vec())
.map_err(|e| PsrpError::protocol(format!("invalid UTF-8 in message body: {e}")))?;
Ok(Self {
destination,
message_type,
rpid,
pid,
data,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn zero_uuid() -> Uuid {
Uuid::nil()
}
#[test]
fn roundtrip_server_message() {
let rpid = Uuid::parse_str("11112222-3333-4444-5555-666677778888").unwrap();
let pid = Uuid::parse_str("aaaabbbb-cccc-dddd-eeee-ffff00001111").unwrap();
let msg =
PsrpMessage::to_server(MessageType::SessionCapability, rpid, pid, "<Obj/>".into());
let bytes = msg.encode();
assert_eq!(bytes.len(), HEADER_LEN + 6);
let decoded = PsrpMessage::decode(&bytes).unwrap();
assert_eq!(decoded.destination, Destination::Server);
assert_eq!(decoded.message_type, MessageType::SessionCapability);
assert_eq!(decoded.rpid, rpid);
assert_eq!(decoded.pid, pid);
assert_eq!(decoded.data, "<Obj/>");
}
#[test]
fn decode_accepts_utf8_bom() {
let msg = PsrpMessage {
destination: Destination::Client,
message_type: MessageType::PipelineOutput,
rpid: zero_uuid(),
pid: zero_uuid(),
data: String::new(),
};
let mut bytes = msg.encode();
bytes.extend_from_slice(b"<X/>");
bytes.splice(HEADER_LEN..HEADER_LEN, BOM.iter().copied());
let decoded = PsrpMessage::decode(&bytes).unwrap();
assert_eq!(decoded.data, "<X/>");
}
#[test]
fn decode_short_header_errors() {
let err = PsrpMessage::decode(&[0u8; 10]).unwrap_err();
assert!(matches!(err, PsrpError::Protocol(_)));
}
#[test]
fn decode_exactly_header_len_ok() {
let bytes = vec![0u8; HEADER_LEN];
let mut bytes = bytes;
bytes[0] = 1;
let msg = PsrpMessage::decode(&bytes).unwrap();
assert_eq!(msg.data, "");
}
#[test]
fn decode_bad_destination_errors() {
let mut bytes = vec![0u8; HEADER_LEN];
bytes[0] = 9; let err = PsrpMessage::decode(&bytes).unwrap_err();
assert!(matches!(err, PsrpError::Protocol(_)));
}
#[test]
fn decode_bad_utf8_errors() {
let mut msg = PsrpMessage {
destination: Destination::Server,
message_type: MessageType::SessionCapability,
rpid: zero_uuid(),
pid: zero_uuid(),
data: String::new(),
}
.encode();
msg.extend_from_slice(&[0xFF, 0xFE, 0xFD]);
let err = PsrpMessage::decode(&msg).unwrap_err();
assert!(matches!(err, PsrpError::Protocol(_)));
}
#[test]
fn message_type_roundtrip_all_known_variants() {
let all = [
MessageType::SessionCapability,
MessageType::InitRunspacePool,
MessageType::PublicKey,
MessageType::EncryptedSessionKey,
MessageType::PublicKeyRequest,
MessageType::SetMaxRunspaces,
MessageType::SetMinRunspaces,
MessageType::RunspaceAvailability,
MessageType::RunspacePoolState,
MessageType::CreatePipeline,
MessageType::GetAvailableRunspaces,
MessageType::UserEvent,
MessageType::ApplicationPrivateData,
MessageType::GetCommandMetadata,
MessageType::ConnectRunspacePool,
MessageType::RunspacePoolInitData,
MessageType::ResetRunspaceState,
MessageType::CloseRunspacePool,
MessageType::RunspacePoolHostCall,
MessageType::RunspacePoolHostResponse,
MessageType::PipelineInput,
MessageType::EndOfPipelineInput,
MessageType::PipelineOutput,
MessageType::ErrorRecord,
MessageType::PipelineState,
MessageType::DebugRecord,
MessageType::VerboseRecord,
MessageType::WarningRecord,
MessageType::ProgressRecord,
MessageType::InformationRecord,
MessageType::PipelineHostCall,
MessageType::PipelineHostResponse,
];
for mt in all {
assert_eq!(MessageType::from_u32(mt.to_u32()), mt, "{mt:?}");
}
}
#[test]
fn unknown_message_type_preserved() {
let mt = MessageType::from_u32(0xDEAD_BEEF);
assert_eq!(mt, MessageType::Unknown(0xDEAD_BEEF));
assert_eq!(mt.to_u32(), 0xDEAD_BEEF);
}
}