tonic_mock/
mock.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use http_body::Body;
3use prost::Message;
4use std::{
5    collections::VecDeque,
6    marker::PhantomData,
7    pin::Pin,
8    sync::{Arc, Mutex},
9    task::{Context, Poll, Waker},
10};
11use tokio::sync::mpsc::Receiver;
12
13use tonic::{
14    Status,
15    codec::{DecodeBuf, Decoder},
16};
17
18// Internal state for channel-based MockBody
19struct ChannelState<T> {
20    receiver: Receiver<T>,
21    buffer: VecDeque<Bytes>,
22    waker: Option<Waker>,
23    closed: bool,
24}
25
26#[derive(Clone)]
27enum MockBodySource<T> {
28    // Static data from a Vec
29    Static(VecDeque<Bytes>),
30    // Dynamic data from a channel
31    Channel(Arc<Mutex<ChannelState<T>>>),
32}
33
34#[derive(Clone)]
35pub struct MockBody<T = Box<dyn Message>> {
36    source: MockBodySource<T>,
37}
38
39impl<T: Message + Send + 'static> MockBody<T> {
40    pub fn new(data: Vec<impl Message>) -> Self {
41        let mut queue: VecDeque<Bytes> = VecDeque::with_capacity(16);
42        for msg in data {
43            let buf = Self::encode(msg);
44            queue.push_back(buf);
45        }
46
47        MockBody {
48            source: MockBodySource::Static(queue),
49        }
50    }
51
52    /// Create a MockBody from a channel receiver
53    ///
54    /// This allows for dynamic streaming of messages without collecting them all upfront.
55    pub fn from_channel(receiver: Receiver<T>) -> Self {
56        let state = ChannelState {
57            receiver,
58            buffer: VecDeque::new(),
59            waker: None,
60            closed: false,
61        };
62
63        MockBody {
64            source: MockBodySource::Channel(Arc::new(Mutex::new(state))),
65        }
66    }
67
68    pub fn len(&self) -> usize {
69        match &self.source {
70            MockBodySource::Static(queue) => queue.len(),
71            MockBodySource::Channel(state) => {
72                let state = state.lock().unwrap();
73                state.buffer.len()
74            }
75        }
76    }
77
78    pub fn is_empty(&self) -> bool {
79        self.len() == 0
80    }
81
82    // see: https://github.com/hyperium/tonic/blob/1b03ece2a81cb7e8b1922b3c3c1f496bd402d76c/tonic/src/codec/encode.rs#L52
83    fn encode(msg: impl Message) -> Bytes {
84        let mut buf = BytesMut::with_capacity(256);
85
86        buf.reserve(5);
87        unsafe {
88            buf.advance_mut(5);
89        }
90        msg.encode(&mut buf).unwrap();
91        {
92            let len = buf.len() - 5;
93            let mut buf = &mut buf[..5];
94            buf.put_u8(0); // byte must be 0, reserve doesn't auto-zero
95            buf.put_u32(len as u32);
96        }
97        buf.freeze()
98    }
99}
100
101impl<T: Message + Send + 'static> Body for MockBody<T> {
102    type Data = Bytes;
103    type Error = Status;
104
105    fn poll_frame(
106        self: Pin<&mut Self>,
107        cx: &mut Context<'_>,
108    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
109        let this = self.get_mut();
110
111        match &mut this.source {
112            MockBodySource::Static(queue) => {
113                // Return data from the static queue
114                if let Some(data) = queue.pop_front() {
115                    Poll::Ready(Some(Ok(http_body::Frame::data(data))))
116                } else {
117                    Poll::Ready(None)
118                }
119            }
120            MockBodySource::Channel(state_arc) => {
121                let mut state = state_arc.lock().unwrap();
122
123                // If we have buffered data, return it
124                if let Some(data) = state.buffer.pop_front() {
125                    return Poll::Ready(Some(Ok(http_body::Frame::data(data))));
126                }
127
128                // If the channel is closed and we have no more buffered data, we're done
129                if state.closed {
130                    return Poll::Ready(None);
131                }
132
133                // Try to receive a message from the channel
134                match state.receiver.try_recv() {
135                    Ok(msg) => {
136                        // Got a message, encode it and return
137                        let buf = Self::encode(msg);
138                        Poll::Ready(Some(Ok(http_body::Frame::data(buf))))
139                    }
140                    Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {
141                        // Channel is empty but not closed, register waker and return Pending
142                        state.waker = Some(cx.waker().clone());
143                        Poll::Pending
144                    }
145                    Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
146                        // Channel is closed, mark as closed and return None
147                        state.closed = true;
148                        Poll::Ready(None)
149                    }
150                }
151            }
152        }
153    }
154}
155
156/// A [`Decoder`] that knows how to decode `U`.
157#[derive(Debug, Clone, Default)]
158pub struct ProstDecoder<U>(PhantomData<U>);
159
160impl<U> ProstDecoder<U> {
161    pub fn new() -> Self {
162        Self(PhantomData)
163    }
164}
165
166impl<U: Message + Default> Decoder for ProstDecoder<U> {
167    type Item = U;
168    type Error = Status;
169
170    fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
171        let item = Message::decode(buf.chunk())
172            .map(Option::Some)
173            .map_err(|e| Status::internal(e.to_string()))?;
174
175        buf.advance(buf.chunk().len());
176        Ok(item)
177    }
178}