mod common {
pub mod server;
pub mod trailers_layer;
}
use arrow_array::{RecordBatch, UInt64Array};
use arrow_flight::{
decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder, error::FlightError, Action,
ActionType, Criteria, Empty, FlightClient, FlightData, FlightDescriptor, FlightInfo,
HandshakeRequest, HandshakeResponse, PutResult, Ticket,
};
use arrow_schema::{DataType, Field, Schema};
use bytes::Bytes;
use common::{server::TestFlightServer, trailers_layer::TrailersLayer};
use futures::{Future, StreamExt, TryStreamExt};
use tokio::{net::TcpListener, task::JoinHandle};
use tonic::{
transport::{Channel, Uri},
Status,
};
use std::{net::SocketAddr, sync::Arc, time::Duration};
const DEFAULT_TIMEOUT_SECONDS: u64 = 30;
#[tokio::test]
async fn test_handshake() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let request_payload = Bytes::from("foo-request-payload");
let response_payload = Bytes::from("bar-response-payload");
let request = HandshakeRequest {
payload: request_payload.clone(),
protocol_version: 0,
};
let response = HandshakeResponse {
payload: response_payload.clone(),
protocol_version: 0,
};
test_server.set_handshake_response(Ok(response));
let response = client.handshake(request_payload).await.unwrap();
assert_eq!(response, response_payload);
assert_eq!(test_server.take_handshake_request(), Some(request));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_handshake_error() {
do_test(|test_server, mut client| async move {
let request_payload = "foo-request-payload".to_string().into_bytes();
let e = Status::unauthenticated("DENIED");
test_server.set_handshake_response(Err(e.clone()));
let response = client.handshake(request_payload).await.unwrap_err();
expect_status(response, e);
})
.await;
}
fn ensure_metadata(client: &FlightClient, test_server: &TestFlightServer) {
let client_metadata = client.metadata().clone().into_headers();
assert!(!client_metadata.is_empty());
let metadata = test_server
.take_last_request_metadata()
.expect("No headers in server")
.into_headers();
for (k, v) in &client_metadata {
assert_eq!(
metadata.get(k).as_ref(),
Some(&v),
"Missing / Mismatched metadata {k:?} sent {client_metadata:?} got {metadata:?}"
);
}
}
fn test_flight_info(request: &FlightDescriptor) -> FlightInfo {
FlightInfo {
schema: Bytes::new(),
endpoint: vec![],
flight_descriptor: Some(request.clone()),
total_bytes: 123,
total_records: 456,
ordered: false,
}
}
#[tokio::test]
async fn test_get_flight_info() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
let expected_response = test_flight_info(&request);
test_server.set_get_flight_info_response(Ok(expected_response.clone()));
let response = client.get_flight_info(request.clone()).await.unwrap();
assert_eq!(response, expected_response);
assert_eq!(test_server.take_get_flight_info_request(), Some(request));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_get_flight_info_error() {
do_test(|test_server, mut client| async move {
let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
let e = Status::unauthenticated("DENIED");
test_server.set_get_flight_info_response(Err(e.clone()));
let response = client.get_flight_info(request.clone()).await.unwrap_err();
expect_status(response, e);
})
.await;
}
#[tokio::test]
async fn test_do_get() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let ticket = Ticket {
ticket: Bytes::from("my awesome flight ticket"),
};
let batch = RecordBatch::try_from_iter(vec![(
"col",
Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
)])
.unwrap();
let response = vec![Ok(batch.clone())];
test_server.set_do_get_response(response);
let mut response_stream = client
.do_get(ticket.clone())
.await
.expect("error making request");
assert_eq!(
response_stream
.headers()
.get("test-resp-header")
.expect("header exists")
.to_str()
.unwrap(),
"some_val",
);
assert!(response_stream.trailers().is_none());
let expected_response = vec![batch];
let response: Vec<_> = (&mut response_stream)
.try_collect()
.await
.expect("Error streaming data");
assert_eq!(response, expected_response);
assert_eq!(
response_stream
.trailers()
.expect("stream exhausted")
.get("test-trailer")
.expect("trailer exists")
.to_str()
.unwrap(),
"trailer_val",
);
assert_eq!(test_server.take_do_get_request(), Some(ticket));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_get_error() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let ticket = Ticket {
ticket: Bytes::from("my awesome flight ticket"),
};
let response = client.do_get(ticket.clone()).await.unwrap_err();
let e = Status::internal("No do_get response configured");
expect_status(response, e);
assert_eq!(test_server.take_do_get_request(), Some(ticket));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_get_error_in_record_batch_stream() {
do_test(|test_server, mut client| async move {
let ticket = Ticket {
ticket: Bytes::from("my awesome flight ticket"),
};
let batch = RecordBatch::try_from_iter(vec![(
"col",
Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
)])
.unwrap();
let e = Status::data_loss("she's dead jim");
let expected_response = vec![Ok(batch), Err(e.clone())];
test_server.set_do_get_response(expected_response);
let response_stream = client
.do_get(ticket.clone())
.await
.expect("error making request");
let response: Result<Vec<_>, FlightError> = response_stream.try_collect().await;
let response = response.unwrap_err();
expect_status(response, e);
assert_eq!(test_server.take_do_get_request(), Some(ticket));
})
.await;
}
#[tokio::test]
async fn test_do_put() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let input_flight_data = test_flight_data().await;
let expected_response = vec![
PutResult {
app_metadata: Bytes::from("foo-metadata1"),
},
PutResult {
app_metadata: Bytes::from("bar-metadata2"),
},
];
test_server.set_do_put_response(expected_response.clone().into_iter().map(Ok).collect());
let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);
let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");
let response: Vec<_> = response_stream
.try_collect()
.await
.expect("Error streaming data");
assert_eq!(response, expected_response);
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_put_error_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);
let response = client.do_put(input_stream).await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
let e = Status::internal("No do_put response configured");
expect_status(response, e);
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_put_error_stream_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let input_flight_data = test_flight_data().await;
let e = Status::invalid_argument("bad arg");
let response = vec![
Ok(PutResult {
app_metadata: Bytes::from("foo-metadata"),
}),
Err(e.clone()),
];
test_server.set_do_put_response(response);
let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);
let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");
let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
expect_status(response, e);
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_put_error_client() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let e = Status::invalid_argument("bad arg: client");
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e.clone(),
))]));
let response = vec![Ok(PutResult {
app_metadata: Bytes::from("foo-metadata"),
})];
test_server.set_do_put_response(response);
let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");
let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
expect_status(response, e);
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_put_error_client_and_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let e_client = Status::invalid_argument("bad arg: client");
let e_server = Status::invalid_argument("bad arg: server");
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e_client.clone(),
))]));
let response = vec![Err(e_server)];
test_server.set_do_put_response(response);
let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");
let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
expect_status(response, e_client);
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_exchange() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let input_flight_data = test_flight_data().await;
let output_flight_data = test_flight_data2().await;
test_server
.set_do_exchange_response(output_flight_data.clone().into_iter().map(Ok).collect());
let response_stream = client
.do_exchange(futures::stream::iter(input_flight_data.clone()))
.await
.expect("error making request");
let response: Vec<_> = response_stream
.try_collect()
.await
.expect("Error streaming data");
let expected_stream = futures::stream::iter(output_flight_data).map(Ok);
let expected_batches: Vec<_> =
FlightRecordBatchStream::new_from_flight_data(expected_stream)
.try_collect()
.await
.unwrap();
assert_eq!(response, expected_batches);
assert_eq!(
test_server.take_do_exchange_request(),
Some(input_flight_data)
);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_exchange_error() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let input_flight_data = test_flight_data().await;
let response = client
.do_exchange(futures::stream::iter(input_flight_data.clone()))
.await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
let e = Status::internal("No do_exchange response configured");
expect_status(response, e);
assert_eq!(
test_server.take_do_exchange_request(),
Some(input_flight_data)
);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_exchange_error_stream() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let input_flight_data = test_flight_data().await;
let e = Status::invalid_argument("the error");
let response = test_flight_data2()
.await
.into_iter()
.enumerate()
.map(|(i, m)| {
if i == 0 {
Ok(m)
} else {
Err(e.clone())
}
})
.collect();
test_server.set_do_exchange_response(response);
let response_stream = client
.do_exchange(futures::stream::iter(input_flight_data.clone()))
.await
.expect("error making request");
let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
expect_status(response, e);
assert_eq!(
test_server.take_do_exchange_request(),
Some(input_flight_data)
);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_get_schema() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let schema = Schema::new(vec![Field::new("foo", DataType::Int64, true)]);
let request = FlightDescriptor::new_cmd("my command");
test_server.set_get_schema_response(Ok(schema.clone()));
let response = client
.get_schema(request.clone())
.await
.expect("error making request");
let expected_schema = schema;
let expected_request = request;
assert_eq!(response, expected_schema);
assert_eq!(
test_server.take_get_schema_request(),
Some(expected_request)
);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_get_schema_error() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let request = FlightDescriptor::new_cmd("my command");
let e = Status::unauthenticated("DENIED");
test_server.set_get_schema_response(Err(e.clone()));
let response = client.get_schema(request).await.unwrap_err();
expect_status(response, e);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_list_flights() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let infos = vec![
test_flight_info(&FlightDescriptor::new_cmd("foo")),
test_flight_info(&FlightDescriptor::new_cmd("bar")),
];
let response = infos.iter().map(|i| Ok(i.clone())).collect();
test_server.set_list_flights_response(response);
let response_stream = client
.list_flights("query")
.await
.expect("error making request");
let expected_response = infos;
let response: Vec<_> = response_stream
.try_collect()
.await
.expect("Error streaming data");
let expected_request = Some(Criteria {
expression: "query".into(),
});
assert_eq!(response, expected_response);
assert_eq!(test_server.take_list_flights_request(), expected_request);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_list_flights_error() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let response = client.list_flights("query").await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
let e = Status::internal("No list_flights response configured");
expect_status(response, e);
let expected_request = Some(Criteria {
expression: "query".into(),
});
assert_eq!(test_server.take_list_flights_request(), expected_request);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_list_flights_error_in_stream() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let e = Status::data_loss("she's dead jim");
let response = vec![
Ok(test_flight_info(&FlightDescriptor::new_cmd("foo"))),
Err(e.clone()),
];
test_server.set_list_flights_response(response);
let response_stream = client
.list_flights("other query")
.await
.expect("error making request");
let response: Result<Vec<_>, FlightError> = response_stream.try_collect().await;
let response = response.unwrap_err();
expect_status(response, e);
let expected_request = Some(Criteria {
expression: "other query".into(),
});
assert_eq!(test_server.take_list_flights_request(), expected_request);
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_list_actions() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let actions = vec![
ActionType {
r#type: "type 1".into(),
description: "awesomeness".into(),
},
ActionType {
r#type: "type 2".into(),
description: "more awesomeness".into(),
},
];
let response = actions.iter().map(|i| Ok(i.clone())).collect();
test_server.set_list_actions_response(response);
let response_stream = client.list_actions().await.expect("error making request");
let expected_response = actions;
let response: Vec<_> = response_stream
.try_collect()
.await
.expect("Error streaming data");
assert_eq!(response, expected_response);
assert_eq!(test_server.take_list_actions_request(), Some(Empty {}));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_list_actions_error() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let response = client.list_actions().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
let e = Status::internal("No list_actions response configured");
expect_status(response, e);
assert_eq!(test_server.take_list_actions_request(), Some(Empty {}));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_list_actions_error_in_stream() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let e = Status::data_loss("she's dead jim");
let response = vec![
Ok(ActionType {
r#type: "type 1".into(),
description: "awesomeness".into(),
}),
Err(e.clone()),
];
test_server.set_list_actions_response(response);
let response_stream = client.list_actions().await.expect("error making request");
let response: Result<Vec<_>, FlightError> = response_stream.try_collect().await;
let response = response.unwrap_err();
expect_status(response, e);
assert_eq!(test_server.take_list_actions_request(), Some(Empty {}));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_action() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let bytes = vec![Bytes::from("foo"), Bytes::from("blarg")];
let response = bytes
.iter()
.cloned()
.map(arrow_flight::Result::new)
.map(Ok)
.collect();
test_server.set_do_action_response(response);
let request = Action::new("action type", "action body");
let response_stream = client
.do_action(request.clone())
.await
.expect("error making request");
let expected_response = bytes;
let response: Vec<_> = response_stream
.try_collect()
.await
.expect("Error streaming data");
assert_eq!(response, expected_response);
assert_eq!(test_server.take_do_action_request(), Some(request));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_action_error() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let request = Action::new("action type", "action body");
let response = client.do_action(request.clone()).await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};
let e = Status::internal("No do_action response configured");
expect_status(response, e);
assert_eq!(test_server.take_do_action_request(), Some(request));
ensure_metadata(&client, &test_server);
})
.await;
}
#[tokio::test]
async fn test_do_action_error_in_stream() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();
let e = Status::data_loss("she's dead jim");
let request = Action::new("action type", "action body");
let response = vec![Ok(arrow_flight::Result::new("foo")), Err(e.clone())];
test_server.set_do_action_response(response);
let response_stream = client
.do_action(request.clone())
.await
.expect("error making request");
let response: Result<Vec<_>, FlightError> = response_stream.try_collect().await;
let response = response.unwrap_err();
expect_status(response, e);
assert_eq!(test_server.take_do_action_request(), Some(request));
ensure_metadata(&client, &test_server);
})
.await;
}
async fn test_flight_data() -> Vec<FlightData> {
let batch = RecordBatch::try_from_iter(vec![(
"col",
Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
)])
.unwrap();
FlightDataEncoderBuilder::new()
.build(futures::stream::iter(vec![Ok(batch)]))
.try_collect()
.await
.unwrap()
}
async fn test_flight_data2() -> Vec<FlightData> {
let batch = RecordBatch::try_from_iter(vec![(
"col2",
Arc::new(UInt64Array::from_iter([10, 23, 33])) as _,
)])
.unwrap();
FlightDataEncoderBuilder::new()
.build(futures::stream::iter(vec![Ok(batch)]))
.try_collect()
.await
.unwrap()
}
async fn do_test<F, Fut>(f: F)
where
F: Fn(TestFlightServer, FlightClient) -> Fut,
Fut: Future<Output = ()>,
{
let test_server = TestFlightServer::new();
let fixture = TestFixture::new(&test_server).await;
let client = FlightClient::new(fixture.channel().await);
f(test_server, client).await;
fixture.shutdown_and_wait().await
}
fn expect_status(error: FlightError, expected: Status) {
let status = if let FlightError::Tonic(status) = error {
status
} else {
panic!("Expected FlightError::Tonic, got: {error:?}");
};
assert_eq!(
status.code(),
expected.code(),
"Got {status:?} want {expected:?}"
);
assert_eq!(
status.message(),
expected.message(),
"Got {status:?} want {expected:?}"
);
assert_eq!(
status.details(),
expected.details(),
"Got {status:?} want {expected:?}"
);
}
struct TestFixture {
shutdown: Option<tokio::sync::oneshot::Sender<()>>,
addr: SocketAddr,
handle: Option<JoinHandle<Result<(), tonic::transport::Error>>>,
}
impl TestFixture {
pub async fn new(test_server: &TestFlightServer) -> Self {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
println!("Listening on {addr}");
let (tx, rx) = tokio::sync::oneshot::channel();
let server_timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECONDS);
let shutdown_future = async move {
rx.await.ok();
};
let serve_future = tonic::transport::Server::builder()
.timeout(server_timeout)
.layer(TrailersLayer)
.add_service(test_server.service())
.serve_with_incoming_shutdown(
tokio_stream::wrappers::TcpListenerStream::new(listener),
shutdown_future,
);
let handle = tokio::task::spawn(serve_future);
Self {
shutdown: Some(tx),
addr,
handle: Some(handle),
}
}
pub async fn channel(&self) -> Channel {
let url = format!("http://{}", self.addr);
let uri: Uri = url.parse().expect("Valid URI");
Channel::builder(uri)
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECONDS))
.connect()
.await
.expect("error connecting to server")
}
pub async fn shutdown_and_wait(mut self) {
if let Some(shutdown) = self.shutdown.take() {
shutdown.send(()).expect("server quit early");
}
if let Some(handle) = self.handle.take() {
println!("Waiting on server to finish");
handle
.await
.expect("task join error (panic?)")
.expect("Server Error found at shutdown");
}
}
}
impl Drop for TestFixture {
fn drop(&mut self) {
if let Some(shutdown) = self.shutdown.take() {
shutdown.send(()).ok();
}
if self.handle.is_some() {
println!("TestFixture::Drop called prior to `shutdown_and_wait`");
}
}
}