use std::task::Poll;
use crate::{
decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Action,
ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
HandshakeRequest, PutResult, Ticket,
};
use arrow_schema::Schema;
use bytes::Bytes;
use futures::{
future::ready,
ready,
stream::{self, BoxStream},
FutureExt, Stream, StreamExt, TryStreamExt,
};
use tonic::{metadata::MetadataMap, transport::Channel};
use crate::error::{FlightError, Result};
#[derive(Debug)]
pub struct FlightClient {
metadata: MetadataMap,
inner: FlightServiceClient<Channel>,
}
impl FlightClient {
pub fn new(channel: Channel) -> Self {
Self::new_from_inner(FlightServiceClient::new(channel))
}
pub fn new_from_inner(inner: FlightServiceClient<Channel>) -> Self {
Self {
metadata: MetadataMap::new(),
inner,
}
}
pub fn metadata(&self) -> &MetadataMap {
&self.metadata
}
pub fn metadata_mut(&mut self) -> &mut MetadataMap {
&mut self.metadata
}
pub fn add_header(&mut self, key: &str, value: &str) -> Result<()> {
let key = tonic::metadata::MetadataKey::<_>::from_bytes(key.as_bytes())
.map_err(|e| FlightError::ExternalError(Box::new(e)))?;
let value = value
.parse()
.map_err(|e| FlightError::ExternalError(Box::new(e)))?;
self.metadata.insert(key, value);
Ok(())
}
pub fn inner(&self) -> &FlightServiceClient<Channel> {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
&mut self.inner
}
pub fn into_inner(self) -> FlightServiceClient<Channel> {
self.inner
}
pub async fn handshake(&mut self, payload: impl Into<Bytes>) -> Result<Bytes> {
let request = HandshakeRequest {
protocol_version: 0,
payload: payload.into(),
};
let request = self.make_request(stream::once(ready(request)));
let mut response_stream = self.inner.handshake(request).await?.into_inner();
if let Some(response) = response_stream.next().await.transpose()? {
if response_stream.next().await.is_some() {
return Err(FlightError::protocol(
"Got unexpected second response from handshake",
));
}
Ok(response.payload)
} else {
Err(FlightError::protocol("No response from handshake"))
}
}
pub async fn do_get(&mut self, ticket: Ticket) -> Result<FlightRecordBatchStream> {
let request = self.make_request(ticket);
let response_stream = self
.inner
.do_get(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);
Ok(FlightRecordBatchStream::new_from_flight_data(
response_stream,
))
}
pub async fn get_flight_info(
&mut self,
descriptor: FlightDescriptor,
) -> Result<FlightInfo> {
let request = self.make_request(descriptor);
let response = self.inner.get_flight_info(request).await?.into_inner();
Ok(response)
}
pub async fn do_put<S: Stream<Item = Result<FlightData>> + Send + 'static>(
&mut self,
request: S,
) -> Result<BoxStream<'static, Result<PutResult>>> {
let (sender, mut receiver) = futures::channel::oneshot::channel();
let mut request = Box::pin(request); let mut sender = Some(sender); let request_stream = futures::stream::poll_fn(move |cx| {
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
Some(Ok(data)) => Some(data),
Some(Err(e)) => {
let _ = sender.take().unwrap().send(e);
None
}
None => None,
})
});
let request = self.make_request(request_stream);
let mut response_stream = self.inner.do_put(request).await?.into_inner();
let error_stream = futures::stream::poll_fn(move |cx| {
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
}
let next = ready!(response_stream.poll_next_unpin(cx));
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
});
Ok(error_stream.boxed())
}
pub async fn do_exchange<S: Stream<Item = FlightData> + Send + 'static>(
&mut self,
request: S,
) -> Result<FlightRecordBatchStream> {
let request = self.make_request(request);
let response = self
.inner
.do_exchange(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);
Ok(FlightRecordBatchStream::new_from_flight_data(response))
}
pub async fn list_flights(
&mut self,
expression: impl Into<Bytes>,
) -> Result<BoxStream<'static, Result<FlightInfo>>> {
let request = Criteria {
expression: expression.into(),
};
let request = self.make_request(request);
let response = self
.inner
.list_flights(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);
Ok(response.boxed())
}
pub async fn get_schema(
&mut self,
flight_descriptor: FlightDescriptor,
) -> Result<Schema> {
let request = self.make_request(flight_descriptor);
let schema_result = self.inner.get_schema(request).await?.into_inner();
let schema: Schema = schema_result.try_into()?;
Ok(schema)
}
pub async fn list_actions(
&mut self,
) -> Result<BoxStream<'static, Result<ActionType>>> {
let request = self.make_request(Empty {});
let action_stream = self
.inner
.list_actions(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);
Ok(action_stream.boxed())
}
pub async fn do_action(
&mut self,
action: Action,
) -> Result<BoxStream<'static, Result<Bytes>>> {
let request = self.make_request(action);
let result_stream = self
.inner
.do_action(request)
.await?
.into_inner()
.map_err(FlightError::Tonic)
.map(|r| {
r.map(|r| {
let crate::Result { body } = r;
body
})
});
Ok(result_stream.boxed())
}
fn make_request<T>(&self, t: T) -> tonic::Request<T> {
let mut request = tonic::Request::new(t);
*request.metadata_mut() = self.metadata.clone();
request
}
}