axum_streams/
protobuf_format.rs1use crate::stream_body_as::StreamBodyAsOptions;
2use crate::stream_format::StreamingFormat;
3use crate::StreamBodyAs;
4use futures::stream::BoxStream;
5use futures::Stream;
6use futures::StreamExt;
7use http::HeaderMap;
8
9pub struct ProtobufStreamFormat;
10
11impl ProtobufStreamFormat {
12 pub fn new() -> Self {
13 Self {}
14 }
15}
16
17impl<T> StreamingFormat<T> for ProtobufStreamFormat
18where
19 T: prost::Message + Send + Sync + 'static,
20{
21 fn to_bytes_stream<'a, 'b>(
22 &'a self,
23 stream: BoxStream<'b, Result<T, axum::Error>>,
24 _: &'a StreamBodyAsOptions,
25 ) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
26 fn write_protobuf_record<T>(obj: T) -> Result<Vec<u8>, axum::Error>
27 where
28 T: prost::Message,
29 {
30 let obj_vec = obj.encode_to_vec();
31 let mut frame_vec = Vec::new();
32 let obj_len = (obj_vec.len() as u64);
33 prost::encoding::encode_varint(obj_len, &mut frame_vec);
34 frame_vec.extend(obj_vec);
35
36 Ok(frame_vec)
37 }
38
39 Box::pin({
40 stream.map(move |obj_res| match obj_res {
41 Err(e) => Err(e),
42 Ok(obj) => {
43 let write_protobuf_res = write_protobuf_record(obj);
44 write_protobuf_res.map(axum::body::Bytes::from)
45 }
46 })
47 })
48 }
49
50 fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option<HeaderMap> {
51 let mut header_map = HeaderMap::new();
52 header_map.insert(
53 http::header::CONTENT_TYPE,
54 options.content_type.clone().unwrap_or_else(|| {
55 http::header::HeaderValue::from_static("application/x-protobuf-stream")
56 }),
57 );
58 Some(header_map)
59 }
60}
61
62impl<'a> StreamBodyAs<'a> {
63 pub fn protobuf<S, T>(stream: S) -> Self
64 where
65 T: prost::Message + Send + Sync + 'static,
66 S: Stream<Item = T> + 'a + Send,
67 {
68 Self::new(
69 ProtobufStreamFormat::new(),
70 stream.map(Ok::<T, axum::Error>),
71 )
72 }
73
74 pub fn protobuf_with_errors<S, T, E>(stream: S) -> Self
75 where
76 T: prost::Message + Send + Sync + 'static,
77 S: Stream<Item = Result<T, E>> + 'a + Send,
78 E: Into<axum::Error>,
79 {
80 Self::new(ProtobufStreamFormat::new(), stream)
81 }
82}
83
84impl StreamBodyAsOptions {
85 pub fn protobuf<'a, S, T>(self, stream: S) -> StreamBodyAs<'a>
86 where
87 T: prost::Message + Send + Sync + 'static,
88 S: Stream<Item = T> + 'a + Send,
89 {
90 StreamBodyAs::with_options(
91 ProtobufStreamFormat::new(),
92 stream.map(Ok::<T, axum::Error>),
93 self,
94 )
95 }
96
97 pub fn protobuf_with_errors<'a, S, T, E>(self, stream: S) -> StreamBodyAs<'a>
98 where
99 T: prost::Message + Send + Sync + 'static,
100 S: Stream<Item = Result<T, E>> + 'a + Send,
101 E: Into<axum::Error>,
102 {
103 StreamBodyAs::with_options(ProtobufStreamFormat::new(), stream, self)
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use crate::test_client::*;
111 use crate::StreamBodyAs;
112 use axum::{routing::*, Router};
113 use futures::stream;
114 use prost::Message;
115
116 #[tokio::test]
117 async fn serialize_protobuf_stream_format() {
118 #[derive(Clone, prost::Message)]
119 struct TestOutputStructure {
120 #[prost(string, tag = "1")]
121 foo1: String,
122 #[prost(string, tag = "2")]
123 foo2: String,
124 }
125
126 let test_stream_vec = vec![
127 TestOutputStructure {
128 foo1: "bar1".to_string(),
129 foo2: "bar2".to_string()
130 };
131 7
132 ];
133
134 let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
135
136 let app = Router::new().route(
137 "/",
138 get(|| async {
139 StreamBodyAs::new(
140 ProtobufStreamFormat::new(),
141 test_stream.map(Ok::<_, axum::Error>),
142 )
143 }),
144 );
145
146 let client = TestClient::new(app).await;
147
148 let expected_proto_buf: Vec<u8> = test_stream_vec
149 .iter()
150 .flat_map(|obj| {
151 let obj_vec = obj.encode_to_vec();
152 let mut frame_vec = Vec::new();
153 let obj_len = (obj_vec.len() as u64);
154 prost::encoding::encode_varint(obj_len, &mut frame_vec);
155 frame_vec.extend(obj_vec);
156 frame_vec
157 })
158 .collect();
159
160 let res = client.get("/").send().await.unwrap();
161 assert_eq!(
162 res.headers()
163 .get("content-type")
164 .and_then(|h| h.to_str().ok()),
165 Some("application/x-protobuf-stream")
166 );
167 let body = res.bytes().await.unwrap().to_vec();
168
169 assert_eq!(body, expected_proto_buf);
170 }
171}