Skip to main content

axum_streams/
protobuf_format.rs

1use crate::stream_body_as::StreamBodyAsOptions;
2use crate::stream_format::StreamingFormat;
3use crate::StreamBodyAs;
4use futures::stream::BoxStream;
5use futures::Stream;
6use futures::StreamExt;
7use http::HeaderMap;
8
9pub struct ProtobufStreamFormat;
10
11impl ProtobufStreamFormat {
12    pub fn new() -> Self {
13        Self {}
14    }
15}
16
17impl<T> StreamingFormat<T> for ProtobufStreamFormat
18where
19    T: prost::Message + Send + Sync + 'static,
20{
21    fn to_bytes_stream<'a, 'b>(
22        &'a self,
23        stream: BoxStream<'b, Result<T, axum::Error>>,
24        _: &'a StreamBodyAsOptions,
25    ) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
26        fn write_protobuf_record<T>(obj: T) -> Result<Vec<u8>, axum::Error>
27        where
28            T: prost::Message,
29        {
30            let obj_vec = obj.encode_to_vec();
31            let mut frame_vec = Vec::new();
32            let obj_len = (obj_vec.len() as u64);
33            prost::encoding::encode_varint(obj_len, &mut frame_vec);
34            frame_vec.extend(obj_vec);
35
36            Ok(frame_vec)
37        }
38
39        Box::pin({
40            stream.map(move |obj_res| match obj_res {
41                Err(e) => Err(e),
42                Ok(obj) => {
43                    let write_protobuf_res = write_protobuf_record(obj);
44                    write_protobuf_res.map(axum::body::Bytes::from)
45                }
46            })
47        })
48    }
49
50    fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option<HeaderMap> {
51        let mut header_map = HeaderMap::new();
52        header_map.insert(
53            http::header::CONTENT_TYPE,
54            options.content_type.clone().unwrap_or_else(|| {
55                http::header::HeaderValue::from_static("application/x-protobuf-stream")
56            }),
57        );
58        Some(header_map)
59    }
60}
61
62impl<'a> StreamBodyAs<'a> {
63    pub fn protobuf<S, T>(stream: S) -> Self
64    where
65        T: prost::Message + Send + Sync + 'static,
66        S: Stream<Item = T> + 'a + Send,
67    {
68        Self::new(
69            ProtobufStreamFormat::new(),
70            stream.map(Ok::<T, axum::Error>),
71        )
72    }
73
74    pub fn protobuf_with_errors<S, T, E>(stream: S) -> Self
75    where
76        T: prost::Message + Send + Sync + 'static,
77        S: Stream<Item = Result<T, E>> + 'a + Send,
78        E: Into<axum::Error>,
79    {
80        Self::new(ProtobufStreamFormat::new(), stream)
81    }
82}
83
84impl StreamBodyAsOptions {
85    pub fn protobuf<'a, S, T>(self, stream: S) -> StreamBodyAs<'a>
86    where
87        T: prost::Message + Send + Sync + 'static,
88        S: Stream<Item = T> + 'a + Send,
89    {
90        StreamBodyAs::with_options(
91            ProtobufStreamFormat::new(),
92            stream.map(Ok::<T, axum::Error>),
93            self,
94        )
95    }
96
97    pub fn protobuf_with_errors<'a, S, T, E>(self, stream: S) -> StreamBodyAs<'a>
98    where
99        T: prost::Message + Send + Sync + 'static,
100        S: Stream<Item = Result<T, E>> + 'a + Send,
101        E: Into<axum::Error>,
102    {
103        StreamBodyAs::with_options(ProtobufStreamFormat::new(), stream, self)
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use crate::test_client::*;
111    use crate::StreamBodyAs;
112    use axum::{routing::*, Router};
113    use futures::stream;
114    use prost::Message;
115
116    #[tokio::test]
117    async fn serialize_protobuf_stream_format() {
118        #[derive(Clone, prost::Message)]
119        struct TestOutputStructure {
120            #[prost(string, tag = "1")]
121            foo1: String,
122            #[prost(string, tag = "2")]
123            foo2: String,
124        }
125
126        let test_stream_vec = vec![
127            TestOutputStructure {
128                foo1: "bar1".to_string(),
129                foo2: "bar2".to_string()
130            };
131            7
132        ];
133
134        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
135
136        let app = Router::new().route(
137            "/",
138            get(|| async {
139                StreamBodyAs::new(
140                    ProtobufStreamFormat::new(),
141                    test_stream.map(Ok::<_, axum::Error>),
142                )
143            }),
144        );
145
146        let client = TestClient::new(app).await;
147
148        let expected_proto_buf: Vec<u8> = test_stream_vec
149            .iter()
150            .flat_map(|obj| {
151                let obj_vec = obj.encode_to_vec();
152                let mut frame_vec = Vec::new();
153                let obj_len = (obj_vec.len() as u64);
154                prost::encoding::encode_varint(obj_len, &mut frame_vec);
155                frame_vec.extend(obj_vec);
156                frame_vec
157            })
158            .collect();
159
160        let res = client.get("/").send().await.unwrap();
161        assert_eq!(
162            res.headers()
163                .get("content-type")
164                .and_then(|h| h.to_str().ok()),
165            Some("application/x-protobuf-stream")
166        );
167        let body = res.bytes().await.unwrap().to_vec();
168
169        assert_eq!(body, expected_proto_buf);
170    }
171}