1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use http_body::Body;
3use prost::Message;
4use std::{
5 collections::VecDeque,
6 marker::PhantomData,
7 pin::Pin,
8 sync::{Arc, Mutex},
9 task::{Context, Poll, Waker},
10};
11use tokio::sync::mpsc::Receiver;
12
13use tonic::{
14 Status,
15 codec::{DecodeBuf, Decoder},
16};
17
18struct ChannelState<T> {
20 receiver: Receiver<T>,
21 buffer: VecDeque<Bytes>,
22 waker: Option<Waker>,
23 closed: bool,
24}
25
26#[derive(Clone)]
27enum MockBodySource<T> {
28 Static(VecDeque<Bytes>),
30 Channel(Arc<Mutex<ChannelState<T>>>),
32}
33
34#[derive(Clone)]
35pub struct MockBody<T = Box<dyn Message>> {
36 source: MockBodySource<T>,
37}
38
39impl<T: Message + Send + 'static> MockBody<T> {
40 pub fn new(data: Vec<impl Message>) -> Self {
41 let mut queue: VecDeque<Bytes> = VecDeque::with_capacity(16);
42 for msg in data {
43 let buf = Self::encode(msg);
44 queue.push_back(buf);
45 }
46
47 MockBody {
48 source: MockBodySource::Static(queue),
49 }
50 }
51
52 pub fn from_channel(receiver: Receiver<T>) -> Self {
56 let state = ChannelState {
57 receiver,
58 buffer: VecDeque::new(),
59 waker: None,
60 closed: false,
61 };
62
63 MockBody {
64 source: MockBodySource::Channel(Arc::new(Mutex::new(state))),
65 }
66 }
67
68 pub fn len(&self) -> usize {
69 match &self.source {
70 MockBodySource::Static(queue) => queue.len(),
71 MockBodySource::Channel(state) => {
72 let state = state.lock().unwrap();
73 state.buffer.len()
74 }
75 }
76 }
77
78 pub fn is_empty(&self) -> bool {
79 self.len() == 0
80 }
81
82 fn encode(msg: impl Message) -> Bytes {
84 let mut buf = BytesMut::with_capacity(256);
85
86 buf.reserve(5);
87 unsafe {
88 buf.advance_mut(5);
89 }
90 msg.encode(&mut buf).unwrap();
91 {
92 let len = buf.len() - 5;
93 let mut buf = &mut buf[..5];
94 buf.put_u8(0); buf.put_u32(len as u32);
96 }
97 buf.freeze()
98 }
99}
100
101impl<T: Message + Send + 'static> Body for MockBody<T> {
102 type Data = Bytes;
103 type Error = Status;
104
105 fn poll_frame(
106 self: Pin<&mut Self>,
107 cx: &mut Context<'_>,
108 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
109 let this = self.get_mut();
110
111 match &mut this.source {
112 MockBodySource::Static(queue) => {
113 if let Some(data) = queue.pop_front() {
115 Poll::Ready(Some(Ok(http_body::Frame::data(data))))
116 } else {
117 Poll::Ready(None)
118 }
119 }
120 MockBodySource::Channel(state_arc) => {
121 let mut state = state_arc.lock().unwrap();
122
123 if let Some(data) = state.buffer.pop_front() {
125 return Poll::Ready(Some(Ok(http_body::Frame::data(data))));
126 }
127
128 if state.closed {
130 return Poll::Ready(None);
131 }
132
133 match state.receiver.try_recv() {
135 Ok(msg) => {
136 let buf = Self::encode(msg);
138 Poll::Ready(Some(Ok(http_body::Frame::data(buf))))
139 }
140 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {
141 state.waker = Some(cx.waker().clone());
143 Poll::Pending
144 }
145 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
146 state.closed = true;
148 Poll::Ready(None)
149 }
150 }
151 }
152 }
153 }
154}
155
156#[derive(Debug, Clone, Default)]
158pub struct ProstDecoder<U>(PhantomData<U>);
159
160impl<U> ProstDecoder<U> {
161 pub fn new() -> Self {
162 Self(PhantomData)
163 }
164}
165
166impl<U: Message + Default> Decoder for ProstDecoder<U> {
167 type Item = U;
168 type Error = Status;
169
170 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
171 let item = Message::decode(buf.chunk())
172 .map(Option::Some)
173 .map_err(|e| Status::internal(e.to_string()))?;
174
175 buf.advance(buf.chunk().len());
176 Ok(item)
177 }
178}