Skip to main content

axum_streams/
arrow_format.rs

1use crate::stream_body_as::StreamBodyAsOptions;
2use crate::{StreamBodyAs, StreamingFormat};
3use arrow::array::RecordBatch;
4use arrow::datatypes::{Schema, SchemaRef};
5use arrow::error::ArrowError;
6use arrow::ipc::writer::{
7    write_message, CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions,
8};
9use bytes::{BufMut, BytesMut};
10use futures::stream::BoxStream;
11use futures::Stream;
12use futures::StreamExt;
13use http::HeaderMap;
14use std::io::Write;
15use std::sync::Arc;
16
17pub struct ArrowRecordBatchIpcStreamFormat {
18    schema: SchemaRef,
19    options: IpcWriteOptions,
20}
21
22impl ArrowRecordBatchIpcStreamFormat {
23    pub fn new(schema: Arc<Schema>) -> Self {
24        Self::with_options(schema, IpcWriteOptions::default())
25    }
26
27    pub fn with_options(schema: Arc<Schema>, options: IpcWriteOptions) -> Self {
28        Self {
29            schema: schema.clone(),
30            options: options.clone(),
31        }
32    }
33}
34
35impl StreamingFormat<RecordBatch> for ArrowRecordBatchIpcStreamFormat {
36    fn to_bytes_stream<'a, 'b>(
37        &'a self,
38        stream: BoxStream<'b, Result<RecordBatch, axum::Error>>,
39        _: &'a StreamBodyAsOptions,
40    ) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
41        fn write_batch(
42            ipc_data_gen: &mut IpcDataGenerator,
43            dictionary_tracker: &mut DictionaryTracker,
44            compression_context: &mut CompressionContext,
45            write_options: &IpcWriteOptions,
46            batch: &RecordBatch,
47            prepend_schema: Option<Arc<Schema>>,
48        ) -> Result<axum::body::Bytes, ArrowError> {
49            let mut writer = BytesMut::new().writer();
50
51            if let Some(prepend_schema) = prepend_schema {
52                let encoded_message = ipc_data_gen.schema_to_bytes_with_dictionary_tracker(
53                    &prepend_schema,
54                    dictionary_tracker,
55                    write_options,
56                );
57                write_message(&mut writer, encoded_message, write_options)?;
58            }
59
60            let (encoded_dictionaries, encoded_message) = ipc_data_gen.encode(
61                batch,
62                dictionary_tracker,
63                write_options,
64                compression_context,
65            )?;
66
67            for encoded_dictionary in encoded_dictionaries {
68                write_message(&mut writer, encoded_dictionary, write_options)?;
69            }
70
71            write_message(&mut writer, encoded_message, write_options)?;
72            writer.flush()?;
73            Ok(writer.into_inner().freeze())
74        }
75
76        fn write_continuation() -> Result<axum::body::Bytes, ArrowError> {
77            let mut writer = BytesMut::with_capacity(8).writer();
78            const CONTINUATION_MARKER: [u8; 4] = [0xff; 4];
79            const TOTAL_LEN: [u8; 4] = [0; 4];
80            writer.write_all(&CONTINUATION_MARKER)?;
81            writer.write_all(&TOTAL_LEN)?;
82            writer.flush()?;
83            Ok(writer.into_inner().freeze())
84        }
85
86        let batch_schema = self.schema.clone();
87        let batch_options = self.options.clone();
88
89        let ipc_data_gen = IpcDataGenerator::default();
90        let dictionary_tracker: DictionaryTracker = DictionaryTracker::new(false);
91        let compression_context = CompressionContext::default();
92
93        let batch_stream = Box::pin({
94            stream.scan(
95                (ipc_data_gen, dictionary_tracker, compression_context, 0),
96                move |(ipc_data_gen, dictionary_tracker, compression_context, idx), batch_res| {
97                    match batch_res {
98                        Err(e) => futures::future::ready(Some(Err(e))),
99                        Ok(batch) => futures::future::ready({
100                            let prepend_schema = if *idx == 0 {
101                                Some(batch_schema.clone())
102                            } else {
103                                None
104                            };
105                            *idx += 1;
106                            let bytes = write_batch(
107                                ipc_data_gen,
108                                dictionary_tracker,
109                                compression_context,
110                                &batch_options,
111                                &batch,
112                                prepend_schema,
113                            )
114                            .map_err(axum::Error::new);
115                            Some(bytes)
116                        }),
117                    }
118                },
119            )
120        });
121
122        let append_stream: BoxStream<Result<axum::body::Bytes, axum::Error>> =
123            Box::pin(futures::stream::once(futures::future::ready({
124                write_continuation().map_err(axum::Error::new)
125            })));
126
127        Box::pin(batch_stream.chain(append_stream))
128    }
129
130    fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option<HeaderMap> {
131        let mut header_map = HeaderMap::new();
132        header_map.insert(
133            http::header::CONTENT_TYPE,
134            options.content_type.clone().unwrap_or_else(|| {
135                http::header::HeaderValue::from_static("application/vnd.apache.arrow.stream")
136            }),
137        );
138        Some(header_map)
139    }
140}
141
142impl<'a> crate::StreamBodyAs<'a> {
143    pub fn arrow_ipc<S>(schema: SchemaRef, stream: S) -> Self
144    where
145        S: Stream<Item = RecordBatch> + 'a + Send,
146    {
147        Self::new(
148            ArrowRecordBatchIpcStreamFormat::new(schema),
149            stream.map(Ok::<RecordBatch, axum::Error>),
150        )
151    }
152
153    pub fn arrow_ipc_with_errors<S, E>(schema: SchemaRef, stream: S) -> Self
154    where
155        S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
156        E: Into<axum::Error>,
157    {
158        Self::new(ArrowRecordBatchIpcStreamFormat::new(schema), stream)
159    }
160
161    pub fn arrow_ipc_with_options<S>(schema: SchemaRef, stream: S, options: IpcWriteOptions) -> Self
162    where
163        S: Stream<Item = RecordBatch> + 'a + Send,
164    {
165        Self::new(
166            ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
167            stream.map(Ok::<RecordBatch, axum::Error>),
168        )
169    }
170
171    pub fn arrow_ipc_with_options_errors<S, E>(
172        schema: SchemaRef,
173        stream: S,
174        options: IpcWriteOptions,
175    ) -> Self
176    where
177        S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
178        E: Into<axum::Error>,
179    {
180        Self::new(
181            ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
182            stream,
183        )
184    }
185}
186
187impl StreamBodyAsOptions {
188    pub fn arrow_ipc<'a, S>(self, schema: SchemaRef, stream: S) -> StreamBodyAs<'a>
189    where
190        S: Stream<Item = RecordBatch> + 'a + Send,
191    {
192        StreamBodyAs::with_options(
193            ArrowRecordBatchIpcStreamFormat::new(schema),
194            stream.map(Ok::<RecordBatch, axum::Error>),
195            self,
196        )
197    }
198
199    pub fn arrow_ipc_with_errors<'a, S, E>(self, schema: SchemaRef, stream: S) -> StreamBodyAs<'a>
200    where
201        S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
202        E: Into<axum::Error>,
203    {
204        StreamBodyAs::with_options(ArrowRecordBatchIpcStreamFormat::new(schema), stream, self)
205    }
206
207    pub fn arrow_ipc_with_options<'a, S>(
208        self,
209        schema: SchemaRef,
210        stream: S,
211        options: IpcWriteOptions,
212    ) -> StreamBodyAs<'a>
213    where
214        S: Stream<Item = RecordBatch> + 'a + Send,
215    {
216        StreamBodyAs::with_options(
217            ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
218            stream.map(Ok::<RecordBatch, axum::Error>),
219            self,
220        )
221    }
222
223    pub fn arrow_ipc_with_options_errors<'a, S, E>(
224        self,
225        schema: SchemaRef,
226        stream: S,
227        options: IpcWriteOptions,
228    ) -> StreamBodyAs<'a>
229    where
230        S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
231        E: Into<axum::Error>,
232    {
233        StreamBodyAs::with_options(
234            ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
235            stream,
236            self,
237        )
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use crate::test_client::*;
245    use crate::StreamBodyAs;
246    use arrow::array::*;
247    use arrow::datatypes::*;
248    use axum::{routing::*, Router};
249    use futures::stream;
250    use std::sync::Arc;
251
252    #[tokio::test]
253    async fn serialize_arrow_stream_format() {
254        let schema = Arc::new(Schema::new(vec![
255            Field::new("id", DataType::Int64, false),
256            Field::new("city", DataType::Utf8, false),
257            Field::new("lat", DataType::Float64, false),
258            Field::new("lng", DataType::Float64, false),
259        ]));
260
261        fn create_test_batch(schema_ref: SchemaRef) -> Vec<RecordBatch> {
262            let vec_schema = schema_ref.clone();
263            (0i64..10i64)
264                .map(move |idx| {
265                    RecordBatch::try_new(
266                        vec_schema.clone(),
267                        vec![
268                            Arc::new(Int64Array::from(vec![idx, idx * 2, idx * 3])),
269                            Arc::new(StringArray::from(vec!["New York", "London", "Gothenburg"])),
270                            Arc::new(Float64Array::from(vec![40.7128, 51.5074, 57.7089])),
271                            Arc::new(Float64Array::from(vec![-74.0060, -0.1278, 11.9746])),
272                        ],
273                    )
274                    .unwrap()
275                })
276                .collect()
277        }
278
279        let test_stream = Box::pin(stream::iter(create_test_batch(schema.clone())));
280
281        let app_schema = schema.clone();
282
283        let app = Router::new().route(
284            "/",
285            get(|| async move {
286                StreamBodyAs::new(
287                    ArrowRecordBatchIpcStreamFormat::new(app_schema.clone()),
288                    test_stream.map(Ok::<_, axum::Error>),
289                )
290            }),
291        );
292
293        let client = TestClient::new(app).await;
294
295        let mut writer =
296            arrow::ipc::writer::StreamWriter::try_new(Vec::new(), &schema).expect("writer failed");
297        for batch in create_test_batch(schema.clone()) {
298            writer.write(&batch).expect("write failed");
299        }
300        writer.finish().expect("writer failed");
301        let expected_buf = writer.into_inner().expect("writer failed");
302
303        let res = client.get("/").send().await.unwrap();
304        assert_eq!(
305            res.headers()
306                .get("content-type")
307                .and_then(|h| h.to_str().ok()),
308            Some("application/vnd.apache.arrow.stream")
309        );
310        let body = res.bytes().await.unwrap().to_vec();
311
312        assert_eq!(body.len(), expected_buf.len());
313        assert_eq!(body, expected_buf);
314    }
315}