use crate::{trailers::LazyTrailers, utils::flight_data_to_arrow_batch, FlightData};
use arrow_array::{ArrayRef, RecordBatch};
use arrow_buffer::Buffer;
use arrow_schema::{Schema, SchemaRef};
use bytes::Bytes;
use futures::{ready, stream::BoxStream, Stream, StreamExt};
use std::{collections::HashMap, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
use tonic::metadata::MetadataMap;
use crate::error::{FlightError, Result};
#[derive(Debug)]
pub struct FlightRecordBatchStream {
headers: MetadataMap,
trailers: Option<LazyTrailers>,
inner: FlightDataDecoder,
}
impl FlightRecordBatchStream {
pub fn new(inner: FlightDataDecoder) -> Self {
Self {
inner,
headers: MetadataMap::default(),
trailers: None,
}
}
pub fn new_from_flight_data<S>(inner: S) -> Self
where
S: Stream<Item = Result<FlightData>> + Send + 'static,
{
Self {
inner: FlightDataDecoder::new(inner),
headers: MetadataMap::default(),
trailers: None,
}
}
pub fn with_headers(self, headers: MetadataMap) -> Self {
Self { headers, ..self }
}
pub fn with_trailers(self, trailers: LazyTrailers) -> Self {
Self {
trailers: Some(trailers),
..self
}
}
pub fn headers(&self) -> &MetadataMap {
&self.headers
}
pub fn trailers(&self) -> Option<MetadataMap> {
self.trailers.as_ref().and_then(|trailers| trailers.get())
}
#[deprecated = "use schema().is_some() instead"]
pub fn got_schema(&self) -> bool {
self.schema().is_some()
}
pub fn schema(&self) -> Option<&SchemaRef> {
self.inner.schema()
}
pub fn into_inner(self) -> FlightDataDecoder {
self.inner
}
}
impl futures::Stream for FlightRecordBatchStream {
type Item = Result<RecordBatch>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
loop {
let had_schema = self.schema().is_some();
let res = ready!(self.inner.poll_next_unpin(cx));
match res {
None => {
return Poll::Ready(None);
}
Some(Err(e)) => {
return Poll::Ready(Some(Err(e)));
}
Some(Ok(data)) => match data.payload {
DecodedPayload::Schema(_) if had_schema => {
return Poll::Ready(Some(Err(FlightError::protocol(
"Unexpectedly saw multiple Schema messages in FlightData stream",
))));
}
DecodedPayload::Schema(_) => {
}
DecodedPayload::RecordBatch(batch) => {
return Poll::Ready(Some(Ok(batch)));
}
DecodedPayload::None => {
}
},
}
}
}
}
pub struct FlightDataDecoder {
response: BoxStream<'static, Result<FlightData>>,
state: Option<FlightStreamState>,
done: bool,
}
impl Debug for FlightDataDecoder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FlightDataDecoder")
.field("response", &"<stream>")
.field("state", &self.state)
.field("done", &self.done)
.finish()
}
}
impl FlightDataDecoder {
pub fn new<S>(response: S) -> Self
where
S: Stream<Item = Result<FlightData>> + Send + 'static,
{
Self {
state: None,
response: response.boxed(),
done: false,
}
}
pub fn schema(&self) -> Option<&SchemaRef> {
self.state.as_ref().map(|state| &state.schema)
}
fn extract_message(&mut self, data: FlightData) -> Result<Option<DecodedFlightData>> {
use arrow_ipc::MessageHeader;
let message = arrow_ipc::root_as_message(&data.data_header[..])
.map_err(|e| FlightError::DecodeError(format!("Error decoding root message: {e}")))?;
match message.header_type() {
MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))),
MessageHeader::Schema => {
let schema = Schema::try_from(&data)
.map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?;
let schema = Arc::new(schema);
let dictionaries_by_field = HashMap::new();
self.state = Some(FlightStreamState {
schema: Arc::clone(&schema),
dictionaries_by_field,
});
Ok(Some(DecodedFlightData::new_schema(data, schema)))
}
MessageHeader::DictionaryBatch => {
let state = if let Some(state) = self.state.as_mut() {
state
} else {
return Err(FlightError::protocol(
"Received DictionaryBatch prior to Schema",
));
};
let buffer = Buffer::from(data.data_body);
let dictionary_batch = message.header_as_dictionary_batch().ok_or_else(|| {
FlightError::protocol(
"Could not get dictionary batch from DictionaryBatch message",
)
})?;
arrow_ipc::reader::read_dictionary(
&buffer,
dictionary_batch,
&state.schema,
&mut state.dictionaries_by_field,
&message.version(),
)
.map_err(|e| {
FlightError::DecodeError(format!("Error decoding ipc dictionary: {e}"))
})?;
Ok(None)
}
MessageHeader::RecordBatch => {
let state = if let Some(state) = self.state.as_ref() {
state
} else {
return Err(FlightError::protocol(
"Received RecordBatch prior to Schema",
));
};
let batch = flight_data_to_arrow_batch(
&data,
Arc::clone(&state.schema),
&state.dictionaries_by_field,
)
.map_err(|e| {
FlightError::DecodeError(format!("Error decoding ipc RecordBatch: {e}"))
})?;
Ok(Some(DecodedFlightData::new_record_batch(data, batch)))
}
other => {
let name = other.variant_name().unwrap_or("UNKNOWN");
Err(FlightError::protocol(format!("Unexpected message: {name}")))
}
}
}
}
impl futures::Stream for FlightDataDecoder {
type Item = Result<DecodedFlightData>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}
loop {
let res = ready!(self.response.poll_next_unpin(cx));
return Poll::Ready(match res {
None => {
self.done = true;
None }
Some(data) => Some(match data {
Err(e) => Err(e),
Ok(data) => match self.extract_message(data) {
Ok(Some(extracted)) => Ok(extracted),
Ok(None) => continue, Err(e) => Err(e),
},
}),
});
}
}
}
#[derive(Debug)]
struct FlightStreamState {
schema: SchemaRef,
dictionaries_by_field: HashMap<i64, ArrayRef>,
}
#[derive(Debug)]
pub struct DecodedFlightData {
pub inner: FlightData,
pub payload: DecodedPayload,
}
impl DecodedFlightData {
pub fn new_none(inner: FlightData) -> Self {
Self {
inner,
payload: DecodedPayload::None,
}
}
pub fn new_schema(inner: FlightData, schema: SchemaRef) -> Self {
Self {
inner,
payload: DecodedPayload::Schema(schema),
}
}
pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self {
Self {
inner,
payload: DecodedPayload::RecordBatch(batch),
}
}
pub fn app_metadata(&self) -> Bytes {
self.inner.app_metadata.clone()
}
}
#[derive(Debug)]
pub enum DecodedPayload {
None,
Schema(SchemaRef),
RecordBatch(RecordBatch),
}