use crate::{
decode::FlightRecordBatchStream,
flight_service_client::FlightServiceClient,
gen::{CancelFlightInfoRequest, CancelFlightInfoResult, RenewFlightEndpointRequest},
trailers::extract_lazy_trailers,
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo,
HandshakeRequest, PollInfo, PutResult, Ticket,
};
use arrow_schema::Schema;
use bytes::Bytes;
use futures::{
future::ready,
stream::{self, BoxStream},
Stream, StreamExt, TryStreamExt,
};
use prost::Message;
use tonic::{metadata::MetadataMap, transport::Channel};
use crate::error::{FlightError, Result};
use crate::streams::{FallibleRequestStream, FallibleTonicResponseStream};
#[derive(Debug)]
pub struct FlightClient {
metadata: MetadataMap,
inner: FlightServiceClient<Channel>,
}
impl FlightClient {
pub fn new(channel: Channel) -> Self {
Self::new_from_inner(FlightServiceClient::new(channel))
}
pub fn new_from_inner(inner: FlightServiceClient<Channel>) -> Self {
Self {
metadata: MetadataMap::new(),
inner,
}
}
pub fn metadata(&self) -> &MetadataMap {
&self.metadata
}
pub fn metadata_mut(&mut self) -> &mut MetadataMap {
&mut self.metadata
}
pub fn add_header(&mut self, key: &str, value: &str) -> Result<()> {
let key = tonic::metadata::MetadataKey::<_>::from_bytes(key.as_bytes())
.map_err(|e| FlightError::ExternalError(Box::new(e)))?;
let value = value
.parse()
.map_err(|e| FlightError::ExternalError(Box::new(e)))?;
self.metadata.insert(key, value);
Ok(())
}
pub fn inner(&self) -> &FlightServiceClient<Channel> {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
&mut self.inner
}
pub fn into_inner(self) -> FlightServiceClient<Channel> {
self.inner
}
pub async fn handshake(&mut self, payload: impl Into<Bytes>) -> Result<Bytes> {
let request = HandshakeRequest {
protocol_version: 0,
payload: payload.into(),
};
let request = self.make_request(stream::once(ready(request)));
let mut response_stream = self.inner.handshake(request).await?.into_inner();
if let Some(response) = response_stream.next().await.transpose()? {
if response_stream.next().await.is_some() {
return Err(FlightError::protocol(
"Got unexpected second response from handshake",
));
}
Ok(response.payload)
} else {
Err(FlightError::protocol("No response from handshake"))
}
}
pub async fn do_get(&mut self, ticket: Ticket) -> Result<FlightRecordBatchStream> {
let request = self.make_request(ticket);
let (md, response_stream, _ext) = self.inner.do_get(request).await?.into_parts();
let (response_stream, trailers) = extract_lazy_trailers(response_stream);
Ok(FlightRecordBatchStream::new_from_flight_data(
response_stream.map_err(FlightError::Tonic),
)
.with_headers(md)
.with_trailers(trailers))
}
pub async fn get_flight_info(&mut self, descriptor: FlightDescriptor) -> Result<FlightInfo> {
let request = self.make_request(descriptor);
let response = self.inner.get_flight_info(request).await?.into_inner();
Ok(response)
}
pub async fn poll_flight_info(&mut self, descriptor: FlightDescriptor) -> Result<PollInfo> {
let request = self.make_request(descriptor);
let response = self.inner.poll_flight_info(request).await?.into_inner();
Ok(response)
}
pub async fn do_put<S: Stream<Item = Result<FlightData>> + Send + 'static>(
&mut self,
request: S,
) -> Result<BoxStream<'static, Result<PutResult>>> {
let (sender, receiver) = futures::channel::oneshot::channel();
let request = Box::pin(request); let request_stream = FallibleRequestStream::new(sender, request);
let request = self.make_request(request_stream);
let response_stream = self.inner.do_put(request).await?.into_inner();
let response_stream = Box::pin(response_stream);
let error_stream = FallibleTonicResponseStream::new(receiver, response_stream);
Ok(error_stream.boxed())
}
pub async fn do_exchange<S: Stream<Item = Result<FlightData>> + Send + 'static>(
&mut self,
request: S,
) -> Result<FlightRecordBatchStream> {
let (sender, receiver) = futures::channel::oneshot::channel();
let request = Box::pin(request);
let request_stream = FallibleRequestStream::new(sender, request);
let request = self.make_request(request_stream);
let response_stream = self.inner.do_exchange(request).await?.into_inner();
let response_stream = Box::pin(response_stream);
let error_stream = FallibleTonicResponseStream::new(receiver, response_stream);
Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
}
pub async fn list_flights(
&mut self,
expression: impl Into<Bytes>,
) -> Result<BoxStream<'static, Result<FlightInfo>>> {
let request = Criteria {
expression: expression.into(),
};
let request = self.make_request(request);
let response = self
.inner
.list_flights(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);
Ok(response.boxed())
}
pub async fn get_schema(&mut self, flight_descriptor: FlightDescriptor) -> Result<Schema> {
let request = self.make_request(flight_descriptor);
let schema_result = self.inner.get_schema(request).await?.into_inner();
let schema: Schema = schema_result.try_into()?;
Ok(schema)
}
pub async fn list_actions(&mut self) -> Result<BoxStream<'static, Result<ActionType>>> {
let request = self.make_request(Empty {});
let action_stream = self
.inner
.list_actions(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);
Ok(action_stream.boxed())
}
pub async fn do_action(&mut self, action: Action) -> Result<BoxStream<'static, Result<Bytes>>> {
let request = self.make_request(action);
let result_stream = self
.inner
.do_action(request)
.await?
.into_inner()
.map_err(FlightError::Tonic)
.map(|r| {
r.map(|r| {
let crate::Result { body } = r;
body
})
});
Ok(result_stream.boxed())
}
pub async fn cancel_flight_info(
&mut self,
request: CancelFlightInfoRequest,
) -> Result<CancelFlightInfoResult> {
let action = Action::new("CancelFlightInfo", request.encode_to_vec());
let response = self.do_action(action).await?.try_next().await?;
let response = response.ok_or(FlightError::protocol(
"Received no response for cancel_flight_info call",
))?;
CancelFlightInfoResult::decode(response)
.map_err(|e| FlightError::DecodeError(e.to_string()))
}
pub async fn renew_flight_endpoint(
&mut self,
request: RenewFlightEndpointRequest,
) -> Result<FlightEndpoint> {
let action = Action::new("RenewFlightEndpoint", request.encode_to_vec());
let response = self.do_action(action).await?.try_next().await?;
let response = response.ok_or(FlightError::protocol(
"Received no response for renew_flight_endpoint call",
))?;
FlightEndpoint::decode(response).map_err(|e| FlightError::DecodeError(e.to_string()))
}
fn make_request<T>(&self, t: T) -> tonic::Request<T> {
let mut request = tonic::Request::new(t);
*request.metadata_mut() = self.metadata.clone();
request
}
}