1use bytes::{BufMut, BytesMut};
2use http::HeaderMap;
3use prost::Message;
4
5#[cfg(feature = "__compress")]
6use http_body_alt::{Body, Frame, util::Full};
7#[cfg(feature = "__compress")]
8use http_encoding::ContentEncoding;
9
10use super::error::ProtocolError;
11
12pub const DEFAULT_LIMIT: usize = 4 * 1024 * 1024;
14
15#[cfg(feature = "__compress")]
16const GRPC_ENCODING: http::HeaderName = http::HeaderName::from_static("grpc-encoding");
17
18pub struct Codec {
23 limit: usize,
24 #[cfg(feature = "__compress")]
25 encoding: ContentEncoding,
26}
27
28impl Codec {
29 pub fn new() -> Self {
30 Self {
31 limit: DEFAULT_LIMIT,
32 #[cfg(feature = "__compress")]
33 encoding: Default::default(),
34 }
35 }
36
37 #[allow(unused_variables)]
38 pub fn from_headers(headers: &HeaderMap) -> Self {
39 Self {
40 limit: DEFAULT_LIMIT,
41 #[cfg(feature = "__compress")]
42 encoding: ContentEncoding::from_headers_with(headers, &GRPC_ENCODING),
43 }
44 }
45
46 pub fn set_limit(&mut self, limit: usize) {
49 self.limit = limit;
50 }
51
52 pub const fn limit(&self) -> usize {
53 self.limit
54 }
55
56 #[cfg(feature = "__compress")]
58 pub fn set_encoding(mut self, encoding: ContentEncoding) -> Self {
59 self.encoding = encoding;
60 self
61 }
62
63 pub fn decode<T: Message + Default>(&self, src: &mut BytesMut) -> Result<Option<T>, ProtocolError> {
72 if src.len() < 5 {
73 return Ok(None);
74 }
75
76 let compressed = src[0] != 0;
77 let len = u32::from_be_bytes(src[1..5].try_into().unwrap()) as usize;
78
79 if self.limit > 0 && len > self.limit {
80 return Err(ProtocolError::MessageTooLarge {
81 size: len,
82 limit: self.limit,
83 });
84 }
85
86 if src.len() < 5 + len {
87 return Ok(None);
88 }
89
90 let _ = src.split_to(5);
91 let payload = src.split_to(len);
92
93 let payload = if compressed { self.decompress(payload)? } else { payload };
94
95 let msg = Message::decode(payload).map_err(ProtocolError::Decode)?;
96
97 Ok(Some(msg))
98 }
99
100 pub fn encode<T: Message>(&self, msg: &T, dst: &mut BytesMut) -> Result<(), ProtocolError> {
106 let encoded_len = msg.encoded_len();
107 dst.reserve(5 + encoded_len);
108 dst.put_u8(0); dst.put_u32(0); msg.encode(dst).map_err(ProtocolError::Encode)?;
111
112 self.compress(dst)?;
113
114 let len = (dst.len() - 5) as u32;
116 dst[1..5].copy_from_slice(&len.to_be_bytes());
117
118 Ok(())
119 }
120
121 #[cfg(feature = "__compress")]
122 fn decompress(&self, payload: BytesMut) -> Result<BytesMut, ProtocolError> {
123 if matches!(self.encoding, ContentEncoding::Identity) {
124 return Err(ProtocolError::CompressedWithoutEncoding);
125 }
126
127 let body = self.encoding.decode_body(Full::new(payload));
128 let mut body = core::pin::pin!(body);
129 let mut out = BytesMut::new();
130
131 let waker = core::task::Waker::noop();
133 let mut cx = core::task::Context::from_waker(waker);
134 loop {
135 match Body::poll_frame(body.as_mut(), &mut cx) {
136 core::task::Poll::Ready(Some(Ok(Frame::Data(data)))) => {
137 out.extend_from_slice(data.as_ref());
138 }
139 core::task::Poll::Ready(Some(Err(e))) => {
140 return Err(ProtocolError::Compress(e.to_string()));
141 }
142 core::task::Poll::Ready(None | Some(Ok(Frame::Trailers(_)))) => break,
143 core::task::Poll::Pending => unreachable!("Full body never returns Pending"),
144 }
145 }
146
147 Ok(out)
148 }
149
150 #[cfg(not(feature = "__compress"))]
151 fn decompress(&self, _: BytesMut) -> Result<BytesMut, ProtocolError> {
152 Err(ProtocolError::CompressUnsupported)
153 }
154
155 #[cfg(feature = "__compress")]
156 fn compress(&self, dst: &mut BytesMut) -> Result<(), ProtocolError> {
157 if matches!(self.encoding, ContentEncoding::Identity) {
158 return Ok(());
159 }
160
161 let payload = dst.split_off(5);
162 let body = self.encoding.encode_body(Full::new(payload));
163 let mut body = core::pin::pin!(body);
164
165 dst.clear();
167 dst.put_u8(1); dst.put_u32(0); let waker = core::task::Waker::noop();
171 let mut cx = core::task::Context::from_waker(waker);
172 loop {
173 match Body::poll_frame(body.as_mut(), &mut cx) {
174 core::task::Poll::Ready(Some(Ok(Frame::Data(data)))) => {
175 dst.extend_from_slice(data.as_ref());
176 }
177 core::task::Poll::Ready(Some(Err(e))) => {
178 return Err(ProtocolError::Compress(e.to_string()));
179 }
180 core::task::Poll::Ready(None | Some(Ok(Frame::Trailers(_)))) => break,
181 core::task::Poll::Pending => unreachable!("Full body never returns Pending"),
182 }
183 }
184
185 Ok(())
186 }
187
188 #[cfg(not(feature = "__compress"))]
189 fn compress(&self, _: &mut BytesMut) -> Result<(), ProtocolError> {
190 Ok(())
191 }
192}