reqwest_streams/
protobuf_stream.rs1use crate::protobuf_len_codec::ProtobufLenPrefixCodec;
2
3use crate::StreamBodyResult;
4use async_trait::*;
5use futures::TryStreamExt;
6use tokio_util::io::StreamReader;
7
8#[async_trait]
13pub trait ProtobufStreamResponse {
14 fn protobuf_stream<'a, 'b, T>(self, max_obj_len: usize) -> impl futures::Stream<Item = StreamBodyResult<T>> + Send + 'b
44 where
45 T: prost::Message + Default + Send + 'b;
46}
47
48#[async_trait]
49impl ProtobufStreamResponse for reqwest::Response {
50 fn protobuf_stream<'a, 'b, T>(self, max_obj_len: usize) -> impl futures::Stream<Item = StreamBodyResult<T>> + Send + 'b
51 where
52 T: prost::Message + Default + Send + 'b,
53 {
54 let reader = StreamReader::new(
55 self.bytes_stream()
56 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)),
57 );
58
59 let codec = ProtobufLenPrefixCodec::<T>::new_with_max_length(max_obj_len);
60 let frames_reader = tokio_util::codec::FramedRead::new(reader, codec);
61
62 frames_reader.into_stream()
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69 use crate::test_client::*;
70 use axum::{routing::*, Router};
71 use axum_streams::*;
72 use futures::stream;
73
74 #[derive(Clone, prost::Message, PartialEq, Eq)]
75 struct MyTestStructure {
76 #[prost(string, tag = "1")]
77 some_test_field1: String,
78 #[prost(string, tag = "2")]
79 some_test_field2: String,
80 }
81
82 fn generate_test_structures() -> Vec<MyTestStructure> {
83 vec![
84 MyTestStructure {
85 some_test_field1: "TestValue1".to_string(),
86 some_test_field2: "TestValue2".to_string()
87 };
88 100
89 ]
90 }
91
92 #[tokio::test]
93 async fn deserialize_proto_stream() {
94 let test_stream_vec = generate_test_structures();
95
96 let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
97
98 let app = Router::new().route("/", get(|| async { StreamBodyAs::protobuf(test_stream) }));
99
100 let client = TestClient::new(app).await;
101
102 let res = client
103 .get("/")
104 .send()
105 .await
106 .unwrap()
107 .protobuf_stream::<MyTestStructure>(1024);
108 let items: Vec<MyTestStructure> = res.try_collect().await.unwrap();
109
110 assert_eq!(items, test_stream_vec);
111 }
112
113 #[tokio::test]
114 async fn deserialize_proto_stream_check_max_len() {
115 let test_stream_vec = generate_test_structures();
116
117 let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
118
119 let app = Router::new().route("/", get(|| async { StreamBodyAs::protobuf(test_stream) }));
120
121 let client = TestClient::new(app).await;
122
123 let res = client
124 .get("/")
125 .send()
126 .await
127 .unwrap()
128 .protobuf_stream::<MyTestStructure>(10);
129 res.try_collect::<Vec<MyTestStructure>>()
130 .await
131 .expect_err("MaxLenReachedError");
132 }
133}