reqwest_streams/
arrow_ipc_stream.rs

1use crate::arrow_ipc_len_codec::ArrowIpcCodec;
2use crate::StreamBodyResult;
3use arrow::array::RecordBatch;
4use async_trait::*;
5use futures::TryStreamExt;
6
7/// Extension trait for [`reqwest::Response`] that provides streaming support for the [Apache Arrow
8/// IPC format].
9///
10/// [Apache Arrow IPC format]: https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc
11#[async_trait]
12pub trait ArrowIpcStreamResponse {
13    fn arrow_ipc_stream<'a>(
14        self,
15        max_obj_len: usize,
16    ) -> impl futures::Stream<Item = StreamBodyResult<RecordBatch>> + Send + 'a;
17}
18
19#[async_trait]
20impl ArrowIpcStreamResponse for reqwest::Response {
21    /// Streams the response as batches of Arrow IPC messages.
22    ///
23    /// The stream will deserialize entries into [`RecordBatch`]es with a maximum object size of
24    /// `max_obj_len` bytes.
25    ///
26    /// # Example
27    ///
28    /// ```rust,no_run
29    /// use arrow::array::RecordBatch;
30    /// use futures::{prelude::*, stream::BoxStream as _};
31    /// use reqwest_streams::ArrowIpcStreamResponse as _;
32    ///
33    /// #[tokio::main]
34    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
35    ///     const MAX_OBJ_LEN: usize = 64 * 1024;
36    ///
37    ///     let stream = reqwest::get("http://localhost:8080/arrow")
38    ///         .await?
39    ///         .arrow_ipc_stream(MAX_OBJ_LEN);
40    ///     let _items: Vec<RecordBatch> = stream.try_collect().await?;
41    ///
42    ///     Ok(())
43    /// }
44    /// ```
45    fn arrow_ipc_stream<'a>(
46        self,
47        max_obj_len: usize,
48    ) -> impl futures::Stream<Item = StreamBodyResult<RecordBatch>>  + Send + 'a {
49        let reader = tokio_util::io::StreamReader::new(
50            self.bytes_stream()
51                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)),
52        );
53
54        let codec = ArrowIpcCodec::new_with_max_length(max_obj_len);
55        let frames_reader = tokio_util::codec::FramedRead::new(reader, codec);
56
57        frames_reader.into_stream()
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use crate::test_client::*;
65    use arrow::array::{Float64Array, Int64Array, StringArray};
66    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
67    use axum::{routing::*, Router};
68    use axum_streams::*;
69    use futures::stream;
70    use std::sync::Arc;
71
72    fn generate_test_schema() -> SchemaRef {
73        Arc::new(Schema::new(vec![
74            Field::new("id", DataType::Int64, false),
75            Field::new("city", DataType::Utf8, false),
76            Field::new("lat", DataType::Float64, false),
77            Field::new("lng", DataType::Float64, false),
78        ]))
79    }
80
81    fn generate_test_batches() -> Vec<RecordBatch> {
82        (0i64..100i64)
83            .map(move |idx| {
84                RecordBatch::try_new(
85                    generate_test_schema(),
86                    vec![
87                        Arc::new(Int64Array::from(vec![idx, idx * 2, idx * 3])),
88                        Arc::new(StringArray::from(vec!["New York", "London", "Gothenburg"])),
89                        Arc::new(Float64Array::from(vec![40.7128, 51.5074, 57.7089])),
90                        Arc::new(Float64Array::from(vec![-74.0060, -0.1278, 11.9746])),
91                    ],
92                )
93                .unwrap()
94            })
95            .collect()
96    }
97
98    #[tokio::test]
99    async fn deserialize_arrow_ipc_stream() {
100        let test_stream_vec = generate_test_batches();
101
102        let test_schema = generate_test_schema();
103        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
104
105        let app = Router::new().route(
106            "/",
107            get(|| async { StreamBodyAs::arrow_ipc(test_schema, test_stream) }),
108        );
109
110        let client = TestClient::new(app).await;
111
112        let res = client.get("/").send().await.unwrap().arrow_ipc_stream(1024);
113
114        let items: Vec<RecordBatch> = res.try_collect().await.unwrap();
115
116        assert_eq!(items, test_stream_vec);
117    }
118
119    #[tokio::test]
120    async fn deserialize_arrow_ipc_stream_check_max_len() {
121        let test_stream_vec = generate_test_batches();
122
123        let test_schema = generate_test_schema();
124        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
125
126        let app = Router::new().route(
127            "/",
128            get(|| async { StreamBodyAs::arrow_ipc(test_schema, test_stream) }),
129        );
130
131        let client = TestClient::new(app).await;
132
133        let res = client.get("/").send().await.unwrap().arrow_ipc_stream(10);
134        res.try_collect::<Vec<RecordBatch>>()
135            .await
136            .expect_err("MaxLenReachedError");
137    }
138}