use crate::stream_body_as::StreamBodyAsOptions;
use crate::{StreamBodyAs, StreamingFormat};
use arrow::array::RecordBatch;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::ArrowError;
use arrow::ipc::writer::{
write_message, CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions,
};
use bytes::{BufMut, BytesMut};
use futures::stream::BoxStream;
use futures::Stream;
use futures::StreamExt;
use http::HeaderMap;
use std::io::Write;
use std::sync::Arc;
pub struct ArrowRecordBatchIpcStreamFormat {
schema: SchemaRef,
options: IpcWriteOptions,
}
impl ArrowRecordBatchIpcStreamFormat {
pub fn new(schema: Arc<Schema>) -> Self {
Self::with_options(schema, IpcWriteOptions::default())
}
pub fn with_options(schema: Arc<Schema>, options: IpcWriteOptions) -> Self {
Self {
schema: schema.clone(),
options: options.clone(),
}
}
}
impl StreamingFormat<RecordBatch> for ArrowRecordBatchIpcStreamFormat {
fn to_bytes_stream<'a, 'b>(
&'a self,
stream: BoxStream<'b, Result<RecordBatch, axum::Error>>,
_: &'a StreamBodyAsOptions,
) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
fn write_batch(
ipc_data_gen: &mut IpcDataGenerator,
dictionary_tracker: &mut DictionaryTracker,
compression_context: &mut CompressionContext,
write_options: &IpcWriteOptions,
batch: &RecordBatch,
prepend_schema: Option<Arc<Schema>>,
) -> Result<axum::body::Bytes, ArrowError> {
let mut writer = BytesMut::new().writer();
if let Some(prepend_schema) = prepend_schema {
let encoded_message = ipc_data_gen.schema_to_bytes_with_dictionary_tracker(
&prepend_schema,
dictionary_tracker,
write_options,
);
write_message(&mut writer, encoded_message, write_options)?;
}
let (encoded_dictionaries, encoded_message) = ipc_data_gen.encode(
batch,
dictionary_tracker,
write_options,
compression_context,
)?;
for encoded_dictionary in encoded_dictionaries {
write_message(&mut writer, encoded_dictionary, write_options)?;
}
write_message(&mut writer, encoded_message, write_options)?;
writer.flush()?;
Ok(writer.into_inner().freeze())
}
fn write_continuation() -> Result<axum::body::Bytes, ArrowError> {
let mut writer = BytesMut::with_capacity(8).writer();
const CONTINUATION_MARKER: [u8; 4] = [0xff; 4];
const TOTAL_LEN: [u8; 4] = [0; 4];
writer.write_all(&CONTINUATION_MARKER)?;
writer.write_all(&TOTAL_LEN)?;
writer.flush()?;
Ok(writer.into_inner().freeze())
}
let batch_schema = self.schema.clone();
let batch_options = self.options.clone();
let ipc_data_gen = IpcDataGenerator::default();
let dictionary_tracker: DictionaryTracker = DictionaryTracker::new(false);
let compression_context = CompressionContext::default();
let batch_stream = Box::pin({
stream.scan(
(ipc_data_gen, dictionary_tracker, compression_context, 0),
move |(ipc_data_gen, dictionary_tracker, compression_context, idx), batch_res| {
match batch_res {
Err(e) => futures::future::ready(Some(Err(e))),
Ok(batch) => futures::future::ready({
let prepend_schema = if *idx == 0 {
Some(batch_schema.clone())
} else {
None
};
*idx += 1;
let bytes = write_batch(
ipc_data_gen,
dictionary_tracker,
compression_context,
&batch_options,
&batch,
prepend_schema,
)
.map_err(axum::Error::new);
Some(bytes)
}),
}
},
)
});
let append_stream: BoxStream<Result<axum::body::Bytes, axum::Error>> =
Box::pin(futures::stream::once(futures::future::ready({
write_continuation().map_err(axum::Error::new)
})));
Box::pin(batch_stream.chain(append_stream))
}
fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option<HeaderMap> {
let mut header_map = HeaderMap::new();
header_map.insert(
http::header::CONTENT_TYPE,
options.content_type.clone().unwrap_or_else(|| {
http::header::HeaderValue::from_static("application/vnd.apache.arrow.stream")
}),
);
Some(header_map)
}
}
impl<'a> crate::StreamBodyAs<'a> {
pub fn arrow_ipc<S>(schema: SchemaRef, stream: S) -> Self
where
S: Stream<Item = RecordBatch> + 'a + Send,
{
Self::new(
ArrowRecordBatchIpcStreamFormat::new(schema),
stream.map(Ok::<RecordBatch, axum::Error>),
)
}
pub fn arrow_ipc_with_errors<S, E>(schema: SchemaRef, stream: S) -> Self
where
S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
E: Into<axum::Error>,
{
Self::new(ArrowRecordBatchIpcStreamFormat::new(schema), stream)
}
pub fn arrow_ipc_with_options<S>(schema: SchemaRef, stream: S, options: IpcWriteOptions) -> Self
where
S: Stream<Item = RecordBatch> + 'a + Send,
{
Self::new(
ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
stream.map(Ok::<RecordBatch, axum::Error>),
)
}
pub fn arrow_ipc_with_options_errors<S, E>(
schema: SchemaRef,
stream: S,
options: IpcWriteOptions,
) -> Self
where
S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
E: Into<axum::Error>,
{
Self::new(
ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
stream,
)
}
}
impl StreamBodyAsOptions {
pub fn arrow_ipc<'a, S>(self, schema: SchemaRef, stream: S) -> StreamBodyAs<'a>
where
S: Stream<Item = RecordBatch> + 'a + Send,
{
StreamBodyAs::with_options(
ArrowRecordBatchIpcStreamFormat::new(schema),
stream.map(Ok::<RecordBatch, axum::Error>),
self,
)
}
pub fn arrow_ipc_with_errors<'a, S, E>(self, schema: SchemaRef, stream: S) -> StreamBodyAs<'a>
where
S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
E: Into<axum::Error>,
{
StreamBodyAs::with_options(ArrowRecordBatchIpcStreamFormat::new(schema), stream, self)
}
pub fn arrow_ipc_with_options<'a, S>(
self,
schema: SchemaRef,
stream: S,
options: IpcWriteOptions,
) -> StreamBodyAs<'a>
where
S: Stream<Item = RecordBatch> + 'a + Send,
{
StreamBodyAs::with_options(
ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
stream.map(Ok::<RecordBatch, axum::Error>),
self,
)
}
pub fn arrow_ipc_with_options_errors<'a, S, E>(
self,
schema: SchemaRef,
stream: S,
options: IpcWriteOptions,
) -> StreamBodyAs<'a>
where
S: Stream<Item = Result<RecordBatch, E>> + 'a + Send,
E: Into<axum::Error>,
{
StreamBodyAs::with_options(
ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
stream,
self,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_client::*;
use crate::StreamBodyAs;
use arrow::array::*;
use arrow::datatypes::*;
use axum::{routing::*, Router};
use futures::stream;
use std::sync::Arc;
#[tokio::test]
async fn serialize_arrow_stream_format() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("city", DataType::Utf8, false),
Field::new("lat", DataType::Float64, false),
Field::new("lng", DataType::Float64, false),
]));
fn create_test_batch(schema_ref: SchemaRef) -> Vec<RecordBatch> {
let vec_schema = schema_ref.clone();
(0i64..10i64)
.map(move |idx| {
RecordBatch::try_new(
vec_schema.clone(),
vec![
Arc::new(Int64Array::from(vec![idx, idx * 2, idx * 3])),
Arc::new(StringArray::from(vec!["New York", "London", "Gothenburg"])),
Arc::new(Float64Array::from(vec![40.7128, 51.5074, 57.7089])),
Arc::new(Float64Array::from(vec![-74.0060, -0.1278, 11.9746])),
],
)
.unwrap()
})
.collect()
}
let test_stream = Box::pin(stream::iter(create_test_batch(schema.clone())));
let app_schema = schema.clone();
let app = Router::new().route(
"/",
get(|| async move {
StreamBodyAs::new(
ArrowRecordBatchIpcStreamFormat::new(app_schema.clone()),
test_stream.map(Ok::<_, axum::Error>),
)
}),
);
let client = TestClient::new(app).await;
let mut writer =
arrow::ipc::writer::StreamWriter::try_new(Vec::new(), &schema).expect("writer failed");
for batch in create_test_batch(schema.clone()) {
writer.write(&batch).expect("write failed");
}
writer.finish().expect("writer failed");
let expected_buf = writer.into_inner().expect("writer failed");
let res = client.get("/").send().await.unwrap();
assert_eq!(
res.headers()
.get("content-type")
.and_then(|h| h.to_str().ok()),
Some("application/vnd.apache.arrow.stream")
);
let body = res.bytes().await.unwrap().to_vec();
assert_eq!(body.len(), expected_buf.len());
assert_eq!(body, expected_buf);
}
}