1use crate::stream_body_as::StreamBodyAsOptions;
2use crate::{StreamBodyAs, StreamingFormat};
3use arrow::array::RecordBatch;
4use arrow::datatypes::{Schema, SchemaRef};
5use arrow::error::ArrowError;
6use arrow::ipc::writer::{
7 write_message, CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions,
8};
9use bytes::{BufMut, BytesMut};
10use futures::stream::BoxStream;
11use futures::Stream;
12use futures::StreamExt;
13use http::HeaderMap;
14use std::io::Write;
15use std::sync::Arc;
16
17pub struct ArrowRecordBatchIpcStreamFormat {
18 schema: SchemaRef,
19 options: IpcWriteOptions,
20}
21
22impl ArrowRecordBatchIpcStreamFormat {
23 pub fn new(schema: Arc<Schema>) -> Self {
24 Self::with_options(schema, IpcWriteOptions::default())
25 }
26
27 pub fn with_options(schema: Arc<Schema>, options: IpcWriteOptions) -> Self {
28 Self {
29 schema: schema.clone(),
30 options: options.clone(),
31 }
32 }
33}
34
35impl StreamingFormat<RecordBatch> for ArrowRecordBatchIpcStreamFormat {
36 fn to_bytes_stream<'a, 'b>(
37 &'a self,
38 stream: BoxStream<'b, Result<RecordBatch, axum::Error>>,
39 _: &'a StreamBodyAsOptions,
40 ) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
41 fn write_batch(
42 ipc_data_gen: &mut IpcDataGenerator,
43 dictionary_tracker: &mut DictionaryTracker,
44 compression_context: &mut CompressionContext,
45 write_options: &IpcWriteOptions,
46 batch: &RecordBatch,
47 prepend_schema: Option<Arc<Schema>>,
48 ) -> Result<axum::body::Bytes, ArrowError> {
49 let mut writer = BytesMut::new().writer();
50
51 if let Some(prepend_schema) = prepend_schema {
52 let encoded_message = ipc_data_gen.schema_to_bytes_with_dictionary_tracker(
53 &prepend_schema,
54 dictionary_tracker,
55 write_options,
56 );
57 write_message(&mut writer, encoded_message, write_options)?;
58 }
59
60 let (encoded_dictionaries, encoded_message) = ipc_data_gen.encode(
61 batch,
62 dictionary_tracker,
63 write_options,
64 compression_context,
65 )?;
66
67 for encoded_dictionary in encoded_dictionaries {
68 write_message(&mut writer, encoded_dictionary, write_options)?;
69 }
70
71 write_message(&mut writer, encoded_message, write_options)?;
72 writer.flush()?;
73 Ok(writer.into_inner().freeze())
74 }
75
76 fn write_continuation() -> Result<axum::body::Bytes, ArrowError> {
77 let mut writer = BytesMut::with_capacity(8).writer();
78 const CONTINUATION_MARKER: [u8; 4] = [0xff; 4];
79 const TOTAL_LEN: [u8; 4] = [0; 4];
80 writer.write_all(&CONTINUATION_MARKER)?;
81 writer.write_all(&TOTAL_LEN)?;
82 writer.flush()?;
83 Ok(writer.into_inner().freeze())
84 }
85
86 let batch_schema = self.schema.clone();
87 let batch_options = self.options.clone();
88
89 let ipc_data_gen = IpcDataGenerator::default();
90 let dictionary_tracker: DictionaryTracker = DictionaryTracker::new(false);
91 let compression_context = CompressionContext::default();
92
93 let batch_stream = Box::pin({
94 stream.scan(
95 (ipc_data_gen, dictionary_tracker, compression_context, 0),
96 move |(ipc_data_gen, dictionary_tracker, compression_context, idx), batch_res| {
97 match batch_res {
98 Err(e) => futures::future::ready(Some(Err(e))),
99 Ok(batch) => futures::future::ready({
100 let prepend_schema = if *idx == 0 {
101 Some(batch_schema.clone())
102 } else {
103 None
104 };
105 *idx += 1;
106 let bytes = write_batch(
107 ipc_data_gen,
108 dictionary_tracker,
109 compression_context,
110 &batch_options,
111 &batch,
112 prepend_schema,
113 )
114 .map_err(axum::Error::new);
115 Some(bytes)
116 }),
117 }
118 },
119 )
120 });
121
122 let append_stream: BoxStream<Result<axum::body::Bytes, axum::Error>> =
123 Box::pin(futures::stream::once(futures::future::ready({
124 write_continuation().map_err(axum::Error::new)
125 })));
126
127 Box::pin(batch_stream.chain(append_stream))
128 }
129
130 fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option<HeaderMap> {
131 let mut header_map = HeaderMap::new();
132 header_map.insert(
133 http::header::CONTENT_TYPE,
134 options.content_type.clone().unwrap_or_else(|| {
135 http::header::HeaderValue::from_static("application/vnd.apache.arrow.stream")
136 }),
137 );
138 Some(header_map)
139 }
140}
141
142impl<'a> crate::StreamBodyAs<'a> {
143 pub fn arrow_ipc<S>(schema: SchemaRef, stream: S) -> Self
144 where
145 S: Stream<Item = RecordBatch> + 'a + Send,
146 {
147 Self::new(
148 ArrowRecordBatchIpcStreamFormat::new(schema),
149 stream.map(Ok::<RecordBatch, axum::Error>),
150 )
151 }
152
153 pub fn arrow_ipc_with_errors<S, E>(schema: SchemaRef, stream: S) -> Self
154 where
155 S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
156 E: Into<axum::Error>,
157 {
158 Self::new(ArrowRecordBatchIpcStreamFormat::new(schema), stream)
159 }
160
161 pub fn arrow_ipc_with_options<S>(schema: SchemaRef, stream: S, options: IpcWriteOptions) -> Self
162 where
163 S: Stream<Item = RecordBatch> + 'a + Send,
164 {
165 Self::new(
166 ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
167 stream.map(Ok::<RecordBatch, axum::Error>),
168 )
169 }
170
171 pub fn arrow_ipc_with_options_errors<S, E>(
172 schema: SchemaRef,
173 stream: S,
174 options: IpcWriteOptions,
175 ) -> Self
176 where
177 S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
178 E: Into<axum::Error>,
179 {
180 Self::new(
181 ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
182 stream,
183 )
184 }
185}
186
187impl StreamBodyAsOptions {
188 pub fn arrow_ipc<'a, S>(self, schema: SchemaRef, stream: S) -> StreamBodyAs<'a>
189 where
190 S: Stream<Item = RecordBatch> + 'a + Send,
191 {
192 StreamBodyAs::with_options(
193 ArrowRecordBatchIpcStreamFormat::new(schema),
194 stream.map(Ok::<RecordBatch, axum::Error>),
195 self,
196 )
197 }
198
199 pub fn arrow_ipc_with_errors<'a, S, E>(self, schema: SchemaRef, stream: S) -> StreamBodyAs<'a>
200 where
201 S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
202 E: Into<axum::Error>,
203 {
204 StreamBodyAs::with_options(ArrowRecordBatchIpcStreamFormat::new(schema), stream, self)
205 }
206
207 pub fn arrow_ipc_with_options<'a, S>(
208 self,
209 schema: SchemaRef,
210 stream: S,
211 options: IpcWriteOptions,
212 ) -> StreamBodyAs<'a>
213 where
214 S: Stream<Item = RecordBatch> + 'a + Send,
215 {
216 StreamBodyAs::with_options(
217 ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
218 stream.map(Ok::<RecordBatch, axum::Error>),
219 self,
220 )
221 }
222
223 pub fn arrow_ipc_with_options_errors<'a, S, E>(
224 self,
225 schema: SchemaRef,
226 stream: S,
227 options: IpcWriteOptions,
228 ) -> StreamBodyAs<'a>
229 where
230 S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
231 E: Into<axum::Error>,
232 {
233 StreamBodyAs::with_options(
234 ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
235 stream,
236 self,
237 )
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::test_client::*;
245 use crate::StreamBodyAs;
246 use arrow::array::*;
247 use arrow::datatypes::*;
248 use axum::{routing::*, Router};
249 use futures::stream;
250 use std::sync::Arc;
251
252 #[tokio::test]
253 async fn serialize_arrow_stream_format() {
254 let schema = Arc::new(Schema::new(vec![
255 Field::new("id", DataType::Int64, false),
256 Field::new("city", DataType::Utf8, false),
257 Field::new("lat", DataType::Float64, false),
258 Field::new("lng", DataType::Float64, false),
259 ]));
260
261 fn create_test_batch(schema_ref: SchemaRef) -> Vec<RecordBatch> {
262 let vec_schema = schema_ref.clone();
263 (0i64..10i64)
264 .map(move |idx| {
265 RecordBatch::try_new(
266 vec_schema.clone(),
267 vec![
268 Arc::new(Int64Array::from(vec![idx, idx * 2, idx * 3])),
269 Arc::new(StringArray::from(vec!["New York", "London", "Gothenburg"])),
270 Arc::new(Float64Array::from(vec![40.7128, 51.5074, 57.7089])),
271 Arc::new(Float64Array::from(vec![-74.0060, -0.1278, 11.9746])),
272 ],
273 )
274 .unwrap()
275 })
276 .collect()
277 }
278
279 let test_stream = Box::pin(stream::iter(create_test_batch(schema.clone())));
280
281 let app_schema = schema.clone();
282
283 let app = Router::new().route(
284 "/",
285 get(|| async move {
286 StreamBodyAs::new(
287 ArrowRecordBatchIpcStreamFormat::new(app_schema.clone()),
288 test_stream.map(Ok::<_, axum::Error>),
289 )
290 }),
291 );
292
293 let client = TestClient::new(app).await;
294
295 let mut writer =
296 arrow::ipc::writer::StreamWriter::try_new(Vec::new(), &schema).expect("writer failed");
297 for batch in create_test_batch(schema.clone()) {
298 writer.write(&batch).expect("write failed");
299 }
300 writer.finish().expect("writer failed");
301 let expected_buf = writer.into_inner().expect("writer failed");
302
303 let res = client.get("/").send().await.unwrap();
304 assert_eq!(
305 res.headers()
306 .get("content-type")
307 .and_then(|h| h.to_str().ok()),
308 Some("application/vnd.apache.arrow.stream")
309 );
310 let body = res.bytes().await.unwrap().to_vec();
311
312 assert_eq!(body.len(), expected_buf.len());
313 assert_eq!(body, expected_buf);
314 }
315}