1use crate::body::{Body, HttpBody};
2use crate::error::Error;
3use crate::status::infer_grpc_status;
4use crate::Status;
5
6use bytes::{Buf, BufMut, Bytes, BytesMut, IntoBuf};
7use futures::{try_ready, Async, Poll, Stream};
8use http::{HeaderMap, StatusCode};
9use log::{debug, trace, warn};
10use std::collections::VecDeque;
11use std::fmt;
12
13type BytesBuf = <Bytes as IntoBuf>::Buf;
14
15pub trait Codec {
17 type Encode;
19
20 type Encoder: Encoder<Item = Self::Encode>;
22
23 type Decode;
25
26 type Decoder: Decoder<Item = Self::Decode>;
28
29 fn encoder(&mut self) -> Self::Encoder;
31
32 fn decoder(&mut self) -> Self::Decoder;
34}
35
36pub trait Encoder {
38 type Item;
40
41 const CONTENT_TYPE: &'static str;
45
46 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Status>;
48}
49
50pub trait Decoder {
52 type Item;
54
55 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Self::Item, Status>;
61}
62
63#[must_use = "futures do nothing unless polled"]
65#[derive(Debug)]
66pub struct Encode<T, U> {
67 inner: EncodeInner<T, U>,
68
69 buf: BytesMut,
71
72 role: Role,
73}
74
75#[derive(Debug)]
76enum EncodeInner<T, U> {
77 Ok {
78 encoder: T,
80
81 inner: U,
83 },
84 Empty,
85 Err(Status),
86}
87
88#[derive(Debug)]
89enum Role {
90 Client,
91 Server,
92}
93
94#[must_use = "futures do nothing unless polled"]
96pub struct Streaming<T, B: Body> {
97 decoder: T,
99
100 inner: B,
102
103 bufs: BufList<B::Data>,
105
106 state: State,
108
109 direction: Direction,
110}
111
112#[derive(Clone, Copy, Debug)]
114pub(crate) enum Direction {
115 Request,
117 Response(StatusCode),
120 EmptyResponse,
124}
125
126#[derive(Debug)]
127enum State {
128 ReadHeader,
129 ReadBody { compression: bool, len: usize },
130 Done,
131}
132
133#[derive(Debug)]
135pub struct EncodeBuf<'a> {
136 bytes: &'a mut BytesMut,
137}
138
139pub struct DecodeBuf<'a> {
141 bufs: &'a mut dyn Buf,
142 len: usize,
143}
144
145#[derive(Debug)]
146pub struct BufList<B> {
147 bufs: VecDeque<B>,
148}
149
150impl<T, U> Encode<T, U>
153where
154 T: Encoder<Item = U::Item>,
155 U: Stream,
156 U::Error: Into<Error>,
157{
158 fn new(encoder: T, inner: U, role: Role) -> Self {
159 Encode {
160 inner: EncodeInner::Ok { encoder, inner },
161 buf: BytesMut::new(),
162 role,
163 }
164 }
165
166 pub(crate) fn request(encoder: T, inner: U) -> Self {
167 Encode::new(encoder, inner, Role::Client)
168 }
169
170 pub(crate) fn response(encoder: T, inner: U) -> Self {
171 Encode::new(encoder, inner, Role::Server)
172 }
173
174 pub(crate) fn empty() -> Self {
175 Encode {
176 inner: EncodeInner::Empty,
177 buf: BytesMut::new(),
178 role: Role::Server,
179 }
180 }
181}
182
183impl<T, U> HttpBody for Encode<T, U>
184where
185 T: Encoder<Item = U::Item>,
186 U: Stream,
187 U::Error: Into<Error>,
188{
189 type Data = BytesBuf;
190 type Error = Status;
191
192 fn is_end_stream(&self) -> bool {
193 if let EncodeInner::Empty = self.inner {
194 true
195 } else {
196 false
197 }
198 }
199
200 fn poll_data(&mut self) -> Poll<Option<Self::Data>, Status> {
201 match self.inner.poll_encode(&mut self.buf) {
202 Ok(ok) => Ok(ok),
203 Err(status) => {
204 match self.role {
205 Role::Client => Err(status),
209 Role::Server => {
212 self.inner = EncodeInner::Err(status);
213 Ok(None.into())
214 }
215 }
216 }
217 }
218 }
219
220 fn poll_trailers(&mut self) -> Poll<Option<HeaderMap>, Status> {
221 if let Role::Client = self.role {
222 return Ok(Async::Ready(None));
223 }
224
225 let map = match self.inner {
226 EncodeInner::Ok { .. } => Status::new(crate::Code::Ok, "").to_header_map(),
227 EncodeInner::Empty => return Ok(None.into()),
228 EncodeInner::Err(ref status) => status.to_header_map(),
229 };
230 Ok(Some(map?).into())
231 }
232}
233
234impl<T, U> EncodeInner<T, U>
235where
236 T: Encoder<Item = U::Item>,
237 U: Stream,
238 U::Error: Into<Error>,
239{
240 fn poll_encode(&mut self, buf: &mut BytesMut) -> Poll<Option<BytesBuf>, Status> {
241 match self {
242 EncodeInner::Ok {
243 ref mut inner,
244 ref mut encoder,
245 } => {
246 let item = try_ready!(inner.poll().map_err(|err| {
247 let err = err.into();
248 debug!("encoder inner stream error: {:?}", err);
249 Status::from_error(&*err)
250 }));
251
252 let item = if let Some(item) = item {
253 buf.reserve(5);
254 unsafe {
255 buf.advance_mut(5);
256 }
257 encoder.encode(item, &mut EncodeBuf { bytes: buf })?;
258
259 let len = buf.len() - 5;
261 assert!(len <= ::std::u32::MAX as usize);
262 {
263 let mut cursor = ::std::io::Cursor::new(&mut buf[..5]);
264 cursor.put_u8(0); cursor.put_u32_be(len as u32);
266 }
267
268 Some(buf.split_to(len + 5).freeze().into_buf())
269 } else {
270 None
271 };
272
273 return Ok(Async::Ready(item));
274 }
275 _ => return Ok(Async::Ready(None)),
276 }
277 }
278}
279
280impl<T, U> Streaming<T, U>
283where
284 T: Decoder,
285 U: Body,
286{
287 pub(crate) fn new(decoder: T, inner: U, direction: Direction) -> Self {
288 Streaming {
289 decoder,
290 inner,
291 bufs: BufList {
292 bufs: VecDeque::new(),
293 },
294 state: State::ReadHeader,
295 direction,
296 }
297 }
298
299 fn decode(&mut self) -> Result<Option<T::Item>, crate::Status> {
300 if let State::ReadHeader = self.state {
301 if self.bufs.remaining() < 5 {
302 return Ok(None);
303 }
304
305 let is_compressed = match self.bufs.get_u8() {
306 0 => false,
307 1 => {
308 trace!("message compressed, compression not supported yet");
309 return Err(crate::Status::new(
310 crate::Code::Unimplemented,
311 "Message compressed, compression not supported yet.".to_string(),
312 ));
313 }
314 f => {
315 trace!("unexpected compression flag");
316 return Err(crate::Status::new(
317 crate::Code::Internal,
318 format!("Unexpected compression flag: {}", f),
319 ));
320 }
321 };
322 let len = self.bufs.get_u32_be() as usize;
323
324 self.state = State::ReadBody {
325 compression: is_compressed,
326 len,
327 }
328 }
329
330 if let State::ReadBody { len, .. } = self.state {
331 if self.bufs.remaining() < len {
332 return Ok(None);
333 }
334
335 match self.decoder.decode(&mut DecodeBuf {
336 bufs: &mut self.bufs,
337 len,
338 }) {
339 Ok(msg) => {
340 self.state = State::ReadHeader;
341 return Ok(Some(msg));
342 }
343 Err(e) => {
344 return Err(e);
345 }
346 }
347 }
348
349 Ok(None)
350 }
351}
352
353impl<T, U> Stream for Streaming<T, U>
354where
355 T: Decoder,
356 U: Body,
357{
358 type Item = T::Item;
359 type Error = Status;
360
361 fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
362 loop {
363 if let State::Done = self.state {
364 break;
365 }
366
367 match self.decode()? {
368 Some(val) => return Ok(Async::Ready(Some(val))),
369 None => (),
370 }
371
372 let chunk = try_ready!(self.inner.poll_data().map_err(|err| {
373 let err = err.into();
374 debug!("decoder inner stream error: {:?}", err);
375 Status::from_error(&*err)
376 }));
377
378 if let Some(data) = chunk {
379 self.bufs.bufs.push_back(data.into_buf());
380 } else {
381 if self.bufs.has_remaining() {
382 trace!("unexpected EOF decoding stream");
383 return Err(crate::Status::new(
384 crate::Code::Internal,
385 "Unexpected EOF decoding stream.".to_string(),
386 ));
387 } else {
388 self.state = State::Done;
389 break;
390 }
391 }
392 }
393
394 if let Direction::Response(status_code) = self.direction {
395 let trailers = try_ready!(self.inner.poll_trailers().map_err(|err| {
396 let err = err.into();
397 debug!("decoder inner trailers error: {:?}", err);
398 Status::from_error(&*err)
399 }));
400 match infer_grpc_status(trailers, status_code) {
401 Ok(_) => Ok(Async::Ready(None)),
402 Err(err) => Err(err),
403 }
404 } else {
405 Ok(Async::Ready(None))
406 }
407 }
408}
409
410impl<T, B> fmt::Debug for Streaming<T, B>
411where
412 T: fmt::Debug,
413 B: Body + fmt::Debug,
414 B::Data: fmt::Debug,
415{
416 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
417 f.debug_struct("Streaming").finish()
418 }
419}
420
421impl<'a> EncodeBuf<'a> {
424 #[inline]
425 pub fn reserve(&mut self, capacity: usize) {
426 self.bytes.reserve(capacity);
427 }
428}
429
430impl<'a> BufMut for EncodeBuf<'a> {
431 #[inline]
432 fn remaining_mut(&self) -> usize {
433 self.bytes.remaining_mut()
434 }
435
436 #[inline]
437 unsafe fn advance_mut(&mut self, cnt: usize) {
438 self.bytes.advance_mut(cnt)
439 }
440
441 #[inline]
442 unsafe fn bytes_mut(&mut self) -> &mut [u8] {
443 self.bytes.bytes_mut()
444 }
445}
446
447impl<'a> Buf for DecodeBuf<'a> {
450 #[inline]
451 fn remaining(&self) -> usize {
452 self.len
453 }
454
455 #[inline]
456 fn bytes(&self) -> &[u8] {
457 let ret = self.bufs.bytes();
458
459 if ret.len() > self.len {
460 &ret[..self.len]
461 } else {
462 ret
463 }
464 }
465
466 #[inline]
467 fn advance(&mut self, cnt: usize) {
468 assert!(cnt <= self.len);
469 self.bufs.advance(cnt);
470 self.len -= cnt;
471 }
472}
473
474impl<'a> fmt::Debug for DecodeBuf<'a> {
475 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476 f.debug_struct("DecodeBuf").finish()
477 }
478}
479
480impl<'a> Drop for DecodeBuf<'a> {
481 fn drop(&mut self) {
482 if self.len > 0 {
483 warn!("DecodeBuf was not advanced to end");
484 self.bufs.advance(self.len);
485 }
486 }
487}
488
489impl<T: Buf> Buf for BufList<T> {
492 #[inline]
493 fn remaining(&self) -> usize {
494 self.bufs.iter().map(|buf| buf.remaining()).sum()
495 }
496
497 #[inline]
498 fn bytes(&self) -> &[u8] {
499 if self.bufs.is_empty() {
500 &[]
501 } else {
502 self.bufs[0].bytes()
503 }
504 }
505
506 #[inline]
507 fn advance(&mut self, mut cnt: usize) {
508 while cnt > 0 {
509 {
510 let front = &mut self.bufs[0];
511 if front.remaining() > cnt {
512 front.advance(cnt);
513 return;
514 } else {
515 cnt -= front.remaining();
516 }
517 }
518 self.bufs.pop_front();
519 }
520 }
521}