reqwest_streams/
arrow_ipc_stream.rs1use crate::arrow_ipc_len_codec::ArrowIpcCodec;
2use crate::StreamBodyResult;
3use arrow::array::RecordBatch;
4use async_trait::*;
5use futures::TryStreamExt;
6
7#[async_trait]
12pub trait ArrowIpcStreamResponse {
13 fn arrow_ipc_stream<'a>(
14 self,
15 max_obj_len: usize,
16 ) -> impl futures::Stream<Item = StreamBodyResult<RecordBatch>> + Send + 'a;
17}
18
19#[async_trait]
20impl ArrowIpcStreamResponse for reqwest::Response {
21 fn arrow_ipc_stream<'a>(
46 self,
47 max_obj_len: usize,
48 ) -> impl futures::Stream<Item = StreamBodyResult<RecordBatch>> + Send + 'a {
49 let reader = tokio_util::io::StreamReader::new(
50 self.bytes_stream()
51 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)),
52 );
53
54 let codec = ArrowIpcCodec::new_with_max_length(max_obj_len);
55 let frames_reader = tokio_util::codec::FramedRead::new(reader, codec);
56
57 frames_reader.into_stream()
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64 use crate::test_client::*;
65 use arrow::array::{Float64Array, Int64Array, StringArray};
66 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
67 use axum::{routing::*, Router};
68 use axum_streams::*;
69 use futures::stream;
70 use std::sync::Arc;
71
72 fn generate_test_schema() -> SchemaRef {
73 Arc::new(Schema::new(vec![
74 Field::new("id", DataType::Int64, false),
75 Field::new("city", DataType::Utf8, false),
76 Field::new("lat", DataType::Float64, false),
77 Field::new("lng", DataType::Float64, false),
78 ]))
79 }
80
81 fn generate_test_batches() -> Vec<RecordBatch> {
82 (0i64..100i64)
83 .map(move |idx| {
84 RecordBatch::try_new(
85 generate_test_schema(),
86 vec![
87 Arc::new(Int64Array::from(vec![idx, idx * 2, idx * 3])),
88 Arc::new(StringArray::from(vec!["New York", "London", "Gothenburg"])),
89 Arc::new(Float64Array::from(vec![40.7128, 51.5074, 57.7089])),
90 Arc::new(Float64Array::from(vec![-74.0060, -0.1278, 11.9746])),
91 ],
92 )
93 .unwrap()
94 })
95 .collect()
96 }
97
98 #[tokio::test]
99 async fn deserialize_arrow_ipc_stream() {
100 let test_stream_vec = generate_test_batches();
101
102 let test_schema = generate_test_schema();
103 let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
104
105 let app = Router::new().route(
106 "/",
107 get(|| async { StreamBodyAs::arrow_ipc(test_schema, test_stream) }),
108 );
109
110 let client = TestClient::new(app).await;
111
112 let res = client.get("/").send().await.unwrap().arrow_ipc_stream(1024);
113
114 let items: Vec<RecordBatch> = res.try_collect().await.unwrap();
115
116 assert_eq!(items, test_stream_vec);
117 }
118
119 #[tokio::test]
120 async fn deserialize_arrow_ipc_stream_check_max_len() {
121 let test_stream_vec = generate_test_batches();
122
123 let test_schema = generate_test_schema();
124 let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
125
126 let app = Router::new().route(
127 "/",
128 get(|| async { StreamBodyAs::arrow_ipc(test_schema, test_stream) }),
129 );
130
131 let client = TestClient::new(app).await;
132
133 let res = client.get("/").send().await.unwrap().arrow_ipc_stream(10);
134 res.try_collect::<Vec<RecordBatch>>()
135 .await
136 .expect_err("MaxLenReachedError");
137 }
138}