use crate::{
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo,
HandshakeRequest, PollInfo, PutResult, Ticket,
decode::FlightRecordBatchStream,
flight_service_client::FlightServiceClient,
r#gen::{CancelFlightInfoRequest, CancelFlightInfoResult, RenewFlightEndpointRequest},
trailers::extract_lazy_trailers,
};
use arrow_schema::Schema;
use bytes::Bytes;
use futures::{
Stream, StreamExt, TryStreamExt,
future::ready,
stream::{self, BoxStream},
};
use prost::Message;
use tonic::codegen::{Body, StdError};
use tonic::{metadata::MetadataMap, transport::Channel};
use crate::error::{FlightError, Result};
use crate::streams::{FallibleRequestStream, FallibleTonicResponseStream};
#[derive(Debug)]
pub struct FlightClient<T = Channel> {
metadata: MetadataMap,
inner: FlightServiceClient<T>,
}
impl<T> FlightClient<T>
where
T: tonic::client::GrpcService<tonic::body::Body>,
T::Error: Into<StdError>,
T::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
{
pub fn new(inner: T) -> Self {
Self::new_from_inner(FlightServiceClient::new(inner))
}
pub fn new_from_inner(inner: FlightServiceClient<T>) -> Self {
Self {
metadata: MetadataMap::new(),
inner,
}
}
pub fn metadata(&self) -> &MetadataMap {
&self.metadata
}
pub fn metadata_mut(&mut self) -> &mut MetadataMap {
&mut self.metadata
}
pub fn add_header(&mut self, key: &str, value: &str) -> Result<()> {
let key = tonic::metadata::MetadataKey::<_>::from_bytes(key.as_bytes())
.map_err(|e| FlightError::ExternalError(Box::new(e)))?;
let value = value
.parse()
.map_err(|e| FlightError::ExternalError(Box::new(e)))?;
self.metadata.insert(key, value);
Ok(())
}
pub fn inner(&self) -> &FlightServiceClient<T> {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut FlightServiceClient<T> {
&mut self.inner
}
pub fn into_inner(self) -> FlightServiceClient<T> {
self.inner
}
pub async fn handshake(&mut self, payload: impl Into<Bytes>) -> Result<Bytes> {
let request = HandshakeRequest {
protocol_version: 0,
payload: payload.into(),
};
let request = self.make_request(stream::once(ready(request)));
let mut response_stream = self.inner.handshake(request).await?.into_inner();
if let Some(response) = response_stream.next().await.transpose()? {
if response_stream.next().await.is_some() {
return Err(FlightError::protocol(
"Got unexpected second response from handshake",
));
}
Ok(response.payload)
} else {
Err(FlightError::protocol("No response from handshake"))
}
}
pub async fn do_get(&mut self, ticket: Ticket) -> Result<FlightRecordBatchStream> {
let request = self.make_request(ticket);
let (md, response_stream, _ext) = self.inner.do_get(request).await?.into_parts();
let (response_stream, trailers) = extract_lazy_trailers(response_stream);
Ok(FlightRecordBatchStream::new_from_flight_data(
response_stream.map_err(|status| status.into()),
)
.with_headers(md)
.with_trailers(trailers))
}
pub async fn get_flight_info(&mut self, descriptor: FlightDescriptor) -> Result<FlightInfo> {
let request = self.make_request(descriptor);
let response = self.inner.get_flight_info(request).await?.into_inner();
Ok(response)
}
pub async fn poll_flight_info(&mut self, descriptor: FlightDescriptor) -> Result<PollInfo> {
let request = self.make_request(descriptor);
let response = self.inner.poll_flight_info(request).await?.into_inner();
Ok(response)
}
pub async fn do_put<S: Stream<Item = Result<FlightData>> + Send + 'static>(
&mut self,
request: S,
) -> Result<BoxStream<'static, Result<PutResult>>> {
let (sender, receiver) = futures::channel::oneshot::channel();
let request = Box::pin(request); let request_stream = FallibleRequestStream::new(sender, request);
let request = self.make_request(request_stream);
let response_stream = self.inner.do_put(request).await?.into_inner();
let response_stream = Box::pin(response_stream);
let error_stream = FallibleTonicResponseStream::new(receiver, response_stream);
Ok(error_stream.boxed())
}
pub async fn do_exchange<S: Stream<Item = Result<FlightData>> + Send + 'static>(
&mut self,
request: S,
) -> Result<FlightRecordBatchStream> {
let (sender, receiver) = futures::channel::oneshot::channel();
let request = Box::pin(request);
let request_stream = FallibleRequestStream::new(sender, request);
let request = self.make_request(request_stream);
let response_stream = self.inner.do_exchange(request).await?.into_inner();
let response_stream = Box::pin(response_stream);
let error_stream = FallibleTonicResponseStream::new(receiver, response_stream);
Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
}
pub async fn list_flights(
&mut self,
expression: impl Into<Bytes>,
) -> Result<BoxStream<'static, Result<FlightInfo>>> {
let request = Criteria {
expression: expression.into(),
};
let request = self.make_request(request);
let response = self
.inner
.list_flights(request)
.await?
.into_inner()
.map_err(|status| status.into());
Ok(response.boxed())
}
pub async fn get_schema(&mut self, flight_descriptor: FlightDescriptor) -> Result<Schema> {
let request = self.make_request(flight_descriptor);
let schema_result = self.inner.get_schema(request).await?.into_inner();
let schema: Schema = schema_result.try_into()?;
Ok(schema)
}
pub async fn list_actions(&mut self) -> Result<BoxStream<'static, Result<ActionType>>> {
let request = self.make_request(Empty {});
let action_stream = self
.inner
.list_actions(request)
.await?
.into_inner()
.map_err(|status| status.into());
Ok(action_stream.boxed())
}
pub async fn do_action(&mut self, action: Action) -> Result<BoxStream<'static, Result<Bytes>>> {
let request = self.make_request(action);
let result_stream = self
.inner
.do_action(request)
.await?
.into_inner()
.map_err(|status| status.into())
.map(|r| {
r.map(|r| {
let crate::Result { body } = r;
body
})
});
Ok(result_stream.boxed())
}
pub async fn cancel_flight_info(
&mut self,
request: CancelFlightInfoRequest,
) -> Result<CancelFlightInfoResult> {
let action = Action::new("CancelFlightInfo", request.encode_to_vec());
let response = self.do_action(action).await?.try_next().await?;
let response = response.ok_or(FlightError::protocol(
"Received no response for cancel_flight_info call",
))?;
CancelFlightInfoResult::decode(response)
.map_err(|e| FlightError::DecodeError(e.to_string()))
}
pub async fn renew_flight_endpoint(
&mut self,
request: RenewFlightEndpointRequest,
) -> Result<FlightEndpoint> {
let action = Action::new("RenewFlightEndpoint", request.encode_to_vec());
let response = self.do_action(action).await?.try_next().await?;
let response = response.ok_or(FlightError::protocol(
"Received no response for renew_flight_endpoint call",
))?;
FlightEndpoint::decode(response).map_err(|e| FlightError::DecodeError(e.to_string()))
}
fn make_request<R>(&self, t: R) -> tonic::Request<R> {
let mut request = tonic::Request::new(t);
*request.metadata_mut() = self.metadata.clone();
request
}
}
#[cfg(test)]
mod tests {
use super::FlightClient;
use crate::encode::FlightDataEncoderBuilder;
use crate::flight_service_server::{FlightService, FlightServiceServer};
use crate::{
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
};
use arrow_array::{RecordBatch, UInt64Array};
use bytes::Bytes;
use futures::{StreamExt, TryStreamExt, stream::BoxStream};
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
use tonic::metadata::MetadataMap;
use tonic::service::interceptor::InterceptedService;
use tonic::transport::Channel;
use tonic::{Request, Response, Status, Streaming};
use uuid::Uuid;
#[derive(Debug, Clone, Default)]
struct InterceptorTestServer {
state: Arc<Mutex<InterceptorTestState>>,
}
#[derive(Debug, Default)]
struct InterceptorTestState {
do_get_request: Option<Ticket>,
do_get_response: Option<Vec<Result<RecordBatch, Status>>>,
last_request_metadata: Option<MetadataMap>,
}
impl InterceptorTestServer {
fn save_metadata<T>(&self, request: &Request<T>) {
self.state.lock().unwrap().last_request_metadata = Some(request.metadata().clone());
}
fn set_do_get_response(&self, response: Vec<Result<RecordBatch, Status>>) {
self.state.lock().unwrap().do_get_response = Some(response);
}
fn take_do_get_request(&self) -> Option<Ticket> {
self.state.lock().unwrap().do_get_request.take()
}
fn take_last_request_metadata(&self) -> Option<MetadataMap> {
self.state.lock().unwrap().last_request_metadata.take()
}
}
#[tonic::async_trait]
impl FlightService for InterceptorTestServer {
type HandshakeStream = BoxStream<'static, Result<HandshakeResponse, Status>>;
type ListFlightsStream = BoxStream<'static, Result<FlightInfo, Status>>;
type DoGetStream = BoxStream<'static, Result<FlightData, Status>>;
type DoPutStream = BoxStream<'static, Result<PutResult, Status>>;
type DoActionStream = BoxStream<'static, Result<crate::Result, Status>>;
type ListActionsStream = BoxStream<'static, Result<ActionType, Status>>;
type DoExchangeStream = BoxStream<'static, Result<FlightData, Status>>;
async fn do_get(
&self,
request: Request<Ticket>,
) -> Result<Response<Self::DoGetStream>, Status> {
self.save_metadata(&request);
let mut state = self.state.lock().unwrap();
state.do_get_request = Some(request.into_inner());
let batches = state
.do_get_response
.take()
.ok_or_else(|| Status::internal("no do_get response configured"))?;
let batch_stream = futures::stream::iter(batches).map_err(Into::into);
let stream = FlightDataEncoderBuilder::new()
.build(batch_stream)
.map_err(Into::into);
Ok(Response::new(stream.boxed()))
}
async fn handshake(
&self,
_: Request<Streaming<HandshakeRequest>>,
) -> Result<Response<Self::HandshakeStream>, Status> {
Err(Status::unimplemented(""))
}
async fn list_flights(
&self,
_: Request<Criteria>,
) -> Result<Response<Self::ListFlightsStream>, Status> {
Err(Status::unimplemented(""))
}
async fn get_flight_info(
&self,
_: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented(""))
}
async fn poll_flight_info(
&self,
_: Request<FlightDescriptor>,
) -> Result<Response<PollInfo>, Status> {
Err(Status::unimplemented(""))
}
async fn get_schema(
&self,
_: Request<FlightDescriptor>,
) -> Result<Response<SchemaResult>, Status> {
Err(Status::unimplemented(""))
}
async fn do_put(
&self,
_: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoPutStream>, Status> {
Err(Status::unimplemented(""))
}
async fn do_action(
&self,
_: Request<Action>,
) -> Result<Response<Self::DoActionStream>, Status> {
Err(Status::unimplemented(""))
}
async fn list_actions(
&self,
_: Request<Empty>,
) -> Result<Response<Self::ListActionsStream>, Status> {
Err(Status::unimplemented(""))
}
async fn do_exchange(
&self,
_: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoExchangeStream>, Status> {
Err(Status::unimplemented(""))
}
}
struct InterceptorTestFixture {
shutdown: Option<tokio::sync::oneshot::Sender<()>>,
addr: SocketAddr,
handle: Option<JoinHandle<Result<(), tonic::transport::Error>>>,
}
impl InterceptorTestFixture {
async fn new(server: FlightServiceServer<InterceptorTestServer>) -> Self {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let shutdown_future = async move {
rx.await.ok();
};
let serve = tonic::transport::Server::builder()
.timeout(Duration::from_secs(30))
.add_service(server)
.serve_with_incoming_shutdown(
tokio_stream::wrappers::TcpListenerStream::new(listener),
shutdown_future,
);
let handle = tokio::task::spawn(serve);
Self {
shutdown: Some(tx),
addr,
handle: Some(handle),
}
}
async fn channel(&self) -> Channel {
let url = format!("http://{}", self.addr);
tonic::transport::Endpoint::from_shared(url)
.expect("valid endpoint")
.timeout(Duration::from_secs(30))
.connect()
.await
.expect("error connecting to server")
}
async fn shutdown_and_wait(mut self) {
if let Some(tx) = self.shutdown.take() {
tx.send(()).expect("server quit early");
}
if let Some(handle) = self.handle.take() {
handle
.await
.expect("task join error (panic?)")
.expect("server error at shutdown");
}
}
}
#[tokio::test]
async fn test_flight_client_with_intercepted_channel_passes_custom_header() {
let test_server = InterceptorTestServer::default();
let fixture =
InterceptorTestFixture::new(FlightServiceServer::new(test_server.clone())).await;
let channel = fixture.channel().await;
let header_name = "x-random-header";
let header_value = format!("random-{}", Uuid::new_v4());
let header_value_for_interceptor = header_value.clone();
let interceptor = move |mut req: Request<()>| -> Result<Request<()>, Status> {
req.metadata_mut().insert(
header_name,
header_value_for_interceptor
.parse()
.expect("valid metadata value"),
);
Ok(req)
};
let intercepted = InterceptedService::new(channel, interceptor);
let mut client = FlightClient::new(intercepted);
let ticket = Ticket {
ticket: Bytes::from("dummy-ticket"),
};
let batch = RecordBatch::try_from_iter(vec![(
"col",
Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
)])
.unwrap();
test_server.set_do_get_response(vec![Ok(batch.clone())]);
let response_stream = client
.do_get(ticket.clone())
.await
.expect("error making do_get request");
let response: Vec<RecordBatch> = response_stream
.try_collect()
.await
.expect("error streaming data");
assert_eq!(response, vec![batch]);
assert_eq!(test_server.take_do_get_request(), Some(ticket));
let metadata = test_server
.take_last_request_metadata()
.expect("server received headers")
.into_headers();
let received = metadata
.get(header_name)
.expect("interceptor header missing on server")
.to_str()
.expect("ascii header value");
assert_eq!(received, header_value);
fixture.shutdown_and_wait().await;
}
}