use std::collections::HashMap;
use crate::error::{KnafehError, RpcStatusCode};
use crate::rpc::message::Metadata;
pub const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
pub fn encode_request(method: &str, body: &[u8], metadata: &Metadata) -> Vec<u8> {
let method_bytes = method.as_bytes();
let method_len =
u16::try_from(method_bytes.len()).expect("method path exceeds u16 max (65535 bytes)");
let body_len = u32::try_from(body.len()).expect("body exceeds u32 max (4GB)");
let cap = 2 + method_bytes.len() + 4 + body.len() + 2 + metadata_wire_size(metadata);
let mut buf = Vec::with_capacity(cap);
buf.extend_from_slice(&method_len.to_be_bytes());
buf.extend_from_slice(method_bytes);
buf.extend_from_slice(&body_len.to_be_bytes());
buf.extend_from_slice(body);
encode_metadata(&mut buf, metadata);
buf
}
pub fn decode_request(data: &[u8]) -> Result<(String, Vec<u8>, Metadata), KnafehError> {
let mut pos = 0;
let method_len = read_u16(data, &mut pos)? as usize;
let method = read_str(data, &mut pos, method_len)?;
let body_len = read_u32(data, &mut pos)? as usize;
let body = read_bytes(data, &mut pos, body_len)?;
let metadata = decode_metadata(data, &mut pos)?;
if pos != data.len() {
return Err(KnafehError::InvalidMessage(format!(
"request has {} trailing bytes",
data.len() - pos
)));
}
Ok((method, body, metadata))
}
pub fn encode_response(
status_code: RpcStatusCode,
status_message: &str,
body: &[u8],
metadata: &Metadata,
) -> Vec<u8> {
let msg_bytes = status_message.as_bytes();
let msg_len =
u16::try_from(msg_bytes.len()).expect("status message exceeds u16 max (65535 bytes)");
let body_len = u32::try_from(body.len()).expect("body exceeds u32 max (4GB)");
let cap = 1 + 2 + msg_bytes.len() + 4 + body.len() + 2 + metadata_wire_size(metadata);
let mut buf = Vec::with_capacity(cap);
buf.push(status_code as u8);
buf.extend_from_slice(&msg_len.to_be_bytes());
buf.extend_from_slice(msg_bytes);
buf.extend_from_slice(&body_len.to_be_bytes());
buf.extend_from_slice(body);
encode_metadata(&mut buf, metadata);
buf
}
pub fn decode_response(
data: &[u8],
) -> Result<(RpcStatusCode, String, Vec<u8>, Metadata), KnafehError> {
let mut pos = 0;
if data.is_empty() {
return Err(KnafehError::InvalidMessage("empty response".into()));
}
let status_code = RpcStatusCode::from_u8(data[pos]);
pos += 1;
let msg_len = read_u16(data, &mut pos)? as usize;
let status_message = read_str(data, &mut pos, msg_len)?;
let body_len = read_u32(data, &mut pos)? as usize;
let body = read_bytes(data, &mut pos, body_len)?;
let metadata = decode_metadata(data, &mut pos)?;
if pos != data.len() {
return Err(KnafehError::InvalidMessage(format!(
"response has {} trailing bytes",
data.len() - pos
)));
}
Ok((status_code, status_message, body, metadata))
}
fn encode_metadata(buf: &mut Vec<u8>, metadata: &Metadata) {
let count = u16::try_from(metadata.len()).expect("metadata count exceeds u16 max");
buf.extend_from_slice(&count.to_be_bytes());
for (k, v) in metadata {
let kb = k.as_bytes();
let vb = v.as_bytes();
let klen = u16::try_from(kb.len()).expect("metadata key length exceeds u16 max");
let vlen = u16::try_from(vb.len()).expect("metadata value length exceeds u16 max");
buf.extend_from_slice(&klen.to_be_bytes());
buf.extend_from_slice(kb);
buf.extend_from_slice(&vlen.to_be_bytes());
buf.extend_from_slice(vb);
}
}
fn decode_metadata(data: &[u8], pos: &mut usize) -> Result<Metadata, KnafehError> {
let count = read_u16(data, pos)? as usize;
let mut metadata = HashMap::with_capacity(count);
for _ in 0..count {
let klen = read_u16(data, pos)? as usize;
let key = read_str(data, pos, klen)?;
let vlen = read_u16(data, pos)? as usize;
let val = read_str(data, pos, vlen)?;
metadata.insert(key, val);
}
Ok(metadata)
}
fn metadata_wire_size(metadata: &Metadata) -> usize {
metadata
.iter()
.map(|(k, v)| 2 + k.len() + 2 + v.len())
.sum()
}
fn read_u16(data: &[u8], pos: &mut usize) -> Result<u16, KnafehError> {
if *pos + 2 > data.len() {
return Err(KnafehError::InvalidMessage("truncated u16".into()));
}
let val = u16::from_be_bytes([data[*pos], data[*pos + 1]]);
*pos += 2;
Ok(val)
}
fn read_u32(data: &[u8], pos: &mut usize) -> Result<u32, KnafehError> {
if *pos + 4 > data.len() {
return Err(KnafehError::InvalidMessage("truncated u32".into()));
}
let val = u32::from_be_bytes([data[*pos], data[*pos + 1], data[*pos + 2], data[*pos + 3]]);
*pos += 4;
Ok(val)
}
fn read_bytes(data: &[u8], pos: &mut usize, len: usize) -> Result<Vec<u8>, KnafehError> {
if *pos + len > data.len() {
return Err(KnafehError::InvalidMessage("truncated bytes".into()));
}
let val = data[*pos..*pos + len].to_vec();
*pos += len;
Ok(val)
}
fn read_str(data: &[u8], pos: &mut usize, len: usize) -> Result<String, KnafehError> {
let bytes = read_bytes(data, pos, len)?;
String::from_utf8(bytes).map_err(|e| KnafehError::InvalidMessage(format!("invalid UTF-8: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_roundtrip() {
let mut meta = Metadata::new();
meta.insert("trace-id".into(), "abc".into());
let wire = encode_request("echo/echo", b"hello", &meta);
let (method, body, meta2) = decode_request(&wire).unwrap();
assert_eq!(method, "echo/echo");
assert_eq!(body, b"hello");
assert_eq!(meta2.get("trace-id").unwrap(), "abc");
}
#[test]
fn test_response_roundtrip() {
let wire = encode_response(RpcStatusCode::Ok, "", b"world", &Metadata::new());
let (code, msg, body, _meta) = decode_response(&wire).unwrap();
assert_eq!(code, RpcStatusCode::Ok);
assert_eq!(msg, "");
assert_eq!(body, b"world");
}
#[test]
fn test_empty_request() {
let wire = encode_request("svc/m", b"", &Metadata::new());
let (method, body, meta) = decode_request(&wire).unwrap();
assert_eq!(method, "svc/m");
assert!(body.is_empty());
assert!(meta.is_empty());
}
}