reqwest_streams/
protobuf_stream.rs

1use crate::protobuf_len_codec::ProtobufLenPrefixCodec;
2
3use crate::StreamBodyResult;
4use async_trait::*;
5use futures::TryStreamExt;
6use tokio_util::io::StreamReader;
7
8/// Extension trait for [`reqwest::Response`] that provides streaming support for the [Protobuf
9/// format].
10///
11/// [Protobuf format]: https://protobuf.dev/programming-guides/encoding/
12#[async_trait]
13pub trait ProtobufStreamResponse {
14    /// Streams the response as batches of Protobuf messages.
15    ///
16    /// The stream will deserialize [`prost::Message`]s as type `T` with a maximum size of
17    /// `max_obj_len` bytes.
18    ///
19    /// # Example
20    ///
21    /// ```rust,no_run
22    /// use futures::{prelude::*, stream::BoxStream as _};
23    /// use reqwest_streams::ProtobufStreamResponse as _;
24    ///
25    /// #[derive(Clone, prost::Message)]
26    /// struct MyTestStructure {
27    ///     #[prost(string, tag = "1")]
28    ///     some_test_field: String,
29    /// }
30    ///
31    /// #[tokio::main]
32    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
33    ///     const MAX_OBJ_LEN: usize = 64 * 1024;
34    ///
35    ///     let stream = reqwest::get("http://localhost:8080/protobuf")
36    ///         .await?
37    ///         .protobuf_stream::<MyTestStructure>(MAX_OBJ_LEN);
38    ///     let _items: Vec<MyTestStructure> = stream.try_collect().await?;
39    ///
40    ///     Ok(())
41    /// }
42    /// ```
43    fn protobuf_stream<'a, 'b, T>(self, max_obj_len: usize) -> impl futures::Stream<Item = StreamBodyResult<T>>  + Send + 'b
44    where
45        T: prost::Message + Default + Send + 'b;
46}
47
48#[async_trait]
49impl ProtobufStreamResponse for reqwest::Response {
50    fn protobuf_stream<'a, 'b, T>(self, max_obj_len: usize) -> impl futures::Stream<Item = StreamBodyResult<T>>  + Send + 'b
51    where
52        T: prost::Message + Default + Send + 'b,
53    {
54        let reader = StreamReader::new(
55            self.bytes_stream()
56                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)),
57        );
58
59        let codec = ProtobufLenPrefixCodec::<T>::new_with_max_length(max_obj_len);
60        let frames_reader = tokio_util::codec::FramedRead::new(reader, codec);
61
62        frames_reader.into_stream()
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69    use crate::test_client::*;
70    use axum::{routing::*, Router};
71    use axum_streams::*;
72    use futures::stream;
73
74    #[derive(Clone, prost::Message, PartialEq, Eq)]
75    struct MyTestStructure {
76        #[prost(string, tag = "1")]
77        some_test_field1: String,
78        #[prost(string, tag = "2")]
79        some_test_field2: String,
80    }
81
82    fn generate_test_structures() -> Vec<MyTestStructure> {
83        vec![
84            MyTestStructure {
85                some_test_field1: "TestValue1".to_string(),
86                some_test_field2: "TestValue2".to_string()
87            };
88            100
89        ]
90    }
91
92    #[tokio::test]
93    async fn deserialize_proto_stream() {
94        let test_stream_vec = generate_test_structures();
95
96        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
97
98        let app = Router::new().route("/", get(|| async { StreamBodyAs::protobuf(test_stream) }));
99
100        let client = TestClient::new(app).await;
101
102        let res = client
103            .get("/")
104            .send()
105            .await
106            .unwrap()
107            .protobuf_stream::<MyTestStructure>(1024);
108        let items: Vec<MyTestStructure> = res.try_collect().await.unwrap();
109
110        assert_eq!(items, test_stream_vec);
111    }
112
113    #[tokio::test]
114    async fn deserialize_proto_stream_check_max_len() {
115        let test_stream_vec = generate_test_structures();
116
117        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
118
119        let app = Router::new().route("/", get(|| async { StreamBodyAs::protobuf(test_stream) }));
120
121        let client = TestClient::new(app).await;
122
123        let res = client
124            .get("/")
125            .send()
126            .await
127            .unwrap()
128            .protobuf_stream::<MyTestStructure>(10);
129        res.try_collect::<Vec<MyTestStructure>>()
130            .await
131            .expect_err("MaxLenReachedError");
132    }
133}