use bytes::{Buf, BufMut, Bytes, BytesMut};
use flate2::read::GzDecoder;
use futures_util::stream;
use std::io::Read;
use tonic::Status;
use super::streaming::MessageStream;
pub const GRPC_MESSAGE_HEADER_LEN: usize = 5;
pub fn parse_unary_grpc_message(
framed_body: &[u8],
max_message_size: usize,
grpc_encoding: Option<&str>,
compression_enabled: bool,
) -> Result<Bytes, Status> {
let messages = parse_all_frames(
BytesMut::from(framed_body),
max_message_size,
grpc_encoding,
compression_enabled,
)?;
match messages.len() {
1 => Ok(messages.into_iter().next().expect("single message exists")),
count => Err(Status::invalid_argument(format!(
"Unary gRPC request must contain exactly one message frame, got {}",
count
))),
}
}
pub fn encode_grpc_message(payload: Bytes) -> Result<Bytes, Status> {
let message_length = u32::try_from(payload.len())
.map_err(|_| Status::resource_exhausted("gRPC message exceeds 4GB frame length limit"))?;
let mut framed = BytesMut::with_capacity(GRPC_MESSAGE_HEADER_LEN + payload.len());
framed.put_u8(0); framed.put_u32(message_length);
framed.extend_from_slice(&payload);
Ok(framed.freeze())
}
pub async fn parse_grpc_client_stream(
body: axum::body::Body,
max_message_size: usize,
grpc_encoding: Option<&str>,
compression_enabled: bool,
) -> Result<MessageStream, Status> {
let body_bytes = axum::body::to_bytes(body, usize::MAX)
.await
.map_err(|e| Status::internal(format!("Failed to read body: {}", e)))?;
let buffer = BytesMut::from(&body_bytes[..]);
let messages = parse_all_frames(buffer, max_message_size, grpc_encoding, compression_enabled)?;
Ok(Box::pin(stream::iter(messages.into_iter().map(Ok))))
}
fn parse_all_frames(
mut buffer: BytesMut,
max_message_size: usize,
grpc_encoding: Option<&str>,
compression_enabled: bool,
) -> Result<Vec<Bytes>, Status> {
let mut messages = Vec::new();
while !buffer.is_empty() {
if buffer.len() < GRPC_MESSAGE_HEADER_LEN {
return Err(Status::internal(
"Incomplete gRPC frame header: expected 5 bytes, got less",
));
}
let compression_flag = buffer[0];
if compression_flag > 1 {
return Err(Status::invalid_argument(format!(
"Invalid gRPC compression flag: {}",
compression_flag
)));
}
let length_bytes = &buffer[1..GRPC_MESSAGE_HEADER_LEN];
let message_length =
u32::from_be_bytes([length_bytes[0], length_bytes[1], length_bytes[2], length_bytes[3]]) as usize;
if message_length > max_message_size {
return Err(Status::resource_exhausted(format!(
"Message size {} exceeds maximum allowed size of {}",
message_length, max_message_size
)));
}
let total_frame_size = GRPC_MESSAGE_HEADER_LEN + message_length;
if buffer.len() < total_frame_size {
return Err(Status::internal(
"Incomplete gRPC message: expected more bytes than available",
));
}
let message_bytes = &buffer[GRPC_MESSAGE_HEADER_LEN..total_frame_size];
let message = if compression_flag == 0 {
Bytes::copy_from_slice(message_bytes)
} else {
decompress_message(message_bytes, grpc_encoding, compression_enabled, max_message_size)?
};
messages.push(message);
buffer.advance(total_frame_size);
}
Ok(messages)
}
fn decompress_message(
message_bytes: &[u8],
grpc_encoding: Option<&str>,
compression_enabled: bool,
max_message_size: usize,
) -> Result<Bytes, Status> {
if !compression_enabled {
return Err(Status::unimplemented(
"gRPC message compression is disabled by server configuration",
));
}
let encoding = grpc_encoding
.map(|value| value.trim().to_ascii_lowercase())
.ok_or_else(|| Status::invalid_argument("Compressed gRPC message missing grpc-encoding header"))?;
let decompressed = match encoding.as_str() {
"gzip" => {
let mut decoder = GzDecoder::new(message_bytes);
let mut out = Vec::new();
decoder
.read_to_end(&mut out)
.map_err(|e| Status::internal(format!("Failed to decompress gzip gRPC frame: {}", e)))?;
out
}
"identity" => {
return Err(Status::invalid_argument(
"Compressed gRPC frame cannot use grpc-encoding=identity",
));
}
other => {
return Err(Status::unimplemented(format!("Unsupported grpc-encoding '{}'", other)));
}
};
if decompressed.len() > max_message_size {
return Err(Status::resource_exhausted(format!(
"Decompressed message size {} exceeds maximum allowed size of {}",
decompressed.len(),
max_message_size
)));
}
Ok(Bytes::from(decompressed))
}
#[cfg(test)]
mod tests {
use super::*;
use flate2::{Compression, write::GzEncoder};
use futures_util::StreamExt;
use std::io::Write;
#[tokio::test]
async fn test_single_frame_parsing() {
let frame = vec![
0x00, 0x00, 0x00, 0x00, 0x05, b'h', b'e', b'l', b'l', b'o', ];
let body = axum::body::Body::from(frame);
let mut stream = parse_grpc_client_stream(body, 1024, None, true).await.unwrap();
let msg = stream.next().await;
assert!(msg.is_some());
assert!(msg.unwrap().is_ok());
let result = stream.next().await;
assert!(result.is_none());
}
#[test]
fn test_encode_grpc_message_adds_framing_header() {
let framed = encode_grpc_message(Bytes::from_static(b"hello")).unwrap();
assert_eq!(framed[0], 0x00);
assert_eq!(&framed[1..5], &[0x00, 0x00, 0x00, 0x05]);
assert_eq!(&framed[5..], b"hello");
}
#[test]
fn test_parse_unary_grpc_message_requires_exactly_one_frame() {
let mut body = Vec::new();
body.extend_from_slice(&[0x00, 0x00, 0x00, 0x00, 0x01, b'a']);
body.extend_from_slice(&[0x00, 0x00, 0x00, 0x00, 0x01, b'b']);
let err = parse_unary_grpc_message(&body, 1024, None, true).unwrap_err();
assert_eq!(err.code(), tonic::Code::InvalidArgument);
}
#[tokio::test]
async fn test_multiple_frames() {
let mut frame = Vec::new();
frame.push(0x00);
frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x05]);
frame.extend_from_slice(b"hello");
frame.push(0x00);
frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x05]);
frame.extend_from_slice(b"world");
let body = axum::body::Body::from(frame);
let mut stream = parse_grpc_client_stream(body, 1024, None, true).await.unwrap();
let msg1 = stream.next().await;
assert!(msg1.is_some());
assert_eq!(msg1.unwrap().unwrap(), b"hello"[..]);
let msg2 = stream.next().await;
assert!(msg2.is_some());
assert_eq!(msg2.unwrap().unwrap(), b"world"[..]);
let msg3 = stream.next().await;
assert!(msg3.is_none());
}
#[tokio::test]
async fn test_empty_body() {
let body = axum::body::Body::from(Vec::<u8>::new());
let mut stream = parse_grpc_client_stream(body, 1024, None, true).await.unwrap();
let result = stream.next().await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_frame_size_at_limit() {
let max_size = 10;
let message = b"0123456789";
let mut frame = Vec::new();
frame.push(0x00);
frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x0a]); frame.extend_from_slice(message);
let body = axum::body::Body::from(frame);
let mut stream = parse_grpc_client_stream(body, max_size, None, true).await.unwrap();
let msg = stream.next().await;
assert!(msg.is_some());
assert_eq!(msg.unwrap().unwrap(), message[..]);
}
#[tokio::test]
async fn test_frame_exceeds_limit() {
let max_size = 5;
let message = b"toolong";
let mut frame = Vec::new();
frame.push(0x00);
frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x07]); frame.extend_from_slice(message);
let body = axum::body::Body::from(frame);
let result = parse_grpc_client_stream(body, max_size, None, true).await;
assert!(result.is_err());
if let Err(status) = result {
assert_eq!(status.code(), tonic::Code::ResourceExhausted);
}
}
#[tokio::test]
async fn test_incomplete_frame_header() {
let frame = vec![0x00, 0x00, 0x00];
let body = axum::body::Body::from(frame);
let result = parse_grpc_client_stream(body, 1024, None, true).await;
assert!(result.is_err());
if let Err(status) = result {
assert_eq!(status.code(), tonic::Code::Internal);
}
}
#[tokio::test]
async fn test_incomplete_frame_body() {
let mut frame = Vec::new();
frame.push(0x00);
frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x0a]); frame.extend_from_slice(b"short");
let body = axum::body::Body::from(frame);
let result = parse_grpc_client_stream(body, 1024, None, true).await;
assert!(result.is_err());
if let Err(status) = result {
assert_eq!(status.code(), tonic::Code::Internal);
}
}
#[tokio::test]
async fn test_compression_flag_set_with_missing_encoding_header() {
let mut frame = Vec::new();
frame.push(0x01); frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x05]);
frame.extend_from_slice(b"hello");
let body = axum::body::Body::from(frame);
let result = parse_grpc_client_stream(body, 1024, None, true).await;
assert!(result.is_err());
if let Err(status) = result {
assert_eq!(status.code(), tonic::Code::InvalidArgument);
}
}
#[tokio::test]
async fn test_compression_flag_set_with_unsupported_encoding() {
let mut frame = Vec::new();
frame.push(0x01); frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x05]);
frame.extend_from_slice(b"hello");
let body = axum::body::Body::from(frame);
let result = parse_grpc_client_stream(body, 1024, Some("br"), true).await;
assert!(result.is_err());
if let Err(status) = result {
assert_eq!(status.code(), tonic::Code::Unimplemented);
assert!(status.message().contains("Unsupported grpc-encoding"));
}
}
#[tokio::test]
async fn test_compression_flag_set_when_compression_disabled() {
let mut frame = Vec::new();
frame.push(0x01); frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x05]);
frame.extend_from_slice(b"hello");
let body = axum::body::Body::from(frame);
let result = parse_grpc_client_stream(body, 1024, Some("gzip"), false).await;
assert!(result.is_err());
if let Err(status) = result {
assert_eq!(status.code(), tonic::Code::Unimplemented);
assert!(status.message().contains("disabled"));
}
}
#[tokio::test]
async fn test_compression_flag_set_with_gzip_encoding_decompresses_message() {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(b"hello").unwrap();
let compressed = encoder.finish().unwrap();
let mut frame = Vec::new();
frame.push(0x01); frame.extend_from_slice(&(compressed.len() as u32).to_be_bytes());
frame.extend_from_slice(&compressed);
let body = axum::body::Body::from(frame);
let mut stream = parse_grpc_client_stream(body, 1024, Some("gzip"), true).await.unwrap();
let msg = stream.next().await.unwrap().unwrap();
assert_eq!(msg, Bytes::from_static(b"hello"));
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_large_message_length() {
let message = b"x".repeat(1000);
let mut frame = Vec::new();
frame.push(0x00);
frame.extend_from_slice(&[0x00, 0x00, 0x03, 0xe8]); frame.extend_from_slice(&message);
let body = axum::body::Body::from(frame);
let mut stream = parse_grpc_client_stream(body, 2000, None, true).await.unwrap();
let msg = stream.next().await;
assert!(msg.is_some());
assert_eq!(msg.unwrap().unwrap().len(), 1000);
}
#[tokio::test]
async fn test_zero_length_message() {
let mut frame = Vec::new();
frame.push(0x00);
frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]);
let body = axum::body::Body::from(frame);
let mut stream = parse_grpc_client_stream(body, 1024, None, true).await.unwrap();
let msg = stream.next().await;
assert!(msg.is_some());
assert_eq!(msg.unwrap().unwrap().len(), 0);
}
#[tokio::test]
async fn test_multiple_frames_with_mixed_sizes() {
let mut frame = Vec::new();
frame.push(0x00);
frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x03]);
frame.extend_from_slice(b"abc");
frame.push(0x00);
frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x07]);
frame.extend_from_slice(b"defghij");
frame.push(0x00);
frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]);
frame.push(0x00);
frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]);
frame.extend_from_slice(b"x");
let body = axum::body::Body::from(frame);
let mut stream = parse_grpc_client_stream(body, 1024, None, true).await.unwrap();
let msg1 = stream.next().await.unwrap().unwrap();
assert_eq!(msg1, b"abc"[..]);
let msg2 = stream.next().await.unwrap().unwrap();
assert_eq!(msg2, b"defghij"[..]);
let msg3 = stream.next().await.unwrap().unwrap();
assert_eq!(msg3.len(), 0);
let msg4 = stream.next().await.unwrap().unwrap();
assert_eq!(msg4, b"x"[..]);
let msg5 = stream.next().await;
assert!(msg5.is_none());
}
#[test]
fn test_big_endian_length_parsing() {
let buffer = BytesMut::from(
&[
0x00, 0x00, 0x00, 0x01, 0x00, ][..],
);
let length_bytes = &buffer[1..5];
let length = u32::from_be_bytes([length_bytes[0], length_bytes[1], length_bytes[2], length_bytes[3]]);
assert_eq!(length, 256);
}
#[test]
fn test_big_endian_max_value() {
let buffer = BytesMut::from(
&[
0x00, 0xff, 0xff, 0xff, 0xff, ][..],
);
let length_bytes = &buffer[1..5];
let length = u32::from_be_bytes([length_bytes[0], length_bytes[1], length_bytes[2], length_bytes[3]]);
assert_eq!(length, u32::MAX);
}
#[tokio::test]
async fn test_error_message_includes_size_info() {
let max_size = 100;
let message = b"x".repeat(150);
let mut frame = Vec::new();
frame.push(0x00);
frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x96]); frame.extend_from_slice(&message);
let body = axum::body::Body::from(frame);
let result = parse_grpc_client_stream(body, max_size, None, true).await;
assert!(result.is_err());
if let Err(status) = result {
assert!(status.message().contains("150"));
assert!(status.message().contains("100"));
}
}
#[tokio::test]
async fn test_stream_collects_all_messages() {
let mut frame = Vec::new();
for i in 0..10 {
frame.push(0x00);
frame.extend_from_slice(&[0x00, 0x00, 0x00, 0x01]);
frame.push(b'0' + i as u8);
}
let body = axum::body::Body::from(frame);
let stream = parse_grpc_client_stream(body, 1024, None, true).await.unwrap();
let messages: Vec<_> = futures_util::StreamExt::collect(stream).await;
assert_eq!(messages.len(), 10);
for (i, msg) in messages.iter().enumerate() {
assert_eq!(msg.as_ref().unwrap()[0], b'0' + i as u8);
}
}
}