use std::marker::PhantomData;
use std::sync::Arc;
use bytes::BufMut;
use bytes::BytesMut;
use futures::TryStreamExt;
use http_body::Frame;
use http_body_util::BodyExt;
use http_body_util::StreamBody;
use prost::Message;
use tokio::sync::mpsc;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::ReceiverStream;
use tonic::Code;
use tonic::Status;
use tonic::Streaming;
pub struct GrpcStreamDecoder<T> {
_marker: PhantomData<T>,
}
pub fn encoded_len_varint(mut value: u64) -> usize {
let mut len = 1;
while value >= 0x80 {
value >>= 7;
len += 1;
}
len
}
pub fn encode_varint(
mut value: u64,
buf: &mut impl BufMut,
) {
while value >= 0x80 {
buf.put_u8((value as u8) | 0x80);
value >>= 7;
}
buf.put_u8(value as u8);
}
impl<T> Default for GrpcStreamDecoder<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> GrpcStreamDecoder<T> {
pub fn new() -> Self {
GrpcStreamDecoder {
_marker: PhantomData,
}
}
}
impl<T> tonic::codec::Decoder for GrpcStreamDecoder<T>
where
T: prost::Message + Default + 'static,
{
type Item = T;
type Error = Status;
fn decode(
&mut self,
buf: &mut tonic::codec::DecodeBuf<'_>,
) -> std::result::Result<Option<Self::Item>, Self::Error> {
match T::decode(buf) {
Ok(chunk) => Ok(Some(chunk)),
Err(e) => Err(Status::new(Code::Internal, format!("Decode error: {e}"))),
}
}
fn buffer_settings(&self) -> tonic::codec::BufferSettings {
tonic::codec::BufferSettings::new(4 * 1024 * 1024, 4 * 1024 * 1025)
}
}
pub(crate) fn create_production_snapshot_stream<T>(
rx: mpsc::Receiver<Result<Arc<T>, Status>>,
max_message_size: usize,
) -> Streaming<T>
where
T: Message + Default + 'static,
{
let byte_stream = ReceiverStream::new(rx).map(|res| {
match res {
Ok(arc_chunk) => {
let chunk: &T = &arc_chunk;
let mut buf = Vec::new();
chunk.encode(&mut buf).map_err(|e| {
Status::new(Code::Internal, format!("Snapshot encoding failed: {e}"))
})?;
let mut frame = BytesMut::with_capacity(5 + buf.len());
frame.put_u8(0); frame.put_u32(buf.len() as u32); frame.extend_from_slice(&buf);
Ok(frame.freeze())
}
Err(e) => Err(e),
}
});
let body = StreamBody::new(byte_stream.map_ok(Frame::data).map_err(|e: Status| e));
Streaming::new_request(
GrpcStreamDecoder::<T> {
_marker: PhantomData,
},
body.boxed_unsync(),
None,
Some(max_message_size),
)
}