Skip to main content

http_grpc_rs/
codec.rs

1use bytes::{BufMut, BytesMut};
2use http::HeaderMap;
3use prost::Message;
4
5#[cfg(feature = "__compress")]
6use http_body_alt::{Body, Frame, util::Full};
7#[cfg(feature = "__compress")]
8use http_encoding::ContentEncoding;
9
10use super::error::ProtocolError;
11
12/// Default body size limit for gRPC messages (4 MiB).
13pub const DEFAULT_LIMIT: usize = 4 * 1024 * 1024;
14
15#[cfg(feature = "__compress")]
16const GRPC_ENCODING: http::HeaderName = http::HeaderName::from_static("grpc-encoding");
17
18/// gRPC length-prefixed framing codec.
19///
20/// Handles the 5-byte gRPC frame header (1 byte compression flag + 4 byte big-endian length)
21/// and protobuf encode/decode with optional compression.
22pub struct Codec {
23    limit: usize,
24    #[cfg(feature = "__compress")]
25    encoding: ContentEncoding,
26}
27
28impl Codec {
29    pub fn new() -> Self {
30        Self {
31            limit: DEFAULT_LIMIT,
32            #[cfg(feature = "__compress")]
33            encoding: Default::default(),
34        }
35    }
36
37    #[allow(unused_variables)]
38    pub fn from_headers(headers: &HeaderMap) -> Self {
39        Self {
40            limit: DEFAULT_LIMIT,
41            #[cfg(feature = "__compress")]
42            encoding: ContentEncoding::from_headers_with(headers, &GRPC_ENCODING),
43        }
44    }
45
46    /// Set the maximum allowed size in bytes for a single gRPC message frame.
47    /// Set to `0` for unlimited.
48    pub fn set_limit(&mut self, limit: usize) {
49        self.limit = limit;
50    }
51
52    pub const fn limit(&self) -> usize {
53        self.limit
54    }
55
56    /// Set the content encoding for compression/decompression.
57    #[cfg(feature = "__compress")]
58    pub fn set_encoding(mut self, encoding: ContentEncoding) -> Self {
59        self.encoding = encoding;
60        self
61    }
62
63    /// Try to decode a complete gRPC message from `src`.
64    ///
65    /// Consumes the frame bytes from `src` on success.
66    ///
67    /// Returns:
68    /// - `Ok(Some(message))` when a complete frame is available
69    /// - `Ok(None)` when more data is needed
70    /// - `Err` on protocol violations (size limit, decode error)
71    pub fn decode<T: Message + Default>(&self, src: &mut BytesMut) -> Result<Option<T>, ProtocolError> {
72        if src.len() < 5 {
73            return Ok(None);
74        }
75
76        let compressed = src[0] != 0;
77        let len = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize;
78
79        if self.limit > 0 && len > self.limit {
80            return Err(ProtocolError::MessageTooLarge {
81                size: len,
82                limit: self.limit,
83            });
84        }
85
86        if src.len() < 5 + len {
87            return Ok(None);
88        }
89
90        let _ = src.split_to(5);
91        let payload = src.split_to(len);
92
93        let payload = if compressed { self.decompress(payload)? } else { payload };
94
95        let msg = Message::decode(payload).map_err(ProtocolError::Decode)?;
96
97        Ok(Some(msg))
98    }
99
100    /// Encode a protobuf message into gRPC length-prefixed framing.
101    ///
102    /// Writes to `dst`: 1 byte compression flag + 4 byte big-endian length + payload.
103    /// When compression is enabled and the `compress` feature is active, the payload
104    /// is compressed and the flag byte is set to 1.
105    pub fn encode<T: Message>(&self, msg: &T, dst: &mut BytesMut) -> Result<(), ProtocolError> {
106        let encoded_len = msg.encoded_len();
107        dst.reserve(5 + encoded_len);
108        dst.put_u8(0); // compression flag placeholder
109        dst.put_u32(0); // length placeholder
110        msg.encode(dst).map_err(ProtocolError::Encode)?;
111
112        self.compress(dst)?;
113
114        // write actual payload length
115        let len = (dst.len() - 5) as u32;
116        dst[1..5].copy_from_slice(&len.to_be_bytes());
117
118        Ok(())
119    }
120
121    #[cfg(feature = "__compress")]
122    fn decompress(&self, payload: BytesMut) -> Result<BytesMut, ProtocolError> {
123        if matches!(self.encoding, ContentEncoding::Identity) {
124            return Err(ProtocolError::CompressedWithoutEncoding);
125        }
126
127        let body = self.encoding.decode_body(Full::new(payload));
128        let mut body = core::pin::pin!(body);
129        let mut out = BytesMut::new();
130
131        // drive synchronously — Full body yields once and never returns Pending
132        let waker = core::task::Waker::noop();
133        let mut cx = core::task::Context::from_waker(waker);
134        loop {
135            match Body::poll_frame(body.as_mut(), &mut cx) {
136                core::task::Poll::Ready(Some(Ok(Frame::Data(data)))) => {
137                    out.extend_from_slice(data.as_ref());
138                }
139                core::task::Poll::Ready(Some(Err(e))) => {
140                    return Err(ProtocolError::Compress(e.to_string()));
141                }
142                core::task::Poll::Ready(None | Some(Ok(Frame::Trailers(_)))) => break,
143                core::task::Poll::Pending => unreachable!("Full body never returns Pending"),
144            }
145        }
146
147        Ok(out)
148    }
149
150    #[cfg(not(feature = "__compress"))]
151    fn decompress(&self, _: BytesMut) -> Result<BytesMut, ProtocolError> {
152        Err(ProtocolError::CompressUnsupported)
153    }
154
155    #[cfg(feature = "__compress")]
156    fn compress(&self, dst: &mut BytesMut) -> Result<(), ProtocolError> {
157        if matches!(self.encoding, ContentEncoding::Identity) {
158            return Ok(());
159        }
160
161        let payload = dst.split_off(5);
162        let body = self.encoding.encode_body(Full::new(payload));
163        let mut body = core::pin::pin!(body);
164
165        // clear and rewrite header
166        dst.clear();
167        dst.put_u8(1); // compressed flag
168        dst.put_u32(0); // length placeholder
169
170        let waker = core::task::Waker::noop();
171        let mut cx = core::task::Context::from_waker(waker);
172        loop {
173            match Body::poll_frame(body.as_mut(), &mut cx) {
174                core::task::Poll::Ready(Some(Ok(Frame::Data(data)))) => {
175                    dst.extend_from_slice(data.as_ref());
176                }
177                core::task::Poll::Ready(Some(Err(e))) => {
178                    return Err(ProtocolError::Compress(e.to_string()));
179                }
180                core::task::Poll::Ready(None | Some(Ok(Frame::Trailers(_)))) => break,
181                core::task::Poll::Pending => unreachable!("Full body never returns Pending"),
182            }
183        }
184
185        Ok(())
186    }
187
188    #[cfg(not(feature = "__compress"))]
189    fn compress(&self, _: &mut BytesMut) -> Result<(), ProtocolError> {
190        Ok(())
191    }
192}