use crate::envelope::{Envelope, EnvelopeError};
use base64::Engine;
use base64::engine::general_purpose::STANDARD;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum WireFormat {
#[default]
Json,
Protobuf,
}
pub struct WireCodec {
format: WireFormat,
}
impl WireCodec {
pub fn new(format: WireFormat) -> Self {
Self { format }
}
pub fn json() -> Self {
Self::new(WireFormat::Json)
}
pub fn protobuf() -> Self {
Self::new(WireFormat::Protobuf)
}
pub fn encode(&self, envelope: &Envelope) -> Result<Vec<u8>, EnvelopeError> {
match self.format {
WireFormat::Json => {
serde_json::to_vec(envelope).map_err(|e| EnvelopeError::Encode(e.to_string()))
}
WireFormat::Protobuf => encode_binary(envelope),
}
}
pub fn decode(&self, data: &[u8]) -> Result<Envelope, EnvelopeError> {
match self.format {
WireFormat::Json => {
serde_json::from_slice(data).map_err(|e| EnvelopeError::Decode(e.to_string()))
}
WireFormat::Protobuf => decode_binary(data),
}
}
}
fn encode_binary(envelope: &Envelope) -> Result<Vec<u8>, EnvelopeError> {
if envelope.request_id.is_some() {
return Err(EnvelopeError::Encode(
"binary envelope does not support request_id".to_string(),
));
}
let payload_b64 = envelope
.payload
.as_str()
.ok_or_else(|| EnvelopeError::Encode("binary payload must be base64 string".to_string()))?;
let payload = STANDARD
.decode(payload_b64)
.map_err(|e| EnvelopeError::Encode(e.to_string()))?;
let msg_type_bytes = envelope.msg_type.as_bytes();
if msg_type_bytes.len() > u16::MAX as usize {
return Err(EnvelopeError::Encode("message type too long".to_string()));
}
let mut out = Vec::with_capacity(2 + msg_type_bytes.len() + payload.len());
let len = msg_type_bytes.len() as u16;
out.extend_from_slice(&len.to_be_bytes());
out.extend_from_slice(msg_type_bytes);
out.extend_from_slice(&payload);
Ok(out)
}
fn decode_binary(data: &[u8]) -> Result<Envelope, EnvelopeError> {
if data.len() < 2 {
return Err(EnvelopeError::Decode(
"binary envelope too short".to_string(),
));
}
let len = u16::from_be_bytes([data[0], data[1]]) as usize;
if data.len() < 2 + len {
return Err(EnvelopeError::Decode(
"binary envelope type length invalid".to_string(),
));
}
let msg_type = std::str::from_utf8(&data[2..2 + len])
.map_err(|e| EnvelopeError::Decode(e.to_string()))?
.to_string();
let payload = &data[2 + len..];
let payload_b64 = STANDARD.encode(payload);
Ok(Envelope {
msg_type,
request_id: None,
payload: serde_json::Value::String(payload_b64),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn json_round_trip() {
let envelope = Envelope::event("chat.say", serde_json::json!({"text": "hi"}));
let codec = WireCodec::json();
let data = codec.encode(&envelope).unwrap();
let decoded = codec.decode(&data).unwrap();
assert_eq!(decoded.msg_type, "chat.say");
}
#[test]
fn binary_round_trip() {
let payload = STANDARD.encode(b"payload");
let envelope = Envelope::event("event.type", serde_json::Value::String(payload));
let codec = WireCodec::protobuf();
let data = codec.encode(&envelope).unwrap();
let decoded = codec.decode(&data).unwrap();
assert_eq!(decoded.msg_type, "event.type");
}
}