tower_grpc/generic/
codec.rs

1use crate::body::{Body, HttpBody};
2use crate::error::Error;
3use crate::status::infer_grpc_status;
4use crate::Status;
5
6use bytes::{Buf, BufMut, Bytes, BytesMut, IntoBuf};
7use futures::{try_ready, Async, Poll, Stream};
8use http::{HeaderMap, StatusCode};
9use log::{debug, trace, warn};
10use std::collections::VecDeque;
11use std::fmt;
12
13type BytesBuf = <Bytes as IntoBuf>::Buf;
14
15/// Encodes and decodes gRPC message types
16pub trait Codec {
17    /// The encode type
18    type Encode;
19
20    /// Encoder type
21    type Encoder: Encoder<Item = Self::Encode>;
22
23    /// The decode type
24    type Decode;
25
26    /// Decoder type
27    type Decoder: Decoder<Item = Self::Decode>;
28
29    /// Returns a new encoder
30    fn encoder(&mut self) -> Self::Encoder;
31
32    /// Returns a new decoder
33    fn decoder(&mut self) -> Self::Decoder;
34}
35
36/// Encodes gRPC message types
37pub trait Encoder {
38    /// Type that is encoded
39    type Item;
40
41    /// The content-type header for messages using this encoding.
42    ///
43    /// Should be `application/grpc+yourencoding`.
44    const CONTENT_TYPE: &'static str;
45
46    /// Encode a message into the provided buffer.
47    fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Status>;
48}
49
50/// Decodes gRPC message types
51pub trait Decoder {
52    /// Type that is decoded
53    type Item;
54
55    /// Decode a message from the buffer.
56    ///
57    /// The buffer will contain exactly the bytes of a full message. There
58    /// is no need to get the length from the bytes, gRPC framing is handled
59    /// for you.
60    fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Self::Item, Status>;
61}
62
63/// Encodes gRPC message types
64#[must_use = "futures do nothing unless polled"]
65#[derive(Debug)]
66pub struct Encode<T, U> {
67    inner: EncodeInner<T, U>,
68
69    /// Destination buffer
70    buf: BytesMut,
71
72    role: Role,
73}
74
75#[derive(Debug)]
76enum EncodeInner<T, U> {
77    Ok {
78        /// The encoder
79        encoder: T,
80
81        /// The source of messages to encode
82        inner: U,
83    },
84    Empty,
85    Err(Status),
86}
87
88#[derive(Debug)]
89enum Role {
90    Client,
91    Server,
92}
93
94/// An stream of inbound gRPC messages
95#[must_use = "futures do nothing unless polled"]
96pub struct Streaming<T, B: Body> {
97    /// The decoder
98    decoder: T,
99
100    /// The source of encoded messages
101    inner: B,
102
103    /// buffer
104    bufs: BufList<B::Data>,
105
106    /// Decoding state
107    state: State,
108
109    direction: Direction,
110}
111
112/// Whether this is a request or a response stream value.
113#[derive(Clone, Copy, Debug)]
114pub(crate) enum Direction {
115    /// For requests, we expect only headers and the streaming body.
116    Request,
117    /// For responses, the received HTTP status code must be provided.
118    /// We also expect to receive trailers after the streaming body.
119    Response(StatusCode),
120    /// For streaming responses with zero response payloads, the HTTP
121    /// status is provided immediately. In this case no additional
122    /// trailers are expected.
123    EmptyResponse,
124}
125
126#[derive(Debug)]
127enum State {
128    ReadHeader,
129    ReadBody { compression: bool, len: usize },
130    Done,
131}
132
133/// A buffer to encode a message into.
134#[derive(Debug)]
135pub struct EncodeBuf<'a> {
136    bytes: &'a mut BytesMut,
137}
138
139/// A buffer to decode messages from.
140pub struct DecodeBuf<'a> {
141    bufs: &'a mut dyn Buf,
142    len: usize,
143}
144
145#[derive(Debug)]
146pub struct BufList<B> {
147    bufs: VecDeque<B>,
148}
149
150// ===== impl Encode =====
151
152impl<T, U> Encode<T, U>
153where
154    T: Encoder<Item = U::Item>,
155    U: Stream,
156    U::Error: Into<Error>,
157{
158    fn new(encoder: T, inner: U, role: Role) -> Self {
159        Encode {
160            inner: EncodeInner::Ok { encoder, inner },
161            buf: BytesMut::new(),
162            role,
163        }
164    }
165
166    pub(crate) fn request(encoder: T, inner: U) -> Self {
167        Encode::new(encoder, inner, Role::Client)
168    }
169
170    pub(crate) fn response(encoder: T, inner: U) -> Self {
171        Encode::new(encoder, inner, Role::Server)
172    }
173
174    pub(crate) fn empty() -> Self {
175        Encode {
176            inner: EncodeInner::Empty,
177            buf: BytesMut::new(),
178            role: Role::Server,
179        }
180    }
181}
182
183impl<T, U> HttpBody for Encode<T, U>
184where
185    T: Encoder<Item = U::Item>,
186    U: Stream,
187    U::Error: Into<Error>,
188{
189    type Data = BytesBuf;
190    type Error = Status;
191
192    fn is_end_stream(&self) -> bool {
193        if let EncodeInner::Empty = self.inner {
194            true
195        } else {
196            false
197        }
198    }
199
200    fn poll_data(&mut self) -> Poll<Option<Self::Data>, Status> {
201        match self.inner.poll_encode(&mut self.buf) {
202            Ok(ok) => Ok(ok),
203            Err(status) => {
204                match self.role {
205                    // clients don't send statuses as trailers, so just return
206                    // this error directly to allow an HTTP2 rst_stream to be
207                    // sent.
208                    Role::Client => Err(status),
209                    // otherwise, its better to send this status in the
210                    // trailers, instead of a RST_STREAM as the server...
211                    Role::Server => {
212                        self.inner = EncodeInner::Err(status);
213                        Ok(None.into())
214                    }
215                }
216            }
217        }
218    }
219
220    fn poll_trailers(&mut self) -> Poll<Option<HeaderMap>, Status> {
221        if let Role::Client = self.role {
222            return Ok(Async::Ready(None));
223        }
224
225        let map = match self.inner {
226            EncodeInner::Ok { .. } => Status::new(crate::Code::Ok, "").to_header_map(),
227            EncodeInner::Empty => return Ok(None.into()),
228            EncodeInner::Err(ref status) => status.to_header_map(),
229        };
230        Ok(Some(map?).into())
231    }
232}
233
234impl<T, U> EncodeInner<T, U>
235where
236    T: Encoder<Item = U::Item>,
237    U: Stream,
238    U::Error: Into<Error>,
239{
240    fn poll_encode(&mut self, buf: &mut BytesMut) -> Poll<Option<BytesBuf>, Status> {
241        match self {
242            EncodeInner::Ok {
243                ref mut inner,
244                ref mut encoder,
245            } => {
246                let item = try_ready!(inner.poll().map_err(|err| {
247                    let err = err.into();
248                    debug!("encoder inner stream error: {:?}", err);
249                    Status::from_error(&*err)
250                }));
251
252                let item = if let Some(item) = item {
253                    buf.reserve(5);
254                    unsafe {
255                        buf.advance_mut(5);
256                    }
257                    encoder.encode(item, &mut EncodeBuf { bytes: buf })?;
258
259                    // now that we know length, we can write the header
260                    let len = buf.len() - 5;
261                    assert!(len <= ::std::u32::MAX as usize);
262                    {
263                        let mut cursor = ::std::io::Cursor::new(&mut buf[..5]);
264                        cursor.put_u8(0); // byte must be 0, reserve doesn't auto-zero
265                        cursor.put_u32_be(len as u32);
266                    }
267
268                    Some(buf.split_to(len + 5).freeze().into_buf())
269                } else {
270                    None
271                };
272
273                return Ok(Async::Ready(item));
274            }
275            _ => return Ok(Async::Ready(None)),
276        }
277    }
278}
279
280// ===== impl Streaming =====
281
282impl<T, U> Streaming<T, U>
283where
284    T: Decoder,
285    U: Body,
286{
287    pub(crate) fn new(decoder: T, inner: U, direction: Direction) -> Self {
288        Streaming {
289            decoder,
290            inner,
291            bufs: BufList {
292                bufs: VecDeque::new(),
293            },
294            state: State::ReadHeader,
295            direction,
296        }
297    }
298
299    fn decode(&mut self) -> Result<Option<T::Item>, crate::Status> {
300        if let State::ReadHeader = self.state {
301            if self.bufs.remaining() < 5 {
302                return Ok(None);
303            }
304
305            let is_compressed = match self.bufs.get_u8() {
306                0 => false,
307                1 => {
308                    trace!("message compressed, compression not supported yet");
309                    return Err(crate::Status::new(
310                        crate::Code::Unimplemented,
311                        "Message compressed, compression not supported yet.".to_string(),
312                    ));
313                }
314                f => {
315                    trace!("unexpected compression flag");
316                    return Err(crate::Status::new(
317                        crate::Code::Internal,
318                        format!("Unexpected compression flag: {}", f),
319                    ));
320                }
321            };
322            let len = self.bufs.get_u32_be() as usize;
323
324            self.state = State::ReadBody {
325                compression: is_compressed,
326                len,
327            }
328        }
329
330        if let State::ReadBody { len, .. } = self.state {
331            if self.bufs.remaining() < len {
332                return Ok(None);
333            }
334
335            match self.decoder.decode(&mut DecodeBuf {
336                bufs: &mut self.bufs,
337                len,
338            }) {
339                Ok(msg) => {
340                    self.state = State::ReadHeader;
341                    return Ok(Some(msg));
342                }
343                Err(e) => {
344                    return Err(e);
345                }
346            }
347        }
348
349        Ok(None)
350    }
351}
352
353impl<T, U> Stream for Streaming<T, U>
354where
355    T: Decoder,
356    U: Body,
357{
358    type Item = T::Item;
359    type Error = Status;
360
361    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
362        loop {
363            if let State::Done = self.state {
364                break;
365            }
366
367            match self.decode()? {
368                Some(val) => return Ok(Async::Ready(Some(val))),
369                None => (),
370            }
371
372            let chunk = try_ready!(self.inner.poll_data().map_err(|err| {
373                let err = err.into();
374                debug!("decoder inner stream error: {:?}", err);
375                Status::from_error(&*err)
376            }));
377
378            if let Some(data) = chunk {
379                self.bufs.bufs.push_back(data.into_buf());
380            } else {
381                if self.bufs.has_remaining() {
382                    trace!("unexpected EOF decoding stream");
383                    return Err(crate::Status::new(
384                        crate::Code::Internal,
385                        "Unexpected EOF decoding stream.".to_string(),
386                    ));
387                } else {
388                    self.state = State::Done;
389                    break;
390                }
391            }
392        }
393
394        if let Direction::Response(status_code) = self.direction {
395            let trailers = try_ready!(self.inner.poll_trailers().map_err(|err| {
396                let err = err.into();
397                debug!("decoder inner trailers error: {:?}", err);
398                Status::from_error(&*err)
399            }));
400            match infer_grpc_status(trailers, status_code) {
401                Ok(_) => Ok(Async::Ready(None)),
402                Err(err) => Err(err),
403            }
404        } else {
405            Ok(Async::Ready(None))
406        }
407    }
408}
409
410impl<T, B> fmt::Debug for Streaming<T, B>
411where
412    T: fmt::Debug,
413    B: Body + fmt::Debug,
414    B::Data: fmt::Debug,
415{
416    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
417        f.debug_struct("Streaming").finish()
418    }
419}
420
421// ===== impl EncodeBuf =====
422
423impl<'a> EncodeBuf<'a> {
424    #[inline]
425    pub fn reserve(&mut self, capacity: usize) {
426        self.bytes.reserve(capacity);
427    }
428}
429
430impl<'a> BufMut for EncodeBuf<'a> {
431    #[inline]
432    fn remaining_mut(&self) -> usize {
433        self.bytes.remaining_mut()
434    }
435
436    #[inline]
437    unsafe fn advance_mut(&mut self, cnt: usize) {
438        self.bytes.advance_mut(cnt)
439    }
440
441    #[inline]
442    unsafe fn bytes_mut(&mut self) -> &mut [u8] {
443        self.bytes.bytes_mut()
444    }
445}
446
447// ===== impl DecodeBuf =====
448
449impl<'a> Buf for DecodeBuf<'a> {
450    #[inline]
451    fn remaining(&self) -> usize {
452        self.len
453    }
454
455    #[inline]
456    fn bytes(&self) -> &[u8] {
457        let ret = self.bufs.bytes();
458
459        if ret.len() > self.len {
460            &ret[..self.len]
461        } else {
462            ret
463        }
464    }
465
466    #[inline]
467    fn advance(&mut self, cnt: usize) {
468        assert!(cnt <= self.len);
469        self.bufs.advance(cnt);
470        self.len -= cnt;
471    }
472}
473
474impl<'a> fmt::Debug for DecodeBuf<'a> {
475    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476        f.debug_struct("DecodeBuf").finish()
477    }
478}
479
480impl<'a> Drop for DecodeBuf<'a> {
481    fn drop(&mut self) {
482        if self.len > 0 {
483            warn!("DecodeBuf was not advanced to end");
484            self.bufs.advance(self.len);
485        }
486    }
487}
488
489// ===== impl BufList =====
490
491impl<T: Buf> Buf for BufList<T> {
492    #[inline]
493    fn remaining(&self) -> usize {
494        self.bufs.iter().map(|buf| buf.remaining()).sum()
495    }
496
497    #[inline]
498    fn bytes(&self) -> &[u8] {
499        if self.bufs.is_empty() {
500            &[]
501        } else {
502            self.bufs[0].bytes()
503        }
504    }
505
506    #[inline]
507    fn advance(&mut self, mut cnt: usize) {
508        while cnt > 0 {
509            {
510                let front = &mut self.bufs[0];
511                if front.remaining() > cnt {
512                    front.advance(cnt);
513                    return;
514                } else {
515                    cnt -= front.remaining();
516                }
517            }
518            self.bufs.pop_front();
519        }
520    }
521}