use bytes::{BufMut, Bytes, BytesMut};
use std::io::Read;
use crate::error::{Error, Result};
const HEADER_LEN: usize = 5;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum GrpcEncoding {
Identity,
Gzip,
}
#[derive(Debug)]
pub struct GrpcFramer {
encoding: GrpcEncoding,
carry: BytesMut,
chunk: Bytes,
}
impl GrpcFramer {
pub fn new(encoding: GrpcEncoding) -> Self {
Self {
encoding,
carry: BytesMut::new(),
chunk: Bytes::new(),
}
}
pub fn encoding(&self) -> GrpcEncoding {
self.encoding
}
pub fn push(&mut self, chunk: Bytes) {
if self.chunk.is_empty() {
self.chunk = chunk;
} else {
self.carry.extend_from_slice(&self.chunk);
self.chunk = Bytes::new();
self.carry.extend_from_slice(&chunk);
}
}
pub fn next_message(&mut self) -> Result<Option<Bytes>> {
if self.carry.is_empty() {
return self.next_from_chunk();
}
self.next_from_carry()
}
fn next_from_chunk(&mut self) -> Result<Option<Bytes>> {
if self.chunk.len() < HEADER_LEN {
if !self.chunk.is_empty() {
self.carry.extend_from_slice(&self.chunk);
self.chunk = Bytes::new();
}
return Ok(None);
}
let flag = self.chunk[0];
let len = u32::from_be_bytes([self.chunk[1], self.chunk[2], self.chunk[3], self.chunk[4]])
as usize;
let total = HEADER_LEN + len;
if self.chunk.len() < total {
self.carry.extend_from_slice(&self.chunk);
self.chunk = Bytes::new();
return Ok(None);
}
let payload = self.chunk.slice(HEADER_LEN..total);
self.chunk = self.chunk.slice(total..);
self.decode_payload(flag, payload).map(Some)
}
fn next_from_carry(&mut self) -> Result<Option<Bytes>> {
if self.carry.len() < HEADER_LEN {
let need = HEADER_LEN - self.carry.len();
self.drain_chunk_into_carry(need);
if self.carry.len() < HEADER_LEN {
return Ok(None);
}
}
let flag = self.carry[0];
let len = u32::from_be_bytes([self.carry[1], self.carry[2], self.carry[3], self.carry[4]])
as usize;
let total = HEADER_LEN + len;
if self.carry.len() < total {
let need = total - self.carry.len();
self.drain_chunk_into_carry(need);
if self.carry.len() < total {
return Ok(None);
}
}
let mut frame = self.carry.split_to(total);
let _header = frame.split_to(HEADER_LEN);
let payload = frame.freeze();
self.decode_payload(flag, payload).map(Some)
}
fn drain_chunk_into_carry(&mut self, need: usize) {
if self.chunk.is_empty() || need == 0 {
return;
}
let take = need.min(self.chunk.len());
let moved = self.chunk.slice(..take);
self.carry.extend_from_slice(&moved);
self.chunk = self.chunk.slice(take..);
}
fn decode_payload(&self, flag: u8, payload: Bytes) -> Result<Bytes> {
match flag {
0 => Ok(payload),
1 => match self.encoding {
GrpcEncoding::Gzip => gunzip(&payload),
GrpcEncoding::Identity => Err(Error::HttpProtocol(
"gRPC message flagged compressed but stream encoding is identity".to_string(),
)),
},
other => Err(Error::HttpProtocol(format!(
"invalid gRPC compression flag: {}",
other
))),
}
}
}
fn gunzip(payload: &[u8]) -> Result<Bytes> {
let mut decoder = flate2::read::GzDecoder::new(payload);
let mut out = Vec::new();
decoder
.read_to_end(&mut out)
.map_err(|e| Error::Decompression(format!("gRPC gzip: {}", e)))?;
Ok(Bytes::from(out))
}
pub fn encode_message(payload: &[u8], compress: bool, encoding: GrpcEncoding) -> Result<Bytes> {
let (flag, body): (u8, std::borrow::Cow<'_, [u8]>) = if compress {
match encoding {
GrpcEncoding::Gzip => (1, std::borrow::Cow::Owned(gzip(payload)?)),
GrpcEncoding::Identity => (0, std::borrow::Cow::Borrowed(payload)),
}
} else {
(0, std::borrow::Cow::Borrowed(payload))
};
let len: u32 = body
.len()
.try_into()
.map_err(|_| Error::HttpProtocol("gRPC message exceeds u32 length".to_string()))?;
let mut buf = BytesMut::with_capacity(HEADER_LEN + body.len());
buf.put_u8(flag);
buf.put_u32(len); buf.put_slice(&body);
Ok(buf.freeze())
}
fn gzip(payload: &[u8]) -> Result<Vec<u8>> {
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder
.write_all(payload)
.map_err(|e| Error::Decompression(format!("gRPC gzip encode: {}", e)))?;
encoder
.finish()
.map_err(|e| Error::Decompression(format!("gRPC gzip finish: {}", e)))
}
#[cfg(test)]
mod tests {
use super::*;
fn frame_identity(payload: &[u8]) -> Vec<u8> {
let mut v = Vec::with_capacity(HEADER_LEN + payload.len());
v.push(0);
v.extend_from_slice(&(payload.len() as u32).to_be_bytes());
v.extend_from_slice(payload);
v
}
fn drain(framer: &mut GrpcFramer) -> Vec<Bytes> {
let mut out = Vec::new();
while let Some(m) = framer.next_message().expect("decode") {
out.push(m);
}
out
}
#[test]
fn multiple_messages_one_chunk() {
let mut wire = Vec::new();
wire.extend_from_slice(&frame_identity(b"hello"));
wire.extend_from_slice(&frame_identity(b"world"));
wire.extend_from_slice(&frame_identity(b"!"));
let mut framer = GrpcFramer::new(GrpcEncoding::Identity);
framer.push(Bytes::from(wire));
let msgs = drain(&mut framer);
assert_eq!(msgs.len(), 3);
assert_eq!(&msgs[0][..], b"hello");
assert_eq!(&msgs[1][..], b"world");
assert_eq!(&msgs[2][..], b"!");
}
#[test]
fn message_split_byte_by_byte() {
let payload = b"the quick brown fox jumps over the lazy dog";
let wire = frame_identity(payload);
let mut framer = GrpcFramer::new(GrpcEncoding::Identity);
let mut got = Vec::new();
for b in &wire {
framer.push(Bytes::copy_from_slice(&[*b]));
while let Some(m) = framer.next_message().expect("decode") {
got.push(m);
}
}
assert_eq!(got.len(), 1);
assert_eq!(&got[0][..], &payload[..]);
}
#[test]
fn message_split_odd_sized_slices() {
let payload = b"abcdefghijklmnopqrstuvwxyz0123456789";
let mut wire = Vec::new();
wire.extend_from_slice(&frame_identity(&payload[..10]));
wire.extend_from_slice(&frame_identity(&payload[10..]));
let mut framer = GrpcFramer::new(GrpcEncoding::Identity);
let mut got = Vec::new();
for window in wire.chunks(7) {
framer.push(Bytes::copy_from_slice(window));
while let Some(m) = framer.next_message().expect("decode") {
got.push(m);
}
}
assert_eq!(got.len(), 2);
assert_eq!(&got[0][..], &payload[..10]);
assert_eq!(&got[1][..], &payload[10..]);
}
#[test]
fn zero_length_payloads() {
let mut wire = Vec::new();
wire.extend_from_slice(&frame_identity(b""));
wire.extend_from_slice(&frame_identity(b"x"));
wire.extend_from_slice(&frame_identity(b""));
let mut framer = GrpcFramer::new(GrpcEncoding::Identity);
framer.push(Bytes::from(wire));
let msgs = drain(&mut framer);
assert_eq!(msgs.len(), 3);
assert_eq!(&msgs[0][..], b"");
assert_eq!(&msgs[1][..], b"x");
assert_eq!(&msgs[2][..], b"");
}
#[test]
fn partial_prefix_completes_next_chunk() {
let payload = b"payload-bytes";
let wire = frame_identity(payload);
let mut framer = GrpcFramer::new(GrpcEncoding::Identity);
framer.push(Bytes::copy_from_slice(&wire[..3]));
assert!(framer.next_message().expect("decode").is_none());
framer.push(Bytes::copy_from_slice(&wire[3..]));
let msgs = drain(&mut framer);
assert_eq!(msgs.len(), 1);
assert_eq!(&msgs[0][..], &payload[..]);
}
#[test]
fn identity_flag_passthrough_under_gzip_stream() {
let payload = b"not actually compressed";
let wire = frame_identity(payload);
let mut framer = GrpcFramer::new(GrpcEncoding::Gzip);
framer.push(Bytes::from(wire));
let msgs = drain(&mut framer);
assert_eq!(msgs.len(), 1);
assert_eq!(&msgs[0][..], &payload[..]);
}
#[test]
fn gzip_flag_round_trip() {
let payload = b"compress me, then decompress me back to exactly this";
let encoded = encode_message(payload, true, GrpcEncoding::Gzip).expect("encode");
assert_eq!(encoded[0], 1);
let mut framer = GrpcFramer::new(GrpcEncoding::Gzip);
framer.push(encoded);
let msgs = drain(&mut framer);
assert_eq!(msgs.len(), 1);
assert_eq!(&msgs[0][..], &payload[..]);
}
#[test]
fn compressed_flag_under_identity_is_error() {
let mut wire = Vec::new();
wire.push(1); wire.extend_from_slice(&(3u32).to_be_bytes());
wire.extend_from_slice(b"abc");
let mut framer = GrpcFramer::new(GrpcEncoding::Identity);
framer.push(Bytes::from(wire));
assert!(framer.next_message().is_err());
}
#[test]
fn bogus_flag_is_error() {
let mut wire = Vec::new();
wire.push(7); wire.extend_from_slice(&(2u32).to_be_bytes());
wire.extend_from_slice(b"hi");
let mut framer = GrpcFramer::new(GrpcEncoding::Gzip);
framer.push(Bytes::from(wire));
assert!(framer.next_message().is_err());
}
#[test]
fn encode_decode_identity_round_trip() {
let payload = b"round trip identity";
let encoded = encode_message(payload, false, GrpcEncoding::Identity).expect("encode");
assert_eq!(encoded[0], 0);
assert_eq!(
u32::from_be_bytes([encoded[1], encoded[2], encoded[3], encoded[4]]) as usize,
payload.len()
);
let mut framer = GrpcFramer::new(GrpcEncoding::Identity);
framer.push(encoded);
let msgs = drain(&mut framer);
assert_eq!(msgs.len(), 1);
assert_eq!(&msgs[0][..], &payload[..]);
}
#[test]
fn encode_compress_under_identity_is_passthrough() {
let payload = b"identity ignores compress flag";
let encoded = encode_message(payload, true, GrpcEncoding::Identity).expect("encode");
assert_eq!(encoded[0], 0);
let mut framer = GrpcFramer::new(GrpcEncoding::Identity);
framer.push(encoded);
let msgs = drain(&mut framer);
assert_eq!(&msgs[0][..], &payload[..]);
}
#[test]
fn second_message_spans_after_contained_first() {
let mut first_chunk = Vec::new();
first_chunk.extend_from_slice(&frame_identity(b"first"));
let second = frame_identity(b"second-message");
first_chunk.extend_from_slice(&second[..8]);
let mut framer = GrpcFramer::new(GrpcEncoding::Identity);
framer.push(Bytes::from(first_chunk));
let mut got = drain(&mut framer);
assert_eq!(got.len(), 1);
assert_eq!(&got[0][..], b"first");
framer.push(Bytes::copy_from_slice(&second[8..]));
got.extend(drain(&mut framer));
assert_eq!(got.len(), 2);
assert_eq!(&got[1][..], b"second-message");
}
}