use std::io::Cursor;
use rmpv::{Integer, Value, decode, encode};
use rpc_runtime_core::{
CapabilityFlags, Envelope, Goodbye, Hello, HelloAck, InstanceId, MessageKind, MethodId,
Notification, NotificationId, Options, Request, RequestId, ResponseError, ResponseOk, Role,
ServiceGuid,
};
use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
use thiserror::Error;
pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CodecLimits {
pub max_message_size: usize,
}
impl Default for CodecLimits {
fn default() -> Self {
Self {
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
}
}
}
#[derive(Debug, Error)]
#[error("{error}")]
pub struct CodecError {
pub error: RuntimeError,
}
impl CodecError {
fn protocol(code: RuntimeErrorCode, message: impl Into<String>) -> Self {
Self {
error: RuntimeError::protocol(code, message),
}
}
pub fn into_runtime_error(self) -> RuntimeError {
self.error
}
}
pub fn encode_envelope(envelope: &Envelope) -> Result<Vec<u8>, CodecError> {
let value = envelope_to_value(envelope);
let mut bytes = Vec::new();
encode::write_value(&mut bytes, &value).map_err(|err| {
CodecError::protocol(
RuntimeErrorCode::PayloadEncodeFailed,
format!("failed to encode MessagePack envelope: {err}"),
)
})?;
Ok(bytes)
}
pub fn decode_envelope(bytes: &[u8], limits: CodecLimits) -> Result<Envelope, CodecError> {
if bytes.len() > limits.max_message_size {
return Err(CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
format!(
"frame size {} exceeds limit {}",
bytes.len(),
limits.max_message_size
),
));
}
let mut cursor = Cursor::new(bytes);
let value = decode::read_value(&mut cursor).map_err(|err| {
CodecError::protocol(
RuntimeErrorCode::PayloadDecodeFailed,
format!("failed to decode MessagePack envelope: {err}"),
)
})?;
if cursor.position() != bytes.len() as u64 {
return Err(CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
"trailing bytes after top-level envelope",
));
}
value_to_envelope(value)
}
pub fn encode_service_guid(guid: ServiceGuid) -> Value {
Value::Binary(guid.get().as_bytes().to_vec())
}
pub fn decode_service_guid(value: &Value) -> Result<ServiceGuid, CodecError> {
match value {
Value::Binary(bytes) if bytes.len() == 16 => {
let uuid = uuid::Uuid::from_slice(bytes).map_err(|err| {
CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
format!("invalid service GUID bytes: {err}"),
)
})?;
Ok(ServiceGuid::new(uuid))
}
Value::Binary(bytes) => Err(CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
format!("service GUID must be exactly 16 bytes, got {}", bytes.len()),
)),
_ => Err(CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
"service GUID must be MessagePack bin(16)",
)),
}
}
fn envelope_to_value(envelope: &Envelope) -> Value {
match envelope {
Envelope::Hello(message) => Value::Array(vec![
u8_value(MessageKind::Hello.as_u8()),
u64_value(message.protocol_version as u64),
u8_value(message.role.as_u8()),
u64_value(message.capability_bits.bits()),
u64_value(message.max_message_size),
options_to_value(&message.options),
]),
Envelope::HelloAck(message) => Value::Array(vec![
u8_value(MessageKind::HelloAck.as_u8()),
u64_value(message.protocol_version as u64),
u64_value(message.accepted_capability_bits.bits()),
u64_value(message.max_message_size),
options_to_value(&message.options),
]),
Envelope::Request(message) => Value::Array(vec![
u8_value(MessageKind::Request.as_u8()),
u64_value(message.request_id.get()),
u64_value(message.instance_id.get()),
u64_value(message.method_id.get() as u64),
message.payload.clone(),
]),
Envelope::ResponseOk(message) => Value::Array(vec![
u8_value(MessageKind::ResponseOk.as_u8()),
u64_value(message.request_id.get()),
message.payload.clone(),
]),
Envelope::ResponseError(message) => Value::Array(vec![
u8_value(MessageKind::ResponseError.as_u8()),
u64_value(message.request_id.get()),
i64_value(message.error_code as i64),
u8_value(message.error_kind),
string_option_to_value(message.error_message.as_deref()),
message.error_details.clone(),
]),
Envelope::Notification(message) => Value::Array(vec![
u8_value(MessageKind::Notification.as_u8()),
u64_value(message.instance_id.map_or(0, InstanceId::get)),
u64_value(message.notification_id.get() as u64),
message.payload.clone(),
]),
Envelope::Goodbye(message) => Value::Array(vec![
u8_value(MessageKind::Goodbye.as_u8()),
u64_value(message.reason_code as u64),
string_option_to_value(message.message.as_deref()),
]),
}
}
fn value_to_envelope(value: Value) -> Result<Envelope, CodecError> {
let fields = match value {
Value::Array(fields) => fields,
_ => {
return Err(CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
"top-level envelope must be a MessagePack array",
));
}
};
let kind = required_u8(
fields.first(),
"message_kind",
RuntimeErrorCode::UnknownMessageKind,
)?;
match kind {
1 => decode_hello(fields),
2 => decode_hello_ack(fields),
3 => decode_request(fields),
4 => decode_response_ok(fields),
5 => decode_response_error(fields),
6 => decode_notification(fields),
7 => Err(CodecError::protocol(
RuntimeErrorCode::RequestCancelUnsupported,
"CANCEL is reserved and unsupported in v1",
)),
8 => decode_goodbye(fields),
other => Err(CodecError::protocol(
RuntimeErrorCode::UnknownMessageKind,
format!("unknown message kind `{other}`"),
)),
}
}
fn decode_hello(fields: Vec<Value>) -> Result<Envelope, CodecError> {
exact_len(&fields, 6)?;
Ok(Envelope::Hello(Hello {
protocol_version: required_u32(fields.get(1), "protocol_version")?,
role: role(required_u8(
fields.get(2),
"role",
RuntimeErrorCode::InvalidEnvelope,
)?)?,
capability_bits: CapabilityFlags::from_bits_retain(required_u64(
fields.get(3),
"capability_bits",
RuntimeErrorCode::InvalidEnvelope,
)?),
max_message_size: required_u64(
fields.get(4),
"max_message_size",
RuntimeErrorCode::InvalidEnvelope,
)?,
options: required_options(fields.get(5), "options")?,
}))
}
fn decode_hello_ack(fields: Vec<Value>) -> Result<Envelope, CodecError> {
exact_len(&fields, 5)?;
Ok(Envelope::HelloAck(HelloAck {
protocol_version: required_u32(fields.get(1), "protocol_version")?,
accepted_capability_bits: CapabilityFlags::from_bits_retain(required_u64(
fields.get(2),
"accepted_capability_bits",
RuntimeErrorCode::InvalidEnvelope,
)?),
max_message_size: required_u64(
fields.get(3),
"max_message_size",
RuntimeErrorCode::InvalidEnvelope,
)?,
options: required_options(fields.get(4), "options")?,
}))
}
fn decode_request(fields: Vec<Value>) -> Result<Envelope, CodecError> {
exact_len(&fields, 5)?;
let instance_id = required_instance_id(fields.get(2))?;
Ok(Envelope::Request(Request {
request_id: RequestId::new(required_u64(
fields.get(1),
"request_id",
RuntimeErrorCode::InvalidRequestId,
)?),
instance_id,
method_id: MethodId::new(required_u32(fields.get(3), "method_id")?),
payload: fields[4].clone(),
}))
}
fn decode_response_ok(fields: Vec<Value>) -> Result<Envelope, CodecError> {
exact_len(&fields, 3)?;
Ok(Envelope::ResponseOk(ResponseOk {
request_id: RequestId::new(required_u64(
fields.get(1),
"request_id",
RuntimeErrorCode::InvalidRequestId,
)?),
payload: fields[2].clone(),
}))
}
fn decode_response_error(fields: Vec<Value>) -> Result<Envelope, CodecError> {
exact_len(&fields, 6)?;
Ok(Envelope::ResponseError(ResponseError {
request_id: RequestId::new(required_u64(
fields.get(1),
"request_id",
RuntimeErrorCode::InvalidRequestId,
)?),
error_code: required_i32(fields.get(2), "error_code")?,
error_kind: required_u8(
fields.get(3),
"error_kind",
RuntimeErrorCode::InvalidEnvelope,
)?,
error_message: optional_string(fields.get(4), "error_message")?,
error_details: fields[5].clone(),
}))
}
fn decode_notification(fields: Vec<Value>) -> Result<Envelope, CodecError> {
exact_len(&fields, 4)?;
let raw_instance_id = required_u64(
fields.get(1),
"instance_id",
RuntimeErrorCode::InvalidInstanceId,
)?;
Ok(Envelope::Notification(Notification {
instance_id: if raw_instance_id == 0 {
None
} else {
Some(InstanceId::new(raw_instance_id).ok_or_else(|| {
CodecError::protocol(
RuntimeErrorCode::InvalidInstanceId,
"instance_id must be non-zero",
)
})?)
},
notification_id: NotificationId::new(required_u32(fields.get(2), "notification_id")?),
payload: fields[3].clone(),
}))
}
fn decode_goodbye(fields: Vec<Value>) -> Result<Envelope, CodecError> {
exact_len(&fields, 3)?;
Ok(Envelope::Goodbye(Goodbye {
reason_code: required_u32(fields.get(1), "reason_code")?,
message: optional_string(fields.get(2), "message")?,
}))
}
fn exact_len(fields: &[Value], expected: usize) -> Result<(), CodecError> {
if fields.len() == expected {
return Ok(());
}
Err(CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
format!(
"invalid envelope field count: expected {expected}, got {}",
fields.len()
),
))
}
fn role(value: u8) -> Result<Role, CodecError> {
match value {
1 => Ok(Role::Client),
2 => Ok(Role::Server),
3 => Ok(Role::Peer),
other => Err(CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
format!("invalid role `{other}`"),
)),
}
}
fn required_instance_id(value: Option<&Value>) -> Result<InstanceId, CodecError> {
let value = required_u64(value, "instance_id", RuntimeErrorCode::InvalidInstanceId)?;
InstanceId::new(value).ok_or_else(|| {
CodecError::protocol(
RuntimeErrorCode::InvalidInstanceId,
"request instance_id must be non-zero",
)
})
}
fn required_u64(
value: Option<&Value>,
field: &str,
code: RuntimeErrorCode,
) -> Result<u64, CodecError> {
value.and_then(Value::as_u64).ok_or_else(|| {
CodecError::protocol(code, format!("field `{field}` must be an unsigned integer"))
})
}
fn required_u32(value: Option<&Value>, field: &str) -> Result<u32, CodecError> {
let value = required_u64(value, field, RuntimeErrorCode::InvalidEnvelope)?;
u32::try_from(value).map_err(|_| {
CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
format!("field `{field}` exceeds u32 range"),
)
})
}
fn required_u8(
value: Option<&Value>,
field: &str,
code: RuntimeErrorCode,
) -> Result<u8, CodecError> {
let value = required_u64(value, field, code)?;
u8::try_from(value)
.map_err(|_| CodecError::protocol(code, format!("field `{field}` exceeds u8 range")))
}
fn required_i32(value: Option<&Value>, field: &str) -> Result<i32, CodecError> {
let value = value.and_then(Value::as_i64).ok_or_else(|| {
CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
format!("field `{field}` must be a signed integer"),
)
})?;
i32::try_from(value).map_err(|_| {
CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
format!("field `{field}` exceeds i32 range"),
)
})
}
fn optional_string(value: Option<&Value>, field: &str) -> Result<Option<String>, CodecError> {
match value {
Some(Value::Nil) => Ok(None),
Some(Value::String(value)) => value
.as_str()
.map(|value| Some(value.to_string()))
.ok_or_else(|| {
CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
format!("field `{field}` must contain valid UTF-8"),
)
}),
_ => Err(CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
format!("field `{field}` must be string or nil"),
)),
}
}
fn required_options(value: Option<&Value>, field: &str) -> Result<Options, CodecError> {
match value {
Some(Value::Nil) => Ok(Vec::new()),
Some(Value::Map(entries)) => entries
.iter()
.map(|(key, value)| {
let key = key.as_str().ok_or_else(|| {
CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
format!("field `{field}` option key must be string"),
)
})?;
if !is_option_scalar(value) {
return Err(CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
format!("field `{field}` option value must be scalar"),
));
}
Ok((key.to_string(), value.clone()))
})
.collect(),
_ => Err(CodecError::protocol(
RuntimeErrorCode::InvalidEnvelope,
format!("field `{field}` must be map or nil"),
)),
}
}
fn is_option_scalar(value: &Value) -> bool {
matches!(
value,
Value::Nil | Value::Boolean(_) | Value::Integer(_) | Value::String(_) | Value::Binary(_)
)
}
fn options_to_value(options: &Options) -> Value {
if options.is_empty() {
return Value::Nil;
}
Value::Map(
options
.iter()
.map(|(key, value)| (Value::from(key.as_str()), value.clone()))
.collect(),
)
}
fn string_option_to_value(value: Option<&str>) -> Value {
value.map_or(Value::Nil, Value::from)
}
fn u8_value(value: u8) -> Value {
u64_value(value as u64)
}
fn u64_value(value: u64) -> Value {
Value::Integer(Integer::from(value))
}
fn i64_value(value: i64) -> Value {
Value::Integer(Integer::from(value))
}
#[cfg(test)]
mod tests {
use super::*;
use rpc_runtime_core::{CapabilityFlags, RUNTIME_PROTOCOL_VERSION};
use uuid::Uuid;
fn roundtrip(envelope: Envelope) {
let bytes = encode_envelope(&envelope).expect("encode envelope");
let decoded = decode_envelope(&bytes, CodecLimits::default()).expect("decode envelope");
assert_eq!(decoded, envelope);
}
#[test]
fn hello_roundtrips() {
roundtrip(Envelope::Hello(Hello {
protocol_version: RUNTIME_PROTOCOL_VERSION,
role: Role::Client,
capability_bits: CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
| CapabilityFlags::GOODBYE,
max_message_size: 4096,
options: vec![("implementation".to_string(), Value::from("test"))],
}));
}
#[test]
fn hello_ack_roundtrips() {
roundtrip(Envelope::HelloAck(HelloAck {
protocol_version: RUNTIME_PROTOCOL_VERSION,
accepted_capability_bits: CapabilityFlags::SERVICE_ACTIVATION,
max_message_size: 8192,
options: Vec::new(),
}));
}
#[test]
fn request_roundtrips() {
roundtrip(Envelope::Request(Request {
request_id: RequestId::new(10),
instance_id: InstanceId::new(22).expect("non-zero instance id"),
method_id: MethodId::new(3),
payload: Value::Array(vec![Value::from("card")]),
}));
}
#[test]
fn response_ok_roundtrips() {
roundtrip(Envelope::ResponseOk(ResponseOk {
request_id: RequestId::new(10),
payload: Value::Nil,
}));
}
#[test]
fn response_error_roundtrips() {
roundtrip(Envelope::ResponseError(ResponseError {
request_id: RequestId::new(10),
error_code: 1007,
error_kind: 3,
error_message: Some("method missing".to_string()),
error_details: Value::Nil,
}));
}
#[test]
fn notification_roundtrips_with_global_instance() {
roundtrip(Envelope::Notification(Notification {
instance_id: None,
notification_id: NotificationId::new(4),
payload: Value::from(true),
}));
}
#[test]
fn goodbye_roundtrips() {
roundtrip(Envelope::Goodbye(Goodbye {
reason_code: 1,
message: Some("shutdown".to_string()),
}));
}
#[test]
fn spec_request_shape_decodes() {
let value = Value::Array(vec![
Value::from(3),
Value::from(42),
Value::from(7),
Value::from(2),
Value::Nil,
]);
let mut bytes = Vec::new();
encode::write_value(&mut bytes, &value).expect("encode raw shape");
let decoded = decode_envelope(&bytes, CodecLimits::default()).expect("decode request");
assert_eq!(
decoded,
Envelope::Request(Request {
request_id: RequestId::new(42),
instance_id: InstanceId::new(7).expect("non-zero instance id"),
method_id: MethodId::new(2),
payload: Value::Nil,
})
);
}
#[test]
fn non_array_fails() {
let bytes = encode_raw(&Value::from("not-an-array"));
let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
assert_eq!(err.error.code, RuntimeErrorCode::InvalidEnvelope);
}
#[test]
fn unknown_kind_fails() {
let bytes = encode_raw(&Value::Array(vec![Value::from(99)]));
let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
assert_eq!(err.error.code, RuntimeErrorCode::UnknownMessageKind);
}
#[test]
fn wrong_field_count_fails() {
let bytes = encode_raw(&Value::Array(vec![
Value::from(4),
Value::from(1),
Value::Nil,
Value::Nil,
]));
let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
assert_eq!(err.error.code, RuntimeErrorCode::InvalidEnvelope);
}
#[test]
fn request_instance_zero_fails() {
let bytes = encode_raw(&Value::Array(vec![
Value::from(3),
Value::from(1),
Value::from(0),
Value::from(2),
Value::Nil,
]));
let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
assert_eq!(err.error.code, RuntimeErrorCode::InvalidInstanceId);
}
#[test]
fn scalar_type_mismatch_fails() {
let bytes = encode_raw(&Value::Array(vec![
Value::from(3),
Value::from("bad-request-id"),
Value::from(1),
Value::from(2),
Value::Nil,
]));
let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
assert_eq!(err.error.code, RuntimeErrorCode::InvalidRequestId);
}
#[test]
fn oversized_frame_fails() {
let bytes = encode_envelope(&Envelope::ResponseOk(ResponseOk {
request_id: RequestId::new(1),
payload: Value::Nil,
}))
.expect("encode");
let err = decode_envelope(
&bytes,
CodecLimits {
max_message_size: bytes.len() - 1,
},
)
.expect_err("must fail");
assert_eq!(err.error.code, RuntimeErrorCode::InvalidEnvelope);
}
#[test]
fn service_guid_uses_raw_16_byte_binary_encoding() {
let uuid = Uuid::parse_str("d7d0c9a0-2eb4-4d2d-a3f9-7a4b2875b7e1").expect("uuid");
let guid = ServiceGuid::new(uuid);
let encoded = encode_service_guid(guid);
assert_eq!(encoded, Value::Binary(uuid.as_bytes().to_vec()));
let decoded = decode_service_guid(&encoded).expect("decode guid");
assert_eq!(decoded, guid);
}
#[test]
fn service_guid_rejects_wrong_binary_length() {
let err = decode_service_guid(&Value::Binary(vec![1, 2, 3])).expect_err("must fail");
assert_eq!(err.error.code, RuntimeErrorCode::InvalidEnvelope);
}
#[test]
fn cancel_fails_as_unsupported() {
let bytes = encode_raw(&Value::Array(vec![Value::from(7)]));
let err = decode_envelope(&bytes, CodecLimits::default()).expect_err("must fail");
assert_eq!(err.error.code, RuntimeErrorCode::RequestCancelUnsupported);
}
#[test]
fn performance_smoke_for_request_codec_path() {
let envelope = Envelope::Request(Request {
request_id: RequestId::new(1),
instance_id: InstanceId::new(1).expect("non-zero instance id"),
method_id: MethodId::new(1),
payload: Value::Nil,
});
for _ in 0..10_000 {
let bytes = encode_envelope(&envelope).expect("encode");
let decoded = decode_envelope(&bytes, CodecLimits::default()).expect("decode");
assert!(matches!(decoded, Envelope::Request(_)));
}
}
fn encode_raw(value: &Value) -> Vec<u8> {
let mut bytes = Vec::new();
encode::write_value(&mut bytes, value).expect("encode raw value");
bytes
}
}